atomr_remote/transport/
test_transport.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6use dashmap::DashMap;
7use parking_lot::Mutex;
8use tokio::sync::mpsc;
9
10use atomr_core::actor::Address;
11
12use crate::pdu::AkkaPdu;
13
14use super::{InboundFrame, Transport, TransportError};
15
16#[derive(Clone)]
19pub struct TestTransport {
20 pub local_address: Address,
21 #[allow(dead_code)]
22 inbound_tx: mpsc::UnboundedSender<InboundFrame>,
23 inbound_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<InboundFrame>>>>,
24 pub registry: Arc<TestRegistry>,
25}
26
27#[derive(Default)]
28pub struct TestRegistry {
29 peers: DashMap<String, mpsc::UnboundedSender<InboundFrame>>,
31}
32
33impl TestRegistry {
34 pub fn new() -> Arc<Self> {
35 Arc::new(Self::default())
36 }
37
38 pub fn register(&self, address: &Address, sink: mpsc::UnboundedSender<InboundFrame>) {
39 self.peers.insert(address.to_string(), sink);
40 }
41
42 pub fn unregister(&self, address: &Address) {
43 self.peers.remove(&address.to_string());
44 }
45}
46
47impl TestTransport {
48 pub fn new(address: Address, registry: Arc<TestRegistry>) -> Self {
49 let (tx, rx) = mpsc::unbounded_channel();
50 registry.register(&address, tx.clone());
51 Self { local_address: address, inbound_tx: tx, inbound_rx: Arc::new(Mutex::new(Some(rx))), registry }
52 }
53}
54
55#[async_trait]
56impl Transport for TestTransport {
57 async fn listen(&self) -> Result<Address, TransportError> {
58 Ok(self.local_address.clone())
59 }
60
61 async fn associate(&self, target: &Address) -> Result<(), TransportError> {
62 if self.registry.peers.contains_key(&target.to_string()) {
63 Ok(())
64 } else {
65 Err(TransportError::NotAssociated(target.to_string()))
66 }
67 }
68
69 async fn send(&self, target: &Address, pdu: AkkaPdu) -> Result<(), TransportError> {
70 let sink = self
71 .registry
72 .peers
73 .get(&target.to_string())
74 .ok_or_else(|| TransportError::NotAssociated(target.to_string()))?
75 .clone();
76 sink.send(InboundFrame { from: self.local_address.clone(), pdu }).map_err(|_| TransportError::Closed)
77 }
78
79 fn inbound(&self) -> mpsc::UnboundedReceiver<InboundFrame> {
80 self.inbound_rx.lock().take().unwrap_or_else(|| {
81 let (_t, r) = mpsc::unbounded_channel();
82 r
83 })
84 }
85
86 async fn disassociate(&self, _target: &Address) -> Result<(), TransportError> {
87 Ok(())
88 }
89
90 async fn shutdown(&self) -> Result<(), TransportError> {
91 self.registry.unregister(&self.local_address);
92 Ok(())
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use crate::pdu::{AkkaPdu, AssociateInfo, PROTOCOL_VERSION};
100 use std::time::Duration;
101
102 #[tokio::test]
103 async fn loopback_send() {
104 let reg = TestRegistry::new();
105 let a = TestTransport::new(Address::remote("test", "A", "h", 1), reg.clone());
106 let b = TestTransport::new(Address::remote("test", "B", "h", 2), reg.clone());
107 let mut inbound_a = a.inbound();
108 let _addr_a = a.listen().await.unwrap();
109 let _addr_b = b.listen().await.unwrap();
110 b.associate(&a.local_address).await.unwrap();
111 let pdu = AkkaPdu::Associate(AssociateInfo {
112 origin: b.local_address.clone(),
113 uid: 1,
114 cookie: None,
115 protocol_version: PROTOCOL_VERSION,
116 });
117 b.send(&a.local_address, pdu).await.unwrap();
118 let frame =
119 tokio::time::timeout(Duration::from_millis(100), inbound_a.recv()).await.unwrap().unwrap();
120 assert_eq!(frame.from, b.local_address);
121 }
122}