use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use parking_lot::Mutex;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, Notify};
use tokio_serial::{SerialPortBuilderExt, SerialStream};
use atomr_core::actor::Address;
use atomr_remote::codec::{read_frame, write_frame};
use atomr_remote::pdu::{AkkaPdu, AssociateInfo, DisassociateReason};
use atomr_remote::transport::{InboundFrame, Transport, TransportError};
use crate::reconnect::ReconnectPolicy;
const DEFAULT_BAUD: u32 = 115_200;
const DEFAULT_MAX_FRAME: usize = 4 * 1024 * 1024;
pub struct SerialTransport {
system_name: String,
device: PathBuf,
baud: u32,
max_frame_size: usize,
state: Arc<SharedState>,
inbound_tx: mpsc::UnboundedSender<InboundFrame>,
inbound_rx: Mutex<Option<mpsc::UnboundedReceiver<InboundFrame>>>,
shutdown: Arc<Notify>,
reconnect_policy: ReconnectPolicy,
}
struct SharedState {
sender: Mutex<Option<mpsc::UnboundedSender<AkkaPdu>>>,
local_address: Mutex<Option<Address>>,
peer_address: Mutex<Option<Address>>,
}
impl SerialTransport {
pub fn new(system_name: impl Into<String>, device: impl Into<PathBuf>) -> Self {
Self::with_options(system_name, device, DEFAULT_BAUD, DEFAULT_MAX_FRAME, ReconnectPolicy::default())
}
pub fn with_options(
system_name: impl Into<String>,
device: impl Into<PathBuf>,
baud: u32,
max_frame_size: usize,
reconnect_policy: ReconnectPolicy,
) -> Self {
let (tx, rx) = mpsc::unbounded_channel();
Self {
system_name: system_name.into(),
device: device.into(),
baud,
max_frame_size,
state: Arc::new(SharedState {
sender: Mutex::new(None),
local_address: Mutex::new(None),
peer_address: Mutex::new(None),
}),
inbound_tx: tx,
inbound_rx: Mutex::new(Some(rx)),
shutdown: Arc::new(Notify::new()),
reconnect_policy,
}
}
pub fn local_address(&self) -> Option<Address> {
self.state.local_address.lock().clone()
}
pub fn with_streams<R, W>(
system_name: impl Into<String>,
reader: R,
writer: W,
max_frame_size: usize,
) -> Self
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let (tx, rx) = mpsc::unbounded_channel();
let state = Arc::new(SharedState {
sender: Mutex::new(None),
local_address: Mutex::new(None),
peer_address: Mutex::new(None),
});
let shutdown = Arc::new(Notify::new());
let this = Self {
system_name: system_name.into(),
device: PathBuf::from("<streams>"),
baud: DEFAULT_BAUD,
max_frame_size,
state: state.clone(),
inbound_tx: tx.clone(),
inbound_rx: Mutex::new(Some(rx)),
shutdown: shutdown.clone(),
reconnect_policy: ReconnectPolicy::never(),
};
let address = Address::remote("akka.serial", &this.system_name, "<streams>", 0);
*state.local_address.lock() = Some(address);
let (out_tx, out_rx) = mpsc::unbounded_channel::<AkkaPdu>();
*state.sender.lock() = Some(out_tx);
tokio::spawn(run_link_halves_with_outbound(
reader,
writer,
out_rx,
max_frame_size,
state,
tx,
shutdown,
));
this
}
}
#[async_trait]
impl Transport for SerialTransport {
async fn listen(&self) -> Result<Address, TransportError> {
let device_str = self.device.to_string_lossy().into_owned();
let address = Address::remote("akka.serial", &self.system_name, device_str, 0);
*self.state.local_address.lock() = Some(address.clone());
spawn_supervisor(
self.device.clone(),
self.baud,
self.max_frame_size,
self.state.clone(),
self.inbound_tx.clone(),
self.shutdown.clone(),
self.reconnect_policy.clone(),
);
Ok(address)
}
async fn associate(&self, _target: &Address) -> Result<(), TransportError> {
Ok(())
}
async fn send(&self, _target: &Address, pdu: AkkaPdu) -> Result<(), TransportError> {
let sender = self.state.sender.lock().clone();
match sender {
Some(tx) => tx.send(pdu).map_err(|_| TransportError::Closed),
None => Err(TransportError::Closed),
}
}
fn inbound(&self) -> mpsc::UnboundedReceiver<InboundFrame> {
self.inbound_rx.lock().take().unwrap_or_else(|| {
let (_tx, rx) = mpsc::unbounded_channel();
rx
})
}
async fn disassociate(&self, _target: &Address) -> Result<(), TransportError> {
if let Some(tx) = self.state.sender.lock().clone() {
let _ = tx.send(AkkaPdu::Disassociate(DisassociateReason::Normal));
}
Ok(())
}
async fn shutdown(&self) -> Result<(), TransportError> {
self.shutdown.notify_waiters();
*self.state.sender.lock() = None;
Ok(())
}
}
fn spawn_supervisor(
device: PathBuf,
baud: u32,
max_frame: usize,
state: Arc<SharedState>,
inbound: mpsc::UnboundedSender<InboundFrame>,
shutdown: Arc<Notify>,
policy: ReconnectPolicy,
) {
tokio::spawn(async move {
let mut delay = policy.initial;
loop {
let opened = tokio::select! {
_ = shutdown.notified() => return,
result = open_device(&device, baud) => result,
};
match opened {
Ok(stream) => {
delay = policy.initial;
run_link(stream, max_frame, state.clone(), inbound.clone(), shutdown.clone()).await;
if !policy.is_enabled() {
return;
}
tracing::warn!(device = %device.display(), "serial link dropped, reconnecting");
}
Err(e) => {
if !policy.is_enabled() {
tracing::warn!(device = %device.display(), error = %e, "serial open failed, reconnect disabled");
return;
}
tracing::debug!(device = %device.display(), error = %e, "serial open failed, will retry");
}
}
tokio::select! {
_ = shutdown.notified() => return,
_ = tokio::time::sleep(delay) => {}
}
delay = policy.next_delay(delay.max(Duration::from_millis(1)));
}
});
}
async fn open_device(device: &Path, baud: u32) -> Result<SerialStream, std::io::Error> {
tokio_serial::new(device.to_string_lossy(), baud).open_native_async().map_err(io_from_serial)
}
fn io_from_serial(e: tokio_serial::Error) -> std::io::Error {
match e.kind {
tokio_serial::ErrorKind::NoDevice => std::io::Error::new(std::io::ErrorKind::NotFound, e.description),
tokio_serial::ErrorKind::InvalidInput => {
std::io::Error::new(std::io::ErrorKind::InvalidInput, e.description)
}
_ => std::io::Error::other(e.description),
}
}
async fn run_link(
stream: SerialStream,
max_frame: usize,
state: Arc<SharedState>,
inbound: mpsc::UnboundedSender<InboundFrame>,
shutdown: Arc<Notify>,
) {
let (reader, writer) = tokio::io::split(stream);
run_link_halves(reader, writer, max_frame, state, inbound, shutdown).await
}
async fn run_link_halves<R, W>(
reader: R,
writer: W,
max_frame: usize,
state: Arc<SharedState>,
inbound: mpsc::UnboundedSender<InboundFrame>,
shutdown: Arc<Notify>,
) where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let (tx, rx) = mpsc::unbounded_channel::<AkkaPdu>();
*state.sender.lock() = Some(tx);
run_link_halves_with_outbound(reader, writer, rx, max_frame, state, inbound, shutdown).await
}
async fn run_link_halves_with_outbound<R, W>(
mut reader: R,
mut writer: W,
mut rx: mpsc::UnboundedReceiver<AkkaPdu>,
max_frame: usize,
state: Arc<SharedState>,
inbound: mpsc::UnboundedSender<InboundFrame>,
shutdown: Arc<Notify>,
) where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let writer_task = tokio::spawn(async move {
while let Some(pdu) = rx.recv().await {
if write_frame(&mut writer, &pdu, max_frame).await.is_err() {
break;
}
if matches!(pdu, AkkaPdu::Disassociate(_)) {
let _ = writer.shutdown().await;
break;
}
}
});
let state_for_reader = state.clone();
let inbound_for_reader = inbound.clone();
let shutdown_for_reader = shutdown.clone();
let reader_task = tokio::spawn(async move {
loop {
let read = tokio::select! {
_ = shutdown_for_reader.notified() => break,
r = read_frame(&mut reader, max_frame) => r,
};
let pdu = match read {
Ok(p) => p,
Err(_) => break,
};
let from = if let AkkaPdu::Associate(AssociateInfo { origin, .. }) = &pdu {
*state_for_reader.peer_address.lock() = Some(origin.clone());
origin.clone()
} else {
state_for_reader
.peer_address
.lock()
.clone()
.or_else(|| state_for_reader.local_address.lock().clone())
.unwrap_or_else(|| Address::local("?"))
};
if inbound_for_reader.send(InboundFrame { from, pdu }).is_err() {
break;
}
}
});
let _ = tokio::join!(writer_task, reader_task);
*state.sender.lock() = None;
*state.peer_address.lock() = None;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn local_address_fields_round_trip_through_parse() {
let path = "/dev/ttyACM0";
let addr = Address::remote("akka.serial", "Sys", path, 0);
let rendered = addr.to_string();
let parsed = Address::parse(&rendered).expect("parse");
assert_eq!(parsed, addr);
assert_eq!(parsed.host.as_deref(), Some(path));
assert_eq!(parsed.port, Some(0));
}
}