use core::num::NonZeroU8;
use embassy_time::{Duration, Instant};
use crate::dm::clusters::adm_comm::{self};
use crate::dm::endpoints::ROOT_ENDPOINT_ID;
use crate::error::{Error, ErrorCode};
use crate::im::{ClusterId, EndptId};
use crate::sc::pase::spake2p::{
Spake2pVerifierData, Spake2pVerifierStrRef, SPAKE2P_VERIFIER_SALT_LEN,
SPAKE2P_VERIFIER_SALT_MIN_LEN,
};
use crate::sc::SessionParameters;
use crate::tlv::{FromTLV, OctetStr, ToTLV};
use crate::transport::exchange::{Exchange, ExchangeId};
use crate::utils::init::{init, Init};
use crate::utils::maybe::Maybe;
use crate::MatterLocalService;
pub use initiator::PaseInitiator;
pub use responder::PaseResponder;
pub use spake2p::{
Spake2pVerifierPassword, Spake2pVerifierPasswordRef, SPAKE2P_VERIFIER_PASSWORD_LEN,
SPAKE2P_VERIFIER_PASSWORD_ZEROED,
};
mod initiator;
mod responder;
pub(crate) mod spake2p;
pub const MIN_COMM_WINDOW_TIMEOUT_SECS: u16 = 3 * 60;
pub const MAX_COMM_WINDOW_TIMEOUT_SECS: u16 = 15 * 60;
fn notify_adm_comm_window_attrs_changed(notify_change: &mut impl FnMut(EndptId, ClusterId)) {
notify_change(ROOT_ENDPOINT_ID, adm_comm::FULL_CLUSTER.id);
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum CommWindowType {
Basic,
Enhanced,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct CommWindowOpener {
pub fab_idx: NonZeroU8,
pub vendor_id: u16,
}
pub struct CommWindow {
mdns_id: u64,
discriminator: u16,
pub(crate) verifier: Spake2pVerifierData,
opener: Option<CommWindowOpener>,
window_expiry: Instant,
pake_failures: u8,
}
impl CommWindow {
fn init_with_pw<'a>(
mdns_id: u64,
password: Spake2pVerifierPasswordRef<'a>,
salt: &'a [u8],
discriminator: u16,
opener: Option<CommWindowOpener>,
window_expiry: Instant,
) -> impl Init<Self> + 'a {
init!(Self {
mdns_id,
discriminator,
verifier <- Spake2pVerifierData::init_with_pw(password, salt),
opener,
window_expiry,
pake_failures: 0,
})
}
fn init<'a>(
mdns_id: u64,
verifier: Spake2pVerifierStrRef<'a>,
salt: &'a [u8],
count: u32,
discriminator: u16,
opener: Option<CommWindowOpener>,
window_expiry: Instant,
) -> impl Init<Self> + 'a {
init!(Self {
mdns_id,
discriminator,
verifier <- Spake2pVerifierData::init(verifier, salt, count),
opener,
window_expiry,
pake_failures: 0,
})
}
pub fn comm_window_type(&self) -> CommWindowType {
if self.verifier.password.is_some() {
CommWindowType::Basic
} else {
CommWindowType::Enhanced
}
}
pub fn opener(&self) -> Option<CommWindowOpener> {
self.opener
}
pub fn mdns_service(&self) -> MatterLocalService {
MatterLocalService::Commissionable {
id: self.mdns_id,
discriminator: self.discriminator,
enhanced: matches!(self.comm_window_type(), CommWindowType::Enhanced),
}
}
}
pub struct Pase {
comm_window: Maybe<CommWindow>,
pub(crate) session_timeout: Option<SessionEstTimeout>,
}
impl Pase {
#[inline(always)]
pub const fn new() -> Self {
Self {
comm_window: Maybe::none(),
session_timeout: None,
}
}
pub fn init() -> impl Init<Self> {
init!(Self {
comm_window <- Maybe::init_none(),
session_timeout: None,
})
}
pub fn check_comm_window_timeout(
&mut self,
notify_mdns: impl FnMut(),
notify_change: impl FnMut(EndptId, ClusterId),
) -> Result<bool, Error> {
let expired = self
.comm_window
.as_opt_ref()
.map(|comm_window| Instant::now() > comm_window.window_expiry)
.unwrap_or(false);
if expired {
warn!("PASE Commissioning Window expired, closing");
self.close_comm_window(notify_mdns, notify_change)?;
Ok(true)
} else {
Ok(false)
}
}
pub fn comm_window(&self) -> Option<&CommWindow> {
self.comm_window.as_opt_ref()
}
fn validate_salt_len(salt: &[u8]) -> Result<(), Error> {
if !(SPAKE2P_VERIFIER_SALT_MIN_LEN..=SPAKE2P_VERIFIER_SALT_LEN).contains(&salt.len()) {
Err(ErrorCode::ConstraintError)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn open_basic_comm_window(
&mut self,
mdns_id: u64,
salt: &[u8],
password: Spake2pVerifierPasswordRef<'_>,
discriminator: u16,
timeout_secs: u16,
opener: Option<CommWindowOpener>,
mut notify_mdns: impl FnMut(),
mut notify_change: impl FnMut(EndptId, ClusterId),
) -> Result<(), Error> {
if self.comm_window.is_some() {
Err(ErrorCode::Busy)?;
}
if !(MIN_COMM_WINDOW_TIMEOUT_SECS..=MAX_COMM_WINDOW_TIMEOUT_SECS).contains(&timeout_secs) {
Err(ErrorCode::InvalidCommand)?;
}
Self::validate_salt_len(salt)?;
let window_expiry = Instant::now().saturating_add(Duration::from_secs(timeout_secs as _));
self.comm_window
.reinit(Maybe::init_some(CommWindow::init_with_pw(
mdns_id,
password,
salt,
discriminator,
opener,
window_expiry,
)));
notify_mdns();
notify_adm_comm_window_attrs_changed(&mut notify_change);
info!("PASE Basic Commissioning Window opened");
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn open_comm_window(
&mut self,
mdns_id: u64,
verifier: Spake2pVerifierStrRef<'_>,
salt: &[u8],
count: u32,
discriminator: u16,
timeout_secs: u16,
opener: Option<CommWindowOpener>,
mut notify_mdns: impl FnMut(),
mut notify_change: impl FnMut(EndptId, ClusterId),
) -> Result<(), Error> {
if self.comm_window.is_some() {
Err(ErrorCode::Busy)?;
}
if !(MIN_COMM_WINDOW_TIMEOUT_SECS..=MAX_COMM_WINDOW_TIMEOUT_SECS).contains(&timeout_secs) {
Err(ErrorCode::InvalidCommand)?;
}
Self::validate_salt_len(salt)?;
let window_expiry = Instant::now().saturating_add(Duration::from_secs(timeout_secs as _));
self.comm_window.reinit(Maybe::init_some(CommWindow::init(
mdns_id,
verifier,
salt,
count,
discriminator,
opener,
window_expiry,
)));
notify_mdns();
notify_adm_comm_window_attrs_changed(&mut notify_change);
info!("PASE Commissioning Window opened");
Ok(())
}
pub fn record_pake_failure(
&mut self,
notify_mdns: impl FnMut(),
notify_change: impl FnMut(EndptId, ClusterId),
) -> Result<(), Error> {
const MAX_PAKE_FAILURES: u8 = 20;
self.session_timeout = None;
let revoke = if let Some(window) = self.comm_window.as_opt_mut() {
window.pake_failures = window.pake_failures.saturating_add(1);
warn!(
"PASE Commissioning Window: PAKE failure {} of {}",
window.pake_failures, MAX_PAKE_FAILURES
);
window.pake_failures >= MAX_PAKE_FAILURES
} else {
false
};
if revoke {
warn!("PASE Commissioning Window revoked after too many failed PAKE attempts");
self.close_comm_window(notify_mdns, notify_change)?;
}
Ok(())
}
pub fn close_comm_window(
&mut self,
mut notify_mdns: impl FnMut(),
mut notify_change: impl FnMut(EndptId, ClusterId),
) -> Result<bool, Error> {
if self.comm_window.is_some() {
self.comm_window.clear();
notify_mdns();
notify_adm_comm_window_attrs_changed(&mut notify_change);
info!("PASE Commissioning Window closed");
Ok(true)
} else {
warn!("No PASE Commissioning Window to close");
Ok(false)
}
}
}
impl Default for Pase {
fn default() -> Self {
Self::new()
}
}
const PASE_SESSION_EST_TIMEOUT_SECS: Duration = Duration::from_secs(60);
pub(crate) const SPAKE2_SESSION_KEYS_INFO: &[u8] = b"SessionKeys";
pub(crate) struct SessionEstTimeout {
session_est_expiry: Instant,
pub(crate) exch_id: ExchangeId,
}
impl SessionEstTimeout {
pub(crate) fn new(exchange: &Exchange) -> Self {
Self {
session_est_expiry: Instant::now().saturating_add(PASE_SESSION_EST_TIMEOUT_SECS),
exch_id: exchange.id(),
}
}
pub(crate) fn is_sess_expired(&self) -> bool {
Instant::now() > self.session_est_expiry
}
}
#[derive(FromTLV, ToTLV, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[tlvargs(lifetime = "'a", start = 1)]
pub(crate) struct PBKDFParamReq<'a> {
pub initiator_random: OctetStr<'a>,
pub initiator_ssid: u16,
pub passcode_id: u16,
pub has_params: bool,
pub session_parameters: Option<SessionParameters>,
}
#[derive(FromTLV, ToTLV, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[tlvargs(lifetime = "'a", start = 1)]
pub(crate) struct PBKDFParamResp<'a> {
pub initiator_random: OctetStr<'a>,
pub responder_random: OctetStr<'a>,
pub responder_ssid: u16,
pub params: Option<PBKDFParamRespParams<'a>>,
pub session_parameters: Option<crate::sc::SessionParameters>,
}
#[derive(FromTLV, ToTLV, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[tlvargs(lifetime = "'a", start = 1)]
pub(crate) struct PBKDFParamRespParams<'a> {
pub iterations: u32,
pub salt: OctetStr<'a>,
}
#[derive(FromTLV, ToTLV, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[tlvargs(lifetime = "'a", start = 1)]
pub(crate) struct Pake1<'a> {
pub pa: OctetStr<'a>,
}
#[derive(FromTLV, ToTLV, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[tlvargs(lifetime = "'a", start = 1)]
pub(crate) struct Pake2<'a> {
pub pb: OctetStr<'a>,
pub cb: OctetStr<'a>,
}
#[derive(FromTLV, ToTLV, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[tlvargs(lifetime = "'a", start = 1)]
pub(crate) struct Pake3<'a> {
pub ca: OctetStr<'a>,
}