Skip to main content

hashtree_network/
mock.rs

1//! Mock implementations for testing and simulation
2//!
3//! Provides mock relay transport and peer connection factory that use
4//! in-memory channels instead of real Nostr relays and WebRTC.
5
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use tokio::sync::{broadcast, mpsc, RwLock};
11
12use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
13use crate::types::SignalingMessage;
14
15// Global registry for mock channels (shared between offer/answer sides)
16lazy_static::lazy_static! {
17    static ref CHANNEL_REGISTRY: RwLock<HashMap<String, Arc<MockDataChannel>>> = RwLock::new(HashMap::new());
18}
19
20/// Clear the channel registry (call between tests)
21pub async fn clear_channel_registry() {
22    CHANNEL_REGISTRY.write().await.clear();
23}
24
25// ============================================================================
26// Mock Relay Transport
27// ============================================================================
28
29/// Mock relay for in-memory signaling
30pub struct MockRelay {
31    /// Broadcast channel for all messages
32    tx: broadcast::Sender<SignalingMessage>,
33}
34
35impl MockRelay {
36    /// Create a new mock relay
37    pub fn new() -> Arc<Self> {
38        Self::new_with_capacity(1000)
39    }
40
41    /// Create a new mock relay with an explicit broadcast buffer capacity.
42    pub fn new_with_capacity(capacity: usize) -> Arc<Self> {
43        let (tx, _) = broadcast::channel(capacity.max(1));
44        Arc::new(Self { tx })
45    }
46
47    /// Create a transport connected to this relay
48    pub fn create_transport(&self, peer_id: String) -> MockRelayTransport {
49        MockRelayTransport {
50            peer_id,
51            tx: self.tx.clone(),
52            rx: tokio::sync::Mutex::new(self.tx.subscribe()),
53            buffer: tokio::sync::Mutex::new(Vec::new()),
54            connected: AtomicBool::new(false),
55        }
56    }
57}
58
59impl Default for MockRelay {
60    fn default() -> Self {
61        let (tx, _) = broadcast::channel(1000);
62        Self { tx }
63    }
64}
65
66/// Mock relay transport using broadcast channels
67pub struct MockRelayTransport {
68    peer_id: String,
69    tx: broadcast::Sender<SignalingMessage>,
70    rx: tokio::sync::Mutex<broadcast::Receiver<SignalingMessage>>,
71    buffer: tokio::sync::Mutex<Vec<SignalingMessage>>,
72    connected: AtomicBool,
73}
74
75impl MockRelayTransport {
76    /// Get our peer ID
77    pub fn peer_id_owned(&self) -> String {
78        self.peer_id.clone()
79    }
80}
81
82#[async_trait]
83impl SignalingTransport for MockRelayTransport {
84    async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
85        self.connected.store(true, Ordering::Relaxed);
86        Ok(())
87    }
88
89    async fn disconnect(&self) {
90        self.connected.store(false, Ordering::Relaxed);
91    }
92
93    async fn publish(&self, msg: SignalingMessage) -> Result<(), TransportError> {
94        if !self.connected.load(Ordering::Relaxed) {
95            return Err(TransportError::NotConnected);
96        }
97        self.tx
98            .send(msg)
99            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
100        Ok(())
101    }
102
103    async fn recv(&self) -> Option<SignalingMessage> {
104        // Check buffer first
105        {
106            let mut buffer = self.buffer.lock().await;
107            if !buffer.is_empty() {
108                return Some(buffer.remove(0));
109            }
110        }
111
112        // Wait for next message
113        let mut rx = self.rx.lock().await;
114        loop {
115            match rx.recv().await {
116                Ok(msg) => {
117                    // Filter: only return messages for us or broadcasts
118                    if msg.is_for(&self.peer_id) || msg.target_peer_id().is_none() {
119                        return Some(msg);
120                    }
121                    // Skip messages for other peers
122                }
123                Err(broadcast::error::RecvError::Closed) => return None,
124                Err(broadcast::error::RecvError::Lagged(_)) => continue,
125            }
126        }
127    }
128
129    fn try_recv(&self) -> Option<SignalingMessage> {
130        // Check buffer first
131        if let Ok(mut buffer) = self.buffer.try_lock() {
132            if !buffer.is_empty() {
133                return Some(buffer.remove(0));
134            }
135        }
136
137        // Try non-blocking receive
138        if let Ok(mut rx) = self.rx.try_lock() {
139            loop {
140                match rx.try_recv() {
141                    Ok(msg) => {
142                        if msg.is_for(&self.peer_id) || msg.target_peer_id().is_none() {
143                            return Some(msg);
144                        }
145                        // Skip messages for other peers
146                    }
147                    Err(_) => return None,
148                }
149            }
150        }
151        None
152    }
153
154    fn peer_id(&self) -> &str {
155        &self.peer_id
156    }
157}
158
159// ============================================================================
160// Mock Data Channel
161// ============================================================================
162
163#[derive(Debug, Clone, Copy, PartialEq, Eq)]
164pub enum MockLatencyMode {
165    /// Use real `tokio::time::sleep` for latency simulation.
166    RealSleep,
167    /// Avoid real-time sleeps for faster simulation loops while still preserving
168    /// relative delay through repeated scheduler yields.
169    YieldOnly,
170}
171
172/// Mock data channel using mpsc channels
173pub struct MockDataChannel {
174    peer_id: u64,
175    tx: mpsc::Sender<Vec<u8>>,
176    rx: tokio::sync::Mutex<mpsc::Receiver<Vec<u8>>>,
177    open: AtomicBool,
178    /// Simulated latency per message (ms)
179    latency_ms: u64,
180    /// How latency is realized in async execution.
181    latency_mode: MockLatencyMode,
182}
183
184impl MockDataChannel {
185    /// Create a connected pair of mock channels
186    pub fn pair(id_a: u64, id_b: u64) -> (Self, Self) {
187        Self::pair_with_latency(id_a, id_b, 0)
188    }
189
190    /// Create a connected pair with simulated latency
191    pub fn pair_with_latency(id_a: u64, id_b: u64, latency_ms: u64) -> (Self, Self) {
192        Self::pair_with_latency_mode(id_a, id_b, latency_ms, MockLatencyMode::RealSleep)
193    }
194
195    /// Create a connected pair with explicit latency mode.
196    pub fn pair_with_latency_mode(
197        id_a: u64,
198        id_b: u64,
199        latency_ms: u64,
200        latency_mode: MockLatencyMode,
201    ) -> (Self, Self) {
202        let (tx_a, rx_a) = mpsc::channel(100);
203        let (tx_b, rx_b) = mpsc::channel(100);
204
205        let chan_a = Self {
206            peer_id: id_a,
207            tx: tx_b, // A sends to B's receiver
208            rx: tokio::sync::Mutex::new(rx_a),
209            open: AtomicBool::new(true),
210            latency_ms,
211            latency_mode,
212        };
213
214        let chan_b = Self {
215            peer_id: id_b,
216            tx: tx_a, // B sends to A's receiver
217            rx: tokio::sync::Mutex::new(rx_b),
218            open: AtomicBool::new(true),
219            latency_ms,
220            latency_mode,
221        };
222
223        (chan_a, chan_b)
224    }
225
226    /// Get peer ID
227    pub fn peer_id(&self) -> u64 {
228        self.peer_id
229    }
230}
231
232#[async_trait]
233impl PeerLink for MockDataChannel {
234    async fn send(&self, data: Vec<u8>) -> Result<(), TransportError> {
235        if !self.open.load(Ordering::Relaxed) {
236            return Err(TransportError::Disconnected);
237        }
238
239        // Simulate latency
240        if self.latency_ms > 0 {
241            match self.latency_mode {
242                MockLatencyMode::RealSleep => {
243                    tokio::time::sleep(std::time::Duration::from_millis(self.latency_ms)).await;
244                }
245                MockLatencyMode::YieldOnly => {
246                    for _ in 0..self.latency_ms.max(1) {
247                        tokio::task::yield_now().await;
248                    }
249                }
250            }
251        }
252
253        if self.latency_mode == MockLatencyMode::YieldOnly {
254            return self
255                .tx
256                .try_send(data)
257                .map_err(|err| TransportError::SendFailed(err.to_string()));
258        }
259
260        self.tx
261            .send(data)
262            .await
263            .map_err(|_| TransportError::Disconnected)
264    }
265
266    async fn recv(&self) -> Option<Vec<u8>> {
267        let mut rx = self.rx.lock().await;
268        rx.recv().await
269    }
270
271    fn try_recv(&self) -> Option<Vec<u8>> {
272        let Ok(mut rx) = self.rx.try_lock() else {
273            return None;
274        };
275        rx.try_recv().ok()
276    }
277
278    fn is_open(&self) -> bool {
279        self.open.load(Ordering::Relaxed)
280    }
281
282    async fn close(&self) {
283        self.open.store(false, Ordering::Relaxed);
284    }
285}
286
287// ============================================================================
288// Mock Peer Connection Factory
289// ============================================================================
290
291/// Mock peer connection factory
292///
293/// Creates mock data channels instead of real WebRTC connections.
294/// Uses a global registry to connect offer/answer sides.
295pub struct MockConnectionFactory {
296    our_peer_id: String,
297    our_node_id: u64,
298    /// Simulated latency per link (ms)
299    latency_ms: u64,
300    /// How link latency is realized.
301    latency_mode: MockLatencyMode,
302    /// Pending outbound channels (we sent offer, waiting for answer)
303    pending: RwLock<HashMap<String, Arc<MockDataChannel>>>,
304}
305
306impl MockConnectionFactory {
307    /// Create a new mock connection factory
308    pub fn new(peer_id: String, latency_ms: u64) -> Self {
309        Self::new_with_latency_mode(peer_id, latency_ms, MockLatencyMode::RealSleep)
310    }
311
312    /// Create a mock connection factory with explicit latency mode.
313    pub fn new_with_latency_mode(
314        peer_id: String,
315        latency_ms: u64,
316        latency_mode: MockLatencyMode,
317    ) -> Self {
318        let node_id = peer_id.parse().unwrap_or(0);
319        Self {
320            our_peer_id: peer_id,
321            our_node_id: node_id,
322            latency_ms,
323            latency_mode,
324            pending: RwLock::new(HashMap::new()),
325        }
326    }
327}
328
329#[async_trait]
330impl PeerLinkFactory for MockConnectionFactory {
331    async fn create_offer(
332        &self,
333        target_peer_id: &str,
334    ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
335        let target_node_id: u64 = target_peer_id.parse().unwrap_or(0);
336
337        // Create channel pair
338        let (our_chan, their_chan) = MockDataChannel::pair_with_latency_mode(
339            self.our_node_id,
340            target_node_id,
341            self.latency_ms,
342            self.latency_mode,
343        );
344        let our_chan = Arc::new(our_chan);
345        let their_chan = Arc::new(their_chan);
346
347        // Channel ID is used to link offer/answer
348        let channel_id = format!("{}_{}", self.our_peer_id, target_peer_id);
349
350        // Store our channel for when answer comes back
351        self.pending
352            .write()
353            .await
354            .insert(target_peer_id.to_string(), our_chan.clone());
355
356        // Store their channel in global registry for answerer to find
357        CHANNEL_REGISTRY
358            .write()
359            .await
360            .insert(channel_id.clone(), their_chan);
361
362        Ok((our_chan, channel_id))
363    }
364
365    async fn accept_offer(
366        &self,
367        _from_peer_id: &str,
368        offer_sdp: &str,
369    ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
370        // offer_sdp is the channel_id
371        let channel_id = offer_sdp;
372
373        // Get our channel from the registry
374        let channel = CHANNEL_REGISTRY
375            .write()
376            .await
377            .remove(channel_id)
378            .ok_or_else(|| TransportError::ConnectionFailed("Channel not found".to_string()))?;
379
380        // Answer SDP is just the channel ID (for mock, we don't need real SDP)
381        Ok((channel, channel_id.to_string()))
382    }
383
384    async fn handle_answer(
385        &self,
386        target_peer_id: &str,
387        _answer_sdp: &str,
388    ) -> Result<Arc<dyn PeerLink>, TransportError> {
389        // Get our pending channel
390        let channel = self
391            .pending
392            .write()
393            .await
394            .remove(target_peer_id)
395            .ok_or_else(|| TransportError::ConnectionFailed("No pending connection".to_string()))?;
396
397        Ok(channel)
398    }
399}