1use std::collections::HashSet;
2use std::time::Duration;
3
4use crate::cdp::{CdpError, CdpEvent, CdpSession};
5use crate::chrome::{TargetInfo, discover_chrome, query_targets, query_version};
6use crate::error::AppError;
7use crate::session;
8
9pub const DEFAULT_CDP_PORT: u16 = 9222;
11
12#[derive(Debug)]
14pub struct ResolvedConnection {
15 pub ws_url: String,
16 pub host: String,
17 pub port: u16,
18}
19
20pub async fn health_check(host: &str, port: u16) -> Result<(), AppError> {
28 query_version(host, port)
29 .await
30 .map(|_| ())
31 .map_err(|_| AppError::stale_session())
32}
33
34pub async fn resolve_connection(
46 host: &str,
47 port: Option<u16>,
48 ws_url: Option<&str>,
49) -> Result<ResolvedConnection, AppError> {
50 let default_port = DEFAULT_CDP_PORT;
51
52 if let Some(ws_url) = ws_url {
54 let resolved_port =
55 extract_port_from_ws_url(ws_url).unwrap_or(port.unwrap_or(default_port));
56 return Ok(ResolvedConnection {
57 ws_url: ws_url.to_string(),
58 host: host.to_string(),
59 port: resolved_port,
60 });
61 }
62
63 if let Some(explicit_port) = port {
65 match query_version(host, explicit_port).await {
66 Ok(version) => {
67 return Ok(ResolvedConnection {
68 ws_url: version.ws_debugger_url,
69 host: host.to_string(),
70 port: explicit_port,
71 });
72 }
73 Err(_) => return Err(AppError::no_chrome_found()),
74 }
75 }
76
77 if let Some(session_data) = session::read_session()? {
79 health_check(host, session_data.port).await?;
80 return Ok(ResolvedConnection {
81 ws_url: session_data.ws_url,
82 host: host.to_string(),
83 port: session_data.port,
84 });
85 }
86
87 match discover_chrome(host, default_port).await {
89 Ok((ws_url, p)) => Ok(ResolvedConnection {
90 ws_url,
91 host: host.to_string(),
92 port: p,
93 }),
94 Err(_) => Err(AppError::no_chrome_found()),
95 }
96}
97
98#[must_use]
100pub fn extract_port_from_ws_url(url: &str) -> Option<u16> {
101 let without_scheme = url
102 .strip_prefix("ws://")
103 .or_else(|| url.strip_prefix("wss://"))?;
104 let host_port = without_scheme.split('/').next()?;
105 let port_str = host_port.rsplit(':').next()?;
106 port_str.parse().ok()
107}
108
109pub fn select_target<'a>(
121 targets: &'a [TargetInfo],
122 tab: Option<&str>,
123) -> Result<&'a TargetInfo, AppError> {
124 match tab {
125 None => targets
126 .iter()
127 .find(|t| t.target_type == "page")
128 .ok_or_else(AppError::no_page_targets),
129 Some(value) => {
130 if let Ok(index) = value.parse::<usize>() {
132 return targets
133 .get(index)
134 .ok_or_else(|| AppError::target_not_found(value));
135 }
136 targets
138 .iter()
139 .find(|t| t.id == value)
140 .ok_or_else(|| AppError::target_not_found(value))
141 }
142 }
143}
144
145pub async fn resolve_target(
151 host: &str,
152 port: u16,
153 tab: Option<&str>,
154) -> Result<TargetInfo, AppError> {
155 let targets = query_targets(host, port).await?;
156
157 if tab.is_none() {
159 if let Some(active_id) = session::read_session()
160 .ok()
161 .flatten()
162 .and_then(|s| s.active_tab_id)
163 {
164 if let Ok(target) = select_target(&targets, Some(&active_id)) {
165 return Ok(target.clone());
166 }
167 }
169 }
170
171 select_target(&targets, tab).cloned()
172}
173
174const PAGE_ENABLE_TIMEOUT_MS: u64 = 300;
180
181#[derive(Debug)]
186pub struct ManagedSession {
187 session: CdpSession,
188 enabled_domains: HashSet<String>,
189}
190
191impl ManagedSession {
192 #[must_use]
194 pub fn new(session: CdpSession) -> Self {
195 Self {
196 session,
197 enabled_domains: HashSet::new(),
198 }
199 }
200
201 pub async fn ensure_domain(&mut self, domain: &str) -> Result<(), CdpError> {
208 if self.enabled_domains.contains(domain) {
209 return Ok(());
210 }
211 let method = format!("{domain}.enable");
212 self.session.send_command(&method, None).await?;
213 self.enabled_domains.insert(domain.to_string());
214 Ok(())
215 }
216
217 pub async fn send_command(
223 &self,
224 method: &str,
225 params: Option<serde_json::Value>,
226 ) -> Result<serde_json::Value, CdpError> {
227 self.session.send_command(method, params).await
228 }
229
230 #[must_use]
232 pub fn session_id(&self) -> &str {
233 self.session.session_id()
234 }
235
236 pub async fn subscribe(
242 &self,
243 method: &str,
244 ) -> Result<tokio::sync::mpsc::Receiver<CdpEvent>, CdpError> {
245 self.session.subscribe(method).await
246 }
247
248 #[must_use]
250 pub fn enabled_domains(&self) -> &HashSet<String> {
251 &self.enabled_domains
252 }
253
254 pub async fn install_dialog_interceptors(&self) {
265 let script = r"(function(){
266if(window.__chrome_cli_intercepted)return;
267window.__chrome_cli_intercepted=true;
268var oA=window.alert,oC=window.confirm,oP=window.prompt;
269function s(t,m,d){try{document.cookie='__chrome_cli_dialog='+
270encodeURIComponent(JSON.stringify({type:t,message:String(m||''),
271defaultValue:String(d||''),timestamp:Date.now()}))+
272'; path=/; max-age=300';}catch(e){}}
273window.alert=function(m){s('alert',m);return oA.apply(this,arguments);};
274window.confirm=function(m){s('confirm',m);return oC.apply(this,arguments);};
275window.prompt=function(m,d){s('prompt',m,d);return oP.apply(this,arguments);};
276})();";
277
278 let _ = self
280 .session
281 .send_command(
282 "Runtime.evaluate",
283 Some(serde_json::json!({ "expression": script })),
284 )
285 .await;
286
287 let _ = self
289 .session
290 .send_command(
291 "Page.addScriptToEvaluateOnNewDocument",
292 Some(serde_json::json!({ "source": script })),
293 )
294 .await;
295 }
296
297 pub async fn spawn_auto_dismiss(&mut self) -> Result<tokio::task::JoinHandle<()>, CdpError> {
310 let mut dialog_rx = self
312 .session
313 .subscribe("Page.javascriptDialogOpening")
314 .await?;
315
316 let page_enable = self.session.send_command("Page.enable", None);
320 let enable_result =
321 tokio::time::timeout(Duration::from_millis(PAGE_ENABLE_TIMEOUT_MS), page_enable).await;
322 if matches!(enable_result, Ok(Ok(_))) {
323 self.enabled_domains.insert("Page".to_string());
324 }
325
326 let session = self.session.clone();
327
328 Ok(tokio::spawn(async move {
329 while let Some(_event) = dialog_rx.recv().await {
330 let params = serde_json::json!({ "accept": false });
331 let _ = session
333 .send_command("Page.handleJavaScriptDialog", Some(params))
334 .await;
335 }
336 }))
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 fn make_target(id: &str, target_type: &str) -> TargetInfo {
345 TargetInfo {
346 id: id.to_string(),
347 target_type: target_type.to_string(),
348 title: format!("Title {id}"),
349 url: format!("https://example.com/{id}"),
350 ws_debugger_url: Some(format!("ws://127.0.0.1:9222/devtools/page/{id}")),
351 }
352 }
353
354 #[test]
355 fn extract_port_ws() {
356 assert_eq!(
357 extract_port_from_ws_url("ws://127.0.0.1:9222/devtools/browser/abc"),
358 Some(9222)
359 );
360 }
361
362 #[test]
363 fn extract_port_wss() {
364 assert_eq!(
365 extract_port_from_ws_url("wss://localhost:9333/devtools/browser/abc"),
366 Some(9333)
367 );
368 }
369
370 #[test]
371 fn extract_port_no_scheme() {
372 assert_eq!(extract_port_from_ws_url("http://localhost:9222"), None);
373 }
374
375 #[test]
376 fn select_target_default_picks_first_page() {
377 let targets = vec![
378 make_target("bg1", "background_page"),
379 make_target("page1", "page"),
380 make_target("page2", "page"),
381 ];
382 let result = select_target(&targets, None).unwrap();
383 assert_eq!(result.id, "page1");
384 }
385
386 #[test]
387 fn select_target_default_skips_non_page() {
388 let targets = vec![
389 make_target("sw1", "service_worker"),
390 make_target("p1", "page"),
391 ];
392 let result = select_target(&targets, None).unwrap();
393 assert_eq!(result.id, "p1");
394 }
395
396 #[test]
397 fn select_target_by_index() {
398 let targets = vec![
399 make_target("a", "page"),
400 make_target("b", "page"),
401 make_target("c", "page"),
402 ];
403 let result = select_target(&targets, Some("1")).unwrap();
404 assert_eq!(result.id, "b");
405 }
406
407 #[test]
408 fn select_target_by_id() {
409 let targets = vec![make_target("ABCDEF", "page"), make_target("GHIJKL", "page")];
410 let result = select_target(&targets, Some("GHIJKL")).unwrap();
411 assert_eq!(result.id, "GHIJKL");
412 }
413
414 #[test]
415 fn select_target_invalid_tab() {
416 let targets = vec![make_target("a", "page")];
417 let result = select_target(&targets, Some("nonexistent"));
418 assert!(result.is_err());
419 assert!(result.unwrap_err().message.contains("not found"));
420 }
421
422 #[test]
423 fn select_target_index_out_of_bounds() {
424 let targets = vec![make_target("a", "page")];
425 let result = select_target(&targets, Some("5"));
426 assert!(result.is_err());
427 }
428
429 #[test]
430 fn select_target_empty_list_no_tab() {
431 let targets: Vec<TargetInfo> = vec![];
432 let result = select_target(&targets, None);
433 assert!(result.is_err());
434 assert!(result.unwrap_err().message.contains("No page targets"));
435 }
436
437 #[test]
438 fn select_target_no_page_targets() {
439 let targets = vec![
440 make_target("sw1", "service_worker"),
441 make_target("bg1", "background_page"),
442 ];
443 let result = select_target(&targets, None);
444 assert!(result.is_err());
445 }
446
447 #[tokio::test]
448 async fn managed_session_enables_domain_once() {
449 use crate::cdp::{CdpClient, CdpConfig, ReconnectConfig};
450 use futures_util::{SinkExt, StreamExt};
451 use std::time::Duration;
452 use tokio::net::TcpListener;
453 use tokio::sync::mpsc;
454 use tokio_tungstenite::tungstenite::Message;
455
456 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
458 let addr = listener.local_addr().unwrap();
459 let (record_tx, mut record_rx) = mpsc::channel::<serde_json::Value>(32);
460
461 tokio::spawn(async move {
462 if let Ok((stream, _)) = listener.accept().await {
463 let ws = tokio_tungstenite::accept_async(stream).await.unwrap();
464 let (mut sink, mut source) = ws.split();
465 while let Some(Ok(Message::Text(text))) = source.next().await {
466 let cmd: serde_json::Value = serde_json::from_str(&text).unwrap();
467 let _ = record_tx.send(cmd.clone()).await;
468
469 if cmd["method"] == "Target.attachToTarget" {
470 let tid = cmd["params"]["targetId"].as_str().unwrap_or("test");
471 let resp = serde_json::json!({
472 "id": cmd["id"],
473 "result": {"sessionId": tid}
474 });
475 let _ = sink.send(Message::Text(resp.to_string().into())).await;
476 } else {
477 let mut resp = serde_json::json!({"id": cmd["id"], "result": {}});
478 if let Some(sid) = cmd.get("sessionId") {
479 resp["sessionId"] = sid.clone();
480 }
481 let _ = sink.send(Message::Text(resp.to_string().into())).await;
482 }
483 }
484 }
485 });
486
487 let url = format!("ws://{addr}");
489 let config = CdpConfig {
490 connect_timeout: Duration::from_secs(5),
491 command_timeout: Duration::from_secs(5),
492 channel_capacity: 256,
493 reconnect: ReconnectConfig {
494 max_retries: 0,
495 ..ReconnectConfig::default()
496 },
497 };
498 let client = CdpClient::connect(&url, config).await.unwrap();
499 let session = client.create_session("test-target").await.unwrap();
500 let _ = tokio::time::timeout(Duration::from_millis(200), record_rx.recv()).await;
502
503 let mut managed = ManagedSession::new(session);
504 assert!(managed.enabled_domains().is_empty());
505
506 managed.ensure_domain("Page").await.unwrap();
508 let msg = tokio::time::timeout(Duration::from_millis(200), record_rx.recv())
509 .await
510 .unwrap()
511 .unwrap();
512 assert_eq!(msg["method"], "Page.enable");
513 assert!(managed.enabled_domains().contains("Page"));
514
515 managed.ensure_domain("Page").await.unwrap();
517 let no_msg = tokio::time::timeout(Duration::from_millis(100), record_rx.recv()).await;
518 assert!(
519 no_msg.is_err(),
520 "No message should be sent for already-enabled domain"
521 );
522
523 managed.ensure_domain("Runtime").await.unwrap();
525 let msg2 = tokio::time::timeout(Duration::from_millis(200), record_rx.recv())
526 .await
527 .unwrap()
528 .unwrap();
529 assert_eq!(msg2["method"], "Runtime.enable");
530
531 let domains = managed.enabled_domains();
533 assert!(domains.contains("Page"));
534 assert!(domains.contains("Runtime"));
535 assert_eq!(domains.len(), 2);
536 }
537}