Skip to main content

atomr_remote_serial/
transport.rs

1//! [`SerialTransport`] — a `Transport` over a USB-attached serial port.
2//!
3//! Symmetric across both ends of the cable: `listen()` and
4//! `associate()` both end up opening the configured device path. There
5//! is exactly one peer per cable, so a single bidirectional link
6//! carries every PDU. Re-attached cables and gadget reboots are handled
7//! inside the transport via [`ReconnectPolicy`].
8
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11use std::time::Duration;
12
13use async_trait::async_trait;
14use parking_lot::Mutex;
15use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
16use tokio::sync::{mpsc, Notify};
17use tokio_serial::{SerialPortBuilderExt, SerialStream};
18
19use atomr_core::actor::Address;
20use atomr_remote::codec::{read_frame, write_frame};
21use atomr_remote::pdu::{AkkaPdu, AssociateInfo, DisassociateReason};
22use atomr_remote::transport::{InboundFrame, Transport, TransportError};
23
24use crate::reconnect::ReconnectPolicy;
25
26const DEFAULT_BAUD: u32 = 115_200;
27const DEFAULT_MAX_FRAME: usize = 4 * 1024 * 1024;
28
29/// Frame-oriented `Transport` over a USB-derived serial endpoint.
30pub struct SerialTransport {
31    system_name: String,
32    device: PathBuf,
33    baud: u32,
34    max_frame_size: usize,
35    state: Arc<SharedState>,
36    inbound_tx: mpsc::UnboundedSender<InboundFrame>,
37    inbound_rx: Mutex<Option<mpsc::UnboundedReceiver<InboundFrame>>>,
38    shutdown: Arc<Notify>,
39    reconnect_policy: ReconnectPolicy,
40}
41
42/// State shared between the public API and the background tasks.
43struct SharedState {
44    /// Outbound mpsc to the active link's writer task. `None` while the
45    /// link is down (during initial connect or between reconnect
46    /// attempts).
47    sender: Mutex<Option<mpsc::UnboundedSender<AkkaPdu>>>,
48    /// Address we advertise to the peer in our `Associate` PDU. Set
49    /// once by `listen()`.
50    local_address: Mutex<Option<Address>>,
51    /// Address the peer claimed in their first `Associate` PDU. Used
52    /// as `from:` for inbound frames once known.
53    peer_address: Mutex<Option<Address>>,
54}
55
56impl SerialTransport {
57    /// Open `device` on a `system_name`-tagged transport with the
58    /// default baud rate (115200) and 4 MiB max frame size.
59    pub fn new(system_name: impl Into<String>, device: impl Into<PathBuf>) -> Self {
60        Self::with_options(system_name, device, DEFAULT_BAUD, DEFAULT_MAX_FRAME, ReconnectPolicy::default())
61    }
62
63    /// Construct with explicit baud, max frame size, and reconnect
64    /// policy. Baud is ignored for true USB CDC-ACM endpoints (the
65    /// rate is set by the link, not the user) but retained for true
66    /// UARTs over USB-to-serial dongles.
67    pub fn with_options(
68        system_name: impl Into<String>,
69        device: impl Into<PathBuf>,
70        baud: u32,
71        max_frame_size: usize,
72        reconnect_policy: ReconnectPolicy,
73    ) -> Self {
74        let (tx, rx) = mpsc::unbounded_channel();
75        Self {
76            system_name: system_name.into(),
77            device: device.into(),
78            baud,
79            max_frame_size,
80            state: Arc::new(SharedState {
81                sender: Mutex::new(None),
82                local_address: Mutex::new(None),
83                peer_address: Mutex::new(None),
84            }),
85            inbound_tx: tx,
86            inbound_rx: Mutex::new(Some(rx)),
87            shutdown: Arc::new(Notify::new()),
88            reconnect_policy,
89        }
90    }
91
92    /// The local `Address` returned by `listen()`. `None` until then.
93    pub fn local_address(&self) -> Option<Address> {
94        self.state.local_address.lock().clone()
95    }
96
97    /// Drive the transport over caller-supplied byte-stream halves
98    /// instead of opening a serial device. Useful for testing (with
99    /// [`tokio::io::duplex`]) and for layering the Akka protocol over
100    /// custom byte pipes (Unix sockets, SSH-tunneled streams, raw fds
101    /// from external tools). No reconnect is attempted; if the streams
102    /// close, the transport stays closed until shutdown.
103    pub fn with_streams<R, W>(
104        system_name: impl Into<String>,
105        reader: R,
106        writer: W,
107        max_frame_size: usize,
108    ) -> Self
109    where
110        R: AsyncRead + Unpin + Send + 'static,
111        W: AsyncWrite + Unpin + Send + 'static,
112    {
113        let (tx, rx) = mpsc::unbounded_channel();
114        let state = Arc::new(SharedState {
115            sender: Mutex::new(None),
116            local_address: Mutex::new(None),
117            peer_address: Mutex::new(None),
118        });
119        let shutdown = Arc::new(Notify::new());
120        let this = Self {
121            system_name: system_name.into(),
122            device: PathBuf::from("<streams>"),
123            baud: DEFAULT_BAUD,
124            max_frame_size,
125            state: state.clone(),
126            inbound_tx: tx.clone(),
127            inbound_rx: Mutex::new(Some(rx)),
128            shutdown: shutdown.clone(),
129            reconnect_policy: ReconnectPolicy::never(),
130        };
131        let address = Address::remote("akka.serial", &this.system_name, "<streams>", 0);
132        *state.local_address.lock() = Some(address);
133
134        // Pre-create the outbound channel so send() works immediately
135        // — the link runner takes the rx half.
136        let (out_tx, out_rx) = mpsc::unbounded_channel::<AkkaPdu>();
137        *state.sender.lock() = Some(out_tx);
138
139        tokio::spawn(run_link_halves_with_outbound(
140            reader,
141            writer,
142            out_rx,
143            max_frame_size,
144            state,
145            tx,
146            shutdown,
147        ));
148        this
149    }
150}
151
152#[async_trait]
153impl Transport for SerialTransport {
154    async fn listen(&self) -> Result<Address, TransportError> {
155        let device_str = self.device.to_string_lossy().into_owned();
156        let address = Address::remote("akka.serial", &self.system_name, device_str, 0);
157        *self.state.local_address.lock() = Some(address.clone());
158
159        // Spawn the supervisor that owns the open/reader/writer/reconnect
160        // lifecycle. It will retry until shutdown if the device isn't
161        // present yet — that's expected when the gadget side boots
162        // before the host side or vice versa.
163        spawn_supervisor(
164            self.device.clone(),
165            self.baud,
166            self.max_frame_size,
167            self.state.clone(),
168            self.inbound_tx.clone(),
169            self.shutdown.clone(),
170            self.reconnect_policy.clone(),
171        );
172        Ok(address)
173    }
174
175    async fn associate(&self, _target: &Address) -> Result<(), TransportError> {
176        // No-op: serial is one-peer-per-cable; the supervisor opens the
177        // device on `listen()` and keeps it open. The protocol layer
178        // will hand us frames to send via `send()`; if the link is
179        // currently down those return `Closed` and the protocol layer
180        // retries.
181        Ok(())
182    }
183
184    async fn send(&self, _target: &Address, pdu: AkkaPdu) -> Result<(), TransportError> {
185        let sender = self.state.sender.lock().clone();
186        match sender {
187            Some(tx) => tx.send(pdu).map_err(|_| TransportError::Closed),
188            None => Err(TransportError::Closed),
189        }
190    }
191
192    fn inbound(&self) -> mpsc::UnboundedReceiver<InboundFrame> {
193        self.inbound_rx.lock().take().unwrap_or_else(|| {
194            let (_tx, rx) = mpsc::unbounded_channel();
195            rx
196        })
197    }
198
199    async fn disassociate(&self, _target: &Address) -> Result<(), TransportError> {
200        if let Some(tx) = self.state.sender.lock().clone() {
201            let _ = tx.send(AkkaPdu::Disassociate(DisassociateReason::Normal));
202        }
203        Ok(())
204    }
205
206    async fn shutdown(&self) -> Result<(), TransportError> {
207        self.shutdown.notify_waiters();
208        *self.state.sender.lock() = None;
209        Ok(())
210    }
211}
212
213fn spawn_supervisor(
214    device: PathBuf,
215    baud: u32,
216    max_frame: usize,
217    state: Arc<SharedState>,
218    inbound: mpsc::UnboundedSender<InboundFrame>,
219    shutdown: Arc<Notify>,
220    policy: ReconnectPolicy,
221) {
222    tokio::spawn(async move {
223        let mut delay = policy.initial;
224        loop {
225            // Race the open against shutdown.
226            let opened = tokio::select! {
227                _ = shutdown.notified() => return,
228                result = open_device(&device, baud) => result,
229            };
230            match opened {
231                Ok(stream) => {
232                    delay = policy.initial;
233                    run_link(stream, max_frame, state.clone(), inbound.clone(), shutdown.clone()).await;
234                    if !policy.is_enabled() {
235                        return;
236                    }
237                    tracing::warn!(device = %device.display(), "serial link dropped, reconnecting");
238                }
239                Err(e) => {
240                    if !policy.is_enabled() {
241                        tracing::warn!(device = %device.display(), error = %e, "serial open failed, reconnect disabled");
242                        return;
243                    }
244                    tracing::debug!(device = %device.display(), error = %e, "serial open failed, will retry");
245                }
246            }
247
248            tokio::select! {
249                _ = shutdown.notified() => return,
250                _ = tokio::time::sleep(delay) => {}
251            }
252            delay = policy.next_delay(delay.max(Duration::from_millis(1)));
253        }
254    });
255}
256
257async fn open_device(device: &Path, baud: u32) -> Result<SerialStream, std::io::Error> {
258    tokio_serial::new(device.to_string_lossy(), baud).open_native_async().map_err(io_from_serial)
259}
260
261fn io_from_serial(e: tokio_serial::Error) -> std::io::Error {
262    match e.kind {
263        tokio_serial::ErrorKind::NoDevice => std::io::Error::new(std::io::ErrorKind::NotFound, e.description),
264        tokio_serial::ErrorKind::InvalidInput => {
265            std::io::Error::new(std::io::ErrorKind::InvalidInput, e.description)
266        }
267        _ => std::io::Error::other(e.description),
268    }
269}
270
271async fn run_link(
272    stream: SerialStream,
273    max_frame: usize,
274    state: Arc<SharedState>,
275    inbound: mpsc::UnboundedSender<InboundFrame>,
276    shutdown: Arc<Notify>,
277) {
278    let (reader, writer) = tokio::io::split(stream);
279    run_link_halves(reader, writer, max_frame, state, inbound, shutdown).await
280}
281
282async fn run_link_halves<R, W>(
283    reader: R,
284    writer: W,
285    max_frame: usize,
286    state: Arc<SharedState>,
287    inbound: mpsc::UnboundedSender<InboundFrame>,
288    shutdown: Arc<Notify>,
289) where
290    R: AsyncRead + Unpin + Send + 'static,
291    W: AsyncWrite + Unpin + Send + 'static,
292{
293    let (tx, rx) = mpsc::unbounded_channel::<AkkaPdu>();
294    *state.sender.lock() = Some(tx);
295    run_link_halves_with_outbound(reader, writer, rx, max_frame, state, inbound, shutdown).await
296}
297
298async fn run_link_halves_with_outbound<R, W>(
299    mut reader: R,
300    mut writer: W,
301    mut rx: mpsc::UnboundedReceiver<AkkaPdu>,
302    max_frame: usize,
303    state: Arc<SharedState>,
304    inbound: mpsc::UnboundedSender<InboundFrame>,
305    shutdown: Arc<Notify>,
306) where
307    R: AsyncRead + Unpin + Send + 'static,
308    W: AsyncWrite + Unpin + Send + 'static,
309{
310    // Writer task — drains the outbound mpsc onto the wire.
311    let writer_task = tokio::spawn(async move {
312        while let Some(pdu) = rx.recv().await {
313            if write_frame(&mut writer, &pdu, max_frame).await.is_err() {
314                break;
315            }
316            if matches!(pdu, AkkaPdu::Disassociate(_)) {
317                let _ = writer.shutdown().await;
318                break;
319            }
320        }
321    });
322
323    // Reader task — first frame must be an Associate so we learn the
324    // peer's Address; thereafter we attribute every frame to it.
325    let state_for_reader = state.clone();
326    let inbound_for_reader = inbound.clone();
327    let shutdown_for_reader = shutdown.clone();
328    let reader_task = tokio::spawn(async move {
329        loop {
330            let read = tokio::select! {
331                _ = shutdown_for_reader.notified() => break,
332                r = read_frame(&mut reader, max_frame) => r,
333            };
334            let pdu = match read {
335                Ok(p) => p,
336                Err(_) => break,
337            };
338
339            // Stamp `from:` based on the peer's advertised Address.
340            // Until the peer's Associate arrives, fall back to the
341            // local Address — the protocol layer will treat the
342            // frame as informational; the first Associate fixes
343            // attribution for subsequent frames.
344            let from = if let AkkaPdu::Associate(AssociateInfo { origin, .. }) = &pdu {
345                *state_for_reader.peer_address.lock() = Some(origin.clone());
346                origin.clone()
347            } else {
348                state_for_reader
349                    .peer_address
350                    .lock()
351                    .clone()
352                    .or_else(|| state_for_reader.local_address.lock().clone())
353                    .unwrap_or_else(|| Address::local("?"))
354            };
355
356            if inbound_for_reader.send(InboundFrame { from, pdu }).is_err() {
357                break;
358            }
359        }
360    });
361
362    let _ = tokio::join!(writer_task, reader_task);
363    *state.sender.lock() = None;
364    *state.peer_address.lock() = None;
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn local_address_fields_round_trip_through_parse() {
373        let path = "/dev/ttyACM0";
374        let addr = Address::remote("akka.serial", "Sys", path, 0);
375        let rendered = addr.to_string();
376        let parsed = Address::parse(&rendered).expect("parse");
377        assert_eq!(parsed, addr);
378        assert_eq!(parsed.host.as_deref(), Some(path));
379        assert_eq!(parsed.port, Some(0));
380    }
381}