1use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use tokio::sync::{broadcast, mpsc, RwLock};
11
12use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
13use crate::types::SignalingMessage;
14
15lazy_static::lazy_static! {
17 static ref CHANNEL_REGISTRY: RwLock<HashMap<String, Arc<MockDataChannel>>> = RwLock::new(HashMap::new());
18}
19
20pub async fn clear_channel_registry() {
22 CHANNEL_REGISTRY.write().await.clear();
23}
24
25pub struct MockRelay {
31 tx: broadcast::Sender<SignalingMessage>,
33}
34
35impl MockRelay {
36 pub fn new() -> Arc<Self> {
38 let (tx, _) = broadcast::channel(1000);
39 Arc::new(Self { tx })
40 }
41
42 pub fn create_transport(&self, peer_id: String) -> MockRelayTransport {
44 MockRelayTransport {
45 peer_id,
46 tx: self.tx.clone(),
47 rx: tokio::sync::Mutex::new(self.tx.subscribe()),
48 buffer: tokio::sync::Mutex::new(Vec::new()),
49 connected: AtomicBool::new(false),
50 }
51 }
52}
53
54impl Default for MockRelay {
55 fn default() -> Self {
56 let (tx, _) = broadcast::channel(1000);
57 Self { tx }
58 }
59}
60
61pub struct MockRelayTransport {
63 peer_id: String,
64 tx: broadcast::Sender<SignalingMessage>,
65 rx: tokio::sync::Mutex<broadcast::Receiver<SignalingMessage>>,
66 buffer: tokio::sync::Mutex<Vec<SignalingMessage>>,
67 connected: AtomicBool,
68}
69
70impl MockRelayTransport {
71 pub fn peer_id_owned(&self) -> String {
73 self.peer_id.clone()
74 }
75}
76
77#[async_trait]
78impl SignalingTransport for MockRelayTransport {
79 async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
80 self.connected.store(true, Ordering::Relaxed);
81 Ok(())
82 }
83
84 async fn disconnect(&self) {
85 self.connected.store(false, Ordering::Relaxed);
86 }
87
88 async fn publish(&self, msg: SignalingMessage) -> Result<(), TransportError> {
89 if !self.connected.load(Ordering::Relaxed) {
90 return Err(TransportError::NotConnected);
91 }
92 self.tx
93 .send(msg)
94 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
95 Ok(())
96 }
97
98 async fn recv(&self) -> Option<SignalingMessage> {
99 {
101 let mut buffer = self.buffer.lock().await;
102 if !buffer.is_empty() {
103 return Some(buffer.remove(0));
104 }
105 }
106
107 let mut rx = self.rx.lock().await;
109 loop {
110 match rx.recv().await {
111 Ok(msg) => {
112 if msg.is_for(&self.peer_id) || msg.target_peer_id().is_none() {
114 return Some(msg);
115 }
116 }
118 Err(broadcast::error::RecvError::Closed) => return None,
119 Err(broadcast::error::RecvError::Lagged(_)) => continue,
120 }
121 }
122 }
123
124 fn try_recv(&self) -> Option<SignalingMessage> {
125 if let Ok(mut buffer) = self.buffer.try_lock() {
127 if !buffer.is_empty() {
128 return Some(buffer.remove(0));
129 }
130 }
131
132 if let Ok(mut rx) = self.rx.try_lock() {
134 loop {
135 match rx.try_recv() {
136 Ok(msg) => {
137 if msg.is_for(&self.peer_id) || msg.target_peer_id().is_none() {
138 return Some(msg);
139 }
140 }
142 Err(_) => return None,
143 }
144 }
145 }
146 None
147 }
148
149 fn peer_id(&self) -> &str {
150 &self.peer_id
151 }
152}
153
154#[derive(Debug, Clone, Copy, PartialEq, Eq)]
159pub enum MockLatencyMode {
160 RealSleep,
162 YieldOnly,
164}
165
166pub struct MockDataChannel {
168 peer_id: u64,
169 tx: mpsc::Sender<Vec<u8>>,
170 rx: tokio::sync::Mutex<mpsc::Receiver<Vec<u8>>>,
171 open: AtomicBool,
172 latency_ms: u64,
174 latency_mode: MockLatencyMode,
176}
177
178impl MockDataChannel {
179 pub fn pair(id_a: u64, id_b: u64) -> (Self, Self) {
181 Self::pair_with_latency(id_a, id_b, 0)
182 }
183
184 pub fn pair_with_latency(id_a: u64, id_b: u64, latency_ms: u64) -> (Self, Self) {
186 Self::pair_with_latency_mode(id_a, id_b, latency_ms, MockLatencyMode::RealSleep)
187 }
188
189 pub fn pair_with_latency_mode(
191 id_a: u64,
192 id_b: u64,
193 latency_ms: u64,
194 latency_mode: MockLatencyMode,
195 ) -> (Self, Self) {
196 let (tx_a, rx_a) = mpsc::channel(100);
197 let (tx_b, rx_b) = mpsc::channel(100);
198
199 let chan_a = Self {
200 peer_id: id_a,
201 tx: tx_b, rx: tokio::sync::Mutex::new(rx_a),
203 open: AtomicBool::new(true),
204 latency_ms,
205 latency_mode,
206 };
207
208 let chan_b = Self {
209 peer_id: id_b,
210 tx: tx_a, rx: tokio::sync::Mutex::new(rx_b),
212 open: AtomicBool::new(true),
213 latency_ms,
214 latency_mode,
215 };
216
217 (chan_a, chan_b)
218 }
219
220 pub fn peer_id(&self) -> u64 {
222 self.peer_id
223 }
224}
225
226#[async_trait]
227impl PeerLink for MockDataChannel {
228 async fn send(&self, data: Vec<u8>) -> Result<(), TransportError> {
229 if !self.open.load(Ordering::Relaxed) {
230 return Err(TransportError::Disconnected);
231 }
232
233 if self.latency_ms > 0 {
235 match self.latency_mode {
236 MockLatencyMode::RealSleep => {
237 tokio::time::sleep(std::time::Duration::from_millis(self.latency_ms)).await;
238 }
239 MockLatencyMode::YieldOnly => {
240 tokio::task::yield_now().await;
241 }
242 }
243 }
244
245 self.tx
246 .send(data)
247 .await
248 .map_err(|_| TransportError::Disconnected)
249 }
250
251 async fn recv(&self) -> Option<Vec<u8>> {
252 let mut rx = self.rx.lock().await;
253 rx.recv().await
254 }
255
256 fn try_recv(&self) -> Option<Vec<u8>> {
257 let Ok(mut rx) = self.rx.try_lock() else {
258 return None;
259 };
260 rx.try_recv().ok()
261 }
262
263 fn is_open(&self) -> bool {
264 self.open.load(Ordering::Relaxed)
265 }
266
267 async fn close(&self) {
268 self.open.store(false, Ordering::Relaxed);
269 }
270}
271
272pub struct MockConnectionFactory {
281 our_peer_id: String,
282 our_node_id: u64,
283 latency_ms: u64,
285 latency_mode: MockLatencyMode,
287 pending: RwLock<HashMap<String, Arc<MockDataChannel>>>,
289}
290
291impl MockConnectionFactory {
292 pub fn new(peer_id: String, latency_ms: u64) -> Self {
294 Self::new_with_latency_mode(peer_id, latency_ms, MockLatencyMode::RealSleep)
295 }
296
297 pub fn new_with_latency_mode(
299 peer_id: String,
300 latency_ms: u64,
301 latency_mode: MockLatencyMode,
302 ) -> Self {
303 let node_id = peer_id.parse().unwrap_or(0);
304 Self {
305 our_peer_id: peer_id,
306 our_node_id: node_id,
307 latency_ms,
308 latency_mode,
309 pending: RwLock::new(HashMap::new()),
310 }
311 }
312}
313
314#[async_trait]
315impl PeerLinkFactory for MockConnectionFactory {
316 async fn create_offer(
317 &self,
318 target_peer_id: &str,
319 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
320 let target_node_id: u64 = target_peer_id.parse().unwrap_or(0);
321
322 let (our_chan, their_chan) = MockDataChannel::pair_with_latency_mode(
324 self.our_node_id,
325 target_node_id,
326 self.latency_ms,
327 self.latency_mode,
328 );
329 let our_chan = Arc::new(our_chan);
330 let their_chan = Arc::new(their_chan);
331
332 let channel_id = format!("{}_{}", self.our_peer_id, target_peer_id);
334
335 self.pending
337 .write()
338 .await
339 .insert(target_peer_id.to_string(), our_chan.clone());
340
341 CHANNEL_REGISTRY
343 .write()
344 .await
345 .insert(channel_id.clone(), their_chan);
346
347 Ok((our_chan, channel_id))
348 }
349
350 async fn accept_offer(
351 &self,
352 _from_peer_id: &str,
353 offer_sdp: &str,
354 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
355 let channel_id = offer_sdp;
357
358 let channel = CHANNEL_REGISTRY
360 .write()
361 .await
362 .remove(channel_id)
363 .ok_or_else(|| TransportError::ConnectionFailed("Channel not found".to_string()))?;
364
365 Ok((channel, channel_id.to_string()))
367 }
368
369 async fn handle_answer(
370 &self,
371 target_peer_id: &str,
372 _answer_sdp: &str,
373 ) -> Result<Arc<dyn PeerLink>, TransportError> {
374 let channel = self
376 .pending
377 .write()
378 .await
379 .remove(target_peer_id)
380 .ok_or_else(|| TransportError::ConnectionFailed("No pending connection".to_string()))?;
381
382 Ok(channel)
383 }
384}