Skip to main content

hashtree_network/
mock.rs

1//! Mock implementations for testing and simulation
2//!
3//! Provides mock relay transport and peer connection factory that use
4//! in-memory channels instead of real Nostr relays and WebRTC.
5
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use tokio::sync::{broadcast, mpsc, RwLock};
11
12use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
13use crate::types::SignalingMessage;
14
15// Global registry for mock channels (shared between offer/answer sides)
16lazy_static::lazy_static! {
17    static ref CHANNEL_REGISTRY: RwLock<HashMap<String, Arc<MockDataChannel>>> = RwLock::new(HashMap::new());
18}
19
20/// Clear the channel registry (call between tests)
21pub async fn clear_channel_registry() {
22    CHANNEL_REGISTRY.write().await.clear();
23}
24
25// ============================================================================
26// Mock Relay Transport
27// ============================================================================
28
29/// Mock relay for in-memory signaling
30pub struct MockRelay {
31    /// Broadcast channel for all messages
32    tx: broadcast::Sender<SignalingMessage>,
33}
34
35impl MockRelay {
36    /// Create a new mock relay
37    pub fn new() -> Arc<Self> {
38        let (tx, _) = broadcast::channel(1000);
39        Arc::new(Self { tx })
40    }
41
42    /// Create a transport connected to this relay
43    pub fn create_transport(&self, peer_id: String, pubkey: String) -> MockRelayTransport {
44        MockRelayTransport {
45            peer_id,
46            pubkey,
47            tx: self.tx.clone(),
48            rx: tokio::sync::Mutex::new(self.tx.subscribe()),
49            buffer: tokio::sync::Mutex::new(Vec::new()),
50            connected: AtomicBool::new(false),
51        }
52    }
53}
54
55impl Default for MockRelay {
56    fn default() -> Self {
57        let (tx, _) = broadcast::channel(1000);
58        Self { tx }
59    }
60}
61
62/// Mock relay transport using broadcast channels
63pub struct MockRelayTransport {
64    peer_id: String,
65    pubkey: String,
66    tx: broadcast::Sender<SignalingMessage>,
67    rx: tokio::sync::Mutex<broadcast::Receiver<SignalingMessage>>,
68    buffer: tokio::sync::Mutex<Vec<SignalingMessage>>,
69    connected: AtomicBool,
70}
71
72impl MockRelayTransport {
73    /// Get our peer ID
74    pub fn peer_id_owned(&self) -> String {
75        self.peer_id.clone()
76    }
77
78    /// Get our pubkey
79    pub fn pubkey_owned(&self) -> String {
80        self.pubkey.clone()
81    }
82}
83
84#[async_trait]
85impl SignalingTransport for MockRelayTransport {
86    async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
87        self.connected.store(true, Ordering::Relaxed);
88        Ok(())
89    }
90
91    async fn disconnect(&self) {
92        self.connected.store(false, Ordering::Relaxed);
93    }
94
95    async fn publish(&self, msg: SignalingMessage) -> Result<(), TransportError> {
96        if !self.connected.load(Ordering::Relaxed) {
97            return Err(TransportError::NotConnected);
98        }
99        self.tx
100            .send(msg)
101            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
102        Ok(())
103    }
104
105    async fn recv(&self) -> Option<SignalingMessage> {
106        // Check buffer first
107        {
108            let mut buffer = self.buffer.lock().await;
109            if !buffer.is_empty() {
110                return Some(buffer.remove(0));
111            }
112        }
113
114        // Wait for next message
115        let mut rx = self.rx.lock().await;
116        loop {
117            match rx.recv().await {
118                Ok(msg) => {
119                    // Filter: only return messages for us or broadcasts
120                    if msg.is_for(&self.peer_id) || msg.target_peer_id().is_none() {
121                        return Some(msg);
122                    }
123                    // Skip messages for other peers
124                }
125                Err(broadcast::error::RecvError::Closed) => return None,
126                Err(broadcast::error::RecvError::Lagged(_)) => continue,
127            }
128        }
129    }
130
131    fn try_recv(&self) -> Option<SignalingMessage> {
132        // Check buffer first
133        if let Ok(mut buffer) = self.buffer.try_lock() {
134            if !buffer.is_empty() {
135                return Some(buffer.remove(0));
136            }
137        }
138
139        // Try non-blocking receive
140        if let Ok(mut rx) = self.rx.try_lock() {
141            loop {
142                match rx.try_recv() {
143                    Ok(msg) => {
144                        if msg.is_for(&self.peer_id) || msg.target_peer_id().is_none() {
145                            return Some(msg);
146                        }
147                        // Skip messages for other peers
148                    }
149                    Err(_) => return None,
150                }
151            }
152        }
153        None
154    }
155
156    fn peer_id(&self) -> &str {
157        &self.peer_id
158    }
159
160    fn pubkey(&self) -> &str {
161        &self.pubkey
162    }
163}
164
165// ============================================================================
166// Mock Data Channel
167// ============================================================================
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
170pub enum MockLatencyMode {
171    /// Use real `tokio::time::sleep` for latency simulation.
172    RealSleep,
173    /// Avoid real-time sleeps for faster simulation loops.
174    YieldOnly,
175}
176
177/// Mock data channel using mpsc channels
178pub struct MockDataChannel {
179    peer_id: u64,
180    tx: mpsc::Sender<Vec<u8>>,
181    rx: tokio::sync::Mutex<mpsc::Receiver<Vec<u8>>>,
182    open: AtomicBool,
183    /// Simulated latency per message (ms)
184    latency_ms: u64,
185    /// How latency is realized in async execution.
186    latency_mode: MockLatencyMode,
187}
188
189impl MockDataChannel {
190    /// Create a connected pair of mock channels
191    pub fn pair(id_a: u64, id_b: u64) -> (Self, Self) {
192        Self::pair_with_latency(id_a, id_b, 0)
193    }
194
195    /// Create a connected pair with simulated latency
196    pub fn pair_with_latency(id_a: u64, id_b: u64, latency_ms: u64) -> (Self, Self) {
197        Self::pair_with_latency_mode(id_a, id_b, latency_ms, MockLatencyMode::RealSleep)
198    }
199
200    /// Create a connected pair with explicit latency mode.
201    pub fn pair_with_latency_mode(
202        id_a: u64,
203        id_b: u64,
204        latency_ms: u64,
205        latency_mode: MockLatencyMode,
206    ) -> (Self, Self) {
207        let (tx_a, rx_a) = mpsc::channel(100);
208        let (tx_b, rx_b) = mpsc::channel(100);
209
210        let chan_a = Self {
211            peer_id: id_a,
212            tx: tx_b, // A sends to B's receiver
213            rx: tokio::sync::Mutex::new(rx_a),
214            open: AtomicBool::new(true),
215            latency_ms,
216            latency_mode,
217        };
218
219        let chan_b = Self {
220            peer_id: id_b,
221            tx: tx_a, // B sends to A's receiver
222            rx: tokio::sync::Mutex::new(rx_b),
223            open: AtomicBool::new(true),
224            latency_ms,
225            latency_mode,
226        };
227
228        (chan_a, chan_b)
229    }
230
231    /// Get peer ID
232    pub fn peer_id(&self) -> u64 {
233        self.peer_id
234    }
235}
236
237#[async_trait]
238impl PeerLink for MockDataChannel {
239    async fn send(&self, data: Vec<u8>) -> Result<(), TransportError> {
240        if !self.open.load(Ordering::Relaxed) {
241            return Err(TransportError::Disconnected);
242        }
243
244        // Simulate latency
245        if self.latency_ms > 0 {
246            match self.latency_mode {
247                MockLatencyMode::RealSleep => {
248                    tokio::time::sleep(std::time::Duration::from_millis(self.latency_ms)).await;
249                }
250                MockLatencyMode::YieldOnly => {
251                    tokio::task::yield_now().await;
252                }
253            }
254        }
255
256        self.tx
257            .send(data)
258            .await
259            .map_err(|_| TransportError::Disconnected)
260    }
261
262    async fn recv(&self) -> Option<Vec<u8>> {
263        let mut rx = self.rx.lock().await;
264        rx.recv().await
265    }
266
267    fn try_recv(&self) -> Option<Vec<u8>> {
268        let Ok(mut rx) = self.rx.try_lock() else {
269            return None;
270        };
271        rx.try_recv().ok()
272    }
273
274    fn is_open(&self) -> bool {
275        self.open.load(Ordering::Relaxed)
276    }
277
278    async fn close(&self) {
279        self.open.store(false, Ordering::Relaxed);
280    }
281}
282
283// ============================================================================
284// Mock Peer Connection Factory
285// ============================================================================
286
287/// Mock peer connection factory
288///
289/// Creates mock data channels instead of real WebRTC connections.
290/// Uses a global registry to connect offer/answer sides.
291pub struct MockConnectionFactory {
292    our_peer_id: String,
293    our_node_id: u64,
294    /// Simulated latency per link (ms)
295    latency_ms: u64,
296    /// How link latency is realized.
297    latency_mode: MockLatencyMode,
298    /// Pending outbound channels (we sent offer, waiting for answer)
299    pending: RwLock<HashMap<String, Arc<MockDataChannel>>>,
300}
301
302impl MockConnectionFactory {
303    /// Create a new mock connection factory
304    pub fn new(peer_id: String, latency_ms: u64) -> Self {
305        Self::new_with_latency_mode(peer_id, latency_ms, MockLatencyMode::RealSleep)
306    }
307
308    /// Create a mock connection factory with explicit latency mode.
309    pub fn new_with_latency_mode(
310        peer_id: String,
311        latency_ms: u64,
312        latency_mode: MockLatencyMode,
313    ) -> Self {
314        let node_id = peer_id.parse().unwrap_or(0);
315        Self {
316            our_peer_id: peer_id,
317            our_node_id: node_id,
318            latency_ms,
319            latency_mode,
320            pending: RwLock::new(HashMap::new()),
321        }
322    }
323}
324
325#[async_trait]
326impl PeerLinkFactory for MockConnectionFactory {
327    async fn create_offer(
328        &self,
329        target_peer_id: &str,
330    ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
331        let target_node_id: u64 = target_peer_id.parse().unwrap_or(0);
332
333        // Create channel pair
334        let (our_chan, their_chan) = MockDataChannel::pair_with_latency_mode(
335            self.our_node_id,
336            target_node_id,
337            self.latency_ms,
338            self.latency_mode,
339        );
340        let our_chan = Arc::new(our_chan);
341        let their_chan = Arc::new(their_chan);
342
343        // Channel ID is used to link offer/answer
344        let channel_id = format!("{}_{}", self.our_peer_id, target_peer_id);
345
346        // Store our channel for when answer comes back
347        self.pending
348            .write()
349            .await
350            .insert(target_peer_id.to_string(), our_chan.clone());
351
352        // Store their channel in global registry for answerer to find
353        CHANNEL_REGISTRY
354            .write()
355            .await
356            .insert(channel_id.clone(), their_chan);
357
358        Ok((our_chan, channel_id))
359    }
360
361    async fn accept_offer(
362        &self,
363        _from_peer_id: &str,
364        offer_sdp: &str,
365    ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
366        // offer_sdp is the channel_id
367        let channel_id = offer_sdp;
368
369        // Get our channel from the registry
370        let channel = CHANNEL_REGISTRY
371            .write()
372            .await
373            .remove(channel_id)
374            .ok_or_else(|| TransportError::ConnectionFailed("Channel not found".to_string()))?;
375
376        // Answer SDP is just the channel ID (for mock, we don't need real SDP)
377        Ok((channel, channel_id.to_string()))
378    }
379
380    async fn handle_answer(
381        &self,
382        target_peer_id: &str,
383        _answer_sdp: &str,
384    ) -> Result<Arc<dyn PeerLink>, TransportError> {
385        // Get our pending channel
386        let channel = self
387            .pending
388            .write()
389            .await
390            .remove(target_peer_id)
391            .ok_or_else(|| TransportError::ConnectionFailed("No pending connection".to_string()))?;
392
393        Ok(channel)
394    }
395}
396
397pub type MockSignalingTransport = MockRelayTransport;