1use super::{InputMessage, Transport};
3use crate::{errors::Result, types::{ControlRequest, ControlResponse, Message}};
4use async_trait::async_trait;
5use futures::stream::{Stream, StreamExt};
6use std::pin::Pin;
7use std::sync::atomic::{AtomicBool, Ordering};
8use tokio::sync::{broadcast, mpsc};
9
10pub struct MockTransportHandle {
12 pub inbound_message_tx: broadcast::Sender<Message>,
14 pub sdk_control_tx: mpsc::Sender<serde_json::Value>,
16 pub outbound_control_rx: mpsc::Receiver<serde_json::Value>,
18 pub outbound_control_request_rx: mpsc::Receiver<serde_json::Value>,
20 pub sent_input_rx: mpsc::Receiver<InputMessage>,
22 pub end_input_rx: mpsc::Receiver<bool>,
24}
25
26pub struct MockTransport {
28 connected: AtomicBool,
29 message_tx: broadcast::Sender<Message>,
31 control_resp_rx: Option<mpsc::Receiver<ControlResponse>>,
33 sdk_control_rx: Option<mpsc::Receiver<serde_json::Value>>,
35 outbound_control_tx: mpsc::Sender<serde_json::Value>,
37 outbound_control_request_tx: mpsc::Sender<serde_json::Value>,
38 sent_input_tx: mpsc::Sender<InputMessage>,
39 end_input_tx: mpsc::Sender<bool>,
40}
41
42impl MockTransport {
43 pub fn pair() -> (Box<dyn Transport + Send>, MockTransportHandle) {
45 let (message_tx, _rx) = broadcast::channel(100);
46 let (sdk_control_tx, sdk_control_rx) = mpsc::channel(100);
47 let (outbound_control_tx, outbound_control_rx) = mpsc::channel(100);
48 let (outbound_control_request_tx, outbound_control_request_rx) = mpsc::channel(100);
49 let (sent_input_tx, sent_input_rx) = mpsc::channel(100);
50 let (end_input_tx, end_input_rx) = mpsc::channel(10);
51
52 let transport = MockTransport {
53 connected: AtomicBool::new(false),
54 message_tx: message_tx.clone(),
55 control_resp_rx: None,
56 sdk_control_rx: Some(sdk_control_rx),
57 outbound_control_tx: outbound_control_tx.clone(),
58 outbound_control_request_tx: outbound_control_request_tx.clone(),
59 sent_input_tx: sent_input_tx.clone(),
60 end_input_tx: end_input_tx.clone(),
61 };
62
63 let handle = MockTransportHandle {
64 inbound_message_tx: message_tx,
65 sdk_control_tx,
66 outbound_control_rx,
67 outbound_control_request_rx,
68 sent_input_rx,
69 end_input_rx,
70 };
71
72 (Box::new(transport), handle)
73 }
74}
75
76#[async_trait]
77impl Transport for MockTransport {
78 fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self }
79
80 async fn connect(&mut self) -> Result<()> {
81 self.connected.store(true, Ordering::SeqCst);
82 Ok(())
83 }
84
85 async fn send_message(&mut self, message: InputMessage) -> Result<()> {
86 let _ = self.sent_input_tx.send(message).await;
87 Ok(())
88 }
89
90 fn receive_messages(&mut self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + 'static>> {
91 let rx = self.message_tx.subscribe();
92 Box::pin(tokio_stream::wrappers::BroadcastStream::new(rx).filter_map(|r| async move {
93 match r {
94 Ok(m) => Some(Ok(m)),
95 Err(_) => None,
96 }
97 }))
98 }
99
100 async fn send_control_request(&mut self, request: ControlRequest) -> Result<()> {
101 let json = match request {
103 ControlRequest::Interrupt { request_id } => serde_json::json!({
104 "type": "control_request",
105 "request": {"type":"interrupt"},
106 "request_id": request_id,
107 }),
108 };
109 let _ = self.outbound_control_request_tx.send(json).await;
110 Ok(())
111 }
112
113 async fn receive_control_response(&mut self) -> Result<Option<ControlResponse>> {
114 if let Some(rx) = &mut self.control_resp_rx { Ok(rx.recv().await) } else { Ok(None) }
115 }
116
117 async fn send_sdk_control_request(&mut self, request: serde_json::Value) -> Result<()> {
118 let _ = self.outbound_control_request_tx.send(request).await;
120 Ok(())
121 }
122
123 async fn send_sdk_control_response(&mut self, response: serde_json::Value) -> Result<()> {
124 let wrapped = serde_json::json!({
126 "type": "control_response",
127 "response": response
128 });
129 let _ = self.outbound_control_tx.send(wrapped).await;
130 Ok(())
131 }
132
133 fn take_sdk_control_receiver(&mut self) -> Option<mpsc::Receiver<serde_json::Value>> {
134 self.sdk_control_rx.take()
135 }
136
137 fn is_connected(&self) -> bool { self.connected.load(Ordering::SeqCst) }
138
139 async fn disconnect(&mut self) -> Result<()> {
140 self.connected.store(false, Ordering::SeqCst);
141 Ok(())
142 }
143
144 async fn end_input(&mut self) -> Result<()> {
145 let _ = self.end_input_tx.send(true).await;
146 Ok(())
147 }
148}