Skip to main content

atomr_remote/transport/
test_transport.rs

1//! In-memory deterministic transport for tests.
2//! akka.net: `Remote/Transport/TestTransport.cs`.
3
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use dashmap::DashMap;
8use parking_lot::Mutex;
9use tokio::sync::mpsc;
10
11use atomr_core::actor::Address;
12
13use crate::pdu::AkkaPdu;
14
15use super::{InboundFrame, Transport, TransportError};
16
17/// A `TestTransport` lets multiple `Address` participants exchange
18/// `AkkaPdu` frames without going through the network.
19#[derive(Clone)]
20pub struct TestTransport {
21    pub local_address: Address,
22    #[allow(dead_code)]
23    inbound_tx: mpsc::UnboundedSender<InboundFrame>,
24    inbound_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<InboundFrame>>>>,
25    pub registry: Arc<TestRegistry>,
26}
27
28#[derive(Default)]
29pub struct TestRegistry {
30    /// Address → outbound channel that delivers to that peer's inbound.
31    peers: DashMap<String, mpsc::UnboundedSender<InboundFrame>>,
32}
33
34impl TestRegistry {
35    pub fn new() -> Arc<Self> {
36        Arc::new(Self::default())
37    }
38
39    pub fn register(&self, address: &Address, sink: mpsc::UnboundedSender<InboundFrame>) {
40        self.peers.insert(address.to_string(), sink);
41    }
42
43    pub fn unregister(&self, address: &Address) {
44        self.peers.remove(&address.to_string());
45    }
46}
47
48impl TestTransport {
49    pub fn new(address: Address, registry: Arc<TestRegistry>) -> Self {
50        let (tx, rx) = mpsc::unbounded_channel();
51        registry.register(&address, tx.clone());
52        Self { local_address: address, inbound_tx: tx, inbound_rx: Arc::new(Mutex::new(Some(rx))), registry }
53    }
54}
55
56#[async_trait]
57impl Transport for TestTransport {
58    async fn listen(&self) -> Result<Address, TransportError> {
59        Ok(self.local_address.clone())
60    }
61
62    async fn associate(&self, target: &Address) -> Result<(), TransportError> {
63        if self.registry.peers.contains_key(&target.to_string()) {
64            Ok(())
65        } else {
66            Err(TransportError::NotAssociated(target.to_string()))
67        }
68    }
69
70    async fn send(&self, target: &Address, pdu: AkkaPdu) -> Result<(), TransportError> {
71        let sink = self
72            .registry
73            .peers
74            .get(&target.to_string())
75            .ok_or_else(|| TransportError::NotAssociated(target.to_string()))?
76            .clone();
77        sink.send(InboundFrame { from: self.local_address.clone(), pdu }).map_err(|_| TransportError::Closed)
78    }
79
80    fn inbound(&self) -> mpsc::UnboundedReceiver<InboundFrame> {
81        self.inbound_rx.lock().take().unwrap_or_else(|| {
82            let (_t, r) = mpsc::unbounded_channel();
83            r
84        })
85    }
86
87    async fn disassociate(&self, _target: &Address) -> Result<(), TransportError> {
88        Ok(())
89    }
90
91    async fn shutdown(&self) -> Result<(), TransportError> {
92        self.registry.unregister(&self.local_address);
93        Ok(())
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use crate::pdu::{AkkaPdu, AssociateInfo, PROTOCOL_VERSION};
101    use std::time::Duration;
102
103    #[tokio::test]
104    async fn loopback_send() {
105        let reg = TestRegistry::new();
106        let a = TestTransport::new(Address::remote("test", "A", "h", 1), reg.clone());
107        let b = TestTransport::new(Address::remote("test", "B", "h", 2), reg.clone());
108        let mut inbound_a = a.inbound();
109        let _addr_a = a.listen().await.unwrap();
110        let _addr_b = b.listen().await.unwrap();
111        b.associate(&a.local_address).await.unwrap();
112        let pdu = AkkaPdu::Associate(AssociateInfo {
113            origin: b.local_address.clone(),
114            uid: 1,
115            cookie: None,
116            protocol_version: PROTOCOL_VERSION,
117        });
118        b.send(&a.local_address, pdu).await.unwrap();
119        let frame =
120            tokio::time::timeout(Duration::from_millis(100), inbound_a.recv()).await.unwrap().unwrap();
121        assert_eq!(frame.from, b.local_address);
122    }
123}