Skip to main content

hashtree_network/
channel.rs

1//! Peer communication channel abstraction
2//!
3//! Trait for sending/receiving bytes to/from peers.
4//! Implementations:
5//! - WebRTC data channels (production)
6//! - MockChannel (testing/simulation)
7
8use async_trait::async_trait;
9use std::time::Duration;
10use thiserror::Error;
11
12/// Channel errors
13#[derive(Debug, Clone, Error)]
14pub enum ChannelError {
15    #[error("channel disconnected")]
16    Disconnected,
17    #[error("operation timed out")]
18    Timeout,
19    #[error("send failed: {0}")]
20    SendFailed(String),
21}
22
23/// A channel to a single peer for sending/receiving bytes
24///
25/// This is the core abstraction that allows the same P2P logic to work with:
26/// - Real WebRTC data channels (production)
27/// - Mock in-memory channels (simulation/testing)
28#[async_trait]
29pub trait PeerChannel: Send + Sync {
30    /// Remote peer ID (string to support both Nostr pubkeys and numeric IDs)
31    fn peer_id(&self) -> &str;
32
33    /// Send bytes to peer
34    async fn send(&self, data: Vec<u8>) -> Result<(), ChannelError>;
35
36    /// Receive bytes from peer (with timeout)
37    async fn recv(&self, timeout: Duration) -> Result<Vec<u8>, ChannelError>;
38
39    /// Check if channel is still connected
40    fn is_connected(&self) -> bool;
41
42    /// Bytes sent through this channel
43    fn bytes_sent(&self) -> u64;
44
45    /// Bytes received through this channel
46    fn bytes_received(&self) -> u64;
47}
48
49/// Mock channel for testing - instant delivery via mpsc
50pub struct MockChannel {
51    peer_id: String,
52    tx: tokio::sync::mpsc::Sender<Vec<u8>>,
53    rx: tokio::sync::Mutex<tokio::sync::mpsc::Receiver<Vec<u8>>>,
54    bytes_sent: std::sync::atomic::AtomicU64,
55    bytes_received: std::sync::atomic::AtomicU64,
56}
57
58impl MockChannel {
59    /// Create a pair of connected mock channels
60    pub fn pair(id_a: impl Into<String>, id_b: impl Into<String>) -> (Self, Self) {
61        let id_a = id_a.into();
62        let id_b = id_b.into();
63        let (tx_a, rx_a) = tokio::sync::mpsc::channel(100);
64        let (tx_b, rx_b) = tokio::sync::mpsc::channel(100);
65
66        let chan_a = MockChannel {
67            peer_id: id_b,
68            tx: tx_b,
69            rx: tokio::sync::Mutex::new(rx_a),
70            bytes_sent: std::sync::atomic::AtomicU64::new(0),
71            bytes_received: std::sync::atomic::AtomicU64::new(0),
72        };
73
74        let chan_b = MockChannel {
75            peer_id: id_a,
76            tx: tx_a,
77            rx: tokio::sync::Mutex::new(rx_b),
78            bytes_sent: std::sync::atomic::AtomicU64::new(0),
79            bytes_received: std::sync::atomic::AtomicU64::new(0),
80        };
81
82        (chan_a, chan_b)
83    }
84}
85
86#[async_trait]
87impl PeerChannel for MockChannel {
88    fn peer_id(&self) -> &str {
89        &self.peer_id
90    }
91
92    async fn send(&self, data: Vec<u8>) -> Result<(), ChannelError> {
93        let len = data.len() as u64;
94        self.tx
95            .send(data)
96            .await
97            .map_err(|_| ChannelError::Disconnected)?;
98        self.bytes_sent
99            .fetch_add(len, std::sync::atomic::Ordering::Relaxed);
100        Ok(())
101    }
102
103    async fn recv(&self, timeout: Duration) -> Result<Vec<u8>, ChannelError> {
104        let mut rx = self.rx.lock().await;
105        match tokio::time::timeout(timeout, rx.recv()).await {
106            Ok(Some(data)) => {
107                self.bytes_received
108                    .fetch_add(data.len() as u64, std::sync::atomic::Ordering::Relaxed);
109                Ok(data)
110            }
111            Ok(None) => Err(ChannelError::Disconnected),
112            Err(_) => Err(ChannelError::Timeout),
113        }
114    }
115
116    fn is_connected(&self) -> bool {
117        !self.tx.is_closed()
118    }
119
120    fn bytes_sent(&self) -> u64 {
121        self.bytes_sent.load(std::sync::atomic::Ordering::Relaxed)
122    }
123
124    fn bytes_received(&self) -> u64 {
125        self.bytes_received
126            .load(std::sync::atomic::Ordering::Relaxed)
127    }
128}
129
130/// Channel wrapper that adds latency (for simulation)
131pub struct LatencyChannel<C: PeerChannel> {
132    inner: C,
133    latency: Duration,
134}
135
136impl<C: PeerChannel> LatencyChannel<C> {
137    pub fn new(inner: C, latency: Duration) -> Self {
138        Self { inner, latency }
139    }
140}
141
142#[async_trait]
143impl<C: PeerChannel> PeerChannel for LatencyChannel<C> {
144    fn peer_id(&self) -> &str {
145        self.inner.peer_id()
146    }
147
148    async fn send(&self, data: Vec<u8>) -> Result<(), ChannelError> {
149        // Simulate network latency on send (one-way delay)
150        tokio::time::sleep(self.latency).await;
151        self.inner.send(data).await
152    }
153
154    async fn recv(&self, timeout: Duration) -> Result<Vec<u8>, ChannelError> {
155        self.inner.recv(timeout).await
156    }
157
158    fn is_connected(&self) -> bool {
159        self.inner.is_connected()
160    }
161
162    fn bytes_sent(&self) -> u64 {
163        self.inner.bytes_sent()
164    }
165
166    fn bytes_received(&self) -> u64 {
167        self.inner.bytes_received()
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[tokio::test]
176    async fn test_mock_channel_roundtrip() {
177        let (chan_a, chan_b) = MockChannel::pair("1", "2");
178
179        // A sends to B
180        chan_a.send(b"hello".to_vec()).await.unwrap();
181        let received = chan_b.recv(Duration::from_secs(1)).await.unwrap();
182        assert_eq!(received, b"hello");
183
184        // B sends to A
185        chan_b.send(b"world".to_vec()).await.unwrap();
186        let received = chan_a.recv(Duration::from_secs(1)).await.unwrap();
187        assert_eq!(received, b"world");
188
189        // Check byte counts
190        assert_eq!(chan_a.bytes_sent(), 5);
191        assert_eq!(chan_a.bytes_received(), 5);
192    }
193
194    #[tokio::test]
195    async fn test_mock_channel_timeout() {
196        let (chan_a, _chan_b) = MockChannel::pair("1", "2");
197
198        // Should timeout since nothing sent
199        let result = chan_a.recv(Duration::from_millis(10)).await;
200        assert!(matches!(result, Err(ChannelError::Timeout)));
201    }
202}