1use std::sync::Arc;
9
10use async_trait::async_trait;
11use dashmap::DashMap;
12use parking_lot::Mutex;
13use tokio::sync::mpsc;
14
15use atomr_core::actor::Address;
16
17use crate::address_uid::AddressUid;
18use crate::pdu::{AkkaPdu, AssociateInfo, DisassociateReason, PROTOCOL_VERSION};
19use crate::settings::RemoteSettings;
20
21use super::{InboundFrame, Transport, TransportError};
22
23#[derive(Debug, Clone)]
25pub struct PeerAssociation {
26 pub address: Address,
27 pub uid: u64,
28}
29
30pub struct AkkaProtocolTransport {
32 inner: Arc<dyn Transport>,
33 settings: RemoteSettings,
34 local_uid: AddressUid,
35 local_address: Mutex<Option<Address>>,
38 peer_uids: DashMap<String, u64>,
41 associated: DashMap<String, ()>,
43 associate_replied: DashMap<String, ()>,
45 inbound_tx: mpsc::UnboundedSender<ProtocolEvent>,
46 inbound_rx: Mutex<Option<mpsc::UnboundedReceiver<ProtocolEvent>>>,
47 pump_started: Mutex<bool>,
48}
49
50#[derive(Debug)]
51pub enum ProtocolEvent {
52 Associated(PeerAssociation),
54 Disassociated { peer: Address, reason: DisassociateReason },
56 Payload { from: Address, pdu: AkkaPdu },
58}
59
60impl AkkaProtocolTransport {
61 pub fn new(inner: Arc<dyn Transport>, settings: RemoteSettings, local_uid: AddressUid) -> Arc<Self> {
62 let (tx, rx) = mpsc::unbounded_channel();
63 Arc::new(Self {
64 inner,
65 settings,
66 local_uid,
67 local_address: Mutex::new(None),
68 peer_uids: DashMap::new(),
69 associated: DashMap::new(),
70 associate_replied: DashMap::new(),
71 inbound_tx: tx,
72 inbound_rx: Mutex::new(Some(rx)),
73 pump_started: Mutex::new(false),
74 })
75 }
76
77 pub fn local_address(&self) -> Option<Address> {
78 self.local_address.lock().clone()
79 }
80
81 pub fn settings(&self) -> &RemoteSettings {
82 &self.settings
83 }
84
85 pub fn local_uid(&self) -> u64 {
86 self.local_uid.get()
87 }
88
89 pub fn raw_transport(&self) -> Arc<dyn Transport> {
90 self.inner.clone()
91 }
92
93 pub async fn start(self: &Arc<Self>) -> Result<Address, TransportError> {
96 let address = self.inner.listen().await?;
97 *self.local_address.lock() = Some(address.clone());
98 self.start_pump();
99 Ok(address)
100 }
101
102 fn start_pump(self: &Arc<Self>) {
103 let mut started = self.pump_started.lock();
104 if *started {
105 return;
106 }
107 *started = true;
108 drop(started);
109
110 let this = self.clone();
111 let mut inbound = self.inner.inbound();
112 tokio::spawn(async move {
113 while let Some(frame) = inbound.recv().await {
114 this.dispatch_frame(frame).await;
115 }
116 });
117 }
118
119 async fn dispatch_frame(&self, frame: InboundFrame) {
120 match frame.pdu {
121 AkkaPdu::Associate(info) => {
122 if info.protocol_version != PROTOCOL_VERSION {
123 let _ = self
124 .inner
125 .send(
126 &info.origin,
127 AkkaPdu::Disassociate(DisassociateReason::HandshakeFailure(format!(
128 "protocol version mismatch: peer={}, local={}",
129 info.protocol_version, PROTOCOL_VERSION
130 ))),
131 )
132 .await;
133 return;
134 }
135 if self.settings.require_cookie.is_some() && self.settings.require_cookie != info.cookie {
136 let _ = self
137 .inner
138 .send(
139 &info.origin,
140 AkkaPdu::Disassociate(DisassociateReason::HandshakeFailure(
141 "cookie mismatch".into(),
142 )),
143 )
144 .await;
145 return;
146 }
147 let key = info.origin.to_string();
148 if let Some(prev) = self.peer_uids.insert(key.clone(), info.uid) {
149 if prev != info.uid && info.uid != 0 {
150 let _ = self.inbound_tx.send(ProtocolEvent::Disassociated {
151 peer: info.origin.clone(),
152 reason: DisassociateReason::Quarantined,
153 });
154 }
155 }
156 self.associated.insert(key.clone(), ());
157
158 if self.associate_replied.insert(key.clone(), ()).is_none() {
163 let local = self.local_address.lock().clone();
164 if let Some(local) = local {
165 let reply = AkkaPdu::Associate(AssociateInfo {
166 origin: local,
167 uid: self.local_uid.get(),
168 cookie: self.settings.require_cookie.clone(),
169 protocol_version: PROTOCOL_VERSION,
170 });
171 let _ = self.inner.send(&info.origin, reply).await;
172 }
173 }
174
175 let _ = self.inbound_tx.send(ProtocolEvent::Associated(PeerAssociation {
176 address: info.origin.clone(),
177 uid: info.uid,
178 }));
179 }
180 AkkaPdu::Disassociate(reason) => {
181 let key = frame.from.to_string();
182 self.associated.remove(&key);
183 self.peer_uids.remove(&key);
184 let _ = self.inbound_tx.send(ProtocolEvent::Disassociated { peer: frame.from, reason });
185 }
186 AkkaPdu::Heartbeat => {
187 }
189 other => {
190 let _ = self.inbound_tx.send(ProtocolEvent::Payload { from: frame.from, pdu: other });
191 }
192 }
193 }
194
195 pub async fn associate(
199 self: &Arc<Self>,
200 target: &Address,
201 local_address: &Address,
202 ) -> Result<(), TransportError> {
203 self.start_pump();
204 self.inner.associate(target).await?;
205 let pdu = AkkaPdu::Associate(AssociateInfo {
206 origin: local_address.clone(),
207 uid: self.local_uid.get(),
208 cookie: self.settings.require_cookie.clone(),
209 protocol_version: PROTOCOL_VERSION,
210 });
211 self.inner.send(target, pdu).await?;
212 Ok(())
213 }
214
215 pub async fn send_pdu(&self, target: &Address, pdu: AkkaPdu) -> Result<(), TransportError> {
216 self.inner.send(target, pdu).await
217 }
218
219 pub async fn disassociate(
220 &self,
221 target: &Address,
222 reason: DisassociateReason,
223 ) -> Result<(), TransportError> {
224 let _ = self.inner.send(target, AkkaPdu::Disassociate(reason)).await;
225 let _ = self.inner.disassociate(target).await;
226 self.associated.remove(&target.to_string());
227 self.peer_uids.remove(&target.to_string());
228 Ok(())
229 }
230
231 pub fn events(&self) -> mpsc::UnboundedReceiver<ProtocolEvent> {
232 self.inbound_rx.lock().take().unwrap_or_else(|| {
233 let (_t, r) = mpsc::unbounded_channel();
234 r
235 })
236 }
237
238 pub fn is_associated(&self, address: &Address) -> bool {
239 self.associated.contains_key(&address.to_string())
240 }
241}
242
243#[async_trait]
244impl Transport for AkkaProtocolTransport {
245 async fn listen(&self) -> Result<Address, TransportError> {
246 self.inner.listen().await
247 }
248
249 async fn associate(&self, target: &Address) -> Result<(), TransportError> {
250 self.inner.associate(target).await
251 }
252
253 async fn send(&self, target: &Address, pdu: AkkaPdu) -> Result<(), TransportError> {
254 self.inner.send(target, pdu).await
255 }
256
257 fn inbound(&self) -> mpsc::UnboundedReceiver<InboundFrame> {
258 self.inner.inbound()
259 }
260
261 async fn disassociate(&self, target: &Address) -> Result<(), TransportError> {
262 self.inner.disassociate(target).await
263 }
264
265 async fn shutdown(&self) -> Result<(), TransportError> {
266 self.inner.shutdown().await
267 }
268}