Skip to main content

atomr_remote/transport/
test_transport.rs

1//! In-memory deterministic transport for tests.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use dashmap::DashMap;
7use parking_lot::Mutex;
8use tokio::sync::mpsc;
9
10use atomr_core::actor::Address;
11
12use crate::pdu::AkkaPdu;
13
14use super::{InboundFrame, Transport, TransportError};
15
16/// A `TestTransport` lets multiple `Address` participants exchange
17/// `AkkaPdu` frames without going through the network.
18#[derive(Clone)]
19pub struct TestTransport {
20    pub local_address: Address,
21    #[allow(dead_code)]
22    inbound_tx: mpsc::UnboundedSender<InboundFrame>,
23    inbound_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<InboundFrame>>>>,
24    pub registry: Arc<TestRegistry>,
25}
26
27#[derive(Default)]
28pub struct TestRegistry {
29    /// Address → outbound channel that delivers to that peer's inbound.
30    peers: DashMap<String, mpsc::UnboundedSender<InboundFrame>>,
31}
32
33impl TestRegistry {
34    pub fn new() -> Arc<Self> {
35        Arc::new(Self::default())
36    }
37
38    pub fn register(&self, address: &Address, sink: mpsc::UnboundedSender<InboundFrame>) {
39        self.peers.insert(address.to_string(), sink);
40    }
41
42    pub fn unregister(&self, address: &Address) {
43        self.peers.remove(&address.to_string());
44    }
45}
46
47impl TestTransport {
48    pub fn new(address: Address, registry: Arc<TestRegistry>) -> Self {
49        let (tx, rx) = mpsc::unbounded_channel();
50        registry.register(&address, tx.clone());
51        Self { local_address: address, inbound_tx: tx, inbound_rx: Arc::new(Mutex::new(Some(rx))), registry }
52    }
53}
54
55#[async_trait]
56impl Transport for TestTransport {
57    async fn listen(&self) -> Result<Address, TransportError> {
58        Ok(self.local_address.clone())
59    }
60
61    async fn associate(&self, target: &Address) -> Result<(), TransportError> {
62        if self.registry.peers.contains_key(&target.to_string()) {
63            Ok(())
64        } else {
65            Err(TransportError::NotAssociated(target.to_string()))
66        }
67    }
68
69    async fn send(&self, target: &Address, pdu: AkkaPdu) -> Result<(), TransportError> {
70        let sink = self
71            .registry
72            .peers
73            .get(&target.to_string())
74            .ok_or_else(|| TransportError::NotAssociated(target.to_string()))?
75            .clone();
76        sink.send(InboundFrame { from: self.local_address.clone(), pdu }).map_err(|_| TransportError::Closed)
77    }
78
79    fn inbound(&self) -> mpsc::UnboundedReceiver<InboundFrame> {
80        self.inbound_rx.lock().take().unwrap_or_else(|| {
81            let (_t, r) = mpsc::unbounded_channel();
82            r
83        })
84    }
85
86    async fn disassociate(&self, _target: &Address) -> Result<(), TransportError> {
87        Ok(())
88    }
89
90    async fn shutdown(&self) -> Result<(), TransportError> {
91        self.registry.unregister(&self.local_address);
92        Ok(())
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::pdu::{AkkaPdu, AssociateInfo, PROTOCOL_VERSION};
100    use std::time::Duration;
101
102    #[tokio::test]
103    async fn loopback_send() {
104        let reg = TestRegistry::new();
105        let a = TestTransport::new(Address::remote("test", "A", "h", 1), reg.clone());
106        let b = TestTransport::new(Address::remote("test", "B", "h", 2), reg.clone());
107        let mut inbound_a = a.inbound();
108        let _addr_a = a.listen().await.unwrap();
109        let _addr_b = b.listen().await.unwrap();
110        b.associate(&a.local_address).await.unwrap();
111        let pdu = AkkaPdu::Associate(AssociateInfo {
112            origin: b.local_address.clone(),
113            uid: 1,
114            cookie: None,
115            protocol_version: PROTOCOL_VERSION,
116        });
117        b.send(&a.local_address, pdu).await.unwrap();
118        let frame =
119            tokio::time::timeout(Duration::from_millis(100), inbound_a.recv()).await.unwrap().unwrap();
120        assert_eq!(frame.from, b.local_address);
121    }
122}