alopex_chirps_mock/
lib.rs

1use alopex_chirps_core::backend::MessageBackend;
2use alopex_chirps_core::error::TransportError;
3use alopex_chirps_wire::frame::Frame;
4use alopex_chirps_wire::node_id::NodeId;
5use async_trait::async_trait;
6use std::collections::HashMap;
7use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
8use std::sync::{
9    Arc,
10    atomic::{AtomicBool, Ordering},
11};
12use tokio::sync::{Mutex, RwLock, mpsc};
13
14type SharedPeers = Arc<RwLock<HashMap<NodeId, (SocketAddr, mpsc::Sender<(NodeId, Frame)>)>>>;
15
16/// 単一プロセス内でのメモリ内トランスポートを提供するネットワーク。
17pub struct MockNetwork {
18    peers: SharedPeers,
19}
20
21impl MockNetwork {
22    /// 空のネットワークを生成する。
23    pub fn new() -> Self {
24        Self {
25            peers: Arc::new(RwLock::new(HashMap::new())),
26        }
27    }
28
29    /// ノードをネットワークに登録し、対応するトランスポートを返す。
30    pub async fn add_node(&self, node_id: NodeId, addr: SocketAddr) -> MockBackend {
31        let (tx, rx) = mpsc::channel(1024);
32        {
33            let mut guard = self.peers.write().await;
34            guard.insert(node_id, (addr, tx));
35        }
36        MockBackend {
37            node_id,
38            peers: Arc::clone(&self.peers),
39            incoming_rx: Mutex::new(Some(rx)),
40            closed: AtomicBool::new(false),
41        }
42    }
43}
44
45impl Default for MockNetwork {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51/// インメモリで `MessageBackend` を提供する実装。
52pub struct MockBackend {
53    node_id: NodeId,
54    peers: SharedPeers,
55    incoming_rx: Mutex<Option<mpsc::Receiver<(NodeId, Frame)>>>,
56    closed: AtomicBool,
57}
58
59#[async_trait]
60impl MessageBackend for MockBackend {
61    async fn send(&self, target: NodeId, frame: Frame) -> Result<(), TransportError> {
62        if self.closed.load(Ordering::SeqCst) {
63            return Err(TransportError::Connection(
64                "モックバックエンドはクローズ済みです".into(),
65            ));
66        }
67        let sender = {
68            let guard = self.peers.read().await;
69            guard
70                .get(&target)
71                .cloned()
72                .ok_or_else(|| TransportError::Connection("ターゲットに接続されていません".into()))?
73                .1
74        };
75        sender
76            .send((self.node_id, frame))
77            .await
78            .map_err(|_| TransportError::Send("送信先が閉じています".into()))
79    }
80
81    async fn broadcast(&self, frame: Frame) -> Result<usize, TransportError> {
82        if self.closed.load(Ordering::SeqCst) {
83            return Err(TransportError::Connection(
84                "モックバックエンドはクローズ済みです".into(),
85            ));
86        }
87        let targets: Vec<mpsc::Sender<(NodeId, Frame)>> = {
88            let guard = self.peers.read().await;
89            guard
90                .iter()
91                .filter(|(id, _)| **id != self.node_id)
92                .map(|(_, (_, tx))| tx.clone())
93                .collect()
94        };
95
96        let mut sent = 0;
97        for tx in targets {
98            if tx.send((self.node_id, frame.clone())).await.is_ok() {
99                sent += 1;
100            }
101        }
102        Ok(sent)
103    }
104
105    async fn subscribe(&self) -> Result<mpsc::Receiver<(NodeId, Frame)>, TransportError> {
106        let mut guard = self.incoming_rx.lock().await;
107        guard
108            .take()
109            .ok_or_else(|| TransportError::Subscribe("すでに購読済みです".into()))
110    }
111
112    async fn close(&self) -> Result<(), TransportError> {
113        self.closed.store(true, Ordering::SeqCst);
114        let mut guard = self.peers.write().await;
115        guard.remove(&self.node_id);
116        Ok(())
117    }
118
119    fn connected_peers(&self) -> Vec<(NodeId, SocketAddr)> {
120        if let Ok(guard) = self.peers.try_read() {
121            guard
122                .iter()
123                .filter(|(id, _)| **id != self.node_id)
124                .map(|(id, (addr, _))| (*id, *addr))
125                .collect()
126        } else {
127            Vec::new()
128        }
129    }
130}
131
132impl MockBackend {
133    /// 任意の `SocketAddr` を簡単に得るためのヘルパ。
134    pub fn ephemeral_addr() -> SocketAddr {
135        SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0))
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use alopex_chirps_wire::frame::UserMessage;
143
144    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
145    async fn send_delivers_to_target() -> anyhow::Result<()> {
146        let network = MockNetwork::new();
147        let a = NodeId::new();
148        let b = NodeId::new();
149        let backend_a = network.add_node(a, MockBackend::ephemeral_addr()).await;
150        let backend_b = network.add_node(b, MockBackend::ephemeral_addr()).await;
151
152        let mut rx_b = backend_b.subscribe().await?;
153        backend_a
154            .send(
155                b,
156                Frame::User(UserMessage {
157                    payload: b"hello".to_vec(),
158                }),
159            )
160            .await?;
161
162        let (from, frame) = rx_b.recv().await.expect("メッセージを受信できる");
163        assert_eq!(from, a);
164        match frame {
165            Frame::User(msg) => assert_eq!(msg.payload, b"hello"),
166            other => panic!("Userフレームを期待しましたが {:?} を受信", other),
167        }
168        Ok(())
169    }
170
171    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
172    async fn broadcast_reaches_all_peers_except_self() -> anyhow::Result<()> {
173        let network = MockNetwork::new();
174        let a = NodeId::new();
175        let b = NodeId::new();
176        let c = NodeId::new();
177        let backend_a = network.add_node(a, MockBackend::ephemeral_addr()).await;
178        let backend_b = network.add_node(b, MockBackend::ephemeral_addr()).await;
179        let backend_c = network.add_node(c, MockBackend::ephemeral_addr()).await;
180
181        let mut rx_b = backend_b.subscribe().await?;
182        let mut rx_c = backend_c.subscribe().await?;
183
184        let sent = backend_a.broadcast(Frame::Ping { seq: 1, from: a }).await?;
185        assert_eq!(sent, 2);
186
187        let (from_b, frame_b) = rx_b.recv().await.expect("Bが受信できる");
188        let (from_c, frame_c) = rx_c.recv().await.expect("Cが受信できる");
189        assert_eq!(from_b, a);
190        assert_eq!(from_c, a);
191        matches!(frame_b, Frame::Ping { .. });
192        matches!(frame_c, Frame::Ping { .. });
193        Ok(())
194    }
195
196    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
197    async fn close_removes_peer_and_blocks_send() -> anyhow::Result<()> {
198        let network = MockNetwork::new();
199        let a = NodeId::new();
200        let b = NodeId::new();
201        let backend_a = network.add_node(a, MockBackend::ephemeral_addr()).await;
202        let backend_b = network.add_node(b, MockBackend::ephemeral_addr()).await;
203
204        let mut rx_b = backend_b.subscribe().await?;
205        backend_b.close().await?;
206
207        let result = backend_a
208            .send(
209                b,
210                Frame::User(UserMessage {
211                    payload: b"ping".to_vec(),
212                }),
213            )
214            .await;
215        assert!(result.is_err());
216        assert!(rx_b.recv().await.is_none());
217        Ok(())
218    }
219}