alopex_chirps_mock/
lib.rs1use alopex_chirps_core::backend::MessageBackend;
2use alopex_chirps_core::error::TransportError;
3use alopex_chirps_wire::frame::Frame;
4use alopex_chirps_wire::node_id::NodeId;
5use async_trait::async_trait;
6use std::collections::HashMap;
7use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
8use std::sync::{
9 Arc,
10 atomic::{AtomicBool, Ordering},
11};
12use tokio::sync::{Mutex, RwLock, mpsc};
13
14type SharedPeers = Arc<RwLock<HashMap<NodeId, (SocketAddr, mpsc::Sender<(NodeId, Frame)>)>>>;
15
16pub struct MockNetwork {
18 peers: SharedPeers,
19}
20
21impl MockNetwork {
22 pub fn new() -> Self {
24 Self {
25 peers: Arc::new(RwLock::new(HashMap::new())),
26 }
27 }
28
29 pub async fn add_node(&self, node_id: NodeId, addr: SocketAddr) -> MockBackend {
31 let (tx, rx) = mpsc::channel(1024);
32 {
33 let mut guard = self.peers.write().await;
34 guard.insert(node_id, (addr, tx));
35 }
36 MockBackend {
37 node_id,
38 peers: Arc::clone(&self.peers),
39 incoming_rx: Mutex::new(Some(rx)),
40 closed: AtomicBool::new(false),
41 }
42 }
43}
44
45impl Default for MockNetwork {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51pub struct MockBackend {
53 node_id: NodeId,
54 peers: SharedPeers,
55 incoming_rx: Mutex<Option<mpsc::Receiver<(NodeId, Frame)>>>,
56 closed: AtomicBool,
57}
58
59#[async_trait]
60impl MessageBackend for MockBackend {
61 async fn send(&self, target: NodeId, frame: Frame) -> Result<(), TransportError> {
62 if self.closed.load(Ordering::SeqCst) {
63 return Err(TransportError::Connection(
64 "モックバックエンドはクローズ済みです".into(),
65 ));
66 }
67 let sender = {
68 let guard = self.peers.read().await;
69 guard
70 .get(&target)
71 .cloned()
72 .ok_or_else(|| TransportError::Connection("ターゲットに接続されていません".into()))?
73 .1
74 };
75 sender
76 .send((self.node_id, frame))
77 .await
78 .map_err(|_| TransportError::Send("送信先が閉じています".into()))
79 }
80
81 async fn broadcast(&self, frame: Frame) -> Result<usize, TransportError> {
82 if self.closed.load(Ordering::SeqCst) {
83 return Err(TransportError::Connection(
84 "モックバックエンドはクローズ済みです".into(),
85 ));
86 }
87 let targets: Vec<mpsc::Sender<(NodeId, Frame)>> = {
88 let guard = self.peers.read().await;
89 guard
90 .iter()
91 .filter(|(id, _)| **id != self.node_id)
92 .map(|(_, (_, tx))| tx.clone())
93 .collect()
94 };
95
96 let mut sent = 0;
97 for tx in targets {
98 if tx.send((self.node_id, frame.clone())).await.is_ok() {
99 sent += 1;
100 }
101 }
102 Ok(sent)
103 }
104
105 async fn subscribe(&self) -> Result<mpsc::Receiver<(NodeId, Frame)>, TransportError> {
106 let mut guard = self.incoming_rx.lock().await;
107 guard
108 .take()
109 .ok_or_else(|| TransportError::Subscribe("すでに購読済みです".into()))
110 }
111
112 async fn close(&self) -> Result<(), TransportError> {
113 self.closed.store(true, Ordering::SeqCst);
114 let mut guard = self.peers.write().await;
115 guard.remove(&self.node_id);
116 Ok(())
117 }
118
119 fn connected_peers(&self) -> Vec<(NodeId, SocketAddr)> {
120 if let Ok(guard) = self.peers.try_read() {
121 guard
122 .iter()
123 .filter(|(id, _)| **id != self.node_id)
124 .map(|(id, (addr, _))| (*id, *addr))
125 .collect()
126 } else {
127 Vec::new()
128 }
129 }
130}
131
132impl MockBackend {
133 pub fn ephemeral_addr() -> SocketAddr {
135 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0))
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use alopex_chirps_wire::frame::UserMessage;
143
144 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
145 async fn send_delivers_to_target() -> anyhow::Result<()> {
146 let network = MockNetwork::new();
147 let a = NodeId::new();
148 let b = NodeId::new();
149 let backend_a = network.add_node(a, MockBackend::ephemeral_addr()).await;
150 let backend_b = network.add_node(b, MockBackend::ephemeral_addr()).await;
151
152 let mut rx_b = backend_b.subscribe().await?;
153 backend_a
154 .send(
155 b,
156 Frame::User(UserMessage {
157 payload: b"hello".to_vec(),
158 }),
159 )
160 .await?;
161
162 let (from, frame) = rx_b.recv().await.expect("メッセージを受信できる");
163 assert_eq!(from, a);
164 match frame {
165 Frame::User(msg) => assert_eq!(msg.payload, b"hello"),
166 other => panic!("Userフレームを期待しましたが {:?} を受信", other),
167 }
168 Ok(())
169 }
170
171 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
172 async fn broadcast_reaches_all_peers_except_self() -> anyhow::Result<()> {
173 let network = MockNetwork::new();
174 let a = NodeId::new();
175 let b = NodeId::new();
176 let c = NodeId::new();
177 let backend_a = network.add_node(a, MockBackend::ephemeral_addr()).await;
178 let backend_b = network.add_node(b, MockBackend::ephemeral_addr()).await;
179 let backend_c = network.add_node(c, MockBackend::ephemeral_addr()).await;
180
181 let mut rx_b = backend_b.subscribe().await?;
182 let mut rx_c = backend_c.subscribe().await?;
183
184 let sent = backend_a.broadcast(Frame::Ping { seq: 1, from: a }).await?;
185 assert_eq!(sent, 2);
186
187 let (from_b, frame_b) = rx_b.recv().await.expect("Bが受信できる");
188 let (from_c, frame_c) = rx_c.recv().await.expect("Cが受信できる");
189 assert_eq!(from_b, a);
190 assert_eq!(from_c, a);
191 matches!(frame_b, Frame::Ping { .. });
192 matches!(frame_c, Frame::Ping { .. });
193 Ok(())
194 }
195
196 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
197 async fn close_removes_peer_and_blocks_send() -> anyhow::Result<()> {
198 let network = MockNetwork::new();
199 let a = NodeId::new();
200 let b = NodeId::new();
201 let backend_a = network.add_node(a, MockBackend::ephemeral_addr()).await;
202 let backend_b = network.add_node(b, MockBackend::ephemeral_addr()).await;
203
204 let mut rx_b = backend_b.subscribe().await?;
205 backend_b.close().await?;
206
207 let result = backend_a
208 .send(
209 b,
210 Frame::User(UserMessage {
211 payload: b"ping".to_vec(),
212 }),
213 )
214 .await;
215 assert!(result.is_err());
216 assert!(rx_b.recv().await.is_none());
217 Ok(())
218 }
219}