use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::pin::Pin;
use heapless::Vec as HVec;
use tokio::sync::{mpsc, watch};
use mbus_core::{
data_unit::common::MAX_ADU_FRAME_LEN, errors::MbusError, transport::AsyncTransport,
};
use crate::client::command::{ClientRequest, ResponseSender, TaskCommand};
use crate::client::decode::decode_response;
use crate::client::encode::encode_request;
#[cfg(feature = "traffic")]
use crate::client::notifier::NotifierStore;
pub(crate) type ConnectFactory<T> =
Box<dyn Fn() -> Pin<Box<dyn Future<Output = Result<T, MbusError>> + Send>> + Send + 'static>;
pub(crate) type PendingCountReceiver = watch::Receiver<usize>;
struct PendingEntry {
resp_tx: ResponseSender,
request: ClientRequest,
}
pub(crate) struct ClientTask<T, const N: usize = 9>
where
T: AsyncTransport + Send + 'static,
{
transport: Option<T>,
connect_fn: ConnectFactory<T>,
cmd_rx: mpsc::Receiver<TaskCommand>,
pending: HashMap<u16, PendingEntry>,
queued: VecDeque<TaskCommand>,
next_txn_id: u16,
in_flight: usize,
pending_count_tx: watch::Sender<usize>,
#[cfg(feature = "traffic")]
notifier: NotifierStore,
}
impl<T: AsyncTransport + Send + 'static, const N: usize> ClientTask<T, N> {
pub(crate) fn new(
connect_fn: ConnectFactory<T>,
cmd_rx: mpsc::Receiver<TaskCommand>,
pending_count_tx: watch::Sender<usize>,
#[cfg(feature = "traffic")] notifier: NotifierStore,
) -> Self {
Self {
transport: None,
connect_fn,
cmd_rx,
pending: HashMap::new(),
queued: VecDeque::new(),
next_txn_id: 1,
in_flight: 0,
pending_count_tx,
#[cfg(feature = "traffic")]
notifier,
}
}
fn advance_txn_id(&mut self) -> u16 {
let id = self.next_txn_id;
self.next_txn_id = match self.next_txn_id.wrapping_add(1) {
0 => 1,
n => n,
};
id
}
fn update_pending_count(&self) {
let _ = self.pending_count_tx.send(self.pending.len());
}
async fn do_connect(&mut self) -> Result<(), MbusError> {
self.transport = None;
let transport = (self.connect_fn)().await?;
self.transport = Some(transport);
Ok(())
}
async fn dispatch_request(&mut self, cmd: TaskCommand) {
let (params, resp_tx) = match cmd {
TaskCommand::Request { params, resp_tx } => (params, resp_tx),
TaskCommand::Connect { resp_tx } => {
let _ = resp_tx.send(Err(MbusError::Unexpected));
return;
}
TaskCommand::Disconnect => return,
};
if self.transport.is_none() {
let _ = resp_tx.send(Err(MbusError::ConnectionClosed));
return;
}
let ttype = T::TRANSPORT_TYPE;
let txn_id = self.advance_txn_id();
let frame = match encode_request(txn_id, ¶ms, ttype) {
Ok(f) => f,
Err(e) => {
let _ = resp_tx.send(Err(e));
return;
}
};
#[cfg(feature = "traffic")]
let unit = params.unit();
let send_result = match self.transport.as_mut() {
Some(t) => t.send(&frame).await,
None => Err(MbusError::ConnectionClosed),
};
match send_result {
Ok(()) => {
#[cfg(feature = "traffic")]
self.fire_tx_frame(txn_id, unit, &frame);
self.pending.insert(
txn_id,
PendingEntry {
resp_tx,
request: params,
},
);
self.in_flight += 1;
self.update_pending_count();
}
Err(e) => {
#[cfg(feature = "traffic")]
self.fire_tx_error(txn_id, unit, e);
let _ = resp_tx.send(Err(e));
}
}
}
fn process_frame(&mut self, frame: &HVec<u8, MAX_ADU_FRAME_LEN>) {
if self.transport.is_none() {
return;
}
let ttype = T::TRANSPORT_TYPE;
let (decoded_txn_id, _unit, inner) = match decode_response(frame, ttype) {
Ok(v) => v,
Err(e) => {
self.fail_entry(0, e);
return;
}
};
let key = self.resolve_key(decoded_txn_id);
if let Some(k) = key
&& let Some(entry) = self.pending.remove(&k)
{
self.in_flight = self.in_flight.saturating_sub(1);
self.update_pending_count();
#[cfg(feature = "traffic")]
self.fire_rx_frame(decoded_txn_id, entry.request.unit(), frame);
let result = inner.map(|response| fix_up_response(response, &entry.request));
let _ = entry.resp_tx.send(result);
}
}
fn resolve_key(&self, decoded_txn_id: u16) -> Option<u16> {
if decoded_txn_id != 0 {
self.pending
.contains_key(&decoded_txn_id)
.then_some(decoded_txn_id)
} else {
self.pending.keys().next().copied()
}
}
fn fail_entry(&mut self, raw_txn_id: u16, error: MbusError) {
if let Some(k) = self.resolve_key(raw_txn_id)
&& let Some(entry) = self.pending.remove(&k)
{
self.in_flight = self.in_flight.saturating_sub(1);
self.update_pending_count();
let _ = entry.resp_tx.send(Err(error));
}
}
pub(crate) fn drain_all(&mut self) {
for (_, entry) in self.pending.drain() {
let _ = entry.resp_tx.send(Err(MbusError::ConnectionClosed));
}
for cmd in self.queued.drain(..) {
if let TaskCommand::Request { resp_tx, .. } = cmd {
let _ = resp_tx.send(Err(MbusError::ConnectionClosed));
}
}
self.in_flight = 0;
self.update_pending_count();
}
async fn handle_command(&mut self, cmd: TaskCommand) {
match cmd {
TaskCommand::Connect { resp_tx } => {
let result = self.do_connect().await;
let _ = resp_tx.send(result);
}
TaskCommand::Disconnect => {
self.drain_all();
self.transport = None;
}
req_cmd => {
if self.in_flight < N {
self.dispatch_request(req_cmd).await;
} else {
self.queued.push_back(req_cmd);
}
}
}
}
#[cfg(feature = "traffic")]
fn fire_tx_frame(
&self,
txn_id: u16,
unit: mbus_core::transport::UnitIdOrSlaveAddr,
frame: &[u8],
) {
if let Ok(mut g) = self.notifier.try_lock()
&& let Some(n) = g.as_mut()
{
n.on_tx_frame(txn_id, unit, frame);
}
}
#[cfg(feature = "traffic")]
fn fire_tx_error(
&self,
txn_id: u16,
unit: mbus_core::transport::UnitIdOrSlaveAddr,
err: MbusError,
) {
if let Ok(mut g) = self.notifier.try_lock()
&& let Some(n) = g.as_mut()
{
n.on_tx_error(txn_id, unit, err, &[]);
}
}
#[cfg(feature = "traffic")]
fn fire_rx_frame(
&self,
txn_id: u16,
unit: mbus_core::transport::UnitIdOrSlaveAddr,
frame: &[u8],
) {
if let Ok(mut g) = self.notifier.try_lock()
&& let Some(n) = g.as_mut()
{
n.on_rx_frame(txn_id, unit, frame);
}
}
pub(crate) async fn run(mut self) {
loop {
while self.in_flight < N {
match self.queued.pop_front() {
Some(cmd) => self.dispatch_request(cmd).await,
None => break,
}
}
tokio::select! {
recv_result = recv_if_active(&mut self.transport, self.in_flight) => {
match recv_result {
Ok(frame) => self.process_frame(&frame),
Err(_) => {
self.transport = None;
self.drain_all();
}
}
}
maybe_cmd = self.cmd_rx.recv() => {
match maybe_cmd {
None => return,
Some(cmd) => self.handle_command(cmd).await,
}
}
}
}
}
}
async fn recv_if_active<T: AsyncTransport>(
transport: &mut Option<T>,
in_flight: usize,
) -> Result<HVec<u8, MAX_ADU_FRAME_LEN>, MbusError> {
match transport.as_mut() {
Some(t) if in_flight > 0 => t.recv().await,
_ => std::future::pending().await,
}
}
fn fix_up_response(
response: crate::client::response::ClientResponse,
original: &crate::client::command::ClientRequest,
) -> crate::client::response::ClientResponse {
use crate::client::command::ClientRequest as Q;
use crate::client::response::ClientResponse as R;
match (response, original) {
#[cfg(feature = "coils")]
(
R::Coils(raw),
Q::ReadMultipleCoils {
address, quantity, ..
},
) => {
use mbus_core::models::coil::Coils;
Coils::new(*address, *quantity)
.and_then(|c| c.with_values(raw.values(), *quantity))
.map(R::Coils)
.unwrap_or_else(|_| R::Coils(raw))
}
#[cfg(feature = "discrete-inputs")]
(
R::DiscreteInputs(raw),
Q::ReadDiscreteInputs {
address, quantity, ..
},
) => {
use mbus_core::models::discrete_input::DiscreteInputs;
DiscreteInputs::new(*address, *quantity)
.and_then(|d| d.with_values(raw.values(), *quantity))
.map(R::DiscreteInputs)
.unwrap_or_else(|_| R::DiscreteInputs(raw))
}
#[cfg(feature = "holding-registers")]
(
R::HoldingRegisters(raw),
Q::ReadHoldingRegisters {
address, quantity, ..
},
) => {
use mbus_core::models::register::HoldingRegisters;
HoldingRegisters::new(*address, *quantity)
.and_then(|r| r.with_values(&raw.values()[..*quantity as usize], *quantity))
.map(R::HoldingRegisters)
.unwrap_or_else(|_| R::HoldingRegisters(raw))
}
#[cfg(feature = "input-registers")]
(
R::InputRegisters(raw),
Q::ReadInputRegisters {
address, quantity, ..
},
) => {
use mbus_core::models::register::InputRegisters;
InputRegisters::new(*address, *quantity)
.and_then(|r| r.with_values(&raw.values()[..*quantity as usize], *quantity))
.map(R::InputRegisters)
.unwrap_or_else(|_| R::InputRegisters(raw))
}
#[cfg(feature = "holding-registers")]
(
R::HoldingRegisters(raw),
Q::ReadWriteMultipleRegisters {
read_address,
read_quantity,
..
},
) => {
use mbus_core::models::register::HoldingRegisters;
HoldingRegisters::new(*read_address, *read_quantity)
.and_then(|r| {
r.with_values(&raw.values()[..*read_quantity as usize], *read_quantity)
})
.map(R::HoldingRegisters)
.unwrap_or_else(|_| R::HoldingRegisters(raw))
}
#[cfg(feature = "fifo")]
(R::FifoQueue(raw), Q::ReadFifoQueue { address, .. }) => {
use mbus_core::models::fifo_queue::{FifoQueue, MAX_FIFO_QUEUE_COUNT_PER_PDU};
let length = raw.length();
let mut arr = [0u16; MAX_FIFO_QUEUE_COUNT_PER_PDU];
arr[..length].copy_from_slice(raw.queue());
R::FifoQueue(FifoQueue::new(*address).with_values(arr, length))
}
(r, _) => r,
}
}