atomr_remote/transport/
test_transport.rs1use std::sync::Arc;
5
6use async_trait::async_trait;
7use dashmap::DashMap;
8use parking_lot::Mutex;
9use tokio::sync::mpsc;
10
11use atomr_core::actor::Address;
12
13use crate::pdu::AkkaPdu;
14
15use super::{InboundFrame, Transport, TransportError};
16
17#[derive(Clone)]
20pub struct TestTransport {
21 pub local_address: Address,
22 #[allow(dead_code)]
23 inbound_tx: mpsc::UnboundedSender<InboundFrame>,
24 inbound_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<InboundFrame>>>>,
25 pub registry: Arc<TestRegistry>,
26}
27
28#[derive(Default)]
29pub struct TestRegistry {
30 peers: DashMap<String, mpsc::UnboundedSender<InboundFrame>>,
32}
33
34impl TestRegistry {
35 pub fn new() -> Arc<Self> {
36 Arc::new(Self::default())
37 }
38
39 pub fn register(&self, address: &Address, sink: mpsc::UnboundedSender<InboundFrame>) {
40 self.peers.insert(address.to_string(), sink);
41 }
42
43 pub fn unregister(&self, address: &Address) {
44 self.peers.remove(&address.to_string());
45 }
46}
47
48impl TestTransport {
49 pub fn new(address: Address, registry: Arc<TestRegistry>) -> Self {
50 let (tx, rx) = mpsc::unbounded_channel();
51 registry.register(&address, tx.clone());
52 Self { local_address: address, inbound_tx: tx, inbound_rx: Arc::new(Mutex::new(Some(rx))), registry }
53 }
54}
55
56#[async_trait]
57impl Transport for TestTransport {
58 async fn listen(&self) -> Result<Address, TransportError> {
59 Ok(self.local_address.clone())
60 }
61
62 async fn associate(&self, target: &Address) -> Result<(), TransportError> {
63 if self.registry.peers.contains_key(&target.to_string()) {
64 Ok(())
65 } else {
66 Err(TransportError::NotAssociated(target.to_string()))
67 }
68 }
69
70 async fn send(&self, target: &Address, pdu: AkkaPdu) -> Result<(), TransportError> {
71 let sink = self
72 .registry
73 .peers
74 .get(&target.to_string())
75 .ok_or_else(|| TransportError::NotAssociated(target.to_string()))?
76 .clone();
77 sink.send(InboundFrame { from: self.local_address.clone(), pdu }).map_err(|_| TransportError::Closed)
78 }
79
80 fn inbound(&self) -> mpsc::UnboundedReceiver<InboundFrame> {
81 self.inbound_rx.lock().take().unwrap_or_else(|| {
82 let (_t, r) = mpsc::unbounded_channel();
83 r
84 })
85 }
86
87 async fn disassociate(&self, _target: &Address) -> Result<(), TransportError> {
88 Ok(())
89 }
90
91 async fn shutdown(&self) -> Result<(), TransportError> {
92 self.registry.unregister(&self.local_address);
93 Ok(())
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use crate::pdu::{AkkaPdu, AssociateInfo, PROTOCOL_VERSION};
101 use std::time::Duration;
102
103 #[tokio::test]
104 async fn loopback_send() {
105 let reg = TestRegistry::new();
106 let a = TestTransport::new(Address::remote("test", "A", "h", 1), reg.clone());
107 let b = TestTransport::new(Address::remote("test", "B", "h", 2), reg.clone());
108 let mut inbound_a = a.inbound();
109 let _addr_a = a.listen().await.unwrap();
110 let _addr_b = b.listen().await.unwrap();
111 b.associate(&a.local_address).await.unwrap();
112 let pdu = AkkaPdu::Associate(AssociateInfo {
113 origin: b.local_address.clone(),
114 uid: 1,
115 cookie: None,
116 protocol_version: PROTOCOL_VERSION,
117 });
118 b.send(&a.local_address, pdu).await.unwrap();
119 let frame =
120 tokio::time::timeout(Duration::from_millis(100), inbound_a.recv()).await.unwrap().unwrap();
121 assert_eq!(frame.from, b.local_address);
122 }
123}