1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde_json::Value;
6use tokio::sync::{Mutex, oneshot};
7
8const DISPATCH_TIMEOUT: Duration = Duration::from_secs(30);
9
10pub struct BridgeDispatch {
15 pending: Arc<Mutex<HashMap<String, oneshot::Sender<DispatchResult>>>>,
16 writer: Arc<Mutex<tokio::io::Stdout>>,
17}
18
19#[derive(Debug)]
20pub struct DispatchResult {
21 pub data: Option<Value>,
22 pub error: Option<String>,
23}
24
25impl BridgeDispatch {
26 #[must_use]
27 pub fn new(writer: tokio::io::Stdout) -> Self {
28 Self {
29 pending: Arc::new(Mutex::new(HashMap::new())),
30 writer: Arc::new(Mutex::new(writer)),
31 }
32 }
33
34 pub async fn dispatch(
41 &self,
42 tab_id: Option<u32>,
43 method: &str,
44 args: Value,
45 ) -> Result<Value, String> {
46 let id = uuid::Uuid::new_v4().to_string();
47
48 let (tx, rx) = oneshot::channel();
49 {
50 let mut pending = self.pending.lock().await;
51 pending.insert(id.clone(), tx);
52 }
53
54 let msg = serde_json::json!({
55 "id": id,
56 "type": "execute",
57 "tab_id": tab_id,
58 "method": method,
59 "args": args,
60 });
61
62 {
63 let mut writer = self.writer.lock().await;
64 crate::native_messaging::write_message(&mut *writer, &msg)
65 .await
66 .map_err(|e| format!("native messaging write failed: {e}"))?;
67 }
68
69 match tokio::time::timeout(DISPATCH_TIMEOUT, rx).await {
70 Ok(Ok(result)) => {
71 if let Some(err) = result.error {
72 Err(err)
73 } else {
74 Ok(result.data.unwrap_or(Value::Null))
75 }
76 }
77 Ok(Err(_)) => {
78 self.cleanup_pending(&id).await;
79 Err("extension disconnected while waiting for response".to_string())
80 }
81 Err(_) => {
82 self.cleanup_pending(&id).await;
83 Err(format!(
84 "timeout ({DISPATCH_TIMEOUT:?}) waiting for {method}"
85 ))
86 }
87 }
88 }
89
90 #[allow(dead_code)]
96 pub async fn dispatch_cdp(
97 &self,
98 tab_id: u32,
99 domain_method: &str,
100 params: Option<Value>,
101 ) -> Result<Value, String> {
102 let id = uuid::Uuid::new_v4().to_string();
103
104 let (tx, rx) = oneshot::channel();
105 {
106 let mut pending = self.pending.lock().await;
107 pending.insert(id.clone(), tx);
108 }
109
110 let msg = serde_json::json!({
111 "id": id,
112 "type": "cdp",
113 "tab_id": tab_id,
114 "domain_method": domain_method,
115 "params": params.unwrap_or(Value::Null),
116 });
117
118 {
119 let mut writer = self.writer.lock().await;
120 crate::native_messaging::write_message(&mut *writer, &msg)
121 .await
122 .map_err(|e| format!("native messaging write failed: {e}"))?;
123 }
124
125 match tokio::time::timeout(DISPATCH_TIMEOUT, rx).await {
126 Ok(Ok(result)) => {
127 if let Some(err) = result.error {
128 Err(err)
129 } else {
130 Ok(result.data.unwrap_or(Value::Null))
131 }
132 }
133 Ok(Err(_)) => {
134 self.cleanup_pending(&id).await;
135 Err("extension disconnected during CDP call".to_string())
136 }
137 Err(_) => {
138 self.cleanup_pending(&id).await;
139 Err(format!(
140 "timeout ({DISPATCH_TIMEOUT:?}) waiting for CDP {domain_method}"
141 ))
142 }
143 }
144 }
145
146 pub async fn on_response(&self, id: &str, data: Option<Value>, error: Option<String>) {
148 let mut pending = self.pending.lock().await;
149 if let Some(tx) = pending.remove(id) {
150 let _ = tx.send(DispatchResult { data, error });
151 }
152 }
153
154 pub async fn cancel_all(&self) {
156 let mut pending = self.pending.lock().await;
157 for (_, tx) in pending.drain() {
158 let _ = tx.send(DispatchResult {
159 data: None,
160 error: Some("extension disconnected".to_string()),
161 });
162 }
163 }
164
165 #[must_use]
166 #[allow(dead_code)]
167 pub async fn pending_count(&self) -> usize {
168 self.pending.lock().await.len()
169 }
170
171 async fn cleanup_pending(&self, id: &str) {
172 let mut pending = self.pending.lock().await;
173 pending.remove(id);
174 }
175
176 pub async fn pending_ids(&self) -> Vec<String> {
178 self.pending.lock().await.keys().cloned().collect()
179 }
180
181 pub async fn register_test_pending(&self, id: &str) -> oneshot::Receiver<DispatchResult> {
183 let (tx, rx) = oneshot::channel();
184 self.pending.lock().await.insert(id.to_string(), tx);
185 rx
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[tokio::test]
194 async fn on_response_resolves_pending() {
195 let stdout = tokio::io::stdout();
196 let dispatch = BridgeDispatch::new(stdout);
197
198 let (tx, rx) = oneshot::channel();
199 {
200 let mut pending = dispatch.pending.lock().await;
201 pending.insert("test-123".to_string(), tx);
202 }
203
204 dispatch
205 .on_response("test-123", Some(serde_json::json!({"ok": true})), None)
206 .await;
207
208 let result = rx.await.unwrap();
209 assert!(result.error.is_none());
210 assert_eq!(result.data.unwrap(), serde_json::json!({"ok": true}));
211 }
212
213 #[tokio::test]
214 async fn on_response_with_error() {
215 let stdout = tokio::io::stdout();
216 let dispatch = BridgeDispatch::new(stdout);
217
218 let (tx, rx) = oneshot::channel();
219 {
220 let mut pending = dispatch.pending.lock().await;
221 pending.insert("test-456".to_string(), tx);
222 }
223
224 dispatch
225 .on_response("test-456", None, Some("bridge timeout".to_string()))
226 .await;
227
228 let result = rx.await.unwrap();
229 assert_eq!(result.error.unwrap(), "bridge timeout");
230 }
231
232 #[tokio::test]
233 async fn cancel_all_resolves_pending() {
234 let stdout = tokio::io::stdout();
235 let dispatch = BridgeDispatch::new(stdout);
236
237 let (tx, rx) = oneshot::channel();
238 {
239 let mut pending = dispatch.pending.lock().await;
240 pending.insert("test-789".to_string(), tx);
241 }
242
243 dispatch.cancel_all().await;
244
245 let result = rx.await.unwrap();
246 assert!(result.error.is_some());
247 assert_eq!(dispatch.pending_count().await, 0);
248 }
249
250 #[tokio::test]
251 async fn unknown_response_id_ignored() {
252 let stdout = tokio::io::stdout();
253 let dispatch = BridgeDispatch::new(stdout);
254
255 dispatch
256 .on_response("nonexistent", Some(serde_json::json!({})), None)
257 .await;
258
259 assert_eq!(dispatch.pending_count().await, 0);
260 }
261
262 #[tokio::test]
263 async fn pending_count_tracks_insertions() {
264 let stdout = tokio::io::stdout();
265 let dispatch = BridgeDispatch::new(stdout);
266
267 assert_eq!(dispatch.pending_count().await, 0);
268
269 let (tx1, _rx1) = oneshot::channel();
270 let (tx2, _rx2) = oneshot::channel();
271 {
272 let mut pending = dispatch.pending.lock().await;
273 pending.insert("a".to_string(), tx1);
274 pending.insert("b".to_string(), tx2);
275 }
276 assert_eq!(dispatch.pending_count().await, 2);
277
278 dispatch
279 .on_response("a", Some(serde_json::json!({"ok": true})), None)
280 .await;
281 assert_eq!(dispatch.pending_count().await, 1);
282 }
283
284 #[tokio::test]
285 async fn on_response_with_null_data_and_no_error() {
286 let stdout = tokio::io::stdout();
287 let dispatch = BridgeDispatch::new(stdout);
288
289 let (tx, rx) = oneshot::channel();
290 {
291 let mut pending = dispatch.pending.lock().await;
292 pending.insert("test-null".to_string(), tx);
293 }
294
295 dispatch.on_response("test-null", None, None).await;
296
297 let result = rx.await.unwrap();
298 assert!(result.data.is_none());
299 assert!(result.error.is_none());
300 }
301
302 #[tokio::test]
303 async fn cancel_all_with_multiple_pending() {
304 let stdout = tokio::io::stdout();
305 let dispatch = BridgeDispatch::new(stdout);
306
307 let (tx1, rx1) = oneshot::channel();
308 let (tx2, rx2) = oneshot::channel();
309 let (tx3, rx3) = oneshot::channel();
310 {
311 let mut pending = dispatch.pending.lock().await;
312 pending.insert("a".to_string(), tx1);
313 pending.insert("b".to_string(), tx2);
314 pending.insert("c".to_string(), tx3);
315 }
316
317 dispatch.cancel_all().await;
318 assert_eq!(dispatch.pending_count().await, 0);
319
320 for rx in [rx1, rx2, rx3] {
321 let result = rx.await.unwrap();
322 assert!(result.error.is_some());
323 assert!(result.error.unwrap().contains("disconnected"));
324 }
325 }
326
327 #[tokio::test]
328 async fn cancel_all_on_empty_is_noop() {
329 let stdout = tokio::io::stdout();
330 let dispatch = BridgeDispatch::new(stdout);
331 dispatch.cancel_all().await;
332 assert_eq!(dispatch.pending_count().await, 0);
333 }
334
335 #[tokio::test]
338 async fn concurrent_100_pending_insertions_and_resolutions() {
339 let stdout = tokio::io::stdout();
340 let dispatch = Arc::new(BridgeDispatch::new(stdout));
341
342 let mut receivers = vec![];
343 for i in 0..100 {
344 let (tx, rx) = oneshot::channel();
345 {
346 let mut pending = dispatch.pending.lock().await;
347 pending.insert(format!("stress-{i}"), tx);
348 }
349 receivers.push((i, rx));
350 }
351 assert_eq!(dispatch.pending_count().await, 100);
352
353 let mut handles = vec![];
354 for i in 0..100 {
355 let d = Arc::clone(&dispatch);
356 handles.push(tokio::spawn(async move {
357 d.on_response(
358 &format!("stress-{i}"),
359 Some(serde_json::json!({"idx": i})),
360 None,
361 )
362 .await;
363 }));
364 }
365
366 for h in handles {
367 h.await.unwrap();
368 }
369
370 assert_eq!(dispatch.pending_count().await, 0);
371 for (i, rx) in receivers {
372 let result = rx.await.unwrap();
373 assert_eq!(result.data.unwrap()["idx"], i);
374 }
375 }
376
377 #[tokio::test]
378 async fn resolve_after_cancel_all_is_noop() {
379 let stdout = tokio::io::stdout();
380 let dispatch = BridgeDispatch::new(stdout);
381
382 let (tx, _rx) = oneshot::channel();
383 {
384 let mut pending = dispatch.pending.lock().await;
385 pending.insert("doomed".to_string(), tx);
386 }
387
388 dispatch.cancel_all().await;
389
390 dispatch
392 .on_response("doomed", Some(serde_json::json!({"late": true})), None)
393 .await;
394 assert_eq!(dispatch.pending_count().await, 0);
395 }
396
397 #[tokio::test]
398 async fn duplicate_id_response_only_resolves_once() {
399 let stdout = tokio::io::stdout();
400 let dispatch = BridgeDispatch::new(stdout);
401
402 let (tx, rx) = oneshot::channel();
403 {
404 let mut pending = dispatch.pending.lock().await;
405 pending.insert("dup".to_string(), tx);
406 }
407
408 dispatch
409 .on_response("dup", Some(serde_json::json!({"first": true})), None)
410 .await;
411 dispatch
413 .on_response("dup", Some(serde_json::json!({"second": true})), None)
414 .await;
415
416 let result = rx.await.unwrap();
417 assert_eq!(result.data.unwrap()["first"], true);
418 }
419
420 #[tokio::test]
421 async fn cancel_all_then_insert_new() {
422 let stdout = tokio::io::stdout();
423 let dispatch = BridgeDispatch::new(stdout);
424
425 let (tx1, rx1) = oneshot::channel();
426 {
427 let mut pending = dispatch.pending.lock().await;
428 pending.insert("before".to_string(), tx1);
429 }
430
431 dispatch.cancel_all().await;
432 let result1 = rx1.await.unwrap();
433 assert!(result1.error.is_some());
434
435 let (tx2, rx2) = oneshot::channel();
437 {
438 let mut pending = dispatch.pending.lock().await;
439 pending.insert("after".to_string(), tx2);
440 }
441 assert_eq!(dispatch.pending_count().await, 1);
442
443 dispatch
444 .on_response("after", Some(serde_json::json!({"ok": true})), None)
445 .await;
446 let result2 = rx2.await.unwrap();
447 assert_eq!(result2.data.unwrap()["ok"], true);
448 }
449
450 #[tokio::test]
451 async fn concurrent_cancel_and_resolve_race() {
452 let stdout = tokio::io::stdout();
453 let dispatch = Arc::new(BridgeDispatch::new(stdout));
454
455 for i in 0..50 {
456 let (tx, _rx) = oneshot::channel();
457 let mut pending = dispatch.pending.lock().await;
458 pending.insert(format!("race-{i}"), tx);
459 }
460
461 let d1 = Arc::clone(&dispatch);
462 let cancel_task = tokio::spawn(async move {
463 d1.cancel_all().await;
464 });
465
466 let d2 = Arc::clone(&dispatch);
467 let resolve_task = tokio::spawn(async move {
468 for i in 0..50 {
469 d2.on_response(&format!("race-{i}"), Some(serde_json::json!({})), None)
470 .await;
471 }
472 });
473
474 cancel_task.await.unwrap();
475 resolve_task.await.unwrap();
476
477 assert_eq!(dispatch.pending_count().await, 0);
479 }
480
481 #[tokio::test]
482 async fn on_response_with_both_data_and_error() {
483 let stdout = tokio::io::stdout();
484 let dispatch = BridgeDispatch::new(stdout);
485
486 let (tx, rx) = oneshot::channel();
487 {
488 let mut pending = dispatch.pending.lock().await;
489 pending.insert("both".to_string(), tx);
490 }
491
492 dispatch
493 .on_response(
494 "both",
495 Some(serde_json::json!({"partial": true})),
496 Some("also an error".to_string()),
497 )
498 .await;
499
500 let result = rx.await.unwrap();
501 assert!(result.data.is_some());
502 assert!(result.error.is_some());
503 }
504
505 use std::sync::Arc;
506}