1use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use tokio::sync::{broadcast, mpsc, RwLock};
11
12use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
13use crate::types::SignalingMessage;
14
15lazy_static::lazy_static! {
17 static ref CHANNEL_REGISTRY: RwLock<HashMap<String, Arc<MockDataChannel>>> = RwLock::new(HashMap::new());
18}
19
20pub async fn clear_channel_registry() {
22 CHANNEL_REGISTRY.write().await.clear();
23}
24
25pub struct MockRelay {
31 tx: broadcast::Sender<SignalingMessage>,
33}
34
35impl MockRelay {
36 pub fn new() -> Arc<Self> {
38 Self::new_with_capacity(1000)
39 }
40
41 pub fn new_with_capacity(capacity: usize) -> Arc<Self> {
43 let (tx, _) = broadcast::channel(capacity.max(1));
44 Arc::new(Self { tx })
45 }
46
47 pub fn create_transport(&self, peer_id: String) -> MockRelayTransport {
49 MockRelayTransport {
50 peer_id,
51 tx: self.tx.clone(),
52 rx: tokio::sync::Mutex::new(self.tx.subscribe()),
53 buffer: tokio::sync::Mutex::new(Vec::new()),
54 connected: AtomicBool::new(false),
55 }
56 }
57}
58
59impl Default for MockRelay {
60 fn default() -> Self {
61 let (tx, _) = broadcast::channel(1000);
62 Self { tx }
63 }
64}
65
66pub struct MockRelayTransport {
68 peer_id: String,
69 tx: broadcast::Sender<SignalingMessage>,
70 rx: tokio::sync::Mutex<broadcast::Receiver<SignalingMessage>>,
71 buffer: tokio::sync::Mutex<Vec<SignalingMessage>>,
72 connected: AtomicBool,
73}
74
75impl MockRelayTransport {
76 pub fn peer_id_owned(&self) -> String {
78 self.peer_id.clone()
79 }
80}
81
82#[async_trait]
83impl SignalingTransport for MockRelayTransport {
84 async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
85 self.connected.store(true, Ordering::Relaxed);
86 Ok(())
87 }
88
89 async fn disconnect(&self) {
90 self.connected.store(false, Ordering::Relaxed);
91 }
92
93 async fn publish(&self, msg: SignalingMessage) -> Result<(), TransportError> {
94 if !self.connected.load(Ordering::Relaxed) {
95 return Err(TransportError::NotConnected);
96 }
97 self.tx
98 .send(msg)
99 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
100 Ok(())
101 }
102
103 async fn recv(&self) -> Option<SignalingMessage> {
104 {
106 let mut buffer = self.buffer.lock().await;
107 if !buffer.is_empty() {
108 return Some(buffer.remove(0));
109 }
110 }
111
112 let mut rx = self.rx.lock().await;
114 loop {
115 match rx.recv().await {
116 Ok(msg) => {
117 if msg.is_for(&self.peer_id) || msg.target_peer_id().is_none() {
119 return Some(msg);
120 }
121 }
123 Err(broadcast::error::RecvError::Closed) => return None,
124 Err(broadcast::error::RecvError::Lagged(_)) => continue,
125 }
126 }
127 }
128
129 fn try_recv(&self) -> Option<SignalingMessage> {
130 if let Ok(mut buffer) = self.buffer.try_lock() {
132 if !buffer.is_empty() {
133 return Some(buffer.remove(0));
134 }
135 }
136
137 if let Ok(mut rx) = self.rx.try_lock() {
139 loop {
140 match rx.try_recv() {
141 Ok(msg) => {
142 if msg.is_for(&self.peer_id) || msg.target_peer_id().is_none() {
143 return Some(msg);
144 }
145 }
147 Err(_) => return None,
148 }
149 }
150 }
151 None
152 }
153
154 fn peer_id(&self) -> &str {
155 &self.peer_id
156 }
157}
158
159#[derive(Debug, Clone, Copy, PartialEq, Eq)]
164pub enum MockLatencyMode {
165 RealSleep,
167 YieldOnly,
170}
171
172pub struct MockDataChannel {
174 peer_id: u64,
175 tx: mpsc::Sender<Vec<u8>>,
176 rx: tokio::sync::Mutex<mpsc::Receiver<Vec<u8>>>,
177 open: AtomicBool,
178 latency_ms: u64,
180 latency_mode: MockLatencyMode,
182}
183
184impl MockDataChannel {
185 pub fn pair(id_a: u64, id_b: u64) -> (Self, Self) {
187 Self::pair_with_latency(id_a, id_b, 0)
188 }
189
190 pub fn pair_with_latency(id_a: u64, id_b: u64, latency_ms: u64) -> (Self, Self) {
192 Self::pair_with_latency_mode(id_a, id_b, latency_ms, MockLatencyMode::RealSleep)
193 }
194
195 pub fn pair_with_latency_mode(
197 id_a: u64,
198 id_b: u64,
199 latency_ms: u64,
200 latency_mode: MockLatencyMode,
201 ) -> (Self, Self) {
202 let (tx_a, rx_a) = mpsc::channel(100);
203 let (tx_b, rx_b) = mpsc::channel(100);
204
205 let chan_a = Self {
206 peer_id: id_a,
207 tx: tx_b, rx: tokio::sync::Mutex::new(rx_a),
209 open: AtomicBool::new(true),
210 latency_ms,
211 latency_mode,
212 };
213
214 let chan_b = Self {
215 peer_id: id_b,
216 tx: tx_a, rx: tokio::sync::Mutex::new(rx_b),
218 open: AtomicBool::new(true),
219 latency_ms,
220 latency_mode,
221 };
222
223 (chan_a, chan_b)
224 }
225
226 pub fn peer_id(&self) -> u64 {
228 self.peer_id
229 }
230}
231
232#[async_trait]
233impl PeerLink for MockDataChannel {
234 async fn send(&self, data: Vec<u8>) -> Result<(), TransportError> {
235 if !self.open.load(Ordering::Relaxed) {
236 return Err(TransportError::Disconnected);
237 }
238
239 if self.latency_ms > 0 {
241 match self.latency_mode {
242 MockLatencyMode::RealSleep => {
243 tokio::time::sleep(std::time::Duration::from_millis(self.latency_ms)).await;
244 }
245 MockLatencyMode::YieldOnly => {
246 for _ in 0..self.latency_ms.max(1) {
247 tokio::task::yield_now().await;
248 }
249 }
250 }
251 }
252
253 if self.latency_mode == MockLatencyMode::YieldOnly {
254 return self
255 .tx
256 .try_send(data)
257 .map_err(|err| TransportError::SendFailed(err.to_string()));
258 }
259
260 self.tx
261 .send(data)
262 .await
263 .map_err(|_| TransportError::Disconnected)
264 }
265
266 async fn recv(&self) -> Option<Vec<u8>> {
267 let mut rx = self.rx.lock().await;
268 rx.recv().await
269 }
270
271 fn try_recv(&self) -> Option<Vec<u8>> {
272 let Ok(mut rx) = self.rx.try_lock() else {
273 return None;
274 };
275 rx.try_recv().ok()
276 }
277
278 fn is_open(&self) -> bool {
279 self.open.load(Ordering::Relaxed)
280 }
281
282 async fn close(&self) {
283 self.open.store(false, Ordering::Relaxed);
284 }
285}
286
287pub struct MockConnectionFactory {
296 our_peer_id: String,
297 our_node_id: u64,
298 latency_ms: u64,
300 latency_mode: MockLatencyMode,
302 pending: RwLock<HashMap<String, Arc<MockDataChannel>>>,
304}
305
306impl MockConnectionFactory {
307 pub fn new(peer_id: String, latency_ms: u64) -> Self {
309 Self::new_with_latency_mode(peer_id, latency_ms, MockLatencyMode::RealSleep)
310 }
311
312 pub fn new_with_latency_mode(
314 peer_id: String,
315 latency_ms: u64,
316 latency_mode: MockLatencyMode,
317 ) -> Self {
318 let node_id = peer_id.parse().unwrap_or(0);
319 Self {
320 our_peer_id: peer_id,
321 our_node_id: node_id,
322 latency_ms,
323 latency_mode,
324 pending: RwLock::new(HashMap::new()),
325 }
326 }
327}
328
329#[async_trait]
330impl PeerLinkFactory for MockConnectionFactory {
331 async fn create_offer(
332 &self,
333 target_peer_id: &str,
334 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
335 let target_node_id: u64 = target_peer_id.parse().unwrap_or(0);
336
337 let (our_chan, their_chan) = MockDataChannel::pair_with_latency_mode(
339 self.our_node_id,
340 target_node_id,
341 self.latency_ms,
342 self.latency_mode,
343 );
344 let our_chan = Arc::new(our_chan);
345 let their_chan = Arc::new(their_chan);
346
347 let channel_id = format!("{}_{}", self.our_peer_id, target_peer_id);
349
350 self.pending
352 .write()
353 .await
354 .insert(target_peer_id.to_string(), our_chan.clone());
355
356 CHANNEL_REGISTRY
358 .write()
359 .await
360 .insert(channel_id.clone(), their_chan);
361
362 Ok((our_chan, channel_id))
363 }
364
365 async fn accept_offer(
366 &self,
367 _from_peer_id: &str,
368 offer_sdp: &str,
369 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
370 let channel_id = offer_sdp;
372
373 let channel = CHANNEL_REGISTRY
375 .write()
376 .await
377 .remove(channel_id)
378 .ok_or_else(|| TransportError::ConnectionFailed("Channel not found".to_string()))?;
379
380 Ok((channel, channel_id.to_string()))
382 }
383
384 async fn handle_answer(
385 &self,
386 target_peer_id: &str,
387 _answer_sdp: &str,
388 ) -> Result<Arc<dyn PeerLink>, TransportError> {
389 let channel = self
391 .pending
392 .write()
393 .await
394 .remove(target_peer_id)
395 .ok_or_else(|| TransportError::ConnectionFailed("No pending connection".to_string()))?;
396
397 Ok(channel)
398 }
399}