use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::{Mutex, mpsc, oneshot};
use tokio::time::timeout;
use tokio_serial::SerialStream;
use tracing::{debug, error, warn};
use super::ZWaveController;
use super::command_class::CommandClass;
use super::types::{NodeId, ZWaveNode};
use crate::homeauto::error::{HomeAutoError, HomeAutoResult};
use crate::homeauto::types::{HomeAutoEvent, Protocol};
pub const SOF: u8 = 0x01;
pub const ACK: u8 = 0x06;
pub const NAK: u8 = 0x15;
pub const CAN: u8 = 0x18;
pub const FRAME_TYPE_REQ: u8 = 0x00;
pub const FRAME_TYPE_RES: u8 = 0x01;
pub const GET_CAPABILITIES: u8 = 0x07;
pub const SERIAL_API_STARTED: u8 = 0x0A; pub const GET_VERSION: u8 = 0x15;
pub const SEND_DATA: u8 = 0x13;
pub const SEND_DATA_MULTI: u8 = 0x14;
pub const GET_INIT_DATA: u8 = 0x02;
pub const APPLICATION_COMMAND_HANDLER: u8 = 0x04; pub const ADD_NODE_TO_NETWORK: u8 = 0x4A;
pub const REMOVE_NODE_FROM_NETWORK: u8 = 0x4B;
pub const SET_DEFAULT: u8 = 0x42;
pub const GET_NODE_PROTOCOL_INFO: u8 = 0x41;
pub const REQUEST_NODE_INFO: u8 = 0x60;
pub const APPLICATION_SLAVE_COMMAND_HANDLER: u8 = 0xA1;
pub const ADD_NODE_ANY: u8 = 0x01;
pub const ADD_NODE_STOP: u8 = 0x05;
pub const REMOVE_NODE_ANY: u8 = 0x01;
pub const REMOVE_NODE_STOP: u8 = 0x05;
pub const TRANSMIT_OPTION_ACK: u8 = 0x01;
pub const TRANSMIT_OPTION_AUTO_ROUTE: u8 = 0x04;
pub const TRANSMIT_OPTION_EXPLORE: u8 = 0x20;
pub fn checksum(data: &[u8]) -> u8 {
data.iter().fold(0xFF_u8, |acc, &b| acc ^ b)
}
#[derive(Debug, Clone)]
pub struct ZApiFrame {
pub frame_type: u8,
pub cmd_id: u8,
pub data: Vec<u8>,
}
impl ZApiFrame {
pub fn new_request(cmd_id: u8, data: Vec<u8>) -> Self {
Self {
frame_type: FRAME_TYPE_REQ,
cmd_id,
data,
}
}
pub fn encode(&self) -> Vec<u8> {
let len = (2 + self.data.len()) as u8; let mut buf = vec![SOF, len, self.frame_type, self.cmd_id];
buf.extend_from_slice(&self.data);
buf.push(checksum(&buf[1..]));
buf
}
pub fn decode_after_sof(data: &[u8]) -> Result<(Self, usize), &'static str> {
if data.len() < 4 {
return Err("ZAPI frame too short");
}
let len = data[0] as usize;
if len < 2 {
return Err("ZAPI LEN < 2");
}
if data.len() < 1 + len + 1 {
return Err("ZAPI frame incomplete");
}
let frame_type = data[1];
let cmd_id = data[2];
let payload = data[3..len - 1 + 2].to_vec(); let received_cs = data[1 + len];
let computed_cs = checksum(&data[0..=len]);
if received_cs != computed_cs {
return Err("ZAPI checksum mismatch");
}
let total = 1 + len + 1; Ok((
Self {
frame_type,
cmd_id,
data: payload,
},
total,
))
}
}
struct PendingCmd {
tx: oneshot::Sender<ZApiFrame>,
}
struct ZApiInner {
pending: HashMap<u8, PendingCmd>,
writer: Option<tokio::io::WriteHalf<SerialStream>>,
event_tx: mpsc::Sender<HomeAutoEvent>,
nodes: HashMap<NodeId, ZWaveNode>,
callback_id: u8,
}
pub struct ZWaveSerialController {
port_path: String,
baud_rate: u32,
inner: Arc<Mutex<ZApiInner>>,
}
impl ZWaveSerialController {
pub fn new(port: impl Into<String>, baud_rate: u32) -> Self {
let (event_tx, _) = mpsc::channel(64);
Self {
port_path: port.into(),
baud_rate,
inner: Arc::new(Mutex::new(ZApiInner {
pending: HashMap::new(),
writer: None,
event_tx,
nodes: HashMap::new(),
callback_id: 1,
})),
}
}
async fn send_request(&self, cmd_id: u8, data: Vec<u8>) -> HomeAutoResult<ZApiFrame> {
const MAX_RETRIES: u8 = 3;
for attempt in 0..MAX_RETRIES {
let frame = ZApiFrame::new_request(cmd_id, data.clone());
let wire = frame.encode();
let rx = {
let mut inner = self.inner.lock().await;
let (tx, rx) = oneshot::channel();
inner.pending.insert(cmd_id, PendingCmd { tx });
match inner.writer.as_mut() {
Some(w) => w.write_all(&wire).await.map_err(HomeAutoError::Io)?,
None => {
inner.pending.remove(&cmd_id);
return Err(HomeAutoError::ZWaveController(
"controller not started — call start() first".into(),
));
}
}
rx
};
match timeout(Duration::from_secs(2), rx).await {
Ok(Ok(resp)) => return Ok(resp),
Ok(Err(_)) => return Err(HomeAutoError::ChannelClosed),
Err(_) if attempt < MAX_RETRIES - 1 => {
warn!("ZAPI timeout for cmd {cmd_id:#04x}, retry {}", attempt + 1);
}
Err(_) => return Err(HomeAutoError::Timeout),
}
}
Err(HomeAutoError::ZWaveNak {
retries: MAX_RETRIES,
})
}
async fn next_callback_id(&self) -> u8 {
let mut inner = self.inner.lock().await;
let id = inner.callback_id;
inner.callback_id = inner.callback_id.wrapping_add(1).max(1);
id
}
async fn spawn_reader(&self, mut reader: tokio::io::ReadHalf<SerialStream>) {
let inner = Arc::clone(&self.inner);
tokio::spawn(async move {
let mut buf: Vec<u8> = Vec::with_capacity(256);
let mut byte = [0u8; 1];
loop {
match reader.read_exact(&mut byte).await {
Err(e) => {
error!("ZAPI serial read error: {e}");
break;
}
Ok(_) => {
let b = byte[0];
match b {
ACK => debug!("ZAPI ACK received"),
NAK => warn!("ZAPI NAK received"),
CAN => warn!("ZAPI CAN received"),
SOF => {
buf.clear();
buf.push(b);
}
_ => {
buf.push(b);
if buf.len() >= 5 && buf.first() == Some(&SOF) {
let payload = &buf[1..]; if payload.len() >= 1 {
let expected = payload[0] as usize + 2; if payload.len() >= expected {
match ZApiFrame::decode_after_sof(payload) {
Ok((frame, _)) => {
let mut g = inner.lock().await;
if let Some(w) = g.writer.as_mut() {
let _ = w.write_all(&[ACK]).await;
}
match frame.frame_type {
FRAME_TYPE_RES => {
if let Some(p) =
g.pending.remove(&frame.cmd_id)
{
let _ = p.tx.send(frame);
}
}
FRAME_TYPE_REQ => {
Self::dispatch_req(&mut g, &frame);
}
_ => {}
}
}
Err(e) => warn!("ZAPI decode: {e}"),
}
buf.clear();
}
}
}
}
}
}
}
}
});
}
fn dispatch_req(inner: &mut ZApiInner, frame: &ZApiFrame) {
match frame.cmd_id {
APPLICATION_COMMAND_HANDLER => {
if frame.data.len() >= 3 {
let src_node_id = frame.data[1];
let cmd_len = frame.data[2] as usize;
if let Some(cmd_data) = frame.data.get(3..3 + cmd_len) {
debug!(
"ZAPI app command from node {src_node_id}: cc={:#04x} cmd={:#04x}",
cmd_data.first().copied().unwrap_or(0),
cmd_data.get(1).copied().unwrap_or(0)
);
}
}
}
ADD_NODE_TO_NETWORK => {
if frame.data.len() >= 2 {
let status = frame.data[0];
let node_id = frame.data[1];
if status == 0x05 && node_id > 0 {
let node = ZWaveNode::new(node_id);
inner.nodes.insert(node_id, node);
debug!("ZAPI: node {node_id} added to network");
}
}
}
REMOVE_NODE_FROM_NETWORK => {
if frame.data.len() >= 2 {
let status = frame.data[0];
let node_id = frame.data[1];
if status == 0x06 && node_id > 0 {
inner.nodes.remove(&node_id);
let _ = inner.event_tx.try_send(HomeAutoEvent::DeviceLeft {
id: node_id.to_string(),
protocol: Protocol::ZWave,
});
debug!("ZAPI: node {node_id} removed from network");
}
}
}
_ => {}
}
}
}
#[async_trait]
impl ZWaveController for ZWaveSerialController {
async fn start(&self) -> HomeAutoResult<()> {
let port = SerialStream::open(&tokio_serial::new(&self.port_path, self.baud_rate))
.map_err(HomeAutoError::Serial)?;
let (reader, writer) = tokio::io::split(port);
self.inner.lock().await.writer = Some(writer);
self.spawn_reader(reader).await;
let resp = self.send_request(GET_VERSION, vec![]).await?;
debug!(
"ZAPI connected: version={}",
String::from_utf8_lossy(&resp.data)
);
let init = self.send_request(GET_INIT_DATA, vec![]).await?;
if init.data.len() >= 3 {
let node_list_len = init.data[2] as usize;
let node_bytes = init.data.get(3..3 + node_list_len).unwrap_or(&[]);
let mut inner = self.inner.lock().await;
for (byte_idx, &byte) in node_bytes.iter().enumerate() {
for bit in 0..8 {
if byte & (1 << bit) != 0 {
let node_id = (byte_idx * 8 + bit + 1) as NodeId;
inner
.nodes
.entry(node_id)
.or_insert_with(|| ZWaveNode::new(node_id));
}
}
}
debug!("ZAPI: {} nodes discovered", inner.nodes.len());
}
Ok(())
}
async fn stop(&self) -> HomeAutoResult<()> {
self.inner.lock().await.writer = None;
Ok(())
}
async fn include_node(&self, timeout_secs: u8) -> HomeAutoResult<ZWaveNode> {
let cb_id = self.next_callback_id().await;
self.send_request(ADD_NODE_TO_NETWORK, vec![ADD_NODE_ANY, cb_id])
.await?;
tokio::time::sleep(Duration::from_secs(timeout_secs as u64)).await;
self.send_request(ADD_NODE_TO_NETWORK, vec![ADD_NODE_STOP, 0])
.await?;
let inner = self.inner.lock().await;
inner
.nodes
.values()
.max_by_key(|n| n.node_id)
.cloned()
.ok_or_else(|| HomeAutoError::ZWaveController("no node was added".into()))
}
async fn exclude_node(&self, timeout_secs: u8) -> HomeAutoResult<()> {
let cb_id = self.next_callback_id().await;
self.send_request(REMOVE_NODE_FROM_NETWORK, vec![REMOVE_NODE_ANY, cb_id])
.await?;
tokio::time::sleep(Duration::from_secs(timeout_secs as u64)).await;
self.send_request(REMOVE_NODE_FROM_NETWORK, vec![REMOVE_NODE_STOP, 0])
.await?;
Ok(())
}
async fn nodes(&self) -> HomeAutoResult<Vec<ZWaveNode>> {
Ok(self.inner.lock().await.nodes.values().cloned().collect())
}
async fn send_cc(&self, node_id: NodeId, cc: CommandClass, data: &[u8]) -> HomeAutoResult<()> {
let cb_id = self.next_callback_id().await;
let tx_opts = TRANSMIT_OPTION_ACK | TRANSMIT_OPTION_AUTO_ROUTE | TRANSMIT_OPTION_EXPLORE;
let mut payload = vec![node_id, (1 + data.len()) as u8, cc.id()];
payload.extend_from_slice(data);
payload.push(tx_opts);
payload.push(cb_id);
let resp = self.send_request(SEND_DATA, payload).await?;
let send_ok = resp.data.first().copied().unwrap_or(0);
if send_ok == 0 {
return Err(HomeAutoError::ZWaveTransmit {
node_id,
msg: "SEND_DATA rejected".into(),
});
}
Ok(())
}
fn events(&self) -> crate::homeauto::BoxStream<'static, HomeAutoEvent> {
let inner = Arc::clone(&self.inner);
let (new_tx, mut rx) = mpsc::channel::<HomeAutoEvent>(64);
tokio::spawn(async move {
inner.lock().await.event_tx = new_tx;
});
Box::pin(async_stream::stream! {
while let Some(event) = rx.recv().await {
yield event;
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zapi_xor_checksum_known_vector() {
let data: &[u8] = &[0x03, 0x00, 0x07]; let cs = checksum(data);
assert_eq!(cs, 0xFF ^ 0x03 ^ 0x00 ^ 0x07);
}
#[test]
fn zapi_single_byte_ack_nak_can() {
assert_eq!(ACK, 0x06);
assert_eq!(NAK, 0x15);
assert_eq!(CAN, 0x18);
}
#[test]
fn zapi_frame_encode_send_data() {
let frame = ZApiFrame::new_request(GET_VERSION, vec![]);
let encoded = frame.encode();
assert_eq!(encoded[0], SOF);
assert_eq!(encoded[1], 2); assert_eq!(encoded[2], FRAME_TYPE_REQ);
assert_eq!(encoded[3], GET_VERSION);
assert_eq!(encoded[4], checksum(&encoded[1..4]));
}
#[test]
fn zapi_frame_decode_app_command_handler() {
let cmd_id = APPLICATION_COMMAND_HANDLER;
let data: Vec<u8> = vec![0x00, 0x05, 0x03, 0x25, 0x03, 0xFF]; let frame = ZApiFrame {
frame_type: FRAME_TYPE_REQ,
cmd_id,
data: data.clone(),
};
let encoded = frame.encode();
let (decoded, _) = ZApiFrame::decode_after_sof(&encoded[1..]).unwrap();
assert_eq!(decoded.cmd_id, APPLICATION_COMMAND_HANDLER);
assert_eq!(decoded.data, data);
}
#[test]
fn zapi_frame_roundtrip() {
let data = vec![0x05, 0xAA, 0xBB];
let frame = ZApiFrame::new_request(SEND_DATA, data.clone());
let encoded = frame.encode();
let (decoded, consumed) = ZApiFrame::decode_after_sof(&encoded[1..]).unwrap();
assert_eq!(decoded.cmd_id, SEND_DATA);
assert_eq!(decoded.data, data);
assert_eq!(consumed, encoded.len() - 1); }
#[test]
fn zapi_checksum_mismatch() {
let mut encoded = ZApiFrame::new_request(GET_VERSION, vec![]).encode();
let last = encoded.len() - 1;
encoded[last] ^= 0xFF;
assert!(ZApiFrame::decode_after_sof(&encoded[1..]).is_err());
}
}