hashtree_network/
channel.rs1use async_trait::async_trait;
9use std::time::Duration;
10use thiserror::Error;
11
12#[derive(Debug, Clone, Error)]
14pub enum ChannelError {
15 #[error("channel disconnected")]
16 Disconnected,
17 #[error("operation timed out")]
18 Timeout,
19 #[error("send failed: {0}")]
20 SendFailed(String),
21}
22
23#[async_trait]
29pub trait PeerChannel: Send + Sync {
30 fn peer_id(&self) -> &str;
32
33 async fn send(&self, data: Vec<u8>) -> Result<(), ChannelError>;
35
36 async fn recv(&self, timeout: Duration) -> Result<Vec<u8>, ChannelError>;
38
39 fn is_connected(&self) -> bool;
41
42 fn bytes_sent(&self) -> u64;
44
45 fn bytes_received(&self) -> u64;
47}
48
49pub struct MockChannel {
51 peer_id: String,
52 tx: tokio::sync::mpsc::Sender<Vec<u8>>,
53 rx: tokio::sync::Mutex<tokio::sync::mpsc::Receiver<Vec<u8>>>,
54 bytes_sent: std::sync::atomic::AtomicU64,
55 bytes_received: std::sync::atomic::AtomicU64,
56}
57
58impl MockChannel {
59 pub fn pair(id_a: impl Into<String>, id_b: impl Into<String>) -> (Self, Self) {
61 let id_a = id_a.into();
62 let id_b = id_b.into();
63 let (tx_a, rx_a) = tokio::sync::mpsc::channel(100);
64 let (tx_b, rx_b) = tokio::sync::mpsc::channel(100);
65
66 let chan_a = MockChannel {
67 peer_id: id_b,
68 tx: tx_b,
69 rx: tokio::sync::Mutex::new(rx_a),
70 bytes_sent: std::sync::atomic::AtomicU64::new(0),
71 bytes_received: std::sync::atomic::AtomicU64::new(0),
72 };
73
74 let chan_b = MockChannel {
75 peer_id: id_a,
76 tx: tx_a,
77 rx: tokio::sync::Mutex::new(rx_b),
78 bytes_sent: std::sync::atomic::AtomicU64::new(0),
79 bytes_received: std::sync::atomic::AtomicU64::new(0),
80 };
81
82 (chan_a, chan_b)
83 }
84}
85
86#[async_trait]
87impl PeerChannel for MockChannel {
88 fn peer_id(&self) -> &str {
89 &self.peer_id
90 }
91
92 async fn send(&self, data: Vec<u8>) -> Result<(), ChannelError> {
93 let len = data.len() as u64;
94 self.tx
95 .send(data)
96 .await
97 .map_err(|_| ChannelError::Disconnected)?;
98 self.bytes_sent
99 .fetch_add(len, std::sync::atomic::Ordering::Relaxed);
100 Ok(())
101 }
102
103 async fn recv(&self, timeout: Duration) -> Result<Vec<u8>, ChannelError> {
104 let mut rx = self.rx.lock().await;
105 match tokio::time::timeout(timeout, rx.recv()).await {
106 Ok(Some(data)) => {
107 self.bytes_received
108 .fetch_add(data.len() as u64, std::sync::atomic::Ordering::Relaxed);
109 Ok(data)
110 }
111 Ok(None) => Err(ChannelError::Disconnected),
112 Err(_) => Err(ChannelError::Timeout),
113 }
114 }
115
116 fn is_connected(&self) -> bool {
117 !self.tx.is_closed()
118 }
119
120 fn bytes_sent(&self) -> u64 {
121 self.bytes_sent.load(std::sync::atomic::Ordering::Relaxed)
122 }
123
124 fn bytes_received(&self) -> u64 {
125 self.bytes_received
126 .load(std::sync::atomic::Ordering::Relaxed)
127 }
128}
129
130pub struct LatencyChannel<C: PeerChannel> {
132 inner: C,
133 latency: Duration,
134}
135
136impl<C: PeerChannel> LatencyChannel<C> {
137 pub fn new(inner: C, latency: Duration) -> Self {
138 Self { inner, latency }
139 }
140}
141
142#[async_trait]
143impl<C: PeerChannel> PeerChannel for LatencyChannel<C> {
144 fn peer_id(&self) -> &str {
145 self.inner.peer_id()
146 }
147
148 async fn send(&self, data: Vec<u8>) -> Result<(), ChannelError> {
149 tokio::time::sleep(self.latency).await;
151 self.inner.send(data).await
152 }
153
154 async fn recv(&self, timeout: Duration) -> Result<Vec<u8>, ChannelError> {
155 self.inner.recv(timeout).await
156 }
157
158 fn is_connected(&self) -> bool {
159 self.inner.is_connected()
160 }
161
162 fn bytes_sent(&self) -> u64 {
163 self.inner.bytes_sent()
164 }
165
166 fn bytes_received(&self) -> u64 {
167 self.inner.bytes_received()
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[tokio::test]
176 async fn test_mock_channel_roundtrip() {
177 let (chan_a, chan_b) = MockChannel::pair("1", "2");
178
179 chan_a.send(b"hello".to_vec()).await.unwrap();
181 let received = chan_b.recv(Duration::from_secs(1)).await.unwrap();
182 assert_eq!(received, b"hello");
183
184 chan_b.send(b"world".to_vec()).await.unwrap();
186 let received = chan_a.recv(Duration::from_secs(1)).await.unwrap();
187 assert_eq!(received, b"world");
188
189 assert_eq!(chan_a.bytes_sent(), 5);
191 assert_eq!(chan_a.bytes_received(), 5);
192 }
193
194 #[tokio::test]
195 async fn test_mock_channel_timeout() {
196 let (chan_a, _chan_b) = MockChannel::pair("1", "2");
197
198 let result = chan_a.recv(Duration::from_millis(10)).await;
200 assert!(matches!(result, Err(ChannelError::Timeout)));
201 }
202}