Skip to main content

atomr_remote/transport/
akka_protocol.rs

1//! Akka-protocol layer atop a raw [`Transport`].
2//!
3//! This wrapper handles handshake (Associate / Associate reply),
4//! validates the protocol version + cookie, attributes inbound frames to
5//! peer UIDs, and exposes `send_payload` / `send_system` helpers that the
6//! Endpoint pair calls directly.
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use dashmap::DashMap;
12use parking_lot::Mutex;
13use tokio::sync::mpsc;
14
15use atomr_core::actor::Address;
16
17use crate::address_uid::AddressUid;
18use crate::pdu::{AkkaPdu, AssociateInfo, DisassociateReason, PROTOCOL_VERSION};
19use crate::settings::RemoteSettings;
20
21use super::{InboundFrame, Transport, TransportError};
22
23/// Outcome of a peer's `Associate` PDU.
24#[derive(Debug, Clone)]
25pub struct PeerAssociation {
26    pub address: Address,
27    pub uid: u64,
28}
29
30/// Wraps an inner `Transport` to enforce the Akka handshake.
31pub struct AkkaProtocolTransport {
32    inner: Arc<dyn Transport>,
33    settings: RemoteSettings,
34    local_uid: AddressUid,
35    /// Local address — captured at `start()`. Used to populate the
36    /// `origin` field in inbound-handshake replies.
37    local_address: Mutex<Option<Address>>,
38    /// Outbound peer state: `target Address -> last UID we observed`. We
39    /// use this to detect a peer restart.
40    peer_uids: DashMap<String, u64>,
41    /// Set of peers we have already finished handshake with.
42    associated: DashMap<String, ()>,
43    /// Set of peers we have already replied to with our Associate.
44    associate_replied: DashMap<String, ()>,
45    inbound_tx: mpsc::UnboundedSender<ProtocolEvent>,
46    inbound_rx: Mutex<Option<mpsc::UnboundedReceiver<ProtocolEvent>>>,
47    pump_started: Mutex<bool>,
48}
49
50#[derive(Debug)]
51pub enum ProtocolEvent {
52    /// Handshake completed with this peer.
53    Associated(PeerAssociation),
54    /// Peer disassociated (graceful or quarantine).
55    Disassociated { peer: Address, reason: DisassociateReason },
56    /// Inbound payload PDU.
57    Payload { from: Address, pdu: AkkaPdu },
58}
59
60impl AkkaProtocolTransport {
61    pub fn new(inner: Arc<dyn Transport>, settings: RemoteSettings, local_uid: AddressUid) -> Arc<Self> {
62        let (tx, rx) = mpsc::unbounded_channel();
63        Arc::new(Self {
64            inner,
65            settings,
66            local_uid,
67            local_address: Mutex::new(None),
68            peer_uids: DashMap::new(),
69            associated: DashMap::new(),
70            associate_replied: DashMap::new(),
71            inbound_tx: tx,
72            inbound_rx: Mutex::new(Some(rx)),
73            pump_started: Mutex::new(false),
74        })
75    }
76
77    pub fn local_address(&self) -> Option<Address> {
78        self.local_address.lock().clone()
79    }
80
81    pub fn settings(&self) -> &RemoteSettings {
82        &self.settings
83    }
84
85    pub fn local_uid(&self) -> u64 {
86        self.local_uid.get()
87    }
88
89    pub fn raw_transport(&self) -> Arc<dyn Transport> {
90        self.inner.clone()
91    }
92
93    /// Start listening on the underlying transport and begin pumping
94    /// inbound PDUs. Returns the local `Address`.
95    pub async fn start(self: &Arc<Self>) -> Result<Address, TransportError> {
96        let address = self.inner.listen().await?;
97        *self.local_address.lock() = Some(address.clone());
98        self.start_pump();
99        Ok(address)
100    }
101
102    fn start_pump(self: &Arc<Self>) {
103        let mut started = self.pump_started.lock();
104        if *started {
105            return;
106        }
107        *started = true;
108        drop(started);
109
110        let this = self.clone();
111        let mut inbound = self.inner.inbound();
112        tokio::spawn(async move {
113            while let Some(frame) = inbound.recv().await {
114                this.dispatch_frame(frame).await;
115            }
116        });
117    }
118
119    async fn dispatch_frame(&self, frame: InboundFrame) {
120        match frame.pdu {
121            AkkaPdu::Associate(info) => {
122                if info.protocol_version != PROTOCOL_VERSION {
123                    let _ = self
124                        .inner
125                        .send(
126                            &info.origin,
127                            AkkaPdu::Disassociate(DisassociateReason::HandshakeFailure(format!(
128                                "protocol version mismatch: peer={}, local={}",
129                                info.protocol_version, PROTOCOL_VERSION
130                            ))),
131                        )
132                        .await;
133                    return;
134                }
135                if self.settings.require_cookie.is_some() && self.settings.require_cookie != info.cookie {
136                    let _ = self
137                        .inner
138                        .send(
139                            &info.origin,
140                            AkkaPdu::Disassociate(DisassociateReason::HandshakeFailure(
141                                "cookie mismatch".into(),
142                            )),
143                        )
144                        .await;
145                    return;
146                }
147                let key = info.origin.to_string();
148                if let Some(prev) = self.peer_uids.insert(key.clone(), info.uid) {
149                    if prev != info.uid && info.uid != 0 {
150                        let _ = self.inbound_tx.send(ProtocolEvent::Disassociated {
151                            peer: info.origin.clone(),
152                            reason: DisassociateReason::Quarantined,
153                        });
154                    }
155                }
156                self.associated.insert(key.clone(), ());
157
158                // Reply with our own Associate so the initiator's pump
159                // can also flip to Connected. The reply travels back
160                // over the same TCP socket pair (the underlying
161                // transport keys peers by Address).
162                if self.associate_replied.insert(key.clone(), ()).is_none() {
163                    let local = self.local_address.lock().clone();
164                    if let Some(local) = local {
165                        let reply = AkkaPdu::Associate(AssociateInfo {
166                            origin: local,
167                            uid: self.local_uid.get(),
168                            cookie: self.settings.require_cookie.clone(),
169                            protocol_version: PROTOCOL_VERSION,
170                        });
171                        let _ = self.inner.send(&info.origin, reply).await;
172                    }
173                }
174
175                let _ = self.inbound_tx.send(ProtocolEvent::Associated(PeerAssociation {
176                    address: info.origin.clone(),
177                    uid: info.uid,
178                }));
179            }
180            AkkaPdu::Disassociate(reason) => {
181                let key = frame.from.to_string();
182                self.associated.remove(&key);
183                self.peer_uids.remove(&key);
184                let _ = self.inbound_tx.send(ProtocolEvent::Disassociated { peer: frame.from, reason });
185            }
186            AkkaPdu::Heartbeat => {
187                // Liveness only; nothing to do at protocol layer.
188            }
189            other => {
190                let _ = self.inbound_tx.send(ProtocolEvent::Payload { from: frame.from, pdu: other });
191            }
192        }
193    }
194
195    /// Initiate an outbound association: open the underlying transport,
196    /// send our `Associate` PDU, and let the inbound pump complete the
197    /// handshake.
198    pub async fn associate(
199        self: &Arc<Self>,
200        target: &Address,
201        local_address: &Address,
202    ) -> Result<(), TransportError> {
203        self.start_pump();
204        self.inner.associate(target).await?;
205        let pdu = AkkaPdu::Associate(AssociateInfo {
206            origin: local_address.clone(),
207            uid: self.local_uid.get(),
208            cookie: self.settings.require_cookie.clone(),
209            protocol_version: PROTOCOL_VERSION,
210        });
211        self.inner.send(target, pdu).await?;
212        Ok(())
213    }
214
215    pub async fn send_pdu(&self, target: &Address, pdu: AkkaPdu) -> Result<(), TransportError> {
216        self.inner.send(target, pdu).await
217    }
218
219    pub async fn disassociate(
220        &self,
221        target: &Address,
222        reason: DisassociateReason,
223    ) -> Result<(), TransportError> {
224        let _ = self.inner.send(target, AkkaPdu::Disassociate(reason)).await;
225        let _ = self.inner.disassociate(target).await;
226        self.associated.remove(&target.to_string());
227        self.peer_uids.remove(&target.to_string());
228        Ok(())
229    }
230
231    pub fn events(&self) -> mpsc::UnboundedReceiver<ProtocolEvent> {
232        self.inbound_rx.lock().take().unwrap_or_else(|| {
233            let (_t, r) = mpsc::unbounded_channel();
234            r
235        })
236    }
237
238    pub fn is_associated(&self, address: &Address) -> bool {
239        self.associated.contains_key(&address.to_string())
240    }
241}
242
243#[async_trait]
244impl Transport for AkkaProtocolTransport {
245    async fn listen(&self) -> Result<Address, TransportError> {
246        self.inner.listen().await
247    }
248
249    async fn associate(&self, target: &Address) -> Result<(), TransportError> {
250        self.inner.associate(target).await
251    }
252
253    async fn send(&self, target: &Address, pdu: AkkaPdu) -> Result<(), TransportError> {
254        self.inner.send(target, pdu).await
255    }
256
257    fn inbound(&self) -> mpsc::UnboundedReceiver<InboundFrame> {
258        self.inner.inbound()
259    }
260
261    async fn disassociate(&self, target: &Address) -> Result<(), TransportError> {
262        self.inner.disassociate(target).await
263    }
264
265    async fn shutdown(&self) -> Result<(), TransportError> {
266        self.inner.shutdown().await
267    }
268}