Skip to main content

chrome_cli/
connection.rs

1use std::collections::HashSet;
2use std::time::Duration;
3
4use crate::cdp::{CdpError, CdpEvent, CdpSession};
5use crate::chrome::{TargetInfo, discover_chrome, query_targets, query_version};
6use crate::error::AppError;
7use crate::session;
8
9/// Default Chrome `DevTools` Protocol port.
10pub const DEFAULT_CDP_PORT: u16 = 9222;
11
12/// Resolved connection info ready for use by a command.
13#[derive(Debug)]
14pub struct ResolvedConnection {
15    pub ws_url: String,
16    pub host: String,
17    pub port: u16,
18}
19
20/// Health-check a connection by querying `/json/version`.
21///
22/// Returns `Ok(())` if Chrome responds, or `Err(AppError::stale_session())` if not.
23///
24/// # Errors
25///
26/// Returns `AppError` with `ConnectionError` exit code if Chrome is unreachable.
27pub async fn health_check(host: &str, port: u16) -> Result<(), AppError> {
28    query_version(host, port)
29        .await
30        .map(|_| ())
31        .map_err(|_| AppError::stale_session())
32}
33
34/// Resolve a Chrome connection using the priority chain:
35///
36/// 1. Explicit `--ws-url`
37/// 2. Explicit `--port` (user provided, not the default)
38/// 3. Session file (with health check)
39/// 4. Auto-discover (default host:port 9222)
40/// 5. Error with suggestion
41///
42/// # Errors
43///
44/// Returns `AppError` if no Chrome connection can be resolved.
45pub async fn resolve_connection(
46    host: &str,
47    port: Option<u16>,
48    ws_url: Option<&str>,
49) -> Result<ResolvedConnection, AppError> {
50    let default_port = DEFAULT_CDP_PORT;
51
52    // 1. Explicit --ws-url
53    if let Some(ws_url) = ws_url {
54        let resolved_port =
55            extract_port_from_ws_url(ws_url).unwrap_or(port.unwrap_or(default_port));
56        return Ok(ResolvedConnection {
57            ws_url: ws_url.to_string(),
58            host: host.to_string(),
59            port: resolved_port,
60        });
61    }
62
63    // 2. Explicit --port (user provided) — try only this port, no DevToolsActivePort fallback
64    if let Some(explicit_port) = port {
65        match query_version(host, explicit_port).await {
66            Ok(version) => {
67                return Ok(ResolvedConnection {
68                    ws_url: version.ws_debugger_url,
69                    host: host.to_string(),
70                    port: explicit_port,
71                });
72            }
73            Err(_) => return Err(AppError::no_chrome_found()),
74        }
75    }
76
77    // 3. Session file
78    if let Some(session_data) = session::read_session()? {
79        health_check(host, session_data.port).await?;
80        return Ok(ResolvedConnection {
81            ws_url: session_data.ws_url,
82            host: host.to_string(),
83            port: session_data.port,
84        });
85    }
86
87    // 4. Auto-discover on default port
88    match discover_chrome(host, default_port).await {
89        Ok((ws_url, p)) => Ok(ResolvedConnection {
90            ws_url,
91            host: host.to_string(),
92            port: p,
93        }),
94        Err(_) => Err(AppError::no_chrome_found()),
95    }
96}
97
98/// Extract port from a WebSocket URL like `ws://host:port/path`.
99#[must_use]
100pub fn extract_port_from_ws_url(url: &str) -> Option<u16> {
101    let without_scheme = url
102        .strip_prefix("ws://")
103        .or_else(|| url.strip_prefix("wss://"))?;
104    let host_port = without_scheme.split('/').next()?;
105    let port_str = host_port.rsplit(':').next()?;
106    port_str.parse().ok()
107}
108
109/// Select a target from a list based on the `--tab` option.
110///
111/// - `None` → first target with `target_type == "page"`
112/// - `Some(value)` → try as numeric index, then as target ID
113///
114/// This is a pure function for testability.
115///
116/// # Errors
117///
118/// Returns `AppError::no_page_targets()` if no page-type target exists,
119/// or `AppError::target_not_found()` if the specified tab cannot be matched.
120pub fn select_target<'a>(
121    targets: &'a [TargetInfo],
122    tab: Option<&str>,
123) -> Result<&'a TargetInfo, AppError> {
124    match tab {
125        None => targets
126            .iter()
127            .find(|t| t.target_type == "page")
128            .ok_or_else(AppError::no_page_targets),
129        Some(value) => {
130            // Try as numeric index first
131            if let Ok(index) = value.parse::<usize>() {
132                return targets
133                    .get(index)
134                    .ok_or_else(|| AppError::target_not_found(value));
135            }
136            // Try as target ID
137            targets
138                .iter()
139                .find(|t| t.id == value)
140                .ok_or_else(|| AppError::target_not_found(value))
141        }
142    }
143}
144
145/// Resolve the target tab from the `--tab` option by querying Chrome for targets.
146///
147/// # Errors
148///
149/// Returns `AppError` if targets cannot be queried or the specified tab is not found.
150pub async fn resolve_target(
151    host: &str,
152    port: u16,
153    tab: Option<&str>,
154) -> Result<TargetInfo, AppError> {
155    let targets = query_targets(host, port).await?;
156    select_target(&targets, tab).cloned()
157}
158
159/// Timeout for `Page.enable` during auto-dismiss setup (milliseconds).
160///
161/// Chrome re-emits `Page.javascriptDialogOpening` to newly-attached sessions
162/// when `Page.enable` is sent, but `Page.enable` itself blocks when a dialog
163/// is already open. We use a short timeout so auto-dismiss can proceed.
164const PAGE_ENABLE_TIMEOUT_MS: u64 = 300;
165
166/// A CDP session wrapper that tracks which domains have been enabled,
167/// ensuring each domain is only enabled once (lazy domain enabling).
168///
169/// This fulfills AC13: "only the required domains are enabled" per command.
170#[derive(Debug)]
171pub struct ManagedSession {
172    session: CdpSession,
173    enabled_domains: HashSet<String>,
174}
175
176impl ManagedSession {
177    /// Wrap a [`CdpSession`] with domain tracking.
178    #[must_use]
179    pub fn new(session: CdpSession) -> Self {
180        Self {
181            session,
182            enabled_domains: HashSet::new(),
183        }
184    }
185
186    /// Ensure a CDP domain is enabled. Sends `{domain}.enable` only if
187    /// the domain has not already been enabled in this session.
188    ///
189    /// # Errors
190    ///
191    /// Returns `CdpError` if the enable command fails.
192    pub async fn ensure_domain(&mut self, domain: &str) -> Result<(), CdpError> {
193        if self.enabled_domains.contains(domain) {
194            return Ok(());
195        }
196        let method = format!("{domain}.enable");
197        self.session.send_command(&method, None).await?;
198        self.enabled_domains.insert(domain.to_string());
199        Ok(())
200    }
201
202    /// Send a command within this session.
203    ///
204    /// # Errors
205    ///
206    /// Returns `CdpError` if the command fails.
207    pub async fn send_command(
208        &self,
209        method: &str,
210        params: Option<serde_json::Value>,
211    ) -> Result<serde_json::Value, CdpError> {
212        self.session.send_command(method, params).await
213    }
214
215    /// Get the underlying session ID.
216    #[must_use]
217    pub fn session_id(&self) -> &str {
218        self.session.session_id()
219    }
220
221    /// Subscribe to CDP events matching a method name within this session.
222    ///
223    /// # Errors
224    ///
225    /// Returns `CdpError` if the transport task has exited.
226    pub async fn subscribe(
227        &self,
228        method: &str,
229    ) -> Result<tokio::sync::mpsc::Receiver<CdpEvent>, CdpError> {
230        self.session.subscribe(method).await
231    }
232
233    /// Returns the set of currently enabled domains.
234    #[must_use]
235    pub fn enabled_domains(&self) -> &HashSet<String> {
236        &self.enabled_domains
237    }
238
239    /// Spawn a background task that automatically dismisses JavaScript dialogs.
240    ///
241    /// Subscribes to dialog events and sends `Page.enable` with a short
242    /// timeout. If a dialog is already open, `Page.enable` will block, but
243    /// Chrome re-emits the `Page.javascriptDialogOpening` event before
244    /// blocking, so the pre-existing dialog is captured and dismissed.
245    /// Returns a `JoinHandle` whose `abort()` method can be called to stop
246    /// the task (or it stops naturally when the session is dropped).
247    ///
248    /// # Errors
249    ///
250    /// Returns `CdpError` if the event subscription fails.
251    pub async fn spawn_auto_dismiss(&mut self) -> Result<tokio::task::JoinHandle<()>, CdpError> {
252        // Subscribe BEFORE Page.enable so we capture re-emitted dialog events.
253        let mut dialog_rx = self
254            .session
255            .subscribe("Page.javascriptDialogOpening")
256            .await?;
257
258        // Send Page.enable with a timeout. If a dialog is already open,
259        // Page.enable blocks but the dialog event is delivered before the
260        // block. We accept the timeout and proceed.
261        let page_enable = self.session.send_command("Page.enable", None);
262        let enable_result =
263            tokio::time::timeout(Duration::from_millis(PAGE_ENABLE_TIMEOUT_MS), page_enable).await;
264        if matches!(enable_result, Ok(Ok(_))) {
265            self.enabled_domains.insert("Page".to_string());
266        }
267
268        let session = self.session.clone();
269
270        Ok(tokio::spawn(async move {
271            while let Some(_event) = dialog_rx.recv().await {
272                let params = serde_json::json!({ "accept": false });
273                // Best-effort dismiss; ignore errors (session may have closed).
274                let _ = session
275                    .send_command("Page.handleJavaScriptDialog", Some(params))
276                    .await;
277            }
278        }))
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    fn make_target(id: &str, target_type: &str) -> TargetInfo {
287        TargetInfo {
288            id: id.to_string(),
289            target_type: target_type.to_string(),
290            title: format!("Title {id}"),
291            url: format!("https://example.com/{id}"),
292            ws_debugger_url: Some(format!("ws://127.0.0.1:9222/devtools/page/{id}")),
293        }
294    }
295
296    #[test]
297    fn extract_port_ws() {
298        assert_eq!(
299            extract_port_from_ws_url("ws://127.0.0.1:9222/devtools/browser/abc"),
300            Some(9222)
301        );
302    }
303
304    #[test]
305    fn extract_port_wss() {
306        assert_eq!(
307            extract_port_from_ws_url("wss://localhost:9333/devtools/browser/abc"),
308            Some(9333)
309        );
310    }
311
312    #[test]
313    fn extract_port_no_scheme() {
314        assert_eq!(extract_port_from_ws_url("http://localhost:9222"), None);
315    }
316
317    #[test]
318    fn select_target_default_picks_first_page() {
319        let targets = vec![
320            make_target("bg1", "background_page"),
321            make_target("page1", "page"),
322            make_target("page2", "page"),
323        ];
324        let result = select_target(&targets, None).unwrap();
325        assert_eq!(result.id, "page1");
326    }
327
328    #[test]
329    fn select_target_default_skips_non_page() {
330        let targets = vec![
331            make_target("sw1", "service_worker"),
332            make_target("p1", "page"),
333        ];
334        let result = select_target(&targets, None).unwrap();
335        assert_eq!(result.id, "p1");
336    }
337
338    #[test]
339    fn select_target_by_index() {
340        let targets = vec![
341            make_target("a", "page"),
342            make_target("b", "page"),
343            make_target("c", "page"),
344        ];
345        let result = select_target(&targets, Some("1")).unwrap();
346        assert_eq!(result.id, "b");
347    }
348
349    #[test]
350    fn select_target_by_id() {
351        let targets = vec![make_target("ABCDEF", "page"), make_target("GHIJKL", "page")];
352        let result = select_target(&targets, Some("GHIJKL")).unwrap();
353        assert_eq!(result.id, "GHIJKL");
354    }
355
356    #[test]
357    fn select_target_invalid_tab() {
358        let targets = vec![make_target("a", "page")];
359        let result = select_target(&targets, Some("nonexistent"));
360        assert!(result.is_err());
361        assert!(result.unwrap_err().message.contains("not found"));
362    }
363
364    #[test]
365    fn select_target_index_out_of_bounds() {
366        let targets = vec![make_target("a", "page")];
367        let result = select_target(&targets, Some("5"));
368        assert!(result.is_err());
369    }
370
371    #[test]
372    fn select_target_empty_list_no_tab() {
373        let targets: Vec<TargetInfo> = vec![];
374        let result = select_target(&targets, None);
375        assert!(result.is_err());
376        assert!(result.unwrap_err().message.contains("No page targets"));
377    }
378
379    #[test]
380    fn select_target_no_page_targets() {
381        let targets = vec![
382            make_target("sw1", "service_worker"),
383            make_target("bg1", "background_page"),
384        ];
385        let result = select_target(&targets, None);
386        assert!(result.is_err());
387    }
388
389    #[tokio::test]
390    async fn managed_session_enables_domain_once() {
391        use crate::cdp::{CdpClient, CdpConfig, ReconnectConfig};
392        use futures_util::{SinkExt, StreamExt};
393        use std::time::Duration;
394        use tokio::net::TcpListener;
395        use tokio::sync::mpsc;
396        use tokio_tungstenite::tungstenite::Message;
397
398        // Start mock CDP server that echoes responses and records messages
399        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
400        let addr = listener.local_addr().unwrap();
401        let (record_tx, mut record_rx) = mpsc::channel::<serde_json::Value>(32);
402
403        tokio::spawn(async move {
404            if let Ok((stream, _)) = listener.accept().await {
405                let ws = tokio_tungstenite::accept_async(stream).await.unwrap();
406                let (mut sink, mut source) = ws.split();
407                while let Some(Ok(Message::Text(text))) = source.next().await {
408                    let cmd: serde_json::Value = serde_json::from_str(&text).unwrap();
409                    let _ = record_tx.send(cmd.clone()).await;
410
411                    if cmd["method"] == "Target.attachToTarget" {
412                        let tid = cmd["params"]["targetId"].as_str().unwrap_or("test");
413                        let resp = serde_json::json!({
414                            "id": cmd["id"],
415                            "result": {"sessionId": tid}
416                        });
417                        let _ = sink.send(Message::Text(resp.to_string().into())).await;
418                    } else {
419                        let mut resp = serde_json::json!({"id": cmd["id"], "result": {}});
420                        if let Some(sid) = cmd.get("sessionId") {
421                            resp["sessionId"] = sid.clone();
422                        }
423                        let _ = sink.send(Message::Text(resp.to_string().into())).await;
424                    }
425                }
426            }
427        });
428
429        // Connect and create session
430        let url = format!("ws://{addr}");
431        let config = CdpConfig {
432            connect_timeout: Duration::from_secs(5),
433            command_timeout: Duration::from_secs(5),
434            channel_capacity: 256,
435            reconnect: ReconnectConfig {
436                max_retries: 0,
437                ..ReconnectConfig::default()
438            },
439        };
440        let client = CdpClient::connect(&url, config).await.unwrap();
441        let session = client.create_session("test-target").await.unwrap();
442        // Drain the attachToTarget message
443        let _ = tokio::time::timeout(Duration::from_millis(200), record_rx.recv()).await;
444
445        let mut managed = ManagedSession::new(session);
446        assert!(managed.enabled_domains().is_empty());
447
448        // First enable: should send Page.enable
449        managed.ensure_domain("Page").await.unwrap();
450        let msg = tokio::time::timeout(Duration::from_millis(200), record_rx.recv())
451            .await
452            .unwrap()
453            .unwrap();
454        assert_eq!(msg["method"], "Page.enable");
455        assert!(managed.enabled_domains().contains("Page"));
456
457        // Second enable of same domain: should NOT send anything
458        managed.ensure_domain("Page").await.unwrap();
459        let no_msg = tokio::time::timeout(Duration::from_millis(100), record_rx.recv()).await;
460        assert!(
461            no_msg.is_err(),
462            "No message should be sent for already-enabled domain"
463        );
464
465        // Enable a different domain
466        managed.ensure_domain("Runtime").await.unwrap();
467        let msg2 = tokio::time::timeout(Duration::from_millis(200), record_rx.recv())
468            .await
469            .unwrap()
470            .unwrap();
471        assert_eq!(msg2["method"], "Runtime.enable");
472
473        // Verify final state
474        let domains = managed.enabled_domains();
475        assert!(domains.contains("Page"));
476        assert!(domains.contains("Runtime"));
477        assert_eq!(domains.len(), 2);
478    }
479}