Skip to main content

victauri_browser/
bridge_dispatch.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde_json::Value;
6use tokio::sync::{Mutex, oneshot};
7
8const DISPATCH_TIMEOUT: Duration = Duration::from_secs(30);
9
10/// Manages in-flight commands sent to the Chrome extension via native messaging.
11///
12/// Each command gets a UUID, is written to the native messaging stdout, and
13/// a oneshot receiver awaits the response from the extension.
14pub struct BridgeDispatch {
15    pending: Arc<Mutex<HashMap<String, oneshot::Sender<DispatchResult>>>>,
16    writer: Arc<Mutex<tokio::io::Stdout>>,
17}
18
19#[derive(Debug)]
20pub struct DispatchResult {
21    pub data: Option<Value>,
22    pub error: Option<String>,
23}
24
25impl BridgeDispatch {
26    #[must_use]
27    pub fn new(writer: tokio::io::Stdout) -> Self {
28        Self {
29            pending: Arc::new(Mutex::new(HashMap::new())),
30            writer: Arc::new(Mutex::new(writer)),
31        }
32    }
33
34    /// Send a command to the extension and await the response.
35    ///
36    /// # Errors
37    ///
38    /// Returns an error if the write fails, the extension disconnects,
39    /// or the command times out (30s).
40    pub async fn dispatch(
41        &self,
42        tab_id: Option<u32>,
43        method: &str,
44        args: Value,
45    ) -> Result<Value, String> {
46        let id = uuid::Uuid::new_v4().to_string();
47
48        let (tx, rx) = oneshot::channel();
49        {
50            let mut pending = self.pending.lock().await;
51            pending.insert(id.clone(), tx);
52        }
53
54        let msg = serde_json::json!({
55            "id": id,
56            "type": "execute",
57            "tab_id": tab_id,
58            "method": method,
59            "args": args,
60        });
61
62        {
63            let mut writer = self.writer.lock().await;
64            crate::native_messaging::write_message(&mut *writer, &msg)
65                .await
66                .map_err(|e| format!("native messaging write failed: {e}"))?;
67        }
68
69        match tokio::time::timeout(DISPATCH_TIMEOUT, rx).await {
70            Ok(Ok(result)) => {
71                if let Some(err) = result.error {
72                    Err(err)
73                } else {
74                    Ok(result.data.unwrap_or(Value::Null))
75                }
76            }
77            Ok(Err(_)) => {
78                self.cleanup_pending(&id).await;
79                Err("extension disconnected while waiting for response".to_string())
80            }
81            Err(_) => {
82                self.cleanup_pending(&id).await;
83                Err(format!(
84                    "timeout ({DISPATCH_TIMEOUT:?}) waiting for {method}"
85                ))
86            }
87        }
88    }
89
90    /// Send a CDP command to the extension.
91    ///
92    /// # Errors
93    ///
94    /// Returns an error on write failure, disconnect, or timeout.
95    #[allow(dead_code)]
96    pub async fn dispatch_cdp(
97        &self,
98        tab_id: u32,
99        domain_method: &str,
100        params: Option<Value>,
101    ) -> Result<Value, String> {
102        let id = uuid::Uuid::new_v4().to_string();
103
104        let (tx, rx) = oneshot::channel();
105        {
106            let mut pending = self.pending.lock().await;
107            pending.insert(id.clone(), tx);
108        }
109
110        let msg = serde_json::json!({
111            "id": id,
112            "type": "cdp",
113            "tab_id": tab_id,
114            "domain_method": domain_method,
115            "params": params.unwrap_or(Value::Null),
116        });
117
118        {
119            let mut writer = self.writer.lock().await;
120            crate::native_messaging::write_message(&mut *writer, &msg)
121                .await
122                .map_err(|e| format!("native messaging write failed: {e}"))?;
123        }
124
125        match tokio::time::timeout(DISPATCH_TIMEOUT, rx).await {
126            Ok(Ok(result)) => {
127                if let Some(err) = result.error {
128                    Err(err)
129                } else {
130                    Ok(result.data.unwrap_or(Value::Null))
131                }
132            }
133            Ok(Err(_)) => {
134                self.cleanup_pending(&id).await;
135                Err("extension disconnected during CDP call".to_string())
136            }
137            Err(_) => {
138                self.cleanup_pending(&id).await;
139                Err(format!(
140                    "timeout ({DISPATCH_TIMEOUT:?}) waiting for CDP {domain_method}"
141                ))
142            }
143        }
144    }
145
146    /// Called by the native messaging read loop when a response arrives.
147    pub async fn on_response(&self, id: &str, data: Option<Value>, error: Option<String>) {
148        let mut pending = self.pending.lock().await;
149        if let Some(tx) = pending.remove(id) {
150            let _ = tx.send(DispatchResult { data, error });
151        }
152    }
153
154    /// Drop all pending commands (e.g. on disconnect).
155    pub async fn cancel_all(&self) {
156        let mut pending = self.pending.lock().await;
157        for (_, tx) in pending.drain() {
158            let _ = tx.send(DispatchResult {
159                data: None,
160                error: Some("extension disconnected".to_string()),
161            });
162        }
163    }
164
165    #[must_use]
166    #[allow(dead_code)]
167    pub async fn pending_count(&self) -> usize {
168        self.pending.lock().await.len()
169    }
170
171    async fn cleanup_pending(&self, id: &str) {
172        let mut pending = self.pending.lock().await;
173        pending.remove(id);
174    }
175
176    /// Return the IDs of all currently pending commands (for testing).
177    pub async fn pending_ids(&self) -> Vec<String> {
178        self.pending.lock().await.keys().cloned().collect()
179    }
180
181    /// Insert a pending command directly and return the receiver (for testing).
182    pub async fn register_test_pending(&self, id: &str) -> oneshot::Receiver<DispatchResult> {
183        let (tx, rx) = oneshot::channel();
184        self.pending.lock().await.insert(id.to_string(), tx);
185        rx
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[tokio::test]
194    async fn on_response_resolves_pending() {
195        let stdout = tokio::io::stdout();
196        let dispatch = BridgeDispatch::new(stdout);
197
198        let (tx, rx) = oneshot::channel();
199        {
200            let mut pending = dispatch.pending.lock().await;
201            pending.insert("test-123".to_string(), tx);
202        }
203
204        dispatch
205            .on_response("test-123", Some(serde_json::json!({"ok": true})), None)
206            .await;
207
208        let result = rx.await.unwrap();
209        assert!(result.error.is_none());
210        assert_eq!(result.data.unwrap(), serde_json::json!({"ok": true}));
211    }
212
213    #[tokio::test]
214    async fn on_response_with_error() {
215        let stdout = tokio::io::stdout();
216        let dispatch = BridgeDispatch::new(stdout);
217
218        let (tx, rx) = oneshot::channel();
219        {
220            let mut pending = dispatch.pending.lock().await;
221            pending.insert("test-456".to_string(), tx);
222        }
223
224        dispatch
225            .on_response("test-456", None, Some("bridge timeout".to_string()))
226            .await;
227
228        let result = rx.await.unwrap();
229        assert_eq!(result.error.unwrap(), "bridge timeout");
230    }
231
232    #[tokio::test]
233    async fn cancel_all_resolves_pending() {
234        let stdout = tokio::io::stdout();
235        let dispatch = BridgeDispatch::new(stdout);
236
237        let (tx, rx) = oneshot::channel();
238        {
239            let mut pending = dispatch.pending.lock().await;
240            pending.insert("test-789".to_string(), tx);
241        }
242
243        dispatch.cancel_all().await;
244
245        let result = rx.await.unwrap();
246        assert!(result.error.is_some());
247        assert_eq!(dispatch.pending_count().await, 0);
248    }
249
250    #[tokio::test]
251    async fn unknown_response_id_ignored() {
252        let stdout = tokio::io::stdout();
253        let dispatch = BridgeDispatch::new(stdout);
254
255        dispatch
256            .on_response("nonexistent", Some(serde_json::json!({})), None)
257            .await;
258
259        assert_eq!(dispatch.pending_count().await, 0);
260    }
261
262    #[tokio::test]
263    async fn pending_count_tracks_insertions() {
264        let stdout = tokio::io::stdout();
265        let dispatch = BridgeDispatch::new(stdout);
266
267        assert_eq!(dispatch.pending_count().await, 0);
268
269        let (tx1, _rx1) = oneshot::channel();
270        let (tx2, _rx2) = oneshot::channel();
271        {
272            let mut pending = dispatch.pending.lock().await;
273            pending.insert("a".to_string(), tx1);
274            pending.insert("b".to_string(), tx2);
275        }
276        assert_eq!(dispatch.pending_count().await, 2);
277
278        dispatch
279            .on_response("a", Some(serde_json::json!({"ok": true})), None)
280            .await;
281        assert_eq!(dispatch.pending_count().await, 1);
282    }
283
284    #[tokio::test]
285    async fn on_response_with_null_data_and_no_error() {
286        let stdout = tokio::io::stdout();
287        let dispatch = BridgeDispatch::new(stdout);
288
289        let (tx, rx) = oneshot::channel();
290        {
291            let mut pending = dispatch.pending.lock().await;
292            pending.insert("test-null".to_string(), tx);
293        }
294
295        dispatch.on_response("test-null", None, None).await;
296
297        let result = rx.await.unwrap();
298        assert!(result.data.is_none());
299        assert!(result.error.is_none());
300    }
301
302    #[tokio::test]
303    async fn cancel_all_with_multiple_pending() {
304        let stdout = tokio::io::stdout();
305        let dispatch = BridgeDispatch::new(stdout);
306
307        let (tx1, rx1) = oneshot::channel();
308        let (tx2, rx2) = oneshot::channel();
309        let (tx3, rx3) = oneshot::channel();
310        {
311            let mut pending = dispatch.pending.lock().await;
312            pending.insert("a".to_string(), tx1);
313            pending.insert("b".to_string(), tx2);
314            pending.insert("c".to_string(), tx3);
315        }
316
317        dispatch.cancel_all().await;
318        assert_eq!(dispatch.pending_count().await, 0);
319
320        for rx in [rx1, rx2, rx3] {
321            let result = rx.await.unwrap();
322            assert!(result.error.is_some());
323            assert!(result.error.unwrap().contains("disconnected"));
324        }
325    }
326
327    #[tokio::test]
328    async fn cancel_all_on_empty_is_noop() {
329        let stdout = tokio::io::stdout();
330        let dispatch = BridgeDispatch::new(stdout);
331        dispatch.cancel_all().await;
332        assert_eq!(dispatch.pending_count().await, 0);
333    }
334
335    // --- Adversarial stress tests ---
336
337    #[tokio::test]
338    async fn concurrent_100_pending_insertions_and_resolutions() {
339        let stdout = tokio::io::stdout();
340        let dispatch = Arc::new(BridgeDispatch::new(stdout));
341
342        let mut receivers = vec![];
343        for i in 0..100 {
344            let (tx, rx) = oneshot::channel();
345            {
346                let mut pending = dispatch.pending.lock().await;
347                pending.insert(format!("stress-{i}"), tx);
348            }
349            receivers.push((i, rx));
350        }
351        assert_eq!(dispatch.pending_count().await, 100);
352
353        let mut handles = vec![];
354        for i in 0..100 {
355            let d = Arc::clone(&dispatch);
356            handles.push(tokio::spawn(async move {
357                d.on_response(
358                    &format!("stress-{i}"),
359                    Some(serde_json::json!({"idx": i})),
360                    None,
361                )
362                .await;
363            }));
364        }
365
366        for h in handles {
367            h.await.unwrap();
368        }
369
370        assert_eq!(dispatch.pending_count().await, 0);
371        for (i, rx) in receivers {
372            let result = rx.await.unwrap();
373            assert_eq!(result.data.unwrap()["idx"], i);
374        }
375    }
376
377    #[tokio::test]
378    async fn resolve_after_cancel_all_is_noop() {
379        let stdout = tokio::io::stdout();
380        let dispatch = BridgeDispatch::new(stdout);
381
382        let (tx, _rx) = oneshot::channel();
383        {
384            let mut pending = dispatch.pending.lock().await;
385            pending.insert("doomed".to_string(), tx);
386        }
387
388        dispatch.cancel_all().await;
389
390        // Trying to resolve after cancel should be a no-op (key already removed)
391        dispatch
392            .on_response("doomed", Some(serde_json::json!({"late": true})), None)
393            .await;
394        assert_eq!(dispatch.pending_count().await, 0);
395    }
396
397    #[tokio::test]
398    async fn duplicate_id_response_only_resolves_once() {
399        let stdout = tokio::io::stdout();
400        let dispatch = BridgeDispatch::new(stdout);
401
402        let (tx, rx) = oneshot::channel();
403        {
404            let mut pending = dispatch.pending.lock().await;
405            pending.insert("dup".to_string(), tx);
406        }
407
408        dispatch
409            .on_response("dup", Some(serde_json::json!({"first": true})), None)
410            .await;
411        // Second response with same ID should be silently ignored
412        dispatch
413            .on_response("dup", Some(serde_json::json!({"second": true})), None)
414            .await;
415
416        let result = rx.await.unwrap();
417        assert_eq!(result.data.unwrap()["first"], true);
418    }
419
420    #[tokio::test]
421    async fn cancel_all_then_insert_new() {
422        let stdout = tokio::io::stdout();
423        let dispatch = BridgeDispatch::new(stdout);
424
425        let (tx1, rx1) = oneshot::channel();
426        {
427            let mut pending = dispatch.pending.lock().await;
428            pending.insert("before".to_string(), tx1);
429        }
430
431        dispatch.cancel_all().await;
432        let result1 = rx1.await.unwrap();
433        assert!(result1.error.is_some());
434
435        // New insertions after cancel should work normally
436        let (tx2, rx2) = oneshot::channel();
437        {
438            let mut pending = dispatch.pending.lock().await;
439            pending.insert("after".to_string(), tx2);
440        }
441        assert_eq!(dispatch.pending_count().await, 1);
442
443        dispatch
444            .on_response("after", Some(serde_json::json!({"ok": true})), None)
445            .await;
446        let result2 = rx2.await.unwrap();
447        assert_eq!(result2.data.unwrap()["ok"], true);
448    }
449
450    #[tokio::test]
451    async fn concurrent_cancel_and_resolve_race() {
452        let stdout = tokio::io::stdout();
453        let dispatch = Arc::new(BridgeDispatch::new(stdout));
454
455        for i in 0..50 {
456            let (tx, _rx) = oneshot::channel();
457            let mut pending = dispatch.pending.lock().await;
458            pending.insert(format!("race-{i}"), tx);
459        }
460
461        let d1 = Arc::clone(&dispatch);
462        let cancel_task = tokio::spawn(async move {
463            d1.cancel_all().await;
464        });
465
466        let d2 = Arc::clone(&dispatch);
467        let resolve_task = tokio::spawn(async move {
468            for i in 0..50 {
469                d2.on_response(&format!("race-{i}"), Some(serde_json::json!({})), None)
470                    .await;
471            }
472        });
473
474        cancel_task.await.unwrap();
475        resolve_task.await.unwrap();
476
477        // Regardless of ordering, pending should be empty
478        assert_eq!(dispatch.pending_count().await, 0);
479    }
480
481    #[tokio::test]
482    async fn on_response_with_both_data_and_error() {
483        let stdout = tokio::io::stdout();
484        let dispatch = BridgeDispatch::new(stdout);
485
486        let (tx, rx) = oneshot::channel();
487        {
488            let mut pending = dispatch.pending.lock().await;
489            pending.insert("both".to_string(), tx);
490        }
491
492        dispatch
493            .on_response(
494                "both",
495                Some(serde_json::json!({"partial": true})),
496                Some("also an error".to_string()),
497            )
498            .await;
499
500        let result = rx.await.unwrap();
501        assert!(result.data.is_some());
502        assert!(result.error.is_some());
503    }
504
505    use std::sync::Arc;
506}