1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde_json::Value;
6use tokio::sync::{Mutex, oneshot};
7
8const DISPATCH_TIMEOUT: Duration = Duration::from_secs(30);
9
10const MAX_PENDING: usize = 1024;
14
15pub struct BridgeDispatch {
20 pending: Arc<Mutex<HashMap<String, oneshot::Sender<DispatchResult>>>>,
21 writer: Arc<Mutex<tokio::io::Stdout>>,
22}
23
24#[derive(Debug)]
25pub struct DispatchResult {
26 pub data: Option<Value>,
27 pub error: Option<String>,
28}
29
30impl BridgeDispatch {
31 #[must_use]
32 pub fn new(writer: tokio::io::Stdout) -> Self {
33 Self {
34 pending: Arc::new(Mutex::new(HashMap::new())),
35 writer: Arc::new(Mutex::new(writer)),
36 }
37 }
38
39 pub async fn dispatch(
46 &self,
47 tab_id: Option<u32>,
48 method: &str,
49 args: Value,
50 ) -> Result<Value, String> {
51 let id = uuid::Uuid::new_v4().to_string();
52
53 let (tx, rx) = oneshot::channel();
54 {
55 let mut pending = self.pending.lock().await;
56 if pending.len() >= MAX_PENDING {
57 return Err(format!(
58 "too many in-flight commands ({MAX_PENDING}); extension unresponsive"
59 ));
60 }
61 pending.insert(id.clone(), tx);
62 }
63
64 let msg = serde_json::json!({
65 "id": id,
66 "type": "execute",
67 "tab_id": tab_id,
68 "method": method,
69 "args": args,
70 });
71
72 {
73 let mut writer = self.writer.lock().await;
74 crate::native_messaging::write_message(&mut *writer, &msg)
75 .await
76 .map_err(|e| format!("native messaging write failed: {e}"))?;
77 }
78
79 match tokio::time::timeout(DISPATCH_TIMEOUT, rx).await {
80 Ok(Ok(result)) => {
81 if let Some(err) = result.error {
82 Err(err)
83 } else {
84 Ok(result.data.unwrap_or(Value::Null))
85 }
86 }
87 Ok(Err(_)) => {
88 self.cleanup_pending(&id).await;
89 Err("extension disconnected while waiting for response".to_string())
90 }
91 Err(_) => {
92 self.cleanup_pending(&id).await;
93 Err(format!(
94 "timeout ({DISPATCH_TIMEOUT:?}) waiting for {method}"
95 ))
96 }
97 }
98 }
99
100 pub async fn on_response(&self, id: &str, data: Option<Value>, error: Option<String>) {
102 let mut pending = self.pending.lock().await;
103 if let Some(tx) = pending.remove(id) {
104 let _ = tx.send(DispatchResult { data, error });
105 }
106 }
107
108 pub async fn cancel_all(&self) {
110 let mut pending = self.pending.lock().await;
111 for (_, tx) in pending.drain() {
112 let _ = tx.send(DispatchResult {
113 data: None,
114 error: Some("extension disconnected".to_string()),
115 });
116 }
117 }
118
119 #[must_use]
120 #[allow(dead_code)]
121 pub async fn pending_count(&self) -> usize {
122 self.pending.lock().await.len()
123 }
124
125 async fn cleanup_pending(&self, id: &str) {
126 let mut pending = self.pending.lock().await;
127 pending.remove(id);
128 }
129
130 pub async fn pending_ids(&self) -> Vec<String> {
132 self.pending.lock().await.keys().cloned().collect()
133 }
134
135 pub async fn register_test_pending(&self, id: &str) -> oneshot::Receiver<DispatchResult> {
137 let (tx, rx) = oneshot::channel();
138 self.pending.lock().await.insert(id.to_string(), tx);
139 rx
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[tokio::test]
148 async fn on_response_resolves_pending() {
149 let stdout = tokio::io::stdout();
150 let dispatch = BridgeDispatch::new(stdout);
151
152 let (tx, rx) = oneshot::channel();
153 {
154 let mut pending = dispatch.pending.lock().await;
155 pending.insert("test-123".to_string(), tx);
156 }
157
158 dispatch
159 .on_response("test-123", Some(serde_json::json!({"ok": true})), None)
160 .await;
161
162 let result = rx.await.unwrap();
163 assert!(result.error.is_none());
164 assert_eq!(result.data.unwrap(), serde_json::json!({"ok": true}));
165 }
166
167 #[tokio::test]
168 async fn on_response_with_error() {
169 let stdout = tokio::io::stdout();
170 let dispatch = BridgeDispatch::new(stdout);
171
172 let (tx, rx) = oneshot::channel();
173 {
174 let mut pending = dispatch.pending.lock().await;
175 pending.insert("test-456".to_string(), tx);
176 }
177
178 dispatch
179 .on_response("test-456", None, Some("bridge timeout".to_string()))
180 .await;
181
182 let result = rx.await.unwrap();
183 assert_eq!(result.error.unwrap(), "bridge timeout");
184 }
185
186 #[tokio::test]
187 async fn cancel_all_resolves_pending() {
188 let stdout = tokio::io::stdout();
189 let dispatch = BridgeDispatch::new(stdout);
190
191 let (tx, rx) = oneshot::channel();
192 {
193 let mut pending = dispatch.pending.lock().await;
194 pending.insert("test-789".to_string(), tx);
195 }
196
197 dispatch.cancel_all().await;
198
199 let result = rx.await.unwrap();
200 assert!(result.error.is_some());
201 assert_eq!(dispatch.pending_count().await, 0);
202 }
203
204 #[tokio::test]
205 async fn unknown_response_id_ignored() {
206 let stdout = tokio::io::stdout();
207 let dispatch = BridgeDispatch::new(stdout);
208
209 dispatch
210 .on_response("nonexistent", Some(serde_json::json!({})), None)
211 .await;
212
213 assert_eq!(dispatch.pending_count().await, 0);
214 }
215
216 #[tokio::test]
217 async fn pending_count_tracks_insertions() {
218 let stdout = tokio::io::stdout();
219 let dispatch = BridgeDispatch::new(stdout);
220
221 assert_eq!(dispatch.pending_count().await, 0);
222
223 let (tx1, _rx1) = oneshot::channel();
224 let (tx2, _rx2) = oneshot::channel();
225 {
226 let mut pending = dispatch.pending.lock().await;
227 pending.insert("a".to_string(), tx1);
228 pending.insert("b".to_string(), tx2);
229 }
230 assert_eq!(dispatch.pending_count().await, 2);
231
232 dispatch
233 .on_response("a", Some(serde_json::json!({"ok": true})), None)
234 .await;
235 assert_eq!(dispatch.pending_count().await, 1);
236 }
237
238 #[tokio::test]
239 async fn on_response_with_null_data_and_no_error() {
240 let stdout = tokio::io::stdout();
241 let dispatch = BridgeDispatch::new(stdout);
242
243 let (tx, rx) = oneshot::channel();
244 {
245 let mut pending = dispatch.pending.lock().await;
246 pending.insert("test-null".to_string(), tx);
247 }
248
249 dispatch.on_response("test-null", None, None).await;
250
251 let result = rx.await.unwrap();
252 assert!(result.data.is_none());
253 assert!(result.error.is_none());
254 }
255
256 #[tokio::test]
257 async fn cancel_all_with_multiple_pending() {
258 let stdout = tokio::io::stdout();
259 let dispatch = BridgeDispatch::new(stdout);
260
261 let (tx1, rx1) = oneshot::channel();
262 let (tx2, rx2) = oneshot::channel();
263 let (tx3, rx3) = oneshot::channel();
264 {
265 let mut pending = dispatch.pending.lock().await;
266 pending.insert("a".to_string(), tx1);
267 pending.insert("b".to_string(), tx2);
268 pending.insert("c".to_string(), tx3);
269 }
270
271 dispatch.cancel_all().await;
272 assert_eq!(dispatch.pending_count().await, 0);
273
274 for rx in [rx1, rx2, rx3] {
275 let result = rx.await.unwrap();
276 assert!(result.error.is_some());
277 assert!(result.error.unwrap().contains("disconnected"));
278 }
279 }
280
281 #[tokio::test]
282 async fn cancel_all_on_empty_is_noop() {
283 let stdout = tokio::io::stdout();
284 let dispatch = BridgeDispatch::new(stdout);
285 dispatch.cancel_all().await;
286 assert_eq!(dispatch.pending_count().await, 0);
287 }
288
289 #[tokio::test]
292 async fn concurrent_100_pending_insertions_and_resolutions() {
293 let stdout = tokio::io::stdout();
294 let dispatch = Arc::new(BridgeDispatch::new(stdout));
295
296 let mut receivers = vec![];
297 for i in 0..100 {
298 let (tx, rx) = oneshot::channel();
299 {
300 let mut pending = dispatch.pending.lock().await;
301 pending.insert(format!("stress-{i}"), tx);
302 }
303 receivers.push((i, rx));
304 }
305 assert_eq!(dispatch.pending_count().await, 100);
306
307 let mut handles = vec![];
308 for i in 0..100 {
309 let d = Arc::clone(&dispatch);
310 handles.push(tokio::spawn(async move {
311 d.on_response(
312 &format!("stress-{i}"),
313 Some(serde_json::json!({"idx": i})),
314 None,
315 )
316 .await;
317 }));
318 }
319
320 for h in handles {
321 h.await.unwrap();
322 }
323
324 assert_eq!(dispatch.pending_count().await, 0);
325 for (i, rx) in receivers {
326 let result = rx.await.unwrap();
327 assert_eq!(result.data.unwrap()["idx"], i);
328 }
329 }
330
331 #[tokio::test]
332 async fn resolve_after_cancel_all_is_noop() {
333 let stdout = tokio::io::stdout();
334 let dispatch = BridgeDispatch::new(stdout);
335
336 let (tx, _rx) = oneshot::channel();
337 {
338 let mut pending = dispatch.pending.lock().await;
339 pending.insert("doomed".to_string(), tx);
340 }
341
342 dispatch.cancel_all().await;
343
344 dispatch
346 .on_response("doomed", Some(serde_json::json!({"late": true})), None)
347 .await;
348 assert_eq!(dispatch.pending_count().await, 0);
349 }
350
351 #[tokio::test]
352 async fn duplicate_id_response_only_resolves_once() {
353 let stdout = tokio::io::stdout();
354 let dispatch = BridgeDispatch::new(stdout);
355
356 let (tx, rx) = oneshot::channel();
357 {
358 let mut pending = dispatch.pending.lock().await;
359 pending.insert("dup".to_string(), tx);
360 }
361
362 dispatch
363 .on_response("dup", Some(serde_json::json!({"first": true})), None)
364 .await;
365 dispatch
367 .on_response("dup", Some(serde_json::json!({"second": true})), None)
368 .await;
369
370 let result = rx.await.unwrap();
371 assert_eq!(result.data.unwrap()["first"], true);
372 }
373
374 #[tokio::test]
375 async fn cancel_all_then_insert_new() {
376 let stdout = tokio::io::stdout();
377 let dispatch = BridgeDispatch::new(stdout);
378
379 let (tx1, rx1) = oneshot::channel();
380 {
381 let mut pending = dispatch.pending.lock().await;
382 pending.insert("before".to_string(), tx1);
383 }
384
385 dispatch.cancel_all().await;
386 let result1 = rx1.await.unwrap();
387 assert!(result1.error.is_some());
388
389 let (tx2, rx2) = oneshot::channel();
391 {
392 let mut pending = dispatch.pending.lock().await;
393 pending.insert("after".to_string(), tx2);
394 }
395 assert_eq!(dispatch.pending_count().await, 1);
396
397 dispatch
398 .on_response("after", Some(serde_json::json!({"ok": true})), None)
399 .await;
400 let result2 = rx2.await.unwrap();
401 assert_eq!(result2.data.unwrap()["ok"], true);
402 }
403
404 #[tokio::test]
405 async fn concurrent_cancel_and_resolve_race() {
406 let stdout = tokio::io::stdout();
407 let dispatch = Arc::new(BridgeDispatch::new(stdout));
408
409 for i in 0..50 {
410 let (tx, _rx) = oneshot::channel();
411 let mut pending = dispatch.pending.lock().await;
412 pending.insert(format!("race-{i}"), tx);
413 }
414
415 let d1 = Arc::clone(&dispatch);
416 let cancel_task = tokio::spawn(async move {
417 d1.cancel_all().await;
418 });
419
420 let d2 = Arc::clone(&dispatch);
421 let resolve_task = tokio::spawn(async move {
422 for i in 0..50 {
423 d2.on_response(&format!("race-{i}"), Some(serde_json::json!({})), None)
424 .await;
425 }
426 });
427
428 cancel_task.await.unwrap();
429 resolve_task.await.unwrap();
430
431 assert_eq!(dispatch.pending_count().await, 0);
433 }
434
435 #[tokio::test]
436 async fn on_response_with_both_data_and_error() {
437 let stdout = tokio::io::stdout();
438 let dispatch = BridgeDispatch::new(stdout);
439
440 let (tx, rx) = oneshot::channel();
441 {
442 let mut pending = dispatch.pending.lock().await;
443 pending.insert("both".to_string(), tx);
444 }
445
446 dispatch
447 .on_response(
448 "both",
449 Some(serde_json::json!({"partial": true})),
450 Some("also an error".to_string()),
451 )
452 .await;
453
454 let result = rx.await.unwrap();
455 assert!(result.data.is_some());
456 assert!(result.error.is_some());
457 }
458
459 use std::sync::Arc;
460}