1use std::collections::HashSet;
2use std::time::Duration;
3
4use crate::cdp::{CdpError, CdpEvent, CdpSession};
5use crate::chrome::{TargetInfo, discover_chrome, query_targets, query_version};
6use crate::error::AppError;
7use crate::session;
8
9pub const DEFAULT_CDP_PORT: u16 = 9222;
11
12#[derive(Debug)]
14pub struct ResolvedConnection {
15 pub ws_url: String,
16 pub host: String,
17 pub port: u16,
18}
19
20pub async fn health_check(host: &str, port: u16) -> Result<(), AppError> {
28 query_version(host, port)
29 .await
30 .map(|_| ())
31 .map_err(|_| AppError::stale_session())
32}
33
34pub async fn resolve_connection(
46 host: &str,
47 port: Option<u16>,
48 ws_url: Option<&str>,
49) -> Result<ResolvedConnection, AppError> {
50 let default_port = DEFAULT_CDP_PORT;
51
52 if let Some(ws_url) = ws_url {
54 let resolved_port =
55 extract_port_from_ws_url(ws_url).unwrap_or(port.unwrap_or(default_port));
56 return Ok(ResolvedConnection {
57 ws_url: ws_url.to_string(),
58 host: host.to_string(),
59 port: resolved_port,
60 });
61 }
62
63 if let Some(explicit_port) = port {
65 match query_version(host, explicit_port).await {
66 Ok(version) => {
67 return Ok(ResolvedConnection {
68 ws_url: version.ws_debugger_url,
69 host: host.to_string(),
70 port: explicit_port,
71 });
72 }
73 Err(_) => return Err(AppError::no_chrome_found()),
74 }
75 }
76
77 if let Some(session_data) = session::read_session()? {
79 health_check(host, session_data.port).await?;
80 return Ok(ResolvedConnection {
81 ws_url: session_data.ws_url,
82 host: host.to_string(),
83 port: session_data.port,
84 });
85 }
86
87 match discover_chrome(host, default_port).await {
89 Ok((ws_url, p)) => Ok(ResolvedConnection {
90 ws_url,
91 host: host.to_string(),
92 port: p,
93 }),
94 Err(_) => Err(AppError::no_chrome_found()),
95 }
96}
97
98#[must_use]
100pub fn extract_port_from_ws_url(url: &str) -> Option<u16> {
101 let without_scheme = url
102 .strip_prefix("ws://")
103 .or_else(|| url.strip_prefix("wss://"))?;
104 let host_port = without_scheme.split('/').next()?;
105 let port_str = host_port.rsplit(':').next()?;
106 port_str.parse().ok()
107}
108
109pub fn select_target<'a>(
121 targets: &'a [TargetInfo],
122 tab: Option<&str>,
123) -> Result<&'a TargetInfo, AppError> {
124 match tab {
125 None => targets
126 .iter()
127 .find(|t| t.target_type == "page")
128 .ok_or_else(AppError::no_page_targets),
129 Some(value) => {
130 if let Ok(index) = value.parse::<usize>() {
132 return targets
133 .get(index)
134 .ok_or_else(|| AppError::target_not_found(value));
135 }
136 targets
138 .iter()
139 .find(|t| t.id == value)
140 .ok_or_else(|| AppError::target_not_found(value))
141 }
142 }
143}
144
145pub async fn resolve_target(
151 host: &str,
152 port: u16,
153 tab: Option<&str>,
154) -> Result<TargetInfo, AppError> {
155 let targets = query_targets(host, port).await?;
156 select_target(&targets, tab).cloned()
157}
158
159const PAGE_ENABLE_TIMEOUT_MS: u64 = 300;
165
166#[derive(Debug)]
171pub struct ManagedSession {
172 session: CdpSession,
173 enabled_domains: HashSet<String>,
174}
175
176impl ManagedSession {
177 #[must_use]
179 pub fn new(session: CdpSession) -> Self {
180 Self {
181 session,
182 enabled_domains: HashSet::new(),
183 }
184 }
185
186 pub async fn ensure_domain(&mut self, domain: &str) -> Result<(), CdpError> {
193 if self.enabled_domains.contains(domain) {
194 return Ok(());
195 }
196 let method = format!("{domain}.enable");
197 self.session.send_command(&method, None).await?;
198 self.enabled_domains.insert(domain.to_string());
199 Ok(())
200 }
201
202 pub async fn send_command(
208 &self,
209 method: &str,
210 params: Option<serde_json::Value>,
211 ) -> Result<serde_json::Value, CdpError> {
212 self.session.send_command(method, params).await
213 }
214
215 #[must_use]
217 pub fn session_id(&self) -> &str {
218 self.session.session_id()
219 }
220
221 pub async fn subscribe(
227 &self,
228 method: &str,
229 ) -> Result<tokio::sync::mpsc::Receiver<CdpEvent>, CdpError> {
230 self.session.subscribe(method).await
231 }
232
233 #[must_use]
235 pub fn enabled_domains(&self) -> &HashSet<String> {
236 &self.enabled_domains
237 }
238
239 pub async fn spawn_auto_dismiss(&mut self) -> Result<tokio::task::JoinHandle<()>, CdpError> {
252 let mut dialog_rx = self
254 .session
255 .subscribe("Page.javascriptDialogOpening")
256 .await?;
257
258 let page_enable = self.session.send_command("Page.enable", None);
262 let enable_result =
263 tokio::time::timeout(Duration::from_millis(PAGE_ENABLE_TIMEOUT_MS), page_enable).await;
264 if matches!(enable_result, Ok(Ok(_))) {
265 self.enabled_domains.insert("Page".to_string());
266 }
267
268 let session = self.session.clone();
269
270 Ok(tokio::spawn(async move {
271 while let Some(_event) = dialog_rx.recv().await {
272 let params = serde_json::json!({ "accept": false });
273 let _ = session
275 .send_command("Page.handleJavaScriptDialog", Some(params))
276 .await;
277 }
278 }))
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 fn make_target(id: &str, target_type: &str) -> TargetInfo {
287 TargetInfo {
288 id: id.to_string(),
289 target_type: target_type.to_string(),
290 title: format!("Title {id}"),
291 url: format!("https://example.com/{id}"),
292 ws_debugger_url: Some(format!("ws://127.0.0.1:9222/devtools/page/{id}")),
293 }
294 }
295
296 #[test]
297 fn extract_port_ws() {
298 assert_eq!(
299 extract_port_from_ws_url("ws://127.0.0.1:9222/devtools/browser/abc"),
300 Some(9222)
301 );
302 }
303
304 #[test]
305 fn extract_port_wss() {
306 assert_eq!(
307 extract_port_from_ws_url("wss://localhost:9333/devtools/browser/abc"),
308 Some(9333)
309 );
310 }
311
312 #[test]
313 fn extract_port_no_scheme() {
314 assert_eq!(extract_port_from_ws_url("http://localhost:9222"), None);
315 }
316
317 #[test]
318 fn select_target_default_picks_first_page() {
319 let targets = vec![
320 make_target("bg1", "background_page"),
321 make_target("page1", "page"),
322 make_target("page2", "page"),
323 ];
324 let result = select_target(&targets, None).unwrap();
325 assert_eq!(result.id, "page1");
326 }
327
328 #[test]
329 fn select_target_default_skips_non_page() {
330 let targets = vec![
331 make_target("sw1", "service_worker"),
332 make_target("p1", "page"),
333 ];
334 let result = select_target(&targets, None).unwrap();
335 assert_eq!(result.id, "p1");
336 }
337
338 #[test]
339 fn select_target_by_index() {
340 let targets = vec![
341 make_target("a", "page"),
342 make_target("b", "page"),
343 make_target("c", "page"),
344 ];
345 let result = select_target(&targets, Some("1")).unwrap();
346 assert_eq!(result.id, "b");
347 }
348
349 #[test]
350 fn select_target_by_id() {
351 let targets = vec![make_target("ABCDEF", "page"), make_target("GHIJKL", "page")];
352 let result = select_target(&targets, Some("GHIJKL")).unwrap();
353 assert_eq!(result.id, "GHIJKL");
354 }
355
356 #[test]
357 fn select_target_invalid_tab() {
358 let targets = vec![make_target("a", "page")];
359 let result = select_target(&targets, Some("nonexistent"));
360 assert!(result.is_err());
361 assert!(result.unwrap_err().message.contains("not found"));
362 }
363
364 #[test]
365 fn select_target_index_out_of_bounds() {
366 let targets = vec![make_target("a", "page")];
367 let result = select_target(&targets, Some("5"));
368 assert!(result.is_err());
369 }
370
371 #[test]
372 fn select_target_empty_list_no_tab() {
373 let targets: Vec<TargetInfo> = vec![];
374 let result = select_target(&targets, None);
375 assert!(result.is_err());
376 assert!(result.unwrap_err().message.contains("No page targets"));
377 }
378
379 #[test]
380 fn select_target_no_page_targets() {
381 let targets = vec![
382 make_target("sw1", "service_worker"),
383 make_target("bg1", "background_page"),
384 ];
385 let result = select_target(&targets, None);
386 assert!(result.is_err());
387 }
388
389 #[tokio::test]
390 async fn managed_session_enables_domain_once() {
391 use crate::cdp::{CdpClient, CdpConfig, ReconnectConfig};
392 use futures_util::{SinkExt, StreamExt};
393 use std::time::Duration;
394 use tokio::net::TcpListener;
395 use tokio::sync::mpsc;
396 use tokio_tungstenite::tungstenite::Message;
397
398 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
400 let addr = listener.local_addr().unwrap();
401 let (record_tx, mut record_rx) = mpsc::channel::<serde_json::Value>(32);
402
403 tokio::spawn(async move {
404 if let Ok((stream, _)) = listener.accept().await {
405 let ws = tokio_tungstenite::accept_async(stream).await.unwrap();
406 let (mut sink, mut source) = ws.split();
407 while let Some(Ok(Message::Text(text))) = source.next().await {
408 let cmd: serde_json::Value = serde_json::from_str(&text).unwrap();
409 let _ = record_tx.send(cmd.clone()).await;
410
411 if cmd["method"] == "Target.attachToTarget" {
412 let tid = cmd["params"]["targetId"].as_str().unwrap_or("test");
413 let resp = serde_json::json!({
414 "id": cmd["id"],
415 "result": {"sessionId": tid}
416 });
417 let _ = sink.send(Message::Text(resp.to_string().into())).await;
418 } else {
419 let mut resp = serde_json::json!({"id": cmd["id"], "result": {}});
420 if let Some(sid) = cmd.get("sessionId") {
421 resp["sessionId"] = sid.clone();
422 }
423 let _ = sink.send(Message::Text(resp.to_string().into())).await;
424 }
425 }
426 }
427 });
428
429 let url = format!("ws://{addr}");
431 let config = CdpConfig {
432 connect_timeout: Duration::from_secs(5),
433 command_timeout: Duration::from_secs(5),
434 channel_capacity: 256,
435 reconnect: ReconnectConfig {
436 max_retries: 0,
437 ..ReconnectConfig::default()
438 },
439 };
440 let client = CdpClient::connect(&url, config).await.unwrap();
441 let session = client.create_session("test-target").await.unwrap();
442 let _ = tokio::time::timeout(Duration::from_millis(200), record_rx.recv()).await;
444
445 let mut managed = ManagedSession::new(session);
446 assert!(managed.enabled_domains().is_empty());
447
448 managed.ensure_domain("Page").await.unwrap();
450 let msg = tokio::time::timeout(Duration::from_millis(200), record_rx.recv())
451 .await
452 .unwrap()
453 .unwrap();
454 assert_eq!(msg["method"], "Page.enable");
455 assert!(managed.enabled_domains().contains("Page"));
456
457 managed.ensure_domain("Page").await.unwrap();
459 let no_msg = tokio::time::timeout(Duration::from_millis(100), record_rx.recv()).await;
460 assert!(
461 no_msg.is_err(),
462 "No message should be sent for already-enabled domain"
463 );
464
465 managed.ensure_domain("Runtime").await.unwrap();
467 let msg2 = tokio::time::timeout(Duration::from_millis(200), record_rx.recv())
468 .await
469 .unwrap()
470 .unwrap();
471 assert_eq!(msg2["method"], "Runtime.enable");
472
473 let domains = managed.enabled_domains();
475 assert!(domains.contains("Page"));
476 assert!(domains.contains("Runtime"));
477 assert_eq!(domains.len(), 2);
478 }
479}