use std::io::{Error, ErrorKind, Read, Write};
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering::Relaxed;
use std::thread::{JoinHandle, spawn};
use std::time::Instant;
use channels::Channels;
use constants::{T_RSTACK_MAX, TX_K};
use log::{debug, error, info, trace, warn};
use state::State;
use tokio::sync::mpsc::{Receiver, Sender, channel};
use transmission::Transmission;
use crate::frame::{Ack, Data, Frame, Nak, RST, RstAck};
use crate::frame_buffer::FrameBuffer;
use crate::protocol::Mask;
use crate::status::Status;
use crate::types::Payload;
use crate::utils::WrappingU3;
use crate::validate::Validate;
mod channels;
mod constants;
mod state;
mod transmission;
#[derive(Debug)]
pub struct Transceiver<T> {
frame_buffer: FrameBuffer<T>,
channels: Channels,
state: State,
transmissions: heapless::Vec<Transmission, TX_K>,
}
impl<T> Transceiver<T> {
#[must_use]
pub const fn new(
serial_port: T,
requests: Receiver<Payload>,
response: Sender<std::io::Result<Payload>>,
) -> Self {
Self {
frame_buffer: FrameBuffer::new(serial_port),
channels: Channels::new(requests, response),
state: State::new(),
transmissions: heapless::Vec::new(),
}
}
fn reset(&mut self) {
self.state.reset(Status::Failed);
self.transmissions.clear();
}
fn handle_io_error(&mut self, error: Error) {
debug!("Handling I/O error: {error}");
self.channels.respond(Err(error));
self.reset();
}
fn leave_reject(&mut self) {
if self.state.reject() {
trace!("Leaving rejection state.");
self.state.set_reject(false);
}
}
fn handle_payload(&self, mut payload: Payload) {
payload.mask();
self.channels.respond(Ok(payload));
}
fn handle_rst_ack(rst_ack: &RstAck) -> Error {
debug!("Received RSTACK: {rst_ack}");
if !rst_ack.is_ash_v2() {
error!("{rst_ack} is not ASHv2: {:#04X}", rst_ack.version());
}
rst_ack.code().map_or_else(
|code| {
warn!("NCP sent RSTACK with unknown code: {code}");
},
|code| {
trace!("NCP sent RSTACK condition: {code}");
},
);
Error::new(ErrorKind::ConnectionReset, "NCP sent RSTACK.")
}
fn handle_error(error: &crate::frame::Error) -> Error {
if !error.is_ash_v2() {
error!("{error} is not ASHv2: {:#04X}", error.version());
}
error.code().map_or_else(
|code| {
error!("NCP sent ERROR with invalid code: {code}");
},
|code| {
warn!("NCP sent ERROR condition: {code}");
},
);
Error::new(ErrorKind::ConnectionReset, "NCP entered ERROR state.")
}
}
impl<T> Transceiver<T>
where
T: Read,
{
fn receive(&mut self) -> std::io::Result<Option<Frame>> {
match self.frame_buffer.read_frame() {
Ok(frame) => Ok(Some(frame)),
Err(error) => {
if error.kind() == ErrorKind::TimedOut {
Ok(None)
} else {
Err(error)
}
}
}
}
}
impl<T> Transceiver<T>
where
T: Write,
{
fn ack_sent_frames(&mut self, ack_num: WrappingU3) {
while let Some(transmission) = self
.transmissions
.iter()
.position(|transmission| transmission.frame_num() + 1 == ack_num)
.map(|index| self.transmissions.remove(index))
{
let duration = transmission.elapsed();
trace!("ACKed frame {transmission} after {duration:?}");
self.state.update_t_rx_ack(Some(duration));
}
}
fn nak_sent_frames(&mut self, nak_num: WrappingU3) -> std::io::Result<()> {
trace!("Handling NAK: {nak_num}");
if let Some(transmission) = self
.transmissions
.iter()
.position(|transmission| transmission.frame_num() == nak_num)
.map(|index| self.transmissions.remove(index))
{
debug!("Retransmitting NAK'ed frame #{}", transmission.frame_num());
self.transmit(transmission)?;
}
Ok(())
}
fn retransmit_timed_out_data(&mut self) -> std::io::Result<()> {
while let Some(transmission) = self
.transmissions
.iter()
.position(|transmission| transmission.is_timed_out(self.state.t_rx_ack()))
.map(|index| self.transmissions.remove(index))
{
debug!(
"Retransmitting timed-out frame #{}",
transmission.frame_num()
);
self.state.update_t_rx_ack(None);
self.transmit(transmission)?;
}
Ok(())
}
fn ack(&mut self) -> std::io::Result<()> {
self.send_ack(Ack::new(self.state.ack_number(), false))
}
fn nak(&mut self) -> std::io::Result<()> {
self.send_nak(Nak::new(self.state.ack_number(), false))
}
fn rst(&mut self) -> std::io::Result<()> {
self.frame_buffer.write_frame(RST)
}
fn transmit(&mut self, mut transmission: Transmission) -> std::io::Result<()> {
let data = transmission.data_for_transmit()?;
trace!("Unmasked {:#04X}", data.unmasked());
self.frame_buffer.write_frame(data)?;
self.transmissions
.insert(0, transmission)
.map_err(|_| Error::new(ErrorKind::OutOfMemory, "Failed to enqueue retransmit"))
}
fn send_chunk(&mut self, chunk: Payload, offset: WrappingU3) -> std::io::Result<()> {
let data = Data::new(
self.state.next_frame_number(),
chunk,
self.state.ack_number() + offset,
);
self.transmit(data.into())
}
fn send_chunks(&mut self) -> std::io::Result<bool> {
let mut offset = WrappingU3::default();
while !self.transmissions.is_full() {
if let Some(chunk) = self.channels.receive()? {
self.send_chunk(chunk, offset)?;
offset += 1;
} else {
return Ok(false);
}
}
Ok(true)
}
fn send_ack(&mut self, ack: Ack) -> std::io::Result<()> {
debug!("Sending ACK: {ack}");
self.frame_buffer.write_frame(ack)
}
fn send_nak(&mut self, nak: Nak) -> std::io::Result<()> {
debug!("Sending NAK: {nak}");
self.frame_buffer.write_frame(nak)
}
fn enter_reject(&mut self) -> std::io::Result<()> {
if self.state.reject() {
Ok(())
} else {
trace!("Entering rejection state.");
self.state.set_reject(true);
self.nak()
}
}
fn handle_data(&mut self, data: Data) -> std::io::Result<()> {
trace!("Unmasked data: {:#04X}", data.unmasked());
if !data.is_crc_valid() {
warn!("Received data frame with invalid CRC.");
self.enter_reject()?;
} else if data.frame_num() == self.state.ack_number() {
self.leave_reject();
self.state.set_last_received_frame_num(data.frame_num());
self.ack()?;
self.ack_sent_frames(data.ack_num());
self.handle_payload(data.into_payload());
} else if data.is_retransmission() {
info!("Received retransmission of frame: {data}");
self.ack()?;
self.ack_sent_frames(data.ack_num());
self.handle_payload(data.into_payload());
} else {
warn!("Received out-of-sequence data frame: {data}");
self.enter_reject()?;
}
Ok(())
}
fn handle_nak(&mut self, nak: &Nak) -> std::io::Result<()> {
if !nak.is_crc_valid() {
warn!("Received ACK with invalid CRC.");
}
self.nak_sent_frames(nak.ack_num())
}
fn handle_ack(&mut self, ack: &Ack) {
if !ack.is_crc_valid() {
warn!("Received ACK with invalid CRC.");
}
self.ack_sent_frames(ack.ack_num());
}
}
impl<T> Transceiver<T>
where
T: Read + Write,
{
pub fn spawn(
serial_port: T,
running: Arc<AtomicBool>,
channel_size: usize,
) -> (
Sender<Payload>,
Receiver<std::io::Result<Payload>>,
JoinHandle<T>,
)
where
T: Send + 'static,
{
let (request_tx, request_rx) = channel(channel_size);
let (response_tx, response_rx) = channel(channel_size);
let transceiver = Self::new(serial_port, request_rx, response_tx);
(request_tx, response_rx, spawn(|| transceiver.run(running)))
}
#[allow(clippy::needless_pass_by_value)]
pub fn run(mut self, running: Arc<AtomicBool>) -> T {
while running.load(Relaxed) {
if let Err(error) = self.main() {
self.handle_io_error(error);
}
}
self.frame_buffer.into_inner()
}
fn main(&mut self) -> std::io::Result<()> {
match self.state.status() {
Status::Disconnected | Status::Failed => Ok(self.connect()?),
Status::Connected => self.communicate(),
}
}
fn communicate(&mut self) -> std::io::Result<()> {
self.send_data()?;
self.handle_callbacks()
}
fn connect(&mut self) -> std::io::Result<()> {
debug!("Connecting to NCP...");
let start = Instant::now();
let mut attempts: usize = 0;
'attempts: loop {
attempts += 1;
self.rst()?;
debug!("Waiting for RSTACK...");
let frame = loop {
if let Some(frame) = self.receive()? {
break frame;
} else if start.elapsed() > T_RSTACK_MAX {
continue 'attempts;
}
};
match frame {
Frame::RstAck(rst_ack) => {
if !rst_ack.is_ash_v2() {
return Err(Error::new(
ErrorKind::Unsupported,
"Received RSTACK is not ASHv2.",
));
}
self.state.set_status(Status::Connected);
info!(
"ASHv2 connection established after {attempts} attempt{}.",
if attempts > 1 { "s" } else { "" }
);
debug!("Establishing connection took {:?}", start.elapsed());
match rst_ack.code() {
Ok(code) => trace!("Received RST_ACK with code: {code}"),
Err(code) => warn!("Received RST_ACK with unknown code: {code}"),
}
return Ok(());
}
other => {
warn!("Expected RSTACK but got: {other}");
}
}
}
}
fn send_data(&mut self) -> std::io::Result<()> {
while self.send_chunks()? {
while self.transmissions.is_full() {
while let Some(frame) = self.receive()? {
self.handle_frame(frame)?;
}
self.retransmit_timed_out_data()?;
}
}
while !self.transmissions.is_empty() {
while let Some(frame) = self.receive()? {
self.handle_frame(frame)?;
}
self.retransmit_timed_out_data()?;
}
Ok(())
}
fn handle_frame(&mut self, frame: Frame) -> std::io::Result<()> {
debug!("Handling: {frame}");
trace!("{frame:#04X}");
if self.state.status() == Status::Connected {
match frame {
Frame::Ack(ref ack) => self.handle_ack(ack),
Frame::Data(data) => self.handle_data(*data)?,
Frame::Error(ref error) => return Err(Self::handle_error(error)),
Frame::Nak(ref nak) => self.handle_nak(nak)?,
Frame::RstAck(ref rst_ack) => return Err(Self::handle_rst_ack(rst_ack)),
Frame::Rst(_) => warn!("Received unexpected RST from NCP."),
}
} else {
warn!("Not connected. Dropping frame: {frame}");
}
Ok(())
}
fn handle_callbacks(&mut self) -> std::io::Result<()> {
while let Some(callback) = self.receive()? {
self.handle_frame(callback)?;
}
Ok(())
}
}