use rand_core::{OsRng, RngCore};
use super::EstablishedSession;
use crate::homeauto::matter::crypto::{
kdf::{derive_passcode_verifier, hkdf_expand_label},
spake2plus::{Spake2PlusProver, Spake2PlusVerifier},
};
use crate::homeauto::matter::error::{MatterError, MatterResult};
const CTX_TAG: u8 = 1 << 5;
const UINT1_TYPE: u8 = 0x04; const UINT2_TYPE: u8 = 0x05; const UINT4_TYPE: u8 = 0x06; const BYTES1_TYPE: u8 = 0x10; const BOOL_FALSE_TYPE: u8 = 0x08; const STRUCT_TYPE: u8 = 0x15; const END_TYPE: u8 = 0x18;
fn tlv_ctx_uint1(tag: u8, val: u8) -> Vec<u8> {
vec![CTX_TAG | UINT1_TYPE, tag, val]
}
fn tlv_ctx_uint2(tag: u8, val: u16) -> Vec<u8> {
let mut v = vec![CTX_TAG | UINT2_TYPE, tag];
v.extend_from_slice(&val.to_le_bytes());
v
}
fn tlv_ctx_uint4(tag: u8, val: u32) -> Vec<u8> {
let mut v = vec![CTX_TAG | UINT4_TYPE, tag];
v.extend_from_slice(&val.to_le_bytes());
v
}
fn tlv_ctx_bytes(tag: u8, data: &[u8]) -> Vec<u8> {
assert!(data.len() <= 255);
let mut v = vec![CTX_TAG | BYTES1_TYPE, tag, data.len() as u8];
v.extend_from_slice(data);
v
}
fn tlv_ctx_bool_false(tag: u8) -> Vec<u8> {
vec![CTX_TAG | BOOL_FALSE_TYPE, tag]
}
fn tlv_ctx_struct(tag: u8, inner: &[u8]) -> Vec<u8> {
let mut v = vec![CTX_TAG | STRUCT_TYPE, tag];
v.extend_from_slice(inner);
v.push(END_TYPE);
v
}
fn tlv_anon_struct(inner: &[u8]) -> Vec<u8> {
let mut v = vec![STRUCT_TYPE];
v.extend_from_slice(inner);
v.push(END_TYPE);
v
}
struct TlvReader<'a> {
buf: &'a [u8],
pos: usize,
}
impl<'a> TlvReader<'a> {
fn new(buf: &'a [u8]) -> Self {
Self { buf, pos: 0 }
}
fn read_byte(&mut self) -> MatterResult<u8> {
if self.pos >= self.buf.len() {
return Err(MatterError::Protocol {
opcode: 0,
msg: "TLV: unexpected end of buffer".into(),
});
}
let b = self.buf[self.pos];
self.pos += 1;
Ok(b)
}
fn read_bytes_slice(&mut self, n: usize) -> MatterResult<&'a [u8]> {
if self.pos + n > self.buf.len() {
return Err(MatterError::Protocol {
opcode: 0,
msg: format!("TLV: need {} bytes, have {}", n, self.buf.len() - self.pos),
});
}
let s = &self.buf[self.pos..self.pos + n];
self.pos += n;
Ok(s)
}
fn remaining(&self) -> usize {
self.buf.len() - self.pos
}
fn read_element(&mut self) -> MatterResult<TlvElem<'a>> {
if self.remaining() == 0 {
return Err(MatterError::Protocol {
opcode: 0,
msg: "TLV: buffer exhausted".into(),
});
}
let ctrl = self.read_byte()?;
let tag_type = (ctrl >> 5) & 0x07;
let val_type = ctrl & 0x1f;
let tag: Option<u8> = match tag_type {
0 => None,
1 => Some(self.read_byte()?),
_ => {
return Err(MatterError::Protocol {
opcode: 0,
msg: format!("TLV: unsupported tag type {tag_type}"),
});
}
};
match val_type {
0x08 => Ok(TlvElem {
tag,
value: TlvVal::Bool(false),
}),
0x09 => Ok(TlvElem {
tag,
value: TlvVal::Bool(true),
}),
0x04 => {
let b = self.read_byte()?;
Ok(TlvElem {
tag,
value: TlvVal::Uint(b as u64),
})
}
0x05 => {
let b = self.read_bytes_slice(2)?;
Ok(TlvElem {
tag,
value: TlvVal::Uint(u16::from_le_bytes([b[0], b[1]]) as u64),
})
}
0x06 => {
let b = self.read_bytes_slice(4)?;
Ok(TlvElem {
tag,
value: TlvVal::Uint(u32::from_le_bytes([b[0], b[1], b[2], b[3]]) as u64),
})
}
0x07 => {
let b = self.read_bytes_slice(8)?;
Ok(TlvElem {
tag,
value: TlvVal::Uint(u64::from_le_bytes(b.try_into().unwrap())),
})
}
0x10 => {
let len = self.read_byte()? as usize;
let data = self.read_bytes_slice(len)?;
Ok(TlvElem {
tag,
value: TlvVal::Bytes(data),
})
}
0x15 => Ok(TlvElem {
tag,
value: TlvVal::StructStart,
}),
0x18 => Ok(TlvElem {
tag,
value: TlvVal::End,
}),
_ => Err(MatterError::Protocol {
opcode: 0,
msg: format!("TLV: unsupported value type {val_type:#04x}"),
}),
}
}
}
struct TlvElem<'a> {
tag: Option<u8>,
value: TlvVal<'a>,
}
enum TlvVal<'a> {
Uint(u64),
Bytes(&'a [u8]),
Bool(#[allow(dead_code)] bool),
StructStart,
End,
}
fn derive_session_keys(ke: &[u8]) -> ([u8; 16], [u8; 16], [u8; 32]) {
let out = hkdf_expand_label(ke, b"", "SessionKeys", 64);
let mut i2r = [0u8; 16];
let mut r2i = [0u8; 16];
let mut challenge = [0u8; 32];
i2r.copy_from_slice(&out[0..16]);
r2i.copy_from_slice(&out[16..32]);
challenge.copy_from_slice(&out[32..64]);
(i2r, r2i, challenge)
}
pub enum PaseCommissionerState {
Idle,
SentParamRequest {
initiator_random: [u8; 32],
session_id: u16,
req_bytes: Vec<u8>,
},
SentPake1 {
prover: Spake2PlusProver,
req_bytes: Vec<u8>,
resp_bytes: Vec<u8>,
session_id: u16,
},
Established(EstablishedSession),
Failed(String),
}
pub struct PaseCommissioner {
passcode: u32,
state: PaseCommissionerState,
}
impl PaseCommissioner {
pub fn new(passcode: u32) -> Self {
Self {
passcode,
state: PaseCommissionerState::Idle,
}
}
pub fn build_param_request(&mut self) -> MatterResult<(u16, Vec<u8>)> {
let mut rng = OsRng;
let mut init_random = [0u8; 32];
rng.fill_bytes(&mut init_random);
let mut sid_bytes = [0u8; 2];
rng.fill_bytes(&mut sid_bytes);
let session_id = u16::from_le_bytes(sid_bytes);
let mut inner = Vec::new();
inner.extend_from_slice(&tlv_ctx_uint1(1, 0)); inner.extend_from_slice(&tlv_ctx_bool_false(2)); inner.extend_from_slice(&tlv_ctx_uint2(3, session_id));
let payload = tlv_anon_struct(&inner);
self.state = PaseCommissionerState::SentParamRequest {
initiator_random: init_random,
session_id,
req_bytes: payload.clone(),
};
Ok((session_id, payload))
}
pub fn handle_param_response(&mut self, payload: &[u8]) -> MatterResult<Vec<u8>> {
let (init_random, req_bytes, session_id) = match &self.state {
PaseCommissionerState::SentParamRequest {
initiator_random,
req_bytes,
session_id,
} => (*initiator_random, req_bytes.clone(), *session_id),
_ => {
return Err(MatterError::Protocol {
opcode: 0x21,
msg: "unexpected state for PBKDFParamResponse".into(),
});
}
};
let (_resp_init_random, _resp_random, _resp_session_id, iterations, salt) =
decode_param_response(payload)?;
let _ = init_random;
let (w0s, w1s) = derive_passcode_verifier(self.passcode, &salt, iterations)
.map_err(|e| MatterError::Spake2(e.to_string()))?;
let prover =
Spake2PlusProver::new(&w0s, &w1s).map_err(|e| MatterError::Spake2(e.to_string()))?;
let pa = prover.pake_message();
let inner = tlv_ctx_bytes(1, &pa);
let pake1 = tlv_anon_struct(&inner);
self.state = PaseCommissionerState::SentPake1 {
prover,
req_bytes,
resp_bytes: payload.to_vec(),
session_id,
};
Ok(pake1)
}
pub fn handle_pake2(&mut self, payload: &[u8]) -> MatterResult<Vec<u8>> {
let (prover, req_bytes, resp_bytes, session_id) = match &self.state {
PaseCommissionerState::SentPake1 {
prover,
req_bytes,
resp_bytes,
session_id,
} => {
let prover = unsafe {
std::ptr::read(prover as *const Spake2PlusProver)
};
(prover, req_bytes.clone(), resp_bytes.clone(), *session_id)
}
_ => {
return Err(MatterError::Protocol {
opcode: 0x23,
msg: "unexpected state for Pake2".into(),
});
}
};
self.state = PaseCommissionerState::Failed("pake2 in progress".into());
let (pb, cb) = decode_pake2(payload)?;
use sha2::{Digest, Sha256};
let context: [u8; 32] = {
let mut h = Sha256::new();
h.update(&req_bytes);
h.update(&resp_bytes);
h.finalize().into()
};
let keys = prover
.finish(&pb, &context)
.map_err(|e| MatterError::Spake2(e.to_string()))?;
if keys.cb != cb.as_slice() {
self.state = PaseCommissionerState::Failed("Pake2 cB verification failed".into());
return Err(MatterError::Spake2("Pake2 cB confirmation mismatch".into()));
}
let inner = tlv_ctx_bytes(1, &keys.ca);
let pake3 = tlv_anon_struct(&inner);
let (i2r, r2i, challenge) = derive_session_keys(&keys.ke);
let session = EstablishedSession {
session_id,
peer_session_id: 0, encrypt_key: i2r, decrypt_key: r2i,
attestation_challenge: challenge,
peer_node_id: None,
};
self.state = PaseCommissionerState::Established(session);
Ok(pake3)
}
pub fn established_session(&self) -> Option<&EstablishedSession> {
match &self.state {
PaseCommissionerState::Established(s) => Some(s),
_ => None,
}
}
}
pub enum PaseCommissioneeState {
Idle,
SentParamResponse {
req_bytes: Vec<u8>,
resp_bytes: Vec<u8>,
salt: Vec<u8>,
iterations: u32,
},
SentPake2 {
verifier: Box<Spake2PlusVerifier>,
keys: crate::homeauto::matter::crypto::spake2plus::Spake2PlusKeys,
session_id: u16,
},
Established(EstablishedSession),
Failed(String),
}
pub struct PaseCommissionee {
passcode: u32,
salt: Vec<u8>,
iterations: u32,
state: PaseCommissioneeState,
}
impl PaseCommissionee {
pub fn new(passcode: u32) -> Self {
let mut salt = vec![0u8; 32];
OsRng.fill_bytes(&mut salt);
Self::new_with_params(passcode, salt, 10000)
}
pub fn new_with_params(passcode: u32, salt: Vec<u8>, iterations: u32) -> Self {
Self {
passcode,
salt,
iterations,
state: PaseCommissioneeState::Idle,
}
}
pub fn handle_param_request(&mut self, payload: &[u8]) -> MatterResult<Vec<u8>> {
if !matches!(self.state, PaseCommissioneeState::Idle) {
return Err(MatterError::Protocol {
opcode: 0x20,
msg: "unexpected state for PBKDFParamRequest".into(),
});
}
let (init_random, _passcode_id, init_session_id) = decode_param_request(payload)?;
let mut resp_random = [0u8; 32];
OsRng.fill_bytes(&mut resp_random);
let mut sid_bytes = [0u8; 2];
OsRng.fill_bytes(&mut sid_bytes);
let resp_session_id = u16::from_le_bytes(sid_bytes);
let pbkdf_params_inner = {
let mut v = Vec::new();
v.extend_from_slice(&tlv_ctx_uint4(1, self.iterations));
v.extend_from_slice(&tlv_ctx_bytes(2, &self.salt));
v
};
let mut inner = Vec::new();
inner.extend_from_slice(&tlv_ctx_bytes(1, &init_random));
inner.extend_from_slice(&tlv_ctx_bytes(2, &resp_random));
inner.extend_from_slice(&tlv_ctx_uint2(3, resp_session_id));
inner.extend_from_slice(&tlv_ctx_struct(4, &pbkdf_params_inner));
let resp_payload = tlv_anon_struct(&inner);
self.state = PaseCommissioneeState::SentParamResponse {
req_bytes: payload.to_vec(),
resp_bytes: resp_payload.clone(),
salt: self.salt.clone(),
iterations: self.iterations,
};
let _ = init_session_id; Ok(resp_payload)
}
pub fn handle_pake1(&mut self, payload: &[u8]) -> MatterResult<Vec<u8>> {
let (req_bytes, resp_bytes, salt, iterations) = match &self.state {
PaseCommissioneeState::SentParamResponse {
req_bytes,
resp_bytes,
salt,
iterations,
} => (
req_bytes.clone(),
resp_bytes.clone(),
salt.clone(),
*iterations,
),
_ => {
return Err(MatterError::Protocol {
opcode: 0x22,
msg: "unexpected state for Pake1".into(),
});
}
};
let pa = decode_pake1(payload)?;
let (w0s, w1s) = derive_passcode_verifier(self.passcode, &salt, iterations)
.map_err(|e| MatterError::Spake2(e.to_string()))?;
let verifier = Spake2PlusVerifier::new_from_w1s(&w0s, &w1s)
.map_err(|e| MatterError::Spake2(e.to_string()))?;
let pb = verifier.pake_message();
use sha2::{Digest, Sha256};
let context: [u8; 32] = {
let mut h = Sha256::new();
h.update(&req_bytes);
h.update(&resp_bytes);
h.finalize().into()
};
let keys = verifier
.finish(&pa, &context)
.map_err(|e| MatterError::Spake2(e.to_string()))?;
let mut inner = Vec::new();
inner.extend_from_slice(&tlv_ctx_bytes(1, &pb));
inner.extend_from_slice(&tlv_ctx_bytes(2, &keys.cb));
let pake2 = tlv_anon_struct(&inner);
let session_id = {
extract_session_id_from_resp(&resp_bytes)
};
self.state = PaseCommissioneeState::SentPake2 {
verifier: Box::new(verifier),
keys,
session_id,
};
Ok(pake2)
}
pub fn handle_pake3(&mut self, payload: &[u8]) -> MatterResult<EstablishedSession> {
let (keys, session_id) = match &self.state {
PaseCommissioneeState::SentPake2 {
keys, session_id, ..
} => {
let keys = unsafe {
std::ptr::read(
keys as *const crate::homeauto::matter::crypto::spake2plus::Spake2PlusKeys,
)
};
(keys, *session_id)
}
_ => {
return Err(MatterError::Protocol {
opcode: 0x24,
msg: "unexpected state for Pake3".into(),
});
}
};
self.state = PaseCommissioneeState::Failed("pake3 in progress".into());
let ca = decode_pake3(payload)?;
if ca != keys.ca.as_ref() {
self.state = PaseCommissioneeState::Failed("Pake3 cA verification failed".into());
return Err(MatterError::Spake2("Pake3 cA confirmation mismatch".into()));
}
let (i2r, r2i, challenge) = derive_session_keys(&keys.ke);
let session = EstablishedSession {
session_id,
peer_session_id: 0,
encrypt_key: r2i, decrypt_key: i2r,
attestation_challenge: challenge,
peer_node_id: None,
};
self.state = PaseCommissioneeState::Established(session.clone());
Ok(session)
}
}
fn decode_param_request(buf: &[u8]) -> MatterResult<([u8; 32], u8, u16)> {
let mut r = TlvReader::new(buf);
let el = r.read_element()?;
if !matches!(el.value, TlvVal::StructStart) || el.tag.is_some() {
return Err(MatterError::Protocol {
opcode: 0x20,
msg: "PBKDFParamRequest: expected anon struct".into(),
});
}
let mut passcode_id: u8 = 0;
let mut session_id: u16 = 0;
loop {
let el = r.read_element()?;
match el.value {
TlvVal::End => break,
TlvVal::Uint(v) => {
if el.tag == Some(1) {
passcode_id = v as u8;
} else if el.tag == Some(3) {
session_id = v as u16;
}
}
TlvVal::Bool(_) => {} TlvVal::Bytes(_) => {} _ => {}
}
}
use sha2::{Digest, Sha256};
let hash: [u8; 32] = Sha256::digest(buf).into();
Ok((hash, passcode_id, session_id))
}
type PbkdfParamResponseFields = ([u8; 32], [u8; 32], u16, u32, Vec<u8>);
fn decode_param_response(buf: &[u8]) -> MatterResult<PbkdfParamResponseFields> {
let mut r = TlvReader::new(buf);
let el = r.read_element()?;
if !matches!(el.value, TlvVal::StructStart) || el.tag.is_some() {
return Err(MatterError::Protocol {
opcode: 0x21,
msg: "PBKDFParamResponse: expected anon struct".into(),
});
}
let mut init_random = [0u8; 32];
let mut resp_random = [0u8; 32];
let mut session_id: u16 = 0;
let mut iterations: u32 = 0;
let mut salt: Vec<u8> = Vec::new();
loop {
let el = r.read_element()?;
match el.value {
TlvVal::End => break,
TlvVal::Bytes(b) => {
if el.tag == Some(1) {
if b.len() == 32 {
init_random.copy_from_slice(b);
}
} else if el.tag == Some(2) && b.len() == 32 {
resp_random.copy_from_slice(b);
}
}
TlvVal::Uint(v) => {
if el.tag == Some(3) {
session_id = v as u16;
}
}
TlvVal::StructStart => {
if el.tag == Some(4) {
loop {
let inner = r.read_element()?;
match inner.value {
TlvVal::End => break,
TlvVal::Uint(v) if inner.tag == Some(1) => {
iterations = v as u32;
}
TlvVal::Bytes(b) if inner.tag == Some(2) => {
salt = b.to_vec();
}
_ => {}
}
}
} else {
loop {
if matches!(r.read_element()?.value, TlvVal::End) {
break;
}
}
}
}
_ => {}
}
}
if iterations == 0 || salt.is_empty() {
return Err(MatterError::Protocol {
opcode: 0x21,
msg: "PBKDFParamResponse: missing pbkdf_params".into(),
});
}
Ok((init_random, resp_random, session_id, iterations, salt))
}
fn decode_pake1(buf: &[u8]) -> MatterResult<Vec<u8>> {
let mut r = TlvReader::new(buf);
let el = r.read_element()?;
if !matches!(el.value, TlvVal::StructStart) || el.tag.is_some() {
return Err(MatterError::Protocol {
opcode: 0x22,
msg: "Pake1: expected anon struct".into(),
});
}
let mut pa: Option<Vec<u8>> = None;
loop {
let el = r.read_element()?;
match el.value {
TlvVal::End => break,
TlvVal::Bytes(b) if el.tag == Some(1) => {
pa = Some(b.to_vec());
}
_ => {}
}
}
pa.ok_or_else(|| MatterError::Protocol {
opcode: 0x22,
msg: "Pake1: missing pA".into(),
})
}
fn decode_pake2(buf: &[u8]) -> MatterResult<(Vec<u8>, Vec<u8>)> {
let mut r = TlvReader::new(buf);
let el = r.read_element()?;
if !matches!(el.value, TlvVal::StructStart) || el.tag.is_some() {
return Err(MatterError::Protocol {
opcode: 0x23,
msg: "Pake2: expected anon struct".into(),
});
}
let mut pb: Option<Vec<u8>> = None;
let mut cb: Option<Vec<u8>> = None;
loop {
let el = r.read_element()?;
match el.value {
TlvVal::End => break,
TlvVal::Bytes(b) => {
if el.tag == Some(1) {
pb = Some(b.to_vec());
} else if el.tag == Some(2) {
cb = Some(b.to_vec());
}
}
_ => {}
}
}
let pb = pb.ok_or_else(|| MatterError::Protocol {
opcode: 0x23,
msg: "Pake2: missing pB".into(),
})?;
let cb = cb.ok_or_else(|| MatterError::Protocol {
opcode: 0x23,
msg: "Pake2: missing cB".into(),
})?;
Ok((pb, cb))
}
fn decode_pake3(buf: &[u8]) -> MatterResult<Vec<u8>> {
let mut r = TlvReader::new(buf);
let el = r.read_element()?;
if !matches!(el.value, TlvVal::StructStart) || el.tag.is_some() {
return Err(MatterError::Protocol {
opcode: 0x24,
msg: "Pake3: expected anon struct".into(),
});
}
let mut ca: Option<Vec<u8>> = None;
loop {
let el = r.read_element()?;
match el.value {
TlvVal::End => break,
TlvVal::Bytes(b) if el.tag == Some(1) => {
ca = Some(b.to_vec());
}
_ => {}
}
}
ca.ok_or_else(|| MatterError::Protocol {
opcode: 0x24,
msg: "Pake3: missing cA".into(),
})
}
fn extract_session_id_from_resp(resp_bytes: &[u8]) -> u16 {
let mut r = TlvReader::new(resp_bytes);
if r.read_element().is_err() {
return 0;
}
loop {
match r.read_element() {
Ok(el) => match el.value {
TlvVal::End => return 0,
TlvVal::Uint(v) if el.tag == Some(3) => return v as u16,
TlvVal::StructStart => {
loop {
match r.read_element() {
Ok(inner) if matches!(inner.value, TlvVal::End) => break,
Ok(_) => {}
Err(_) => return 0,
}
}
}
_ => {}
},
Err(_) => return 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_PASSCODE: u32 = 20202021;
#[test]
fn pase_full_handshake_correct_passcode() {
let salt = b"matter-pase-test-salt-32bytes!!!".to_vec();
let iterations = 1000u32;
let mut commissioner = PaseCommissioner::new(TEST_PASSCODE);
let mut commissionee = PaseCommissionee::new_with_params(TEST_PASSCODE, salt, iterations);
let (_session_id, req) = commissioner
.build_param_request()
.expect("build_param_request failed");
let resp = commissionee
.handle_param_request(&req)
.expect("handle_param_request failed");
let pake1 = commissioner
.handle_param_response(&resp)
.expect("handle_param_response failed");
let pake2 = commissionee
.handle_pake1(&pake1)
.expect("handle_pake1 failed");
let pake3 = commissioner
.handle_pake2(&pake2)
.expect("handle_pake2 failed");
let comm_sess = commissioner
.established_session()
.expect("commissioner should be established");
let comm_ee_sess = commissionee
.handle_pake3(&pake3)
.expect("handle_pake3 failed");
assert_eq!(
comm_sess.encrypt_key, comm_ee_sess.decrypt_key,
"commissioner encrypt_key must equal commissionee decrypt_key (I2R)"
);
assert_eq!(
comm_sess.decrypt_key, comm_ee_sess.encrypt_key,
"commissioner decrypt_key must equal commissionee encrypt_key (R2I)"
);
assert_eq!(
comm_sess.attestation_challenge, comm_ee_sess.attestation_challenge,
"attestation challenge must match"
);
}
#[test]
fn pase_handshake_wrong_passcode_fails_at_pake3() {
let salt = b"matter-pase-test-salt-32bytes!!!".to_vec();
let iterations = 1000u32;
let mut commissioner = PaseCommissioner::new(11111111); let mut commissionee = PaseCommissionee::new_with_params(TEST_PASSCODE, salt, iterations);
let (_sid, req) = commissioner.build_param_request().unwrap();
let resp = commissionee.handle_param_request(&req).unwrap();
let pake1 = commissioner.handle_param_response(&resp).unwrap();
let pake2 = commissionee.handle_pake1(&pake1).unwrap();
let result = commissioner.handle_pake2(&pake2);
assert!(result.is_err(), "wrong passcode should fail at Pake2");
}
#[test]
fn pase_param_request_encodes_session_id() {
let mut commissioner = PaseCommissioner::new(TEST_PASSCODE);
let (session_id, req_bytes) = commissioner.build_param_request().unwrap();
let mut r = TlvReader::new(&req_bytes);
let _ = r.read_element().unwrap(); let mut found_sid: Option<u16> = None;
loop {
let el = r.read_element().unwrap();
match el.value {
TlvVal::End => break,
TlvVal::Uint(v) if el.tag == Some(3) => {
found_sid = Some(v as u16);
}
_ => {}
}
}
assert_eq!(
found_sid,
Some(session_id),
"session_id must be encoded at tag 3"
);
}
}