use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, error, info, warn};
use super::data_model::{
BasicInformationCluster, DataModelNode, GeneralCommissioningCluster,
NetworkCommissioningCluster, OperationalCredentialsCluster,
};
use super::discovery::CommissionableAdvertiser;
use super::fabric::FabricManager;
use super::interaction_model::read::ReportData;
use super::interaction_model::{ImOpcode, InteractionStatus, InvokeResponse, InvokeResponseItem};
use super::secure_channel::{
CaseResponder, EstablishedSession, PaseCommissionee, SECURE_CHANNEL_PROTOCOL_ID,
SecureChannelOpcode,
};
use super::transport::message::{MatterMessage, MessageHeader, SessionType};
use super::transport::{SessionKeys, UdpTransport};
use super::types::MatterDeviceConfig;
use crate::homeauto::error::{HomeAutoError, HomeAutoResult};
pub type OnOffHandler = Arc<dyn Fn(bool) + Send + Sync>;
pub type LevelHandler = Arc<dyn Fn(u8) + Send + Sync>;
pub type ColorTempHandler = Arc<dyn Fn(u16) + Send + Sync>;
pub type ThermostatHandler = Arc<dyn Fn(f32) + Send + Sync>;
const _MATTER_PORT: u16 = 5540;
const _MATTER_MDNS_SERVICE_TYPE: &str = "_matter._tcp";
const IM_PROTOCOL_ID: u16 = 0x0001;
const CLUSTER_ON_OFF: u32 = 0x0006;
const CLUSTER_LEVEL_CONTROL: u32 = 0x0008;
const CLUSTER_COLOR_CONTROL: u32 = 0x0300;
const CLUSTER_THERMOSTAT: u32 = 0x0201;
const CMD_OFF: u32 = 0x00;
const CMD_ON: u32 = 0x01;
const CMD_TOGGLE: u32 = 0x02;
const CMD_MOVE_TO_LEVEL: u32 = 0x00;
const CMD_MOVE_TO_LEVEL_WITH_ON_OFF: u32 = 0x04;
const CMD_SETPOINT_RAISE_LOWER: u32 = 0x00;
struct ServerInner {
on_off: Option<OnOffHandler>,
level: Option<LevelHandler>,
color_temp: Option<ColorTempHandler>,
thermostat: Option<ThermostatHandler>,
running: bool,
_commissioned: bool,
}
pub struct MatterDeviceServer {
config: MatterDeviceConfig,
inner: Arc<Mutex<ServerInner>>,
qr_code: String,
pairing_code: String,
}
impl MatterDeviceServer {
pub async fn new(config: MatterDeviceConfig) -> HomeAutoResult<Self> {
let qr_code = generate_qr_code_string(&config);
let pairing_code = generate_pairing_code(&config);
Ok(Self {
config,
inner: Arc::new(Mutex::new(ServerInner {
on_off: None,
level: None,
color_temp: None,
thermostat: None,
running: false,
_commissioned: false,
})),
qr_code,
pairing_code,
})
}
pub async fn start(&self) -> HomeAutoResult<()> {
{
let mut inner = self.inner.lock().await;
if inner.running {
return Err(HomeAutoError::Matter("server already running".into()));
}
inner.running = true;
}
info!(
"Matter device '{}' starting on UDP port {}",
self.config.device_name, self.config.port
);
info!("QR code: {}", self.qr_code);
info!("Manual pairing code: {}", self.pairing_code);
info!("Discriminator: {}", self.config.discriminator);
let _mdns_handle = self.start_mdns_advertisement()?;
let _fabric_manager = FabricManager::load(&self.config.storage_path)
.await
.map_err(|e| HomeAutoError::Matter(format!("FabricManager load: {e}")))?;
let transport = Arc::new(
UdpTransport::new(self.config.port)
.await
.map_err(|e| HomeAutoError::Matter(format!("UDP bind: {e}")))?,
);
info!("Matter UDP transport bound on port {}", self.config.port);
let mut data_model = DataModelNode::new();
data_model.add_cluster(0, Box::new(BasicInformationCluster::new(&self.config)));
data_model.add_cluster(0, Box::new(GeneralCommissioningCluster::new()));
data_model.add_cluster(0, Box::new(OperationalCredentialsCluster::new()));
data_model.add_cluster(0, Box::new(NetworkCommissioningCluster::new()));
let data_model = Arc::new(data_model);
let pase_state: Arc<Mutex<Option<PaseCommissionee>>> = Arc::new(Mutex::new(None));
let case_sessions: Arc<Mutex<HashMap<u16, CaseResponder>>> =
Arc::new(Mutex::new(HashMap::new()));
let established: Arc<Mutex<HashMap<u16, EstablishedSession>>> =
Arc::new(Mutex::new(HashMap::new()));
let passcode = self.config.passcode;
loop {
let running = self.inner.lock().await.running;
if !running {
break;
}
match tokio::time::timeout(std::time::Duration::from_millis(100), transport.recv())
.await
{
Ok(Ok((msg, peer))) => {
let session_id = msg.header.session_id;
debug!(
"Matter UDP from {peer}: session={session_id} payload_len={}",
msg.payload.len()
);
if session_id == 0 {
handle_commissioning_message(
msg,
peer,
&transport,
&pase_state,
&established,
passcode,
)
.await;
} else {
handle_operational_message(
msg,
peer,
&transport,
&case_sessions,
&established,
&data_model,
&self.inner,
)
.await;
}
}
Ok(Err(e)) => {
error!("Matter UDP recv error: {e}");
break;
}
Err(_) => {} }
}
drop(_mdns_handle);
self.inner.lock().await.running = false;
Ok(())
}
pub async fn stop(&self) -> HomeAutoResult<()> {
self.inner.lock().await.running = false;
Ok(())
}
pub fn set_on_off_handler(&self, f: impl Fn(bool) + Send + Sync + 'static) {
let inner = self.inner.clone();
tokio::spawn(async move {
inner.lock().await.on_off = Some(Arc::new(f));
});
}
pub fn set_level_handler(&self, f: impl Fn(u8) + Send + Sync + 'static) {
let inner = self.inner.clone();
tokio::spawn(async move {
inner.lock().await.level = Some(Arc::new(f));
});
}
pub fn set_color_temp_handler(&self, f: impl Fn(u16) + Send + Sync + 'static) {
let inner = self.inner.clone();
tokio::spawn(async move {
inner.lock().await.color_temp = Some(Arc::new(f));
});
}
pub fn set_thermostat_handler(&self, f: impl Fn(f32) + Send + Sync + 'static) {
let inner = self.inner.clone();
tokio::spawn(async move {
inner.lock().await.thermostat = Some(Arc::new(f));
});
}
pub fn qr_code(&self) -> &str {
&self.qr_code
}
pub fn pairing_code(&self) -> &str {
&self.pairing_code
}
fn start_mdns_advertisement(&self) -> HomeAutoResult<Option<CommissionableAdvertiser>> {
CommissionableAdvertiser::start(&self.config)
.map_err(|e| HomeAutoError::Matter(e.to_string()))
.map(Some)
}
}
pub fn parse_payload_header(payload: &[u8]) -> Option<(u8, u8, u16, u16, &[u8])> {
if payload.len() < 6 {
return None;
}
let exchange_flags = payload[0];
let opcode = payload[1];
let exchange_id = u16::from_le_bytes([payload[2], payload[3]]);
let protocol_id = u16::from_le_bytes([payload[4], payload[5]]);
const EXCHANGE_FLAG_ACK: u8 = 0x02;
let base = if exchange_flags & EXCHANGE_FLAG_ACK != 0 {
10
} else {
6
};
if payload.len() < base {
return None;
}
Some((
exchange_flags,
opcode,
exchange_id,
protocol_id,
&payload[base..],
))
}
pub fn build_payload(
opcode: u8,
exchange_id: u16,
protocol_id: u16,
app_payload: &[u8],
) -> Vec<u8> {
let mut out = Vec::with_capacity(6 + app_payload.len());
out.push(0x00); out.push(opcode);
out.extend_from_slice(&exchange_id.to_le_bytes());
out.extend_from_slice(&protocol_id.to_le_bytes());
out.extend_from_slice(app_payload);
out
}
fn make_response_message(session_id: u16, message_counter: u32, payload: Vec<u8>) -> MatterMessage {
MatterMessage {
header: MessageHeader {
version: 0,
session_id,
session_type: SessionType::Unicast,
source_node_id: None,
dest_node_id: None,
message_counter,
security_flags: 0x00,
},
payload,
}
}
async fn handle_commissioning_message(
msg: MatterMessage,
peer: SocketAddr,
transport: &Arc<UdpTransport>,
pase_state: &Arc<Mutex<Option<PaseCommissionee>>>,
established: &Arc<Mutex<HashMap<u16, EstablishedSession>>>,
passcode: u32,
) {
let counter = msg.header.message_counter;
let (exchange_flags, opcode, exchange_id, protocol_id, app_payload) =
match parse_payload_header(&msg.payload) {
Some(v) => v,
None => {
warn!("PASE: malformed payload header from {peer}");
return;
}
};
debug!(
"PASE proto={protocol_id:#06x} opcode={opcode:#04x} exch={exchange_id} flags={exchange_flags:#04x} from {peer}"
);
if protocol_id != SECURE_CHANNEL_PROTOCOL_ID {
debug!("PASE: ignoring non-SecureChannel protocol {protocol_id:#06x}");
return;
}
match opcode {
x if x == SecureChannelOpcode::PbkdfParamRequest as u8 => {
let mut commissionee = PaseCommissionee::new(passcode);
let resp_payload = match commissionee.handle_param_request(app_payload) {
Ok(p) => p,
Err(e) => {
error!("PASE: PBKDFParamRequest error: {e}");
return;
}
};
*pase_state.lock().await = Some(commissionee);
let wire_payload = build_payload(
SecureChannelOpcode::PbkdfParamResponse as u8,
exchange_id,
SECURE_CHANNEL_PROTOCOL_ID,
&resp_payload,
);
let resp = make_response_message(0, counter.wrapping_add(1), wire_payload);
if let Err(e) = transport.send(&resp, peer).await {
error!("PASE: send PBKDFParamResponse error: {e}");
} else {
debug!("PASE: sent PBKDFParamResponse to {peer}");
}
}
x if x == SecureChannelOpcode::Pake1 as u8 => {
let mut guard = pase_state.lock().await;
let commissionee = match guard.as_mut() {
Some(c) => c,
None => {
warn!("PASE: received Pake1 but no PASE state, ignoring");
return;
}
};
let pake2_payload = match commissionee.handle_pake1(app_payload) {
Ok(p) => p,
Err(e) => {
error!("PASE: Pake1 error: {e}");
*guard = None;
return;
}
};
let wire_payload = build_payload(
SecureChannelOpcode::Pake2 as u8,
exchange_id,
SECURE_CHANNEL_PROTOCOL_ID,
&pake2_payload,
);
let resp = make_response_message(0, counter.wrapping_add(1), wire_payload);
if let Err(e) = transport.send(&resp, peer).await {
error!("PASE: send Pake2 error: {e}");
} else {
debug!("PASE: sent Pake2 to {peer}");
}
}
x if x == SecureChannelOpcode::Pake3 as u8 => {
let mut guard = pase_state.lock().await;
let commissionee = match guard.take() {
Some(mut c) => match c.handle_pake3(app_payload) {
Ok(session) => session,
Err(e) => {
error!("PASE: Pake3 error: {e}");
return;
}
},
None => {
warn!("PASE: received Pake3 but no PASE state, ignoring");
return;
}
};
drop(guard);
let session_id = commissionee.session_id;
info!("PASE: session {session_id} established with {peer}");
let keys = SessionKeys {
encrypt_key: commissionee.encrypt_key,
decrypt_key: commissionee.decrypt_key,
};
transport.sessions.lock().await.insert(session_id, keys);
established.lock().await.insert(session_id, commissionee);
let status_tlv = build_status_report_success();
let wire_payload = build_payload(
SecureChannelOpcode::StatusReport as u8,
exchange_id,
SECURE_CHANNEL_PROTOCOL_ID,
&status_tlv,
);
let resp = make_response_message(session_id, counter.wrapping_add(1), wire_payload);
if let Err(e) = transport.send(&resp, peer).await {
error!("PASE: send StatusReport error: {e}");
} else {
debug!("PASE: sent StatusReport success to {peer}");
}
}
other => {
debug!("PASE: unhandled SecureChannel opcode {other:#04x} from {peer}");
}
}
}
fn build_status_report_success() -> Vec<u8> {
let mut v = Vec::with_capacity(8);
v.extend_from_slice(&0u16.to_le_bytes()); v.extend_from_slice(&0u32.to_le_bytes()); v.extend_from_slice(&0u16.to_le_bytes()); v
}
async fn handle_operational_message(
msg: MatterMessage,
peer: SocketAddr,
transport: &Arc<UdpTransport>,
_case_sessions: &Arc<Mutex<HashMap<u16, CaseResponder>>>,
_established: &Arc<Mutex<HashMap<u16, EstablishedSession>>>,
data_model: &Arc<DataModelNode>,
inner: &Arc<Mutex<ServerInner>>,
) {
let session_id = msg.header.session_id;
let counter = msg.header.message_counter;
let (exchange_flags, opcode, exchange_id, protocol_id, app_payload) =
match parse_payload_header(&msg.payload) {
Some(v) => v,
None => {
warn!("OP: malformed payload header session={session_id} from {peer}");
return;
}
};
debug!(
"OP proto={protocol_id:#06x} opcode={opcode:#04x} exch={exchange_id} \
flags={exchange_flags:#04x} session={session_id} from {peer}"
);
if protocol_id != IM_PROTOCOL_ID {
debug!("OP: ignoring non-IM protocol {protocol_id:#06x}");
return;
}
match opcode {
x if x == ImOpcode::InvokeRequest as u8 => {
use super::interaction_model::InvokeRequest;
let req = match InvokeRequest::decode(app_payload) {
Ok(r) => r,
Err(e) => {
error!("OP: InvokeRequest decode error: {e}");
return;
}
};
let mut resp_items = Vec::new();
for (cmd_path, args) in &req.invoke_requests {
let ep = cmd_path.endpoint_id;
let cluster = cmd_path.cluster_id;
let cmd = cmd_path.command_id;
debug!(
"OP invoke: ep={ep} cluster={cluster:#010x} cmd={cmd:#010x} args_len={}",
args.len()
);
dispatch_handler_callbacks(cluster, cmd, args, inner).await;
let result = data_model.dispatch_invoke(ep, cluster, cmd, args).await;
let item = match result {
Ok(response_data) => InvokeResponseItem::Command {
path: cmd_path.clone(),
data: response_data,
},
Err(e) => {
warn!("OP invoke error ep={ep} cluster={cluster:#010x}: {e}");
InvokeResponseItem::Status {
path: cmd_path.clone(),
status: InteractionStatus::Failure,
}
}
};
resp_items.push(item);
}
if !req.suppress_response {
let invoke_resp = InvokeResponse {
suppress_response: false,
invoke_responses: resp_items,
};
let wire_payload = build_payload(
ImOpcode::InvokeResponse as u8,
exchange_id,
IM_PROTOCOL_ID,
&invoke_resp.encode(),
);
let resp = make_response_message(session_id, counter.wrapping_add(1), wire_payload);
if let Err(e) = transport.send(&resp, peer).await {
error!("OP: send InvokeResponse error: {e}");
}
}
}
x if x == ImOpcode::ReadRequest as u8 => {
use super::interaction_model::ReadRequest;
let req = match ReadRequest::decode(app_payload) {
Ok(r) => r,
Err(e) => {
error!("OP: ReadRequest decode error: {e}");
return;
}
};
let mut all_attrs = Vec::new();
for path in &req.attribute_requests {
let mut attrs = data_model.dispatch_read(path).await;
all_attrs.append(&mut attrs);
}
let report = ReportData {
subscription_id: None,
attribute_reports: all_attrs,
suppress_response: false,
};
let wire_payload = build_payload(
ImOpcode::ReportData as u8,
exchange_id,
IM_PROTOCOL_ID,
&report.encode(),
);
let resp = make_response_message(session_id, counter.wrapping_add(1), wire_payload);
if let Err(e) = transport.send(&resp, peer).await {
error!("OP: send ReportData error: {e}");
}
}
other => {
debug!("OP: unhandled IM opcode {other:#04x} from {peer}");
}
}
}
async fn dispatch_handler_callbacks(
cluster: u32,
cmd: u32,
args: &[u8],
inner: &Arc<Mutex<ServerInner>>,
) {
match cluster {
CLUSTER_ON_OFF => {
let handler = inner.lock().await.on_off.clone();
if let Some(h) = handler {
match cmd {
CMD_OFF => h(false),
CMD_ON => h(true),
CMD_TOGGLE => {
h(true);
}
_ => {}
}
}
}
CLUSTER_LEVEL_CONTROL => {
let handler = inner.lock().await.level.clone();
if let Some(h) = handler
&& (cmd == CMD_MOVE_TO_LEVEL || cmd == CMD_MOVE_TO_LEVEL_WITH_ON_OFF)
{
let level = decode_first_uint8(args).unwrap_or(0);
h(level);
}
}
CLUSTER_COLOR_CONTROL => {
let handler = inner.lock().await.color_temp.clone();
if let Some(h) = handler {
if cmd == 0x0A {
let mireds = decode_first_uint16(args).unwrap_or(0);
h(mireds);
}
}
}
CLUSTER_THERMOSTAT => {
let handler = inner.lock().await.thermostat.clone();
if let Some(h) = handler
&& cmd == CMD_SETPOINT_RAISE_LOWER
{
let amount = decode_signed_int8_tag1(args).unwrap_or(0);
h(amount as f32 * 0.1);
}
}
_ => {}
}
}
fn decode_first_uint8(data: &[u8]) -> Option<u8> {
let data = if data.first() == Some(&0x15) {
&data[1..]
} else {
data
};
let mut i = 0;
while i + 2 < data.len() {
let ctrl = data[i];
let tag = data[i + 1];
let val_type = ctrl & 0x1F;
if (ctrl & 0xE0) == 0x20 && val_type == 0x04 && tag == 0 {
return Some(data[i + 2]);
}
i += skip_tlv_element(data, i);
}
None
}
fn decode_first_uint16(data: &[u8]) -> Option<u16> {
let data = if data.first() == Some(&0x15) {
&data[1..]
} else {
data
};
let mut i = 0;
while i + 3 < data.len() {
let ctrl = data[i];
let tag = data[i + 1];
let val_type = ctrl & 0x1F;
if (ctrl & 0xE0) == 0x20 && val_type == 0x05 && tag == 0 {
return Some(u16::from_le_bytes([data[i + 2], data[i + 3]]));
}
i += skip_tlv_element(data, i);
}
None
}
fn decode_signed_int8_tag1(data: &[u8]) -> Option<i8> {
let data = if data.first() == Some(&0x15) {
&data[1..]
} else {
data
};
let mut i = 0;
while i + 2 < data.len() {
let ctrl = data[i];
let tag = data[i + 1];
let val_type = ctrl & 0x1F;
if (ctrl & 0xE0) == 0x20 && val_type == 0x00 && tag == 1 {
return Some(data[i + 2] as i8);
}
i += skip_tlv_element(data, i);
}
None
}
fn skip_tlv_element(data: &[u8], pos: usize) -> usize {
if pos >= data.len() {
return 1;
}
let ctrl = data[pos];
let tag_type = (ctrl >> 5) & 0x07;
let val_type = ctrl & 0x1F;
let tag_bytes = match tag_type {
0 => 0,
1 => 1,
_ => return 1,
};
let header = 1 + tag_bytes;
let val_bytes = match val_type {
0x00 | 0x04 => 1, 0x01 | 0x05 => 2, 0x02 | 0x06 => 4, 0x03 | 0x07 => 8, 0x08 | 0x09 => 0, 0x10 => {
let len_pos = pos + header;
if len_pos >= data.len() {
return 1;
}
data[len_pos] as usize + 1
}
0x18 => 0, _ => return 1,
};
header + val_bytes
}
fn generate_qr_code_string(config: &MatterDeviceConfig) -> String {
let mut bits: u128 = 0;
let mut pos = 0usize;
let push = |bits: &mut u128, pos: &mut usize, val: u64, count: usize| {
*bits |= (val as u128 & ((1u128 << count) - 1)) << *pos;
*pos += count;
};
push(&mut bits, &mut pos, 0, 3); push(&mut bits, &mut pos, config.vendor_id as u64, 16);
push(&mut bits, &mut pos, config.product_id as u64, 16);
push(&mut bits, &mut pos, 0, 2); push(&mut bits, &mut pos, 0x10, 8); push(&mut bits, &mut pos, config.discriminator as u64, 12);
push(&mut bits, &mut pos, config.passcode as u64, 27);
push(&mut bits, &mut pos, 0, 4);
let mut payload = [0u8; 11];
for (i, b) in payload.iter_mut().enumerate() {
*b = ((bits >> (i * 8)) & 0xFF) as u8;
}
let encoded = base38_encode(&payload);
format!("MT:{encoded}")
}
const BASE38_CHARS: &[u8; 38] = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ-.";
fn base38_encode(data: &[u8]) -> String {
let mut out = String::new();
let mut i = 0;
while i + 1 < data.len() {
let v = data[i] as u32 | ((data[i + 1] as u32) << 8);
let c0 = (v % 38) as usize;
let c1 = ((v / 38) % 38) as usize;
let c2 = ((v / (38 * 38)) % 38) as usize;
out.push(BASE38_CHARS[c0] as char);
out.push(BASE38_CHARS[c1] as char);
out.push(BASE38_CHARS[c2] as char);
i += 2;
}
if i < data.len() {
let v = data[i] as u32;
out.push(BASE38_CHARS[(v % 38) as usize] as char);
out.push(BASE38_CHARS[(v / 38) as usize] as char);
}
out
}
fn generate_pairing_code(config: &MatterDeviceConfig) -> String {
let disc = config.discriminator as u32;
let pass = config.passcode;
let chunk1 = disc >> 10; let chunk2 = ((disc & 0x3FF) << 14) | (pass >> 14); let chunk3 = pass & 0x3FFF; format!("{chunk1:02}{chunk2:06}{chunk3:04}0")
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> MatterDeviceConfig {
MatterDeviceConfig::builder()
.device_name("Test Device")
.vendor_id(0xFFF1)
.product_id(0x8001)
.discriminator(3840)
.passcode(20202021)
.build()
}
#[test]
fn qr_code_starts_with_mt() {
let config = test_config();
let qr = generate_qr_code_string(&config);
assert!(
qr.starts_with("MT:"),
"QR code must start with MT:, got: {qr}"
);
}
#[test]
fn pairing_code_is_numeric() {
let config = test_config();
let code = generate_pairing_code(&config);
assert!(!code.is_empty(), "pairing code must not be empty");
assert!(
code.chars().all(|c| c.is_ascii_digit()),
"pairing code must be all digits, got: {code}"
);
}
#[test]
fn parse_payload_header_roundtrip() {
let app = vec![0xDE, 0xAD, 0xBE, 0xEF];
let built = build_payload(0x20, 0x1234, 0x0000, &app);
let parsed = parse_payload_header(&built).expect("parse failed");
assert_eq!(parsed.1, 0x20); assert_eq!(parsed.2, 0x1234); assert_eq!(parsed.3, 0x0000); assert_eq!(parsed.4, app.as_slice());
}
#[test]
fn decode_first_uint8_in_struct() {
let data = vec![0x15u8, 0x24, 0x00, 42, 0x18];
assert_eq!(decode_first_uint8(&data), Some(42));
}
}