use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::broadcast;
use tokio_util::sync::CancellationToken;
use crate::command::Command;
use crate::config::{Parity, SerialConfig, StopBits};
use crate::device::SerialDevice;
use crate::error::Result;
use crate::event::{Event, EventBus};
use crate::mapper::{LineEndingMapper, Mapper};
const fn parity_letter(p: Parity) -> char {
match p {
Parity::None => 'N',
Parity::Even => 'E',
Parity::Odd => 'O',
Parity::Mark => 'M',
Parity::Space => 'S',
}
}
const fn stop_bits_number(s: StopBits) -> u8 {
match s {
StopBits::One => 1,
StopBits::Two => 2,
}
}
const READ_BUFFER_BYTES: usize = 4096;
const SEND_BREAK_DURATION: Duration = Duration::from_millis(250);
const HELP_TEXT: &str = "commands: ?/h help, q/x quit, c show config, t toggle DTR, \
g toggle RTS, b<rate><Enter> set baud, \\ send break";
pub struct Session<D: SerialDevice + 'static> {
device: D,
bus: EventBus,
cancel: CancellationToken,
omap: Box<dyn Mapper>,
imap: Box<dyn Mapper>,
dtr_asserted: bool,
rts_asserted: bool,
}
impl<D: SerialDevice + 'static> Session<D> {
#[must_use]
pub fn new(device: D) -> Self {
Self {
device,
bus: EventBus::default(),
cancel: CancellationToken::new(),
omap: Box::new(LineEndingMapper::default()),
imap: Box::new(LineEndingMapper::default()),
dtr_asserted: true,
rts_asserted: true,
}
}
#[must_use]
pub fn with_bus(device: D, bus: EventBus) -> Self {
Self {
device,
bus,
cancel: CancellationToken::new(),
omap: Box::new(LineEndingMapper::default()),
imap: Box::new(LineEndingMapper::default()),
dtr_asserted: true,
rts_asserted: true,
}
}
#[must_use]
pub fn with_omap<M: Mapper + 'static>(mut self, mapper: M) -> Self {
self.omap = Box::new(mapper);
self
}
#[must_use]
pub fn with_imap<M: Mapper + 'static>(mut self, mapper: M) -> Self {
self.imap = Box::new(mapper);
self
}
#[must_use]
pub const fn with_initial_dtr(mut self, asserted: bool) -> Self {
self.dtr_asserted = asserted;
self
}
#[must_use]
pub const fn with_initial_rts(mut self, asserted: bool) -> Self {
self.rts_asserted = asserted;
self
}
#[must_use]
pub const fn bus(&self) -> &EventBus {
&self.bus
}
#[must_use]
pub fn cancellation_token(&self) -> CancellationToken {
self.cancel.clone()
}
pub async fn run(mut self) -> crate::Result<()> {
let mut subscriber = self.bus.subscribe();
self.bus.publish(Event::DeviceConnected);
let mut read_buf = vec![0_u8; READ_BUFFER_BYTES];
loop {
tokio::select! {
biased;
() = self.cancel.cancelled() => break,
res = self.device.read(&mut read_buf) => match res {
Ok(0) => {
self.bus.publish(Event::DeviceDisconnected {
reason: "EOF on serial read".into(),
});
break;
}
Ok(n) => {
let mapped = self.imap.map(&read_buf[..n]);
self.bus.publish(Event::RxBytes(mapped));
}
Err(err) => {
self.bus.publish(Event::DeviceDisconnected {
reason: format!("serial read failed: {err}"),
});
break;
}
},
msg = subscriber.recv() => match msg {
Ok(Event::TxBytes(bytes)) => {
let mapped = self.omap.map(&bytes);
if let Err(err) = self.device.write_all(&mapped).await {
self.bus.publish(Event::DeviceDisconnected {
reason: format!("serial write failed: {err}"),
});
break;
}
}
Ok(Event::Command(cmd)) => self.dispatch_command(cmd).await,
Ok(_) | Err(broadcast::error::RecvError::Lagged(_)) => {}
Err(broadcast::error::RecvError::Closed) => break,
},
}
}
Ok(())
}
pub(crate) async fn dispatch_command(&mut self, cmd: Command) {
match cmd {
Command::Quit => self.cancel.cancel(),
Command::Help => {
self.bus.publish(Event::SystemMessage(HELP_TEXT.into()));
}
Command::ShowConfig => {
let cfg = self.device.config();
self.bus.publish(Event::SystemMessage(format!(
"config: {} {}{}{} flow={:?}",
cfg.baud_rate,
cfg.data_bits.bits(),
parity_letter(cfg.parity),
stop_bits_number(cfg.stop_bits),
cfg.flow_control,
)));
}
Command::SetBaud(rate) => match self.device.set_baud_rate(rate) {
Ok(()) => {
self.bus
.publish(Event::ConfigChanged(*self.device.config()));
}
Err(err) => {
self.bus.publish(Event::Error(Arc::new(err)));
}
},
Command::ApplyConfig(cfg) => {
if let Err(err) = self.apply_config(cfg).await {
self.bus.publish(Event::Error(Arc::new(err)));
}
}
Command::ToggleDtr => {
let new_state = !self.dtr_asserted;
self.apply_dtr(new_state);
}
Command::ToggleRts => {
let new_state = !self.rts_asserted;
self.apply_rts(new_state);
}
Command::SetDtrAbs(state) => self.apply_dtr(state),
Command::SetRtsAbs(state) => self.apply_rts(state),
Command::SendBreak => match self.device.send_break(SEND_BREAK_DURATION) {
Ok(()) => {
self.bus.publish(Event::SystemMessage(format!(
"sent {} ms break",
SEND_BREAK_DURATION.as_millis()
)));
}
Err(err) => {
self.bus.publish(Event::Error(Arc::new(err)));
}
},
Command::OpenMenu => {
self.bus.publish(Event::MenuOpened);
}
}
}
fn apply_dtr(&mut self, new_state: bool) {
match self.device.set_dtr(new_state) {
Ok(()) => {
self.dtr_asserted = new_state;
self.bus.publish(Event::SystemMessage(format!(
"DTR: {}",
if new_state { "asserted" } else { "deasserted" }
)));
self.bus.publish(Event::ModemLinesChanged {
dtr: self.dtr_asserted,
rts: self.rts_asserted,
});
}
Err(err) => {
self.bus.publish(Event::Error(Arc::new(err)));
}
}
}
fn apply_rts(&mut self, new_state: bool) {
match self.device.set_rts(new_state) {
Ok(()) => {
self.rts_asserted = new_state;
self.bus.publish(Event::SystemMessage(format!(
"RTS: {}",
if new_state { "asserted" } else { "deasserted" }
)));
self.bus.publish(Event::ModemLinesChanged {
dtr: self.dtr_asserted,
rts: self.rts_asserted,
});
}
Err(err) => {
self.bus.publish(Event::Error(Arc::new(err)));
}
}
}
#[allow(clippy::unused_async)]
pub async fn apply_config(&mut self, new: SerialConfig) -> Result<()> {
let snapshot = *self.device.config();
if let Err(e) = self.device.set_baud_rate(new.baud_rate) {
self.rollback(&snapshot);
return Err(e);
}
if let Err(e) = self.device.set_data_bits(new.data_bits) {
self.rollback(&snapshot);
return Err(e);
}
if let Err(e) = self.device.set_stop_bits(new.stop_bits) {
self.rollback(&snapshot);
return Err(e);
}
if let Err(e) = self.device.set_parity(new.parity) {
self.rollback(&snapshot);
return Err(e);
}
if let Err(e) = self.device.set_flow_control(new.flow_control) {
self.rollback(&snapshot);
return Err(e);
}
self.bus
.publish(Event::ConfigChanged(*self.device.config()));
Ok(())
}
fn rollback(&mut self, snapshot: &SerialConfig) {
let _ = self.device.set_baud_rate(snapshot.baud_rate);
let _ = self.device.set_data_bits(snapshot.data_bits);
let _ = self.device.set_stop_bits(snapshot.stop_bits);
let _ = self.device.set_parity(snapshot.parity);
let _ = self.device.set_flow_control(snapshot.flow_control);
}
}
#[cfg(test)]
mod tests {
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::broadcast::error::TryRecvError;
use super::{Event, Result, SerialDevice, Session};
use crate::command::Command;
use crate::config::{DataBits, FlowControl, ModemStatus, Parity, SerialConfig, StopBits};
use crate::error::Error;
#[allow(clippy::struct_excessive_bools)]
struct MockDevice {
config: SerialConfig,
fail_baud: bool,
fail_data_bits: bool,
fail_stop_bits: bool,
fail_parity: bool,
fail_flow: bool,
}
impl MockDevice {
const fn new(config: SerialConfig) -> Self {
Self {
config,
fail_baud: false,
fail_data_bits: false,
fail_stop_bits: false,
fail_parity: false,
fail_flow: false,
}
}
}
impl AsyncRead for MockDevice {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Pending
}
}
impl AsyncWrite for MockDevice {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl SerialDevice for MockDevice {
fn set_baud_rate(&mut self, baud: u32) -> Result<()> {
if self.fail_baud {
self.fail_baud = false;
return Err(Error::InvalidConfig("mock: baud fail".into()));
}
self.config.baud_rate = baud;
Ok(())
}
fn set_data_bits(&mut self, bits: DataBits) -> Result<()> {
if self.fail_data_bits {
self.fail_data_bits = false;
return Err(Error::InvalidConfig("mock: data_bits fail".into()));
}
self.config.data_bits = bits;
Ok(())
}
fn set_stop_bits(&mut self, bits: StopBits) -> Result<()> {
if self.fail_stop_bits {
self.fail_stop_bits = false;
return Err(Error::InvalidConfig("mock: stop_bits fail".into()));
}
self.config.stop_bits = bits;
Ok(())
}
fn set_parity(&mut self, parity: Parity) -> Result<()> {
if self.fail_parity {
self.fail_parity = false;
return Err(Error::InvalidConfig("mock: parity fail".into()));
}
self.config.parity = parity;
Ok(())
}
fn set_flow_control(&mut self, flow: FlowControl) -> Result<()> {
if self.fail_flow {
self.fail_flow = false;
return Err(Error::InvalidConfig("mock: flow fail".into()));
}
self.config.flow_control = flow;
Ok(())
}
fn set_dtr(&mut self, _level: bool) -> Result<()> {
Ok(())
}
fn set_rts(&mut self, _level: bool) -> Result<()> {
Ok(())
}
fn send_break(&mut self, _duration: Duration) -> Result<()> {
Ok(())
}
fn modem_status(&mut self) -> Result<ModemStatus> {
Ok(ModemStatus::default())
}
fn config(&self) -> &SerialConfig {
&self.config
}
}
fn new_cfg() -> SerialConfig {
SerialConfig {
baud_rate: 9600,
data_bits: DataBits::Seven,
stop_bits: StopBits::Two,
parity: Parity::Even,
flow_control: FlowControl::Hardware,
..SerialConfig::default()
}
}
#[tokio::test]
async fn apply_config_success_publishes_config_changed() {
let device = MockDevice::new(SerialConfig::default());
let mut session = Session::new(device);
let mut rx = session.bus().subscribe();
let target = new_cfg();
session
.apply_config(target)
.await
.expect("apply_config should succeed");
let got = session.device.config();
assert_eq!(got.baud_rate, target.baud_rate);
assert_eq!(got.data_bits, target.data_bits);
assert_eq!(got.stop_bits, target.stop_bits);
assert_eq!(got.parity, target.parity);
assert_eq!(got.flow_control, target.flow_control);
match rx.try_recv() {
Ok(Event::ConfigChanged(cfg)) => {
assert_eq!(cfg.baud_rate, target.baud_rate);
assert_eq!(cfg.flow_control, target.flow_control);
}
other => panic!("expected ConfigChanged, got {other:?}"),
}
}
#[tokio::test]
async fn apply_config_rolls_back_on_middle_failure() {
let mut device = MockDevice::new(SerialConfig::default());
device.fail_flow = true;
let initial = *device.config();
let mut session = Session::new(device);
let mut rx = session.bus().subscribe();
let target = new_cfg();
let err = session
.apply_config(target)
.await
.expect_err("apply_config must fail when flow setter errors");
assert!(matches!(err, Error::InvalidConfig(_)));
let got = session.device.config();
assert_eq!(got.baud_rate, initial.baud_rate);
assert_eq!(got.data_bits, initial.data_bits);
assert_eq!(got.stop_bits, initial.stop_bits);
assert_eq!(got.parity, initial.parity);
assert_eq!(got.flow_control, initial.flow_control);
match rx.try_recv() {
Err(TryRecvError::Empty) => {}
Ok(Event::ConfigChanged(_)) => panic!("unexpected ConfigChanged after rollback"),
other => panic!("unexpected bus state: {other:?}"),
}
}
#[tokio::test]
async fn apply_config_command_dispatches_through_session() {
let device = MockDevice::new(SerialConfig::default());
let mut session = Session::new(device);
let mut rx = session.bus().subscribe();
let target = SerialConfig {
baud_rate: 9600,
..SerialConfig::default()
};
session.dispatch_command(Command::ApplyConfig(target)).await;
let ev = rx.try_recv().expect("ConfigChanged should be on the bus");
match ev {
Event::ConfigChanged(cfg) => assert_eq!(cfg.baud_rate, 9600),
other => panic!("expected ConfigChanged, got {other:?}"),
}
}
#[tokio::test]
async fn apply_config_command_on_failure_publishes_error() {
let mut device = MockDevice::new(SerialConfig::default());
device.fail_baud = true;
let mut session = Session::new(device);
let mut rx = session.bus().subscribe();
let target = SerialConfig {
baud_rate: 9600,
..SerialConfig::default()
};
session.dispatch_command(Command::ApplyConfig(target)).await;
match rx.try_recv() {
Ok(Event::Error(_)) => {}
other => panic!("expected Error, got {other:?}"),
}
}
#[tokio::test]
async fn set_dtr_abs_publishes_modem_lines_changed() {
let device = MockDevice::new(SerialConfig::default());
let mut session = Session::new(device);
let mut rx = session.bus().subscribe();
session.dispatch_command(Command::SetDtrAbs(true)).await;
match rx.recv().await.unwrap() {
Event::SystemMessage(_) => {}
other => panic!("expected SystemMessage, got {other:?}"),
}
match rx.recv().await.unwrap() {
Event::ModemLinesChanged { dtr, rts } => {
assert!(dtr);
assert!(rts);
}
other => panic!("expected ModemLinesChanged, got {other:?}"),
}
}
#[tokio::test]
async fn set_rts_abs_publishes_modem_lines_changed() {
let device = MockDevice::new(SerialConfig::default());
let mut session = Session::new(device);
let mut rx = session.bus().subscribe();
session.dispatch_command(Command::SetRtsAbs(false)).await;
let _ = rx.recv().await; match rx.recv().await.unwrap() {
Event::ModemLinesChanged { dtr, rts } => {
assert!(dtr);
assert!(!rts);
}
other => panic!("expected ModemLinesChanged, got {other:?}"),
}
}
#[tokio::test]
async fn toggle_dtr_now_also_publishes_modem_lines_changed() {
let device = MockDevice::new(SerialConfig::default());
let mut session = Session::new(device);
let mut rx = session.bus().subscribe();
session.dispatch_command(Command::ToggleDtr).await;
let _ = rx.recv().await; match rx.recv().await.unwrap() {
Event::ModemLinesChanged { dtr, rts } => {
assert!(!dtr);
assert!(rts);
}
other => panic!("expected ModemLinesChanged, got {other:?}"),
}
}
#[tokio::test]
async fn toggle_rts_now_also_publishes_modem_lines_changed() {
let device = MockDevice::new(SerialConfig::default());
let mut session = Session::new(device);
let mut rx = session.bus().subscribe();
session.dispatch_command(Command::ToggleRts).await;
let _ = rx.recv().await; match rx.recv().await.unwrap() {
Event::ModemLinesChanged { dtr, rts } => {
assert!(dtr);
assert!(!rts);
}
other => panic!("expected ModemLinesChanged, got {other:?}"),
}
}
#[tokio::test]
async fn apply_config_rolls_back_on_first_step_failure() {
let mut device = MockDevice::new(SerialConfig::default());
device.fail_baud = true;
let initial = *device.config();
let mut session = Session::new(device);
let mut rx = session.bus().subscribe();
let target = new_cfg();
let err = session
.apply_config(target)
.await
.expect_err("apply_config must fail when baud setter errors");
assert!(matches!(err, Error::InvalidConfig(_)));
let got = session.device.config();
assert_eq!(got, &initial);
match rx.try_recv() {
Err(TryRecvError::Empty) => {}
Ok(Event::ConfigChanged(_)) => panic!("unexpected ConfigChanged after rollback"),
other => panic!("unexpected bus state: {other:?}"),
}
}
}