1use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use tokio::sync::{broadcast, mpsc, RwLock};
11
12use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
13use crate::types::SignalingMessage;
14
15lazy_static::lazy_static! {
17 static ref CHANNEL_REGISTRY: RwLock<HashMap<String, Arc<MockDataChannel>>> = RwLock::new(HashMap::new());
18}
19
20pub async fn clear_channel_registry() {
22 CHANNEL_REGISTRY.write().await.clear();
23}
24
25pub struct MockRelay {
31 tx: broadcast::Sender<SignalingMessage>,
33}
34
35impl MockRelay {
36 pub fn new() -> Arc<Self> {
38 let (tx, _) = broadcast::channel(1000);
39 Arc::new(Self { tx })
40 }
41
42 pub fn create_transport(&self, peer_id: String) -> MockRelayTransport {
44 MockRelayTransport {
45 peer_id,
46 tx: self.tx.clone(),
47 rx: tokio::sync::Mutex::new(self.tx.subscribe()),
48 buffer: tokio::sync::Mutex::new(Vec::new()),
49 connected: AtomicBool::new(false),
50 }
51 }
52}
53
54impl Default for MockRelay {
55 fn default() -> Self {
56 let (tx, _) = broadcast::channel(1000);
57 Self { tx }
58 }
59}
60
61pub struct MockRelayTransport {
63 peer_id: String,
64 tx: broadcast::Sender<SignalingMessage>,
65 rx: tokio::sync::Mutex<broadcast::Receiver<SignalingMessage>>,
66 buffer: tokio::sync::Mutex<Vec<SignalingMessage>>,
67 connected: AtomicBool,
68}
69
70impl MockRelayTransport {
71 pub fn peer_id_owned(&self) -> String {
73 self.peer_id.clone()
74 }
75}
76
77#[async_trait]
78impl SignalingTransport for MockRelayTransport {
79 async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
80 self.connected.store(true, Ordering::Relaxed);
81 Ok(())
82 }
83
84 async fn disconnect(&self) {
85 self.connected.store(false, Ordering::Relaxed);
86 }
87
88 async fn publish(&self, msg: SignalingMessage) -> Result<(), TransportError> {
89 if !self.connected.load(Ordering::Relaxed) {
90 return Err(TransportError::NotConnected);
91 }
92 self.tx
93 .send(msg)
94 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
95 Ok(())
96 }
97
98 async fn recv(&self) -> Option<SignalingMessage> {
99 {
101 let mut buffer = self.buffer.lock().await;
102 if !buffer.is_empty() {
103 return Some(buffer.remove(0));
104 }
105 }
106
107 let mut rx = self.rx.lock().await;
109 loop {
110 match rx.recv().await {
111 Ok(msg) => {
112 if msg.is_for(&self.peer_id) || msg.target_peer_id().is_none() {
114 return Some(msg);
115 }
116 }
118 Err(broadcast::error::RecvError::Closed) => return None,
119 Err(broadcast::error::RecvError::Lagged(_)) => continue,
120 }
121 }
122 }
123
124 fn try_recv(&self) -> Option<SignalingMessage> {
125 if let Ok(mut buffer) = self.buffer.try_lock() {
127 if !buffer.is_empty() {
128 return Some(buffer.remove(0));
129 }
130 }
131
132 if let Ok(mut rx) = self.rx.try_lock() {
134 loop {
135 match rx.try_recv() {
136 Ok(msg) => {
137 if msg.is_for(&self.peer_id) || msg.target_peer_id().is_none() {
138 return Some(msg);
139 }
140 }
142 Err(_) => return None,
143 }
144 }
145 }
146 None
147 }
148
149 fn peer_id(&self) -> &str {
150 &self.peer_id
151 }
152}
153
154#[derive(Debug, Clone, Copy, PartialEq, Eq)]
159pub enum MockLatencyMode {
160 RealSleep,
162 YieldOnly,
165}
166
167pub struct MockDataChannel {
169 peer_id: u64,
170 tx: mpsc::Sender<Vec<u8>>,
171 rx: tokio::sync::Mutex<mpsc::Receiver<Vec<u8>>>,
172 open: AtomicBool,
173 latency_ms: u64,
175 latency_mode: MockLatencyMode,
177}
178
179impl MockDataChannel {
180 pub fn pair(id_a: u64, id_b: u64) -> (Self, Self) {
182 Self::pair_with_latency(id_a, id_b, 0)
183 }
184
185 pub fn pair_with_latency(id_a: u64, id_b: u64, latency_ms: u64) -> (Self, Self) {
187 Self::pair_with_latency_mode(id_a, id_b, latency_ms, MockLatencyMode::RealSleep)
188 }
189
190 pub fn pair_with_latency_mode(
192 id_a: u64,
193 id_b: u64,
194 latency_ms: u64,
195 latency_mode: MockLatencyMode,
196 ) -> (Self, Self) {
197 let (tx_a, rx_a) = mpsc::channel(100);
198 let (tx_b, rx_b) = mpsc::channel(100);
199
200 let chan_a = Self {
201 peer_id: id_a,
202 tx: tx_b, rx: tokio::sync::Mutex::new(rx_a),
204 open: AtomicBool::new(true),
205 latency_ms,
206 latency_mode,
207 };
208
209 let chan_b = Self {
210 peer_id: id_b,
211 tx: tx_a, rx: tokio::sync::Mutex::new(rx_b),
213 open: AtomicBool::new(true),
214 latency_ms,
215 latency_mode,
216 };
217
218 (chan_a, chan_b)
219 }
220
221 pub fn peer_id(&self) -> u64 {
223 self.peer_id
224 }
225}
226
227#[async_trait]
228impl PeerLink for MockDataChannel {
229 async fn send(&self, data: Vec<u8>) -> Result<(), TransportError> {
230 if !self.open.load(Ordering::Relaxed) {
231 return Err(TransportError::Disconnected);
232 }
233
234 if self.latency_ms > 0 {
236 match self.latency_mode {
237 MockLatencyMode::RealSleep => {
238 tokio::time::sleep(std::time::Duration::from_millis(self.latency_ms)).await;
239 }
240 MockLatencyMode::YieldOnly => {
241 for _ in 0..self.latency_ms.max(1) {
242 tokio::task::yield_now().await;
243 }
244 }
245 }
246 }
247
248 self.tx
249 .send(data)
250 .await
251 .map_err(|_| TransportError::Disconnected)
252 }
253
254 async fn recv(&self) -> Option<Vec<u8>> {
255 let mut rx = self.rx.lock().await;
256 rx.recv().await
257 }
258
259 fn try_recv(&self) -> Option<Vec<u8>> {
260 let Ok(mut rx) = self.rx.try_lock() else {
261 return None;
262 };
263 rx.try_recv().ok()
264 }
265
266 fn is_open(&self) -> bool {
267 self.open.load(Ordering::Relaxed)
268 }
269
270 async fn close(&self) {
271 self.open.store(false, Ordering::Relaxed);
272 }
273}
274
275pub struct MockConnectionFactory {
284 our_peer_id: String,
285 our_node_id: u64,
286 latency_ms: u64,
288 latency_mode: MockLatencyMode,
290 pending: RwLock<HashMap<String, Arc<MockDataChannel>>>,
292}
293
294impl MockConnectionFactory {
295 pub fn new(peer_id: String, latency_ms: u64) -> Self {
297 Self::new_with_latency_mode(peer_id, latency_ms, MockLatencyMode::RealSleep)
298 }
299
300 pub fn new_with_latency_mode(
302 peer_id: String,
303 latency_ms: u64,
304 latency_mode: MockLatencyMode,
305 ) -> Self {
306 let node_id = peer_id.parse().unwrap_or(0);
307 Self {
308 our_peer_id: peer_id,
309 our_node_id: node_id,
310 latency_ms,
311 latency_mode,
312 pending: RwLock::new(HashMap::new()),
313 }
314 }
315}
316
317#[async_trait]
318impl PeerLinkFactory for MockConnectionFactory {
319 async fn create_offer(
320 &self,
321 target_peer_id: &str,
322 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
323 let target_node_id: u64 = target_peer_id.parse().unwrap_or(0);
324
325 let (our_chan, their_chan) = MockDataChannel::pair_with_latency_mode(
327 self.our_node_id,
328 target_node_id,
329 self.latency_ms,
330 self.latency_mode,
331 );
332 let our_chan = Arc::new(our_chan);
333 let their_chan = Arc::new(their_chan);
334
335 let channel_id = format!("{}_{}", self.our_peer_id, target_peer_id);
337
338 self.pending
340 .write()
341 .await
342 .insert(target_peer_id.to_string(), our_chan.clone());
343
344 CHANNEL_REGISTRY
346 .write()
347 .await
348 .insert(channel_id.clone(), their_chan);
349
350 Ok((our_chan, channel_id))
351 }
352
353 async fn accept_offer(
354 &self,
355 _from_peer_id: &str,
356 offer_sdp: &str,
357 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
358 let channel_id = offer_sdp;
360
361 let channel = CHANNEL_REGISTRY
363 .write()
364 .await
365 .remove(channel_id)
366 .ok_or_else(|| TransportError::ConnectionFailed("Channel not found".to_string()))?;
367
368 Ok((channel, channel_id.to_string()))
370 }
371
372 async fn handle_answer(
373 &self,
374 target_peer_id: &str,
375 _answer_sdp: &str,
376 ) -> Result<Arc<dyn PeerLink>, TransportError> {
377 let channel = self
379 .pending
380 .write()
381 .await
382 .remove(target_peer_id)
383 .ok_or_else(|| TransportError::ConnectionFailed("No pending connection".to_string()))?;
384
385 Ok(channel)
386 }
387}