/* Copyright (C) 2020-2022 by Jacob Alexander
*
* This file is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This file is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this file. If not, see <http://www.gnu.org/licenses/>.
*/
/// Mailbox
/// Handles message passing between devices, modules and api calls
/// Uses a broadcast channel to handle communication
// ----- Modules -----
use crate::api::Endpoint;
use hid_io_protocol::commands::CommandError;
use hid_io_protocol::{HidIoCommandId, HidIoPacketType};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::broadcast;
use tokio_stream::{wrappers::BroadcastStream, StreamExt};
// ----- Types -----
pub const HIDIO_PKT_BUF_DATA_SIZE: usize = 500;
pub type HidIoPacketBuffer = hid_io_protocol::HidIoPacketBuffer<HIDIO_PKT_BUF_DATA_SIZE>;
// ----- Enumerations -----
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Address {
// All/any addressed (used as a broadcast destination, not as a source)
All,
// Capnproto API address, with node uid
ApiCapnp {
uid: u64,
},
// Cancel all subscriptions
CancelAllSubscriptions,
// Cancel subscription
// Used to gracefully end message streams
CancelSubscription {
// Uid of endpoint of the subscription
uid: u64,
// Subscription id
sid: u64,
},
// HidIo address, with node uid
DeviceHidio {
uid: u64,
},
// Generic HID address, with nod uid
DeviceHid {
uid: u64,
},
// Drop subscription
DropSubscription,
// Module address
Module,
}
// ----- Consts -----
/// Number of message slots for the mailbox broadcast channel
/// Must be equal to the largest queue needed for the slowest receiver
const CHANNEL_SLOTS: usize = 100;
// ----- Structs -----
/// HID-IO Mailbox
///
/// Handles passing messages to various components inside of HID-IO
/// Best thought of as a broadcast style packet switcher.
/// Each thread (usually async tokio) is given a receiver and can filter for
/// any desired packets.
/// This is not quite as effecient as direct channels; however, this greatly
/// simplifies message passing across HID-IO. Making it easier to add new modules.
///
/// This struct can be safely cloned and passed around anywhere in the codebase.
/// In most cases only the sender field is used (as it has the subscribe() function as well).
#[derive(Clone, Debug)]
pub struct Mailbox {
pub nodes: Arc<RwLock<Vec<Endpoint>>>,
pub last_uid: Arc<RwLock<u64>>,
pub lookup: Arc<RwLock<HashMap<String, Vec<u64>>>>,
pub sender: broadcast::Sender<Message>,
pub ack_timeout: Arc<RwLock<std::time::Duration>>,
pub rt: Arc<tokio::runtime::Runtime>,
}
impl Mailbox {
pub fn new(rt: Arc<tokio::runtime::Runtime>) -> Mailbox {
// Create broadcast channel
let (sender, _) = broadcast::channel::<Message>(CHANNEL_SLOTS);
// Setup nodes list
let nodes = Arc::new(RwLock::new(vec![]));
// Setup nodes lookup table
let lookup = Arc::new(RwLock::new(HashMap::new()));
// Setup last uid assigned (uids are reused when possible for devices)
let last_uid: Arc<RwLock<u64>> = Arc::new(RwLock::new(0));
// Setup default timeout of 2 seconds
let ack_timeout: Arc<RwLock<std::time::Duration>> =
Arc::new(RwLock::new(std::time::Duration::from_millis(2000)));
Mailbox {
nodes,
last_uid,
lookup,
sender,
ack_timeout,
rt,
}
}
/// Attempt to locate an unused id for the device key
pub fn get_uid(&mut self, key: String, path: String) -> Option<u64> {
let mut lookup = self.lookup.write().unwrap();
let lookup_entry = lookup.entry(key).or_default();
// Locate an id
'outer: for uid in lookup_entry.iter() {
#[allow(clippy::significant_drop_in_scrutinee)]
for mut node in (*self.nodes.read().unwrap()).clone() {
if node.uid() == *uid {
// Id is being used, and has the same path (i.e. this device)
if node.path() == path {
// Return an invalid Id (0)
return Some(0);
}
// Id is being used, and is not available
continue 'outer;
}
}
// Id is not being used
return Some(*uid);
}
// Could not locate an id
None
}
/// Add uid to lookup
pub fn add_uid(&mut self, key: String, uid: u64) {
let mut lookup = self.lookup.write().unwrap();
let lookup_entry = lookup.entry(key).or_default();
lookup_entry.push(uid);
}
/// Assign uid
/// This function will attempt to lookup an existing id first
/// And generate a new uid if necessary
/// An error is returned if this lookup already has a uid (string+path)
pub fn assign_uid(&mut self, key: String, path: String) -> Result<u64, std::io::Error> {
match self.get_uid(key.clone(), path) {
Some(0) => Err(std::io::Error::new(
std::io::ErrorKind::Other,
"uid has already been registered!",
)),
Some(uid) => Ok(uid),
None => {
// Get last created id and increment
(*self.last_uid.write().unwrap()) += 1;
let uid = *self.last_uid.read().unwrap();
// Add id to lookup
self.add_uid(key, uid);
Ok(uid)
}
}
}
/// Register node as an endpoint (device or api)
pub fn register_node(&mut self, mut endpoint: Endpoint) {
info!("Registering endpoint: {}", endpoint.uid());
let mut nodes = self.nodes.write().unwrap();
(*nodes).push(endpoint);
}
/// Unregister node as an endpoint (device or api)
pub fn unregister_node(&mut self, uid: u64) {
info!("Unregistering endpoint: {}", uid);
let mut nodes = self.nodes.write().unwrap();
*nodes = nodes
.drain_filter(|dev| dev.uid() != uid)
.collect::<Vec<_>>();
}
/// Convenience function to send a HidIo Command to device using the mailbox
/// Returns the Ack message if enabled.
/// Ack will timeout if it exceeds self.ack_timeout
pub async fn send_command(
&self,
src: Address,
dst: Address,
id: HidIoCommandId,
data: Vec<u8>,
ack: bool,
) -> Result<Option<Message>, AckWaitError> {
// Select packet type
/* TODO Add firmware support for NAData
let ptype = if ack {
HidIoPacketType::Data
} else {
HidIoPacketType::NAData
};
*/
let ptype = HidIoPacketType::Data;
// Construct command packet
let data = HidIoPacketBuffer {
ptype,
id,
max_len: 64, //..Defaults
data: heapless::Vec::from_slice(&data).unwrap(),
done: true,
};
// Check receiver count
if self.sender.receiver_count() == 0 {
error!("send_command (no active receivers)");
return Err(AckWaitError::NoActiveReceivers);
}
// Subscribe to messages before sending message, but this means we have to check the
// receiver count earlier
let receiver = self.sender.subscribe();
// Construct command message and broadcast
let result = self.sender.send(Message {
src,
dst,
data: data.clone(),
});
if let Err(e) = result {
error!(
"send_command failed, something is odd, should not get here... {:?}",
e
);
return Err(AckWaitError::NoActiveReceivers);
}
// No Ack data packet command, no Ack to wait for
if !ack {
return Ok(None);
}
// Construct stream filter
tokio::pin! {
let stream = BroadcastStream::new(receiver)
.filter(Result::is_ok)
.map(Result::unwrap)
.filter(|msg| msg.src == src && msg.dst == Address::All && msg.data.id == id);
}
// Wait on filtered messages
let ack_timeout = *self.ack_timeout.read().unwrap();
loop {
match tokio::time::timeout(ack_timeout, stream.next()).await {
Ok(Some(msg)) => {
match msg.data.ptype {
HidIoPacketType::Ack => {
return Ok(Some(msg));
}
// We may still want the message data from a Nak
HidIoPacketType::Nak => {
let msg = Box::new(msg);
return Err(AckWaitError::NakReceived { msg });
}
_ => {}
}
}
Ok(None) => {
return Err(AckWaitError::Invalid);
}
Err(_) => {
warn!("Timeout ({:?}) receiving Ack for: {}", ack_timeout, data);
return Err(AckWaitError::Timeout);
}
}
}
}
/// Convenience function to send a HidIoPacketBuffer using the mailbox
/// Returns the Ack message if available and applicable
pub fn try_send_message(&self, msg: Message) -> Result<Option<Message>, CommandError> {
// Check receiver count
if self.sender.receiver_count() == 0 {
error!("send_command (no active receivers)");
return Err(CommandError::TxNoActiveReceivers);
}
// Subscribe to messages before sending message, but this means we have to check the
// receiver count earlier
let mut receiver = self.sender.subscribe();
// Construct command message and broadcast
let result = self.sender.send(msg.clone());
if let Err(e) = result {
error!(
"send_command failed, something is odd, should not get here... {:?}",
e
);
return Err(CommandError::TxNoActiveReceivers);
}
// Only wait for a response if this is a Data packet
if msg.data.ptype != HidIoPacketType::Data {
return Ok(None);
}
// Loop until we find the message we want
let start_time = std::time::Instant::now();
loop {
// Check for timeout
if start_time.elapsed() >= *self.ack_timeout.read().unwrap() {
warn!(
"Timeout ({:?}) receiving Ack for command: src:{:?} dst:{:?}",
*self.ack_timeout.read().unwrap(),
msg.src,
msg.dst
);
return Err(CommandError::RxTimeout);
}
// Attempt to receive message
match receiver.try_recv() {
Ok(rcvmsg) => {
// Packet must have the same address as was sent, except reversed
if rcvmsg.dst == Address::All
&& rcvmsg.src == msg.dst
&& rcvmsg.data.id == msg.data.id
{
match rcvmsg.data.ptype {
HidIoPacketType::Ack | HidIoPacketType::Nak => {
return Ok(Some(rcvmsg));
}
_ => {}
}
}
}
Err(broadcast::error::TryRecvError::Empty) => {
// Sleep while queue is empty
std::thread::yield_now();
std::thread::sleep(std::time::Duration::from_millis(1));
}
Err(broadcast::error::TryRecvError::Lagged(_skipped)) => {} // TODO (HaaTa): Should probably warn if lagging
Err(broadcast::error::TryRecvError::Closed) => {
// Channel has closed, this is very bad
return Err(CommandError::TxBufferSendFailed);
}
}
}
}
/// Convenience function to send a HidIo Command to device using the mailbox
/// Returns the Ack message if enabled.
/// This is the blocking version of send_command().
/// Ack will timeout if it exceeds self.ack_timeout
pub fn try_send_command(
&self,
src: Address,
dst: Address,
id: HidIoCommandId,
data: Vec<u8>,
ack: bool,
) -> Result<Option<Message>, AckWaitError> {
// Select packet type
/* TODO Add firmware support for NAData
let ptype = if ack {
HidIoPacketType::Data
} else {
HidIoPacketType::NAData
};
*/
let ptype = HidIoPacketType::Data;
// Construct command packet
let data = HidIoPacketBuffer {
ptype,
id,
max_len: 64, //..Defaults
data: heapless::Vec::from_slice(&data).unwrap(),
done: true,
};
// Check receiver count
if self.sender.receiver_count() == 0 {
error!("send_command (no active receivers)");
return Err(AckWaitError::NoActiveReceivers);
}
// Subscribe to messages before sending message, but this means we have to check the
// receiver count earlier
let mut receiver = self.sender.subscribe();
// Construct command message and broadcast
let result = self.sender.send(Message { src, dst, data });
if let Err(e) = result {
error!(
"send_command failed, something is odd, should not get here... {:?}",
e
);
return Err(AckWaitError::NoActiveReceivers);
}
// No Ack data packet command, no Ack to wait for
if !ack {
return Ok(None);
}
// Loop until we find the message we want
let start_time = std::time::Instant::now();
loop {
// Check for timeout
if start_time.elapsed() >= *self.ack_timeout.read().unwrap() {
warn!(
"Timeout ({:?}) receiving Ack for command: src:{:?} dst:{:?}",
*self.ack_timeout.read().unwrap(),
src,
dst
);
return Err(AckWaitError::Timeout);
}
// Attempt to receive message
match receiver.try_recv() {
Ok(msg) => {
// Packet must have the same address as was sent, except reversed
// The HIDIO device does not keep track of senders, so it will be all
if msg.dst == Address::All && msg.src == dst && msg.data.id == id {
match msg.data.ptype {
HidIoPacketType::Ack => {
return Ok(Some(msg));
}
// We may still want the message data from a Nak
HidIoPacketType::Nak => {
let msg = Box::new(msg);
return Err(AckWaitError::NakReceived { msg });
}
_ => {}
}
}
}
Err(broadcast::error::TryRecvError::Empty) => {
// Sleep while queue is empty
std::thread::yield_now();
std::thread::sleep(std::time::Duration::from_millis(1));
}
Err(broadcast::error::TryRecvError::Lagged(_skipped)) => {} // TODO (HaaTa): Should probably warn if lagging
Err(broadcast::error::TryRecvError::Closed) => {
// Channel has closed, this is very bad
return Err(AckWaitError::ChannelClosed);
}
}
}
}
pub fn drop_subscriber(&self, uid: u64, sid: u64) {
// Construct a dummy message
let data = HidIoPacketBuffer::default();
// Construct command message and broadcast
let result = self.sender.send(Message {
src: Address::DropSubscription,
dst: Address::CancelSubscription { uid, sid },
data,
});
if let Err(e) = result {
error!("drop_subscriber {:?}", e);
}
}
pub fn drop_all_subscribers(&self) {
// Construct a dummy message
let data = HidIoPacketBuffer::default();
// Construct command message and broadcast
let result = self.sender.send(Message {
src: Address::DropSubscription,
dst: Address::CancelAllSubscriptions,
data,
});
if let Err(e) = result {
error!("drop_all_subscribers {:?}", e);
}
}
}
impl Default for Mailbox {
fn default() -> Self {
let rt: Arc<tokio::runtime::Runtime> = Arc::new(
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap(),
);
Self::new(rt)
}
}
/// Container for HidIoPacketBuffer
/// Used to indicate the source and destinations inside of hid-io-core.
/// Also contains a variety of convenience functions using the src and dst information.
#[derive(PartialEq, Clone, Debug)]
#[allow(clippy::derive_partial_eq_without_eq)]
pub struct Message {
pub src: Address,
pub dst: Address,
pub data: HidIoPacketBuffer,
}
impl Message {
pub fn new(src: Address, dst: Address, data: HidIoPacketBuffer) -> Message {
Message { src, dst, data }
}
/// Acknowledgement of a HidIo packet
pub fn send_ack(&self, sender: broadcast::Sender<Message>, data: Vec<u8>) {
let src = self.dst;
let dst = self.src;
// Construct ack packet
let data = HidIoPacketBuffer {
ptype: HidIoPacketType::Ack,
id: self.data.id, // id,
max_len: 64, // Default
data: heapless::Vec::from_slice(&data).unwrap(),
done: true,
};
// Construct ack message and broadcast
let result = sender.send(Message { src, dst, data });
if let Err(e) = result {
error!("send_ack {:?}", e);
}
}
/// Rejection/Nak of a HidIo packet
pub fn send_nak(&self, sender: broadcast::Sender<Message>, data: Vec<u8>) {
let src = self.dst;
let dst = self.src;
// Construct ack packet
let data = HidIoPacketBuffer {
ptype: HidIoPacketType::Nak,
id: self.data.id, // id,
max_len: 64, // Default
data: heapless::Vec::from_slice(&data).unwrap(),
done: true,
};
// Construct ack message and broadcast
let result = sender.send(Message { src, dst, data });
if let Err(e) = result {
error!("send_ack {:?}", e);
}
}
}
#[derive(Debug)]
pub enum AckWaitError {
TooManySyncs,
NakReceived { msg: Box<Message> },
Invalid,
NoActiveReceivers,
Timeout,
ChannelClosed,
}