1use std::path::{Path, PathBuf};
10use std::sync::Arc;
11use std::time::Duration;
12
13use async_trait::async_trait;
14use parking_lot::Mutex;
15use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
16use tokio::sync::{mpsc, Notify};
17use tokio_serial::{SerialPortBuilderExt, SerialStream};
18
19use atomr_core::actor::Address;
20use atomr_remote::codec::{read_frame, write_frame};
21use atomr_remote::pdu::{AkkaPdu, AssociateInfo, DisassociateReason};
22use atomr_remote::transport::{InboundFrame, Transport, TransportError};
23
24use crate::reconnect::ReconnectPolicy;
25
26const DEFAULT_BAUD: u32 = 115_200;
27const DEFAULT_MAX_FRAME: usize = 4 * 1024 * 1024;
28
29pub struct SerialTransport {
31 system_name: String,
32 device: PathBuf,
33 baud: u32,
34 max_frame_size: usize,
35 state: Arc<SharedState>,
36 inbound_tx: mpsc::UnboundedSender<InboundFrame>,
37 inbound_rx: Mutex<Option<mpsc::UnboundedReceiver<InboundFrame>>>,
38 shutdown: Arc<Notify>,
39 reconnect_policy: ReconnectPolicy,
40}
41
42struct SharedState {
44 sender: Mutex<Option<mpsc::UnboundedSender<AkkaPdu>>>,
48 local_address: Mutex<Option<Address>>,
51 peer_address: Mutex<Option<Address>>,
54}
55
56impl SerialTransport {
57 pub fn new(system_name: impl Into<String>, device: impl Into<PathBuf>) -> Self {
60 Self::with_options(system_name, device, DEFAULT_BAUD, DEFAULT_MAX_FRAME, ReconnectPolicy::default())
61 }
62
63 pub fn with_options(
68 system_name: impl Into<String>,
69 device: impl Into<PathBuf>,
70 baud: u32,
71 max_frame_size: usize,
72 reconnect_policy: ReconnectPolicy,
73 ) -> Self {
74 let (tx, rx) = mpsc::unbounded_channel();
75 Self {
76 system_name: system_name.into(),
77 device: device.into(),
78 baud,
79 max_frame_size,
80 state: Arc::new(SharedState {
81 sender: Mutex::new(None),
82 local_address: Mutex::new(None),
83 peer_address: Mutex::new(None),
84 }),
85 inbound_tx: tx,
86 inbound_rx: Mutex::new(Some(rx)),
87 shutdown: Arc::new(Notify::new()),
88 reconnect_policy,
89 }
90 }
91
92 pub fn local_address(&self) -> Option<Address> {
94 self.state.local_address.lock().clone()
95 }
96
97 pub fn with_streams<R, W>(
104 system_name: impl Into<String>,
105 reader: R,
106 writer: W,
107 max_frame_size: usize,
108 ) -> Self
109 where
110 R: AsyncRead + Unpin + Send + 'static,
111 W: AsyncWrite + Unpin + Send + 'static,
112 {
113 let (tx, rx) = mpsc::unbounded_channel();
114 let state = Arc::new(SharedState {
115 sender: Mutex::new(None),
116 local_address: Mutex::new(None),
117 peer_address: Mutex::new(None),
118 });
119 let shutdown = Arc::new(Notify::new());
120 let this = Self {
121 system_name: system_name.into(),
122 device: PathBuf::from("<streams>"),
123 baud: DEFAULT_BAUD,
124 max_frame_size,
125 state: state.clone(),
126 inbound_tx: tx.clone(),
127 inbound_rx: Mutex::new(Some(rx)),
128 shutdown: shutdown.clone(),
129 reconnect_policy: ReconnectPolicy::never(),
130 };
131 let address = Address::remote("akka.serial", &this.system_name, "<streams>", 0);
132 *state.local_address.lock() = Some(address);
133
134 let (out_tx, out_rx) = mpsc::unbounded_channel::<AkkaPdu>();
137 *state.sender.lock() = Some(out_tx);
138
139 tokio::spawn(run_link_halves_with_outbound(
140 reader,
141 writer,
142 out_rx,
143 max_frame_size,
144 state,
145 tx,
146 shutdown,
147 ));
148 this
149 }
150}
151
152#[async_trait]
153impl Transport for SerialTransport {
154 async fn listen(&self) -> Result<Address, TransportError> {
155 let device_str = self.device.to_string_lossy().into_owned();
156 let address = Address::remote("akka.serial", &self.system_name, device_str, 0);
157 *self.state.local_address.lock() = Some(address.clone());
158
159 spawn_supervisor(
164 self.device.clone(),
165 self.baud,
166 self.max_frame_size,
167 self.state.clone(),
168 self.inbound_tx.clone(),
169 self.shutdown.clone(),
170 self.reconnect_policy.clone(),
171 );
172 Ok(address)
173 }
174
175 async fn associate(&self, _target: &Address) -> Result<(), TransportError> {
176 Ok(())
182 }
183
184 async fn send(&self, _target: &Address, pdu: AkkaPdu) -> Result<(), TransportError> {
185 let sender = self.state.sender.lock().clone();
186 match sender {
187 Some(tx) => tx.send(pdu).map_err(|_| TransportError::Closed),
188 None => Err(TransportError::Closed),
189 }
190 }
191
192 fn inbound(&self) -> mpsc::UnboundedReceiver<InboundFrame> {
193 self.inbound_rx.lock().take().unwrap_or_else(|| {
194 let (_tx, rx) = mpsc::unbounded_channel();
195 rx
196 })
197 }
198
199 async fn disassociate(&self, _target: &Address) -> Result<(), TransportError> {
200 if let Some(tx) = self.state.sender.lock().clone() {
201 let _ = tx.send(AkkaPdu::Disassociate(DisassociateReason::Normal));
202 }
203 Ok(())
204 }
205
206 async fn shutdown(&self) -> Result<(), TransportError> {
207 self.shutdown.notify_waiters();
208 *self.state.sender.lock() = None;
209 Ok(())
210 }
211}
212
213fn spawn_supervisor(
214 device: PathBuf,
215 baud: u32,
216 max_frame: usize,
217 state: Arc<SharedState>,
218 inbound: mpsc::UnboundedSender<InboundFrame>,
219 shutdown: Arc<Notify>,
220 policy: ReconnectPolicy,
221) {
222 tokio::spawn(async move {
223 let mut delay = policy.initial;
224 loop {
225 let opened = tokio::select! {
227 _ = shutdown.notified() => return,
228 result = open_device(&device, baud) => result,
229 };
230 match opened {
231 Ok(stream) => {
232 delay = policy.initial;
233 run_link(stream, max_frame, state.clone(), inbound.clone(), shutdown.clone()).await;
234 if !policy.is_enabled() {
235 return;
236 }
237 tracing::warn!(device = %device.display(), "serial link dropped, reconnecting");
238 }
239 Err(e) => {
240 if !policy.is_enabled() {
241 tracing::warn!(device = %device.display(), error = %e, "serial open failed, reconnect disabled");
242 return;
243 }
244 tracing::debug!(device = %device.display(), error = %e, "serial open failed, will retry");
245 }
246 }
247
248 tokio::select! {
249 _ = shutdown.notified() => return,
250 _ = tokio::time::sleep(delay) => {}
251 }
252 delay = policy.next_delay(delay.max(Duration::from_millis(1)));
253 }
254 });
255}
256
257async fn open_device(device: &Path, baud: u32) -> Result<SerialStream, std::io::Error> {
258 tokio_serial::new(device.to_string_lossy(), baud).open_native_async().map_err(io_from_serial)
259}
260
261fn io_from_serial(e: tokio_serial::Error) -> std::io::Error {
262 match e.kind {
263 tokio_serial::ErrorKind::NoDevice => std::io::Error::new(std::io::ErrorKind::NotFound, e.description),
264 tokio_serial::ErrorKind::InvalidInput => {
265 std::io::Error::new(std::io::ErrorKind::InvalidInput, e.description)
266 }
267 _ => std::io::Error::other(e.description),
268 }
269}
270
271async fn run_link(
272 stream: SerialStream,
273 max_frame: usize,
274 state: Arc<SharedState>,
275 inbound: mpsc::UnboundedSender<InboundFrame>,
276 shutdown: Arc<Notify>,
277) {
278 let (reader, writer) = tokio::io::split(stream);
279 run_link_halves(reader, writer, max_frame, state, inbound, shutdown).await
280}
281
282async fn run_link_halves<R, W>(
283 reader: R,
284 writer: W,
285 max_frame: usize,
286 state: Arc<SharedState>,
287 inbound: mpsc::UnboundedSender<InboundFrame>,
288 shutdown: Arc<Notify>,
289) where
290 R: AsyncRead + Unpin + Send + 'static,
291 W: AsyncWrite + Unpin + Send + 'static,
292{
293 let (tx, rx) = mpsc::unbounded_channel::<AkkaPdu>();
294 *state.sender.lock() = Some(tx);
295 run_link_halves_with_outbound(reader, writer, rx, max_frame, state, inbound, shutdown).await
296}
297
298async fn run_link_halves_with_outbound<R, W>(
299 mut reader: R,
300 mut writer: W,
301 mut rx: mpsc::UnboundedReceiver<AkkaPdu>,
302 max_frame: usize,
303 state: Arc<SharedState>,
304 inbound: mpsc::UnboundedSender<InboundFrame>,
305 shutdown: Arc<Notify>,
306) where
307 R: AsyncRead + Unpin + Send + 'static,
308 W: AsyncWrite + Unpin + Send + 'static,
309{
310 let writer_task = tokio::spawn(async move {
312 while let Some(pdu) = rx.recv().await {
313 if write_frame(&mut writer, &pdu, max_frame).await.is_err() {
314 break;
315 }
316 if matches!(pdu, AkkaPdu::Disassociate(_)) {
317 let _ = writer.shutdown().await;
318 break;
319 }
320 }
321 });
322
323 let state_for_reader = state.clone();
326 let inbound_for_reader = inbound.clone();
327 let shutdown_for_reader = shutdown.clone();
328 let reader_task = tokio::spawn(async move {
329 loop {
330 let read = tokio::select! {
331 _ = shutdown_for_reader.notified() => break,
332 r = read_frame(&mut reader, max_frame) => r,
333 };
334 let pdu = match read {
335 Ok(p) => p,
336 Err(_) => break,
337 };
338
339 let from = if let AkkaPdu::Associate(AssociateInfo { origin, .. }) = &pdu {
345 *state_for_reader.peer_address.lock() = Some(origin.clone());
346 origin.clone()
347 } else {
348 state_for_reader
349 .peer_address
350 .lock()
351 .clone()
352 .or_else(|| state_for_reader.local_address.lock().clone())
353 .unwrap_or_else(|| Address::local("?"))
354 };
355
356 if inbound_for_reader.send(InboundFrame { from, pdu }).is_err() {
357 break;
358 }
359 }
360 });
361
362 let _ = tokio::join!(writer_task, reader_task);
363 *state.sender.lock() = None;
364 *state.peer_address.lock() = None;
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn local_address_fields_round_trip_through_parse() {
373 let path = "/dev/ttyACM0";
374 let addr = Address::remote("akka.serial", "Sys", path, 0);
375 let rendered = addr.to_string();
376 let parsed = Address::parse(&rendered).expect("parse");
377 assert_eq!(parsed, addr);
378 assert_eq!(parsed.host.as_deref(), Some(path));
379 assert_eq!(parsed.port, Some(0));
380 }
381}