Skip to main content

hashtree_network/
mock.rs

1//! Mock implementations for testing and simulation
2//!
3//! Provides mock relay transport and peer connection factory that use
4//! in-memory channels instead of real Nostr relays and WebRTC.
5
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use tokio::sync::{broadcast, mpsc, RwLock};
11
12use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
13use crate::types::SignalingMessage;
14
15// Global registry for mock channels (shared between offer/answer sides)
16lazy_static::lazy_static! {
17    static ref CHANNEL_REGISTRY: RwLock<HashMap<String, Arc<MockDataChannel>>> = RwLock::new(HashMap::new());
18}
19
20/// Clear the channel registry (call between tests)
21pub async fn clear_channel_registry() {
22    CHANNEL_REGISTRY.write().await.clear();
23}
24
25// ============================================================================
26// Mock Relay Transport
27// ============================================================================
28
29/// Mock relay for in-memory signaling
30pub struct MockRelay {
31    /// Broadcast channel for all messages
32    tx: broadcast::Sender<SignalingMessage>,
33}
34
35impl MockRelay {
36    /// Create a new mock relay
37    pub fn new() -> Arc<Self> {
38        let (tx, _) = broadcast::channel(1000);
39        Arc::new(Self { tx })
40    }
41
42    /// Create a transport connected to this relay
43    pub fn create_transport(&self, peer_id: String) -> MockRelayTransport {
44        MockRelayTransport {
45            peer_id,
46            tx: self.tx.clone(),
47            rx: tokio::sync::Mutex::new(self.tx.subscribe()),
48            buffer: tokio::sync::Mutex::new(Vec::new()),
49            connected: AtomicBool::new(false),
50        }
51    }
52}
53
54impl Default for MockRelay {
55    fn default() -> Self {
56        let (tx, _) = broadcast::channel(1000);
57        Self { tx }
58    }
59}
60
61/// Mock relay transport using broadcast channels
62pub struct MockRelayTransport {
63    peer_id: String,
64    tx: broadcast::Sender<SignalingMessage>,
65    rx: tokio::sync::Mutex<broadcast::Receiver<SignalingMessage>>,
66    buffer: tokio::sync::Mutex<Vec<SignalingMessage>>,
67    connected: AtomicBool,
68}
69
70impl MockRelayTransport {
71    /// Get our peer ID
72    pub fn peer_id_owned(&self) -> String {
73        self.peer_id.clone()
74    }
75}
76
77#[async_trait]
78impl SignalingTransport for MockRelayTransport {
79    async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
80        self.connected.store(true, Ordering::Relaxed);
81        Ok(())
82    }
83
84    async fn disconnect(&self) {
85        self.connected.store(false, Ordering::Relaxed);
86    }
87
88    async fn publish(&self, msg: SignalingMessage) -> Result<(), TransportError> {
89        if !self.connected.load(Ordering::Relaxed) {
90            return Err(TransportError::NotConnected);
91        }
92        self.tx
93            .send(msg)
94            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
95        Ok(())
96    }
97
98    async fn recv(&self) -> Option<SignalingMessage> {
99        // Check buffer first
100        {
101            let mut buffer = self.buffer.lock().await;
102            if !buffer.is_empty() {
103                return Some(buffer.remove(0));
104            }
105        }
106
107        // Wait for next message
108        let mut rx = self.rx.lock().await;
109        loop {
110            match rx.recv().await {
111                Ok(msg) => {
112                    // Filter: only return messages for us or broadcasts
113                    if msg.is_for(&self.peer_id) || msg.target_peer_id().is_none() {
114                        return Some(msg);
115                    }
116                    // Skip messages for other peers
117                }
118                Err(broadcast::error::RecvError::Closed) => return None,
119                Err(broadcast::error::RecvError::Lagged(_)) => continue,
120            }
121        }
122    }
123
124    fn try_recv(&self) -> Option<SignalingMessage> {
125        // Check buffer first
126        if let Ok(mut buffer) = self.buffer.try_lock() {
127            if !buffer.is_empty() {
128                return Some(buffer.remove(0));
129            }
130        }
131
132        // Try non-blocking receive
133        if let Ok(mut rx) = self.rx.try_lock() {
134            loop {
135                match rx.try_recv() {
136                    Ok(msg) => {
137                        if msg.is_for(&self.peer_id) || msg.target_peer_id().is_none() {
138                            return Some(msg);
139                        }
140                        // Skip messages for other peers
141                    }
142                    Err(_) => return None,
143                }
144            }
145        }
146        None
147    }
148
149    fn peer_id(&self) -> &str {
150        &self.peer_id
151    }
152}
153
154// ============================================================================
155// Mock Data Channel
156// ============================================================================
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq)]
159pub enum MockLatencyMode {
160    /// Use real `tokio::time::sleep` for latency simulation.
161    RealSleep,
162    /// Avoid real-time sleeps for faster simulation loops while still preserving
163    /// relative delay through repeated scheduler yields.
164    YieldOnly,
165}
166
167/// Mock data channel using mpsc channels
168pub struct MockDataChannel {
169    peer_id: u64,
170    tx: mpsc::Sender<Vec<u8>>,
171    rx: tokio::sync::Mutex<mpsc::Receiver<Vec<u8>>>,
172    open: AtomicBool,
173    /// Simulated latency per message (ms)
174    latency_ms: u64,
175    /// How latency is realized in async execution.
176    latency_mode: MockLatencyMode,
177}
178
179impl MockDataChannel {
180    /// Create a connected pair of mock channels
181    pub fn pair(id_a: u64, id_b: u64) -> (Self, Self) {
182        Self::pair_with_latency(id_a, id_b, 0)
183    }
184
185    /// Create a connected pair with simulated latency
186    pub fn pair_with_latency(id_a: u64, id_b: u64, latency_ms: u64) -> (Self, Self) {
187        Self::pair_with_latency_mode(id_a, id_b, latency_ms, MockLatencyMode::RealSleep)
188    }
189
190    /// Create a connected pair with explicit latency mode.
191    pub fn pair_with_latency_mode(
192        id_a: u64,
193        id_b: u64,
194        latency_ms: u64,
195        latency_mode: MockLatencyMode,
196    ) -> (Self, Self) {
197        let (tx_a, rx_a) = mpsc::channel(100);
198        let (tx_b, rx_b) = mpsc::channel(100);
199
200        let chan_a = Self {
201            peer_id: id_a,
202            tx: tx_b, // A sends to B's receiver
203            rx: tokio::sync::Mutex::new(rx_a),
204            open: AtomicBool::new(true),
205            latency_ms,
206            latency_mode,
207        };
208
209        let chan_b = Self {
210            peer_id: id_b,
211            tx: tx_a, // B sends to A's receiver
212            rx: tokio::sync::Mutex::new(rx_b),
213            open: AtomicBool::new(true),
214            latency_ms,
215            latency_mode,
216        };
217
218        (chan_a, chan_b)
219    }
220
221    /// Get peer ID
222    pub fn peer_id(&self) -> u64 {
223        self.peer_id
224    }
225}
226
227#[async_trait]
228impl PeerLink for MockDataChannel {
229    async fn send(&self, data: Vec<u8>) -> Result<(), TransportError> {
230        if !self.open.load(Ordering::Relaxed) {
231            return Err(TransportError::Disconnected);
232        }
233
234        // Simulate latency
235        if self.latency_ms > 0 {
236            match self.latency_mode {
237                MockLatencyMode::RealSleep => {
238                    tokio::time::sleep(std::time::Duration::from_millis(self.latency_ms)).await;
239                }
240                MockLatencyMode::YieldOnly => {
241                    for _ in 0..self.latency_ms.max(1) {
242                        tokio::task::yield_now().await;
243                    }
244                }
245            }
246        }
247
248        self.tx
249            .send(data)
250            .await
251            .map_err(|_| TransportError::Disconnected)
252    }
253
254    async fn recv(&self) -> Option<Vec<u8>> {
255        let mut rx = self.rx.lock().await;
256        rx.recv().await
257    }
258
259    fn try_recv(&self) -> Option<Vec<u8>> {
260        let Ok(mut rx) = self.rx.try_lock() else {
261            return None;
262        };
263        rx.try_recv().ok()
264    }
265
266    fn is_open(&self) -> bool {
267        self.open.load(Ordering::Relaxed)
268    }
269
270    async fn close(&self) {
271        self.open.store(false, Ordering::Relaxed);
272    }
273}
274
275// ============================================================================
276// Mock Peer Connection Factory
277// ============================================================================
278
279/// Mock peer connection factory
280///
281/// Creates mock data channels instead of real WebRTC connections.
282/// Uses a global registry to connect offer/answer sides.
283pub struct MockConnectionFactory {
284    our_peer_id: String,
285    our_node_id: u64,
286    /// Simulated latency per link (ms)
287    latency_ms: u64,
288    /// How link latency is realized.
289    latency_mode: MockLatencyMode,
290    /// Pending outbound channels (we sent offer, waiting for answer)
291    pending: RwLock<HashMap<String, Arc<MockDataChannel>>>,
292}
293
294impl MockConnectionFactory {
295    /// Create a new mock connection factory
296    pub fn new(peer_id: String, latency_ms: u64) -> Self {
297        Self::new_with_latency_mode(peer_id, latency_ms, MockLatencyMode::RealSleep)
298    }
299
300    /// Create a mock connection factory with explicit latency mode.
301    pub fn new_with_latency_mode(
302        peer_id: String,
303        latency_ms: u64,
304        latency_mode: MockLatencyMode,
305    ) -> Self {
306        let node_id = peer_id.parse().unwrap_or(0);
307        Self {
308            our_peer_id: peer_id,
309            our_node_id: node_id,
310            latency_ms,
311            latency_mode,
312            pending: RwLock::new(HashMap::new()),
313        }
314    }
315}
316
317#[async_trait]
318impl PeerLinkFactory for MockConnectionFactory {
319    async fn create_offer(
320        &self,
321        target_peer_id: &str,
322    ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
323        let target_node_id: u64 = target_peer_id.parse().unwrap_or(0);
324
325        // Create channel pair
326        let (our_chan, their_chan) = MockDataChannel::pair_with_latency_mode(
327            self.our_node_id,
328            target_node_id,
329            self.latency_ms,
330            self.latency_mode,
331        );
332        let our_chan = Arc::new(our_chan);
333        let their_chan = Arc::new(their_chan);
334
335        // Channel ID is used to link offer/answer
336        let channel_id = format!("{}_{}", self.our_peer_id, target_peer_id);
337
338        // Store our channel for when answer comes back
339        self.pending
340            .write()
341            .await
342            .insert(target_peer_id.to_string(), our_chan.clone());
343
344        // Store their channel in global registry for answerer to find
345        CHANNEL_REGISTRY
346            .write()
347            .await
348            .insert(channel_id.clone(), their_chan);
349
350        Ok((our_chan, channel_id))
351    }
352
353    async fn accept_offer(
354        &self,
355        _from_peer_id: &str,
356        offer_sdp: &str,
357    ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
358        // offer_sdp is the channel_id
359        let channel_id = offer_sdp;
360
361        // Get our channel from the registry
362        let channel = CHANNEL_REGISTRY
363            .write()
364            .await
365            .remove(channel_id)
366            .ok_or_else(|| TransportError::ConnectionFailed("Channel not found".to_string()))?;
367
368        // Answer SDP is just the channel ID (for mock, we don't need real SDP)
369        Ok((channel, channel_id.to_string()))
370    }
371
372    async fn handle_answer(
373        &self,
374        target_peer_id: &str,
375        _answer_sdp: &str,
376    ) -> Result<Arc<dyn PeerLink>, TransportError> {
377        // Get our pending channel
378        let channel = self
379            .pending
380            .write()
381            .await
382            .remove(target_peer_id)
383            .ok_or_else(|| TransportError::ConnectionFailed("No pending connection".to_string()))?;
384
385        Ok(channel)
386    }
387}