Skip to main content

hashtree_network/
mock.rs

1//! Mock implementations for testing and simulation
2//!
3//! Provides mock relay transport and peer connection factory that use
4//! in-memory channels instead of real Nostr relays and WebRTC.
5
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use tokio::sync::{broadcast, mpsc, RwLock};
11
12use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
13use crate::types::SignalingMessage;
14
15// Global registry for mock channels (shared between offer/answer sides)
16lazy_static::lazy_static! {
17    static ref CHANNEL_REGISTRY: RwLock<HashMap<String, Arc<MockDataChannel>>> = RwLock::new(HashMap::new());
18}
19
20/// Clear the channel registry (call between tests)
21pub async fn clear_channel_registry() {
22    CHANNEL_REGISTRY.write().await.clear();
23}
24
25// ============================================================================
26// Mock Relay Transport
27// ============================================================================
28
29/// Mock relay for in-memory signaling
30pub struct MockRelay {
31    /// Broadcast channel for all messages
32    tx: broadcast::Sender<SignalingMessage>,
33}
34
35impl MockRelay {
36    /// Create a new mock relay
37    pub fn new() -> Arc<Self> {
38        let (tx, _) = broadcast::channel(1000);
39        Arc::new(Self { tx })
40    }
41
42    /// Create a transport connected to this relay
43    pub fn create_transport(&self, peer_id: String) -> MockRelayTransport {
44        MockRelayTransport {
45            peer_id,
46            tx: self.tx.clone(),
47            rx: tokio::sync::Mutex::new(self.tx.subscribe()),
48            buffer: tokio::sync::Mutex::new(Vec::new()),
49            connected: AtomicBool::new(false),
50        }
51    }
52}
53
54impl Default for MockRelay {
55    fn default() -> Self {
56        let (tx, _) = broadcast::channel(1000);
57        Self { tx }
58    }
59}
60
61/// Mock relay transport using broadcast channels
62pub struct MockRelayTransport {
63    peer_id: String,
64    tx: broadcast::Sender<SignalingMessage>,
65    rx: tokio::sync::Mutex<broadcast::Receiver<SignalingMessage>>,
66    buffer: tokio::sync::Mutex<Vec<SignalingMessage>>,
67    connected: AtomicBool,
68}
69
70impl MockRelayTransport {
71    /// Get our peer ID
72    pub fn peer_id_owned(&self) -> String {
73        self.peer_id.clone()
74    }
75}
76
77#[async_trait]
78impl SignalingTransport for MockRelayTransport {
79    async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
80        self.connected.store(true, Ordering::Relaxed);
81        Ok(())
82    }
83
84    async fn disconnect(&self) {
85        self.connected.store(false, Ordering::Relaxed);
86    }
87
88    async fn publish(&self, msg: SignalingMessage) -> Result<(), TransportError> {
89        if !self.connected.load(Ordering::Relaxed) {
90            return Err(TransportError::NotConnected);
91        }
92        self.tx
93            .send(msg)
94            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
95        Ok(())
96    }
97
98    async fn recv(&self) -> Option<SignalingMessage> {
99        // Check buffer first
100        {
101            let mut buffer = self.buffer.lock().await;
102            if !buffer.is_empty() {
103                return Some(buffer.remove(0));
104            }
105        }
106
107        // Wait for next message
108        let mut rx = self.rx.lock().await;
109        loop {
110            match rx.recv().await {
111                Ok(msg) => {
112                    // Filter: only return messages for us or broadcasts
113                    if msg.is_for(&self.peer_id) || msg.target_peer_id().is_none() {
114                        return Some(msg);
115                    }
116                    // Skip messages for other peers
117                }
118                Err(broadcast::error::RecvError::Closed) => return None,
119                Err(broadcast::error::RecvError::Lagged(_)) => continue,
120            }
121        }
122    }
123
124    fn try_recv(&self) -> Option<SignalingMessage> {
125        // Check buffer first
126        if let Ok(mut buffer) = self.buffer.try_lock() {
127            if !buffer.is_empty() {
128                return Some(buffer.remove(0));
129            }
130        }
131
132        // Try non-blocking receive
133        if let Ok(mut rx) = self.rx.try_lock() {
134            loop {
135                match rx.try_recv() {
136                    Ok(msg) => {
137                        if msg.is_for(&self.peer_id) || msg.target_peer_id().is_none() {
138                            return Some(msg);
139                        }
140                        // Skip messages for other peers
141                    }
142                    Err(_) => return None,
143                }
144            }
145        }
146        None
147    }
148
149    fn peer_id(&self) -> &str {
150        &self.peer_id
151    }
152}
153
154// ============================================================================
155// Mock Data Channel
156// ============================================================================
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq)]
159pub enum MockLatencyMode {
160    /// Use real `tokio::time::sleep` for latency simulation.
161    RealSleep,
162    /// Avoid real-time sleeps for faster simulation loops.
163    YieldOnly,
164}
165
166/// Mock data channel using mpsc channels
167pub struct MockDataChannel {
168    peer_id: u64,
169    tx: mpsc::Sender<Vec<u8>>,
170    rx: tokio::sync::Mutex<mpsc::Receiver<Vec<u8>>>,
171    open: AtomicBool,
172    /// Simulated latency per message (ms)
173    latency_ms: u64,
174    /// How latency is realized in async execution.
175    latency_mode: MockLatencyMode,
176}
177
178impl MockDataChannel {
179    /// Create a connected pair of mock channels
180    pub fn pair(id_a: u64, id_b: u64) -> (Self, Self) {
181        Self::pair_with_latency(id_a, id_b, 0)
182    }
183
184    /// Create a connected pair with simulated latency
185    pub fn pair_with_latency(id_a: u64, id_b: u64, latency_ms: u64) -> (Self, Self) {
186        Self::pair_with_latency_mode(id_a, id_b, latency_ms, MockLatencyMode::RealSleep)
187    }
188
189    /// Create a connected pair with explicit latency mode.
190    pub fn pair_with_latency_mode(
191        id_a: u64,
192        id_b: u64,
193        latency_ms: u64,
194        latency_mode: MockLatencyMode,
195    ) -> (Self, Self) {
196        let (tx_a, rx_a) = mpsc::channel(100);
197        let (tx_b, rx_b) = mpsc::channel(100);
198
199        let chan_a = Self {
200            peer_id: id_a,
201            tx: tx_b, // A sends to B's receiver
202            rx: tokio::sync::Mutex::new(rx_a),
203            open: AtomicBool::new(true),
204            latency_ms,
205            latency_mode,
206        };
207
208        let chan_b = Self {
209            peer_id: id_b,
210            tx: tx_a, // B sends to A's receiver
211            rx: tokio::sync::Mutex::new(rx_b),
212            open: AtomicBool::new(true),
213            latency_ms,
214            latency_mode,
215        };
216
217        (chan_a, chan_b)
218    }
219
220    /// Get peer ID
221    pub fn peer_id(&self) -> u64 {
222        self.peer_id
223    }
224}
225
226#[async_trait]
227impl PeerLink for MockDataChannel {
228    async fn send(&self, data: Vec<u8>) -> Result<(), TransportError> {
229        if !self.open.load(Ordering::Relaxed) {
230            return Err(TransportError::Disconnected);
231        }
232
233        // Simulate latency
234        if self.latency_ms > 0 {
235            match self.latency_mode {
236                MockLatencyMode::RealSleep => {
237                    tokio::time::sleep(std::time::Duration::from_millis(self.latency_ms)).await;
238                }
239                MockLatencyMode::YieldOnly => {
240                    tokio::task::yield_now().await;
241                }
242            }
243        }
244
245        self.tx
246            .send(data)
247            .await
248            .map_err(|_| TransportError::Disconnected)
249    }
250
251    async fn recv(&self) -> Option<Vec<u8>> {
252        let mut rx = self.rx.lock().await;
253        rx.recv().await
254    }
255
256    fn try_recv(&self) -> Option<Vec<u8>> {
257        let Ok(mut rx) = self.rx.try_lock() else {
258            return None;
259        };
260        rx.try_recv().ok()
261    }
262
263    fn is_open(&self) -> bool {
264        self.open.load(Ordering::Relaxed)
265    }
266
267    async fn close(&self) {
268        self.open.store(false, Ordering::Relaxed);
269    }
270}
271
272// ============================================================================
273// Mock Peer Connection Factory
274// ============================================================================
275
276/// Mock peer connection factory
277///
278/// Creates mock data channels instead of real WebRTC connections.
279/// Uses a global registry to connect offer/answer sides.
280pub struct MockConnectionFactory {
281    our_peer_id: String,
282    our_node_id: u64,
283    /// Simulated latency per link (ms)
284    latency_ms: u64,
285    /// How link latency is realized.
286    latency_mode: MockLatencyMode,
287    /// Pending outbound channels (we sent offer, waiting for answer)
288    pending: RwLock<HashMap<String, Arc<MockDataChannel>>>,
289}
290
291impl MockConnectionFactory {
292    /// Create a new mock connection factory
293    pub fn new(peer_id: String, latency_ms: u64) -> Self {
294        Self::new_with_latency_mode(peer_id, latency_ms, MockLatencyMode::RealSleep)
295    }
296
297    /// Create a mock connection factory with explicit latency mode.
298    pub fn new_with_latency_mode(
299        peer_id: String,
300        latency_ms: u64,
301        latency_mode: MockLatencyMode,
302    ) -> Self {
303        let node_id = peer_id.parse().unwrap_or(0);
304        Self {
305            our_peer_id: peer_id,
306            our_node_id: node_id,
307            latency_ms,
308            latency_mode,
309            pending: RwLock::new(HashMap::new()),
310        }
311    }
312}
313
314#[async_trait]
315impl PeerLinkFactory for MockConnectionFactory {
316    async fn create_offer(
317        &self,
318        target_peer_id: &str,
319    ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
320        let target_node_id: u64 = target_peer_id.parse().unwrap_or(0);
321
322        // Create channel pair
323        let (our_chan, their_chan) = MockDataChannel::pair_with_latency_mode(
324            self.our_node_id,
325            target_node_id,
326            self.latency_ms,
327            self.latency_mode,
328        );
329        let our_chan = Arc::new(our_chan);
330        let their_chan = Arc::new(their_chan);
331
332        // Channel ID is used to link offer/answer
333        let channel_id = format!("{}_{}", self.our_peer_id, target_peer_id);
334
335        // Store our channel for when answer comes back
336        self.pending
337            .write()
338            .await
339            .insert(target_peer_id.to_string(), our_chan.clone());
340
341        // Store their channel in global registry for answerer to find
342        CHANNEL_REGISTRY
343            .write()
344            .await
345            .insert(channel_id.clone(), their_chan);
346
347        Ok((our_chan, channel_id))
348    }
349
350    async fn accept_offer(
351        &self,
352        _from_peer_id: &str,
353        offer_sdp: &str,
354    ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
355        // offer_sdp is the channel_id
356        let channel_id = offer_sdp;
357
358        // Get our channel from the registry
359        let channel = CHANNEL_REGISTRY
360            .write()
361            .await
362            .remove(channel_id)
363            .ok_or_else(|| TransportError::ConnectionFailed("Channel not found".to_string()))?;
364
365        // Answer SDP is just the channel ID (for mock, we don't need real SDP)
366        Ok((channel, channel_id.to_string()))
367    }
368
369    async fn handle_answer(
370        &self,
371        target_peer_id: &str,
372        _answer_sdp: &str,
373    ) -> Result<Arc<dyn PeerLink>, TransportError> {
374        // Get our pending channel
375        let channel = self
376            .pending
377            .write()
378            .await
379            .remove(target_peer_id)
380            .ok_or_else(|| TransportError::ConnectionFailed("No pending connection".to_string()))?;
381
382        Ok(channel)
383    }
384}