Skip to main content

atomr_cluster/
transport.rs

1//! Concrete `GossipTransport` implementations and a unified cluster
2//! frame that carries both gossip PDUs and Python-level remote-tells
3//! over the same underlying connection.
4//!
5//! Two transports ship here:
6//!
7//! * [`InProcessClusterTransport`] — channel-based, deterministic;
8//!   suitable for unit tests that need multiple `ActorSystem`s in one
9//!   process.
10//! * [`TcpClusterTransport`] — opens a TCP listener and connects out to
11//!   peers on demand. Frames are length-prefixed bincode-encoded
12//!   [`ClusterFrame`]s.
13//!
14//! The transports are not Python-aware. The [`RemoteMessageSink`] trait
15//! is implemented by the caller (e.g. the pycore binding) to receive
16//! `RemoteTell` frames and route them to the right local actor after
17//! decoding the payload through the codec registry.
18
19use std::collections::HashMap;
20use std::net::SocketAddr;
21use std::sync::Arc;
22
23use bincode::config::standard as bincode_cfg;
24use dashmap::DashMap;
25use parking_lot::Mutex;
26use serde::{Deserialize, Serialize};
27use tokio::io::{AsyncReadExt, AsyncWriteExt};
28use tokio::net::{TcpListener, TcpStream};
29use tokio::sync::{mpsc, Notify};
30
31use atomr_core::actor::Address;
32
33use crate::cluster_daemon::GossipTransport;
34use crate::gossip_pdu::GossipPdu;
35
36/// Wire-level frame used by both [`InProcessClusterTransport`] and
37/// [`TcpClusterTransport`]. The two variants are multiplexed over the
38/// same connection so that gossip and Python-level remote-tells share
39/// the same association.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub enum ClusterFrame {
42    /// Cluster-membership gossip PDU.
43    Gossip(GossipPdu),
44    /// Type-erased Python actor-message envelope. The receiving side
45    /// looks up `manifest` in the codec registry to decode `payload`.
46    RemoteTell { target_path: String, manifest: String, payload: Vec<u8>, sender_path: Option<String> },
47}
48
49/// Sink for inbound `RemoteTell` frames. The pycore binding implements
50/// this — typically by decoding the payload via the codec registry and
51/// invoking `tell` on the matching local actor.
52pub trait RemoteMessageSink: Send + Sync + 'static {
53    /// Deliver a `RemoteTell` frame. Errors must not crash the
54    /// transport — the implementor is responsible for logging or
55    /// dead-lettering.
56    fn deliver(&self, target_path: &str, manifest: &str, payload: &[u8], sender_path: Option<&str>);
57}
58
59// ---------------------------------------------------------------------------
60// In-process transport. Useful for deterministic multi-node tests.
61// ---------------------------------------------------------------------------
62
63/// Shared registry that wires up [`InProcessClusterTransport`] siblings
64/// in the same process. A single registry is created once per "logical
65/// network" and handed to every transport that should be able to reach
66/// every other.
67#[derive(Default)]
68pub struct InProcessRegistry {
69    peers: DashMap<String, mpsc::UnboundedSender<ClusterFrame>>,
70}
71
72impl InProcessRegistry {
73    pub fn new() -> Arc<Self> {
74        Arc::new(Self::default())
75    }
76
77    fn register(&self, addr: &Address, tx: mpsc::UnboundedSender<ClusterFrame>) {
78        self.peers.insert(addr.to_string(), tx);
79    }
80
81    fn unregister(&self, addr: &Address) {
82        self.peers.remove(&addr.to_string());
83    }
84
85    fn send(&self, target: &Address, frame: ClusterFrame) {
86        if let Some(tx) = self.peers.get(&target.to_string()) {
87            let _ = tx.send(frame);
88        }
89    }
90}
91
92/// Channel-backed cluster transport. Discovers peers through a shared
93/// [`InProcessRegistry`]. Construct one per node, register the daemon's
94/// gossip inbox + a [`RemoteMessageSink`] via [`Self::start`], and call
95/// [`Self::send_remote`] to push remote-tells.
96pub struct InProcessClusterTransport {
97    self_addr: Address,
98    registry: Arc<InProcessRegistry>,
99    #[allow(dead_code)]
100    inbound_tx: mpsc::UnboundedSender<ClusterFrame>,
101    inbound_rx: Mutex<Option<mpsc::UnboundedReceiver<ClusterFrame>>>,
102}
103
104impl InProcessClusterTransport {
105    pub fn new(self_addr: Address, registry: Arc<InProcessRegistry>) -> Self {
106        let (tx, rx) = mpsc::unbounded_channel();
107        registry.register(&self_addr, tx.clone());
108        Self { self_addr, registry, inbound_tx: tx, inbound_rx: Mutex::new(Some(rx)) }
109    }
110
111    pub fn self_address(&self) -> &Address {
112        &self.self_addr
113    }
114
115    /// Send a `RemoteTell` frame to `target`. Drops silently if the
116    /// peer is not registered — matches the best-effort semantics of
117    /// [`GossipTransport::send`].
118    pub fn send_remote(
119        &self,
120        target: &Address,
121        target_path: String,
122        manifest: String,
123        payload: Vec<u8>,
124        sender_path: Option<String>,
125    ) {
126        self.registry.send(target, ClusterFrame::RemoteTell { target_path, manifest, payload, sender_path });
127    }
128
129    /// Spawn the inbound demultiplex task. `gossip_inbox` is the
130    /// daemon's [`crate::ClusterDaemonHandle::gossip_inbox`] sender;
131    /// `sink` receives `RemoteTell` frames.
132    pub fn start(&self, gossip_inbox: mpsc::UnboundedSender<GossipPdu>, sink: Arc<dyn RemoteMessageSink>) {
133        let mut rx = match self.inbound_rx.lock().take() {
134            Some(rx) => rx,
135            None => return,
136        };
137        tokio::spawn(async move {
138            while let Some(frame) = rx.recv().await {
139                match frame {
140                    ClusterFrame::Gossip(p) => {
141                        let _ = gossip_inbox.send(p);
142                    }
143                    ClusterFrame::RemoteTell { target_path, manifest, payload, sender_path } => {
144                        sink.deliver(&target_path, &manifest, &payload, sender_path.as_deref());
145                    }
146                }
147            }
148        });
149    }
150}
151
152impl GossipTransport for InProcessClusterTransport {
153    fn send(&self, target: &Address, pdu: GossipPdu) {
154        // Self-send is a no-op (consistent with the existing in-mem test
155        // network). The daemon never picks itself as a gossip target, but
156        // be defensive.
157        if target == &self.self_addr {
158            return;
159        }
160        self.registry.send(target, ClusterFrame::Gossip(pdu));
161    }
162}
163
164impl Drop for InProcessClusterTransport {
165    fn drop(&mut self) {
166        self.registry.unregister(&self.self_addr);
167    }
168}
169
170// ---------------------------------------------------------------------------
171// TCP transport.
172// ---------------------------------------------------------------------------
173
174/// TCP-based cluster transport. One listener per node accepts inbound
175/// connections; outbound connections are opened on demand and reused
176/// per peer address. Frames are length-prefixed (4-byte big-endian)
177/// bincode-encoded [`ClusterFrame`]s.
178pub struct TcpClusterTransport {
179    self_addr: Address,
180    bind: SocketAddr,
181    advertised_host: Option<String>,
182    peers: Arc<DashMap<String, mpsc::UnboundedSender<ClusterFrame>>>,
183    inbound_tx: mpsc::UnboundedSender<ClusterFrame>,
184    inbound_rx: Mutex<Option<mpsc::UnboundedReceiver<ClusterFrame>>>,
185    shutdown: Arc<Notify>,
186    listen_addr: Mutex<Option<SocketAddr>>,
187}
188
189impl TcpClusterTransport {
190    /// Build a new TCP transport. The system name is taken from
191    /// `self_addr.system`; the bind socket is given separately because
192    /// `Address` doesn't carry a port until `listen` resolves it.
193    pub fn new(self_addr: Address, bind: SocketAddr) -> Self {
194        Self::with_advertised(self_addr, bind, None)
195    }
196
197    pub fn with_advertised(self_addr: Address, bind: SocketAddr, advertised_host: Option<String>) -> Self {
198        let (tx, rx) = mpsc::unbounded_channel();
199        Self {
200            self_addr,
201            bind,
202            advertised_host,
203            peers: Arc::new(DashMap::new()),
204            inbound_tx: tx,
205            inbound_rx: Mutex::new(Some(rx)),
206            shutdown: Arc::new(Notify::new()),
207            listen_addr: Mutex::new(None),
208        }
209    }
210
211    /// Listen on the configured bind address. The returned `Address`
212    /// reflects the actually-bound socket (so callers that pass
213    /// `0.0.0.0:0` learn the auto-allocated port). The protocol
214    /// scheme is forced to `akka.tcp` since the resolved address
215    /// represents a real TCP listener.
216    pub async fn listen(&self) -> std::io::Result<Address> {
217        let listener = TcpListener::bind(self.bind).await?;
218        let bound = listener.local_addr()?;
219        *self.listen_addr.lock() = Some(bound);
220        let host = self.advertised_host.clone().unwrap_or_else(|| bound.ip().to_string());
221        let resolved = Address::remote("akka.tcp", self.self_addr.system.clone(), host, bound.port());
222
223        let inbound = self.inbound_tx.clone();
224        let shutdown = self.shutdown.clone();
225        tokio::spawn(async move {
226            loop {
227                tokio::select! {
228                    _ = shutdown.notified() => break,
229                    accept = listener.accept() => {
230                        let Ok((sock, _)) = accept else { continue };
231                        let _ = sock.set_nodelay(true);
232                        let inbound = inbound.clone();
233                        tokio::spawn(handle_inbound_socket(sock, inbound));
234                    }
235                }
236            }
237        });
238        Ok(resolved)
239    }
240
241    pub fn local_address(&self) -> Option<SocketAddr> {
242        *self.listen_addr.lock()
243    }
244
245    /// Hand the inbound receiver out (call once). Subsequent calls
246    /// return an empty channel.
247    pub fn take_inbound(&self) -> mpsc::UnboundedReceiver<ClusterFrame> {
248        self.inbound_rx.lock().take().unwrap_or_else(|| mpsc::unbounded_channel().1)
249    }
250
251    /// Spawn the inbound demultiplex task. Mirrors
252    /// [`InProcessClusterTransport::start`].
253    pub fn start(&self, gossip_inbox: mpsc::UnboundedSender<GossipPdu>, sink: Arc<dyn RemoteMessageSink>) {
254        let mut rx = self.take_inbound();
255        tokio::spawn(async move {
256            while let Some(frame) = rx.recv().await {
257                match frame {
258                    ClusterFrame::Gossip(p) => {
259                        let _ = gossip_inbox.send(p);
260                    }
261                    ClusterFrame::RemoteTell { target_path, manifest, payload, sender_path } => {
262                        sink.deliver(&target_path, &manifest, &payload, sender_path.as_deref());
263                    }
264                }
265            }
266        });
267    }
268
269    /// Send a `RemoteTell` frame to `target`. Best-effort.
270    pub fn send_remote(
271        &self,
272        target: &Address,
273        target_path: String,
274        manifest: String,
275        payload: Vec<u8>,
276        sender_path: Option<String>,
277    ) {
278        let frame = ClusterFrame::RemoteTell { target_path, manifest, payload, sender_path };
279        let target = target.clone();
280        let peers = self.peers.clone();
281        tokio::spawn(async move {
282            send_via_tcp(target, frame, peers).await;
283        });
284    }
285
286    pub async fn shutdown(&self) {
287        self.shutdown.notify_waiters();
288        self.peers.clear();
289    }
290}
291
292impl GossipTransport for TcpClusterTransport {
293    fn send(&self, target: &Address, pdu: GossipPdu) {
294        if target == &self.self_addr {
295            return;
296        }
297        let frame = ClusterFrame::Gossip(pdu);
298        let target = target.clone();
299        let peers = self.peers.clone();
300        tokio::spawn(async move {
301            send_via_tcp(target, frame, peers).await;
302        });
303    }
304}
305
306async fn send_via_tcp(
307    target: Address,
308    frame: ClusterFrame,
309    peers: Arc<DashMap<String, mpsc::UnboundedSender<ClusterFrame>>>,
310) {
311    let key = target.to_string();
312    if let Some(tx) = peers.get(&key) {
313        let _ = tx.send(frame);
314        return;
315    }
316    // Otherwise open a new connection and remember it.
317    let host = match target.host.as_deref() {
318        Some(h) => h.to_string(),
319        None => return,
320    };
321    let port = match target.port {
322        Some(p) => p,
323        None => return,
324    };
325    let stream = match TcpStream::connect((host.as_str(), port)).await {
326        Ok(s) => s,
327        Err(_) => return,
328    };
329    let _ = stream.set_nodelay(true);
330    let (mut reader, mut writer) = stream.into_split();
331    let (tx, mut rx) = mpsc::unbounded_channel::<ClusterFrame>();
332    peers.insert(key.clone(), tx.clone());
333
334    let key_w = key.clone();
335    let peers_w = peers.clone();
336    tokio::spawn(async move {
337        while let Some(f) = rx.recv().await {
338            if write_frame(&mut writer, &f).await.is_err() {
339                break;
340            }
341        }
342        peers_w.remove(&key_w);
343    });
344
345    // Reader: outbound TCP also receives any reply frames the peer
346    // might choose to send back over the same socket. (We don't use
347    // this in practice — the peer's listener accepts a separate
348    // connection — but draining the half-open socket prevents weird
349    // EOF artefacts.)
350    tokio::spawn(async move {
351        let mut buf = Vec::new();
352        loop {
353            buf.clear();
354            if read_frame_into(&mut reader, &mut buf).await.is_err() {
355                break;
356            }
357        }
358    });
359
360    let _ = tx.send(frame);
361}
362
363async fn handle_inbound_socket(sock: TcpStream, inbound: mpsc::UnboundedSender<ClusterFrame>) {
364    let (mut reader, mut _writer) = sock.into_split();
365    let mut buf = Vec::new();
366    loop {
367        buf.clear();
368        if read_frame_into(&mut reader, &mut buf).await.is_err() {
369            break;
370        }
371        match bincode::serde::decode_from_slice::<ClusterFrame, _>(&buf, bincode_cfg()) {
372            Ok((frame, _)) => {
373                if inbound.send(frame).is_err() {
374                    break;
375                }
376            }
377            Err(_) => break,
378        }
379    }
380}
381
382async fn write_frame<W: AsyncWriteExt + Unpin>(writer: &mut W, frame: &ClusterFrame) -> std::io::Result<()> {
383    let bytes = bincode::serde::encode_to_vec(frame, bincode_cfg())
384        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
385    let len = (bytes.len() as u32).to_be_bytes();
386    writer.write_all(&len).await?;
387    writer.write_all(&bytes).await?;
388    writer.flush().await?;
389    Ok(())
390}
391
392async fn read_frame_into<R: AsyncReadExt + Unpin>(reader: &mut R, buf: &mut Vec<u8>) -> std::io::Result<()> {
393    let mut len_buf = [0u8; 4];
394    reader.read_exact(&mut len_buf).await?;
395    let len = u32::from_be_bytes(len_buf) as usize;
396    if len > 16 * 1024 * 1024 {
397        return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "frame too large"));
398    }
399    buf.resize(len, 0);
400    reader.read_exact(buf).await?;
401    Ok(())
402}
403
404// ---------------------------------------------------------------------------
405// Convenience: in-memory sink that records frames for assertions.
406// ---------------------------------------------------------------------------
407
408/// Test-only sink that buffers every received `RemoteTell` for later
409/// inspection. Public so binding-side tests can use it.
410#[derive(Default)]
411pub struct RecordingSink {
412    pub records: Mutex<Vec<RemoteTellRecord>>,
413}
414
415#[derive(Debug, Clone)]
416pub struct RemoteTellRecord {
417    pub target_path: String,
418    pub manifest: String,
419    pub payload: Vec<u8>,
420    pub sender_path: Option<String>,
421}
422
423impl RemoteMessageSink for RecordingSink {
424    fn deliver(&self, target_path: &str, manifest: &str, payload: &[u8], sender_path: Option<&str>) {
425        self.records.lock().push(RemoteTellRecord {
426            target_path: target_path.to_string(),
427            manifest: manifest.to_string(),
428            payload: payload.to_vec(),
429            sender_path: sender_path.map(|s| s.to_string()),
430        });
431    }
432}
433
434// Keep one-import linter happy.
435#[allow(dead_code)]
436type _Hm = HashMap<(), ()>;
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use crate::vector_clock::VectorClock;
442    use std::time::Duration;
443
444    fn local(name: &str) -> Address {
445        Address::local(name)
446    }
447
448    #[tokio::test]
449    async fn in_process_gossip_round_trip() {
450        let net = InProcessRegistry::new();
451        let a = Arc::new(InProcessClusterTransport::new(local("A"), net.clone()));
452        let b = Arc::new(InProcessClusterTransport::new(local("B"), net.clone()));
453
454        let (gossip_tx_b, mut gossip_rx_b) = mpsc::unbounded_channel();
455        let sink: Arc<dyn RemoteMessageSink> = Arc::new(RecordingSink::default());
456        b.start(gossip_tx_b, sink);
457
458        a.send(&local("B"), GossipPdu::Status { from: "A".into(), version: VectorClock::new() });
459        let pdu =
460            tokio::time::timeout(Duration::from_millis(200), gossip_rx_b.recv()).await.unwrap().unwrap();
461        assert!(matches!(pdu, GossipPdu::Status { .. }));
462    }
463
464    #[tokio::test]
465    async fn in_process_remote_tell_delivered_to_sink() {
466        let net = InProcessRegistry::new();
467        let a = Arc::new(InProcessClusterTransport::new(local("A"), net.clone()));
468        let b = Arc::new(InProcessClusterTransport::new(local("B"), net.clone()));
469
470        let (gossip_tx, _gossip_rx) = mpsc::unbounded_channel();
471        let sink = Arc::new(RecordingSink::default());
472        let sink_dyn: Arc<dyn RemoteMessageSink> = sink.clone();
473        b.start(gossip_tx, sink_dyn);
474
475        a.send_remote(
476            &local("B"),
477            "akka://B/user/echo".into(),
478            "json:Hello".into(),
479            b"{\"name\":\"Ada\"}".to_vec(),
480            None,
481        );
482        // Allow the channel deliver tick.
483        tokio::time::sleep(Duration::from_millis(20)).await;
484        let recs = sink.records.lock().clone();
485        assert_eq!(recs.len(), 1);
486        assert_eq!(recs[0].target_path, "akka://B/user/echo");
487        assert_eq!(recs[0].manifest, "json:Hello");
488        assert_eq!(recs[0].payload, b"{\"name\":\"Ada\"}");
489    }
490
491    #[tokio::test]
492    async fn tcp_round_trip_remote_tell() {
493        let bind: SocketAddr = "127.0.0.1:0".parse().unwrap();
494        let a_addr = Address::remote("akka.tcp", "A", "127.0.0.1", 0);
495        let b_addr_seed = Address::remote("akka.tcp", "B", "127.0.0.1", 0);
496        let a = Arc::new(TcpClusterTransport::new(a_addr, bind));
497        let b = Arc::new(TcpClusterTransport::new(b_addr_seed, bind));
498
499        let resolved_b = b.listen().await.unwrap();
500        let _resolved_a = a.listen().await.unwrap();
501
502        let (gossip_tx, _gossip_rx) = mpsc::unbounded_channel();
503        let sink = Arc::new(RecordingSink::default());
504        let sink_dyn: Arc<dyn RemoteMessageSink> = sink.clone();
505        b.start(gossip_tx, sink_dyn);
506
507        a.send_remote(
508            &resolved_b,
509            format!("{}/user/echo", resolved_b),
510            "json:Hello".into(),
511            b"hi".to_vec(),
512            None,
513        );
514
515        // Poll for delivery.
516        for _ in 0..50 {
517            if !sink.records.lock().is_empty() {
518                break;
519            }
520            tokio::time::sleep(Duration::from_millis(20)).await;
521        }
522        let recs = sink.records.lock().clone();
523        assert_eq!(recs.len(), 1, "expected one frame, got {recs:?}");
524        assert_eq!(recs[0].manifest, "json:Hello");
525        assert_eq!(recs[0].payload, b"hi");
526
527        a.shutdown().await;
528        b.shutdown().await;
529    }
530}