use rand_core::RngCore;
use crate::crypto::{
CanonEcPointRef, Crypto, HmacHashRef, Kdf, AEAD_CANON_KEY_LEN, EC_POINT_ZEROED,
HMAC_HASH_ZEROED,
};
use crate::dm::AttrChangeNotifier;
use crate::error::{Error, ErrorCode};
use crate::sc::pase::spake2p::{Spake2P, Spake2pRandom, Spake2pRandomRef, Spake2pSessionKeys};
use crate::sc::{check_opcode, complete_with_status, OpCode, SCStatusCodes};
use crate::tlv::{get_root_node_struct, FromTLV, OctetStr, TLVElement, TagType, ToTLV};
use crate::transport::exchange::Exchange;
use crate::transport::session::{ReservedSession, SessionMode};
use crate::utils::init::{init, Init};
use super::{
PBKDFParamReq, PBKDFParamResp, PBKDFParamRespParams, Pake1, Pake2, Pake3,
SPAKE2_SESSION_KEYS_INFO,
};
pub struct PaseResponder<'a, C: Crypto> {
crypto: C,
notify: &'a dyn AttrChangeNotifier,
spake2p: Spake2P,
}
impl<'a, C: Crypto> PaseResponder<'a, C> {
pub const fn new(crypto: C, notify: &'a dyn AttrChangeNotifier) -> Self {
Self {
crypto,
notify,
spake2p: Spake2P::new(),
}
}
pub fn init(crypto: C, notify: &'a dyn AttrChangeNotifier) -> impl Init<Self> {
init!(Self {
crypto,
notify,
spake2p <- Spake2P::init(),
})
}
pub async fn handle(&mut self, exchange: &mut Exchange<'_>) -> Result<(), Error> {
let result = self.handle_inner(exchange).await;
let pake_failed = matches!(result, Ok(false) | Err(_));
if pake_failed {
let notify_mdns = || exchange.matter().transport().notify_mdns_changed();
let notify_change =
|endpt_id, cluster_id| self.notify.notify_cluster_changed(endpt_id, cluster_id);
let _ = exchange
.with_state(|state| state.pase.record_pake_failure(notify_mdns, notify_change));
}
result.map(|_| ())
}
async fn handle_inner(&mut self, exchange: &mut Exchange<'_>) -> Result<bool, Error> {
let mut session = ReservedSession::reserve(exchange.matter(), &self.crypto).await?;
if !self.update_session_timeout(exchange, true).await? {
return Ok(true);
}
if !self
.handle_pbkdfparamrequest(exchange, &mut session)
.await?
{
self.clear_session_timeout(exchange)?;
return Ok(true);
}
exchange.recv_fetch().await?;
if !self.update_session_timeout(exchange, false).await? {
return Ok(true);
}
if !self.handle_pasepake1(exchange).await? {
self.clear_session_timeout(exchange)?;
return Ok(true);
}
exchange.recv_fetch().await?;
if !self.update_session_timeout(exchange, false).await? {
return Ok(true);
}
let success = self.handle_pasepake3(exchange, session).await?;
exchange.acknowledge().await?;
self.clear_session_timeout(exchange)?;
Ok(success)
}
async fn handle_pbkdfparamrequest(
&mut self,
exchange: &mut Exchange<'_>,
session: &mut ReservedSession<'_>,
) -> Result<bool, Error> {
check_opcode(exchange, OpCode::PBKDFParamRequest)?;
let rx = exchange.rx()?;
let mut salt = [0u8; super::spake2p::SPAKE2P_VERIFIER_SALT_LEN];
let mut salt_len = 0usize;
let mut count = 0;
let notify_mdns = || exchange.matter().transport().notify_mdns_changed();
let notify_change =
|endpt_id, cluster_id| self.notify.notify_cluster_changed(endpt_id, cluster_id);
let has_comm_window = {
exchange.with_state(|state| {
state
.pase
.check_comm_window_timeout(notify_mdns, notify_change)?;
if let Some(comm_window) = state.pase.comm_window() {
let src = comm_window.verifier.salt_bytes();
salt[..src.len()].copy_from_slice(src);
salt_len = src.len();
count = comm_window.verifier.count;
Ok(true)
} else {
Ok(false)
}
})?
};
if has_comm_window {
let mut our_random = Spake2pRandom::new();
let mut initiator_random = Spake2pRandom::new();
let (local_sessid, peer_sessid, resp) = {
let req = PBKDFParamReq::from_tlv(&TLVElement::new(rx.payload()))?;
if req.passcode_id != 0 {
error!("Can't yet handle passcode_id != 0");
Err(ErrorCode::Invalid)?;
}
if let Some(params) = req.session_parameters.as_ref() {
exchange.with_state(|state| {
exchange
.id()
.session(&mut state.sessions)
.set_peer_session_params(params);
Ok(())
})?;
session.set_peer_session_params(params)?;
}
let mut rand = self.crypto.rand()?;
rand.fill_bytes(our_random.access_mut());
let local_sessid =
exchange.with_state(|state| Ok(state.sessions.get_next_sess_id()))?;
initiator_random.load(Spake2pRandomRef::try_new(req.initiator_random.0)?);
let max_paths = exchange.matter().dev_det().max_paths_per_invoke;
let resp = PBKDFParamResp {
initiator_random: OctetStr::new(initiator_random.access()),
responder_random: OctetStr::new(our_random.access()),
responder_ssid: local_sessid,
params: (!req.has_params).then(|| PBKDFParamRespParams {
iterations: count,
salt: OctetStr::new(&salt[..salt_len]),
}),
session_parameters: Some(crate::sc::SessionParameters {
max_paths_per_invoke: Some(max_paths),
..Default::default()
}),
};
(local_sessid, req.initiator_ssid, resp)
};
let mut context = Some(self.spake2p.start_context(
&self.crypto,
local_sessid,
peer_sessid,
rx.payload(),
)?);
exchange
.send_with(|_, wb| {
resp.to_tlv(&TagType::Anonymous, &mut *wb)?;
if let Some(context) = context.take() {
self.spake2p.finish_context::<&C>(context, wb.as_slice())?;
}
Ok(Some(OpCode::PBKDFParamResponse.into()))
})
.await?;
Ok(true)
} else {
debug!("Dropping PBKDFParamRequest: no commissioning window open");
Ok(false)
}
}
async fn handle_pasepake1(&mut self, exchange: &mut Exchange<'_>) -> Result<bool, Error> {
check_opcode(exchange, OpCode::PASEPake1)?;
let req = get_root_node_struct(exchange.rx()?.payload())?;
let pake1 = Pake1::from_tlv(&req)?;
let a_pt: CanonEcPointRef<'_> = pake1.pa.0.try_into()?;
let mut b_pt = EC_POINT_ZEROED;
let mut cb = HMAC_HASH_ZEROED;
let notify_mdns = || exchange.matter().transport().notify_mdns_changed();
let notify_change =
|endpt_id, cluster_id| self.notify.notify_cluster_changed(endpt_id, cluster_id);
let has_comm_window = {
exchange.with_state(|state| {
state
.pase
.check_comm_window_timeout(notify_mdns, notify_change)?;
if let Some(comm_window) = state.pase.comm_window() {
self.spake2p.setup_verifier(
&self.crypto,
&comm_window.verifier,
a_pt,
&mut b_pt,
&mut cb,
)?;
Ok(true)
} else {
Ok(false)
}
})?
};
if has_comm_window {
exchange
.send_with(|_, wb| {
let resp = Pake2 {
pb: OctetStr::new(b_pt.access()),
cb: OctetStr::new(cb.access()),
};
resp.to_tlv(&TagType::Anonymous, wb)?;
Ok(Some(OpCode::PASEPake2.into()))
})
.await?;
Ok(true)
} else {
debug!("Dropping PASEPake1: no commissioning window open");
Ok(false)
}
}
async fn handle_pasepake3(
&mut self,
exchange: &mut Exchange<'_>,
mut session: ReservedSession<'_>,
) -> Result<bool, Error> {
check_opcode(exchange, OpCode::PASEPake3)?;
let req = get_root_node_struct(exchange.rx()?.payload())?;
let pake3 = Pake3::from_tlv(&req)?;
let ca: HmacHashRef<'_> = pake3.ca.0.try_into()?;
let verify_result = self.spake2p.verify(ca);
let success = verify_result.is_ok();
let status = match verify_result {
Ok((local_sessid, peer_sessid, ke)) => {
let mut session_keys = Spake2pSessionKeys::new(); self.crypto
.kdf()?
.expand(&[], ke, SPAKE2_SESSION_KEYS_INFO, &mut session_keys)
.map_err(|_x| ErrorCode::InvalidData)?;
let peer_addr = exchange.with_state(|state| {
let sess = exchange.id().session(&mut state.sessions);
Ok(sess.get_peer_addr())
})?;
let (dec_key, remaining) = session_keys
.reference()
.split::<AEAD_CANON_KEY_LEN, { AEAD_CANON_KEY_LEN * 2 }>();
let (enc_key, att_challenge) =
remaining.split::<AEAD_CANON_KEY_LEN, AEAD_CANON_KEY_LEN>();
session.update(
0,
0,
peer_sessid,
local_sessid,
peer_addr,
SessionMode::Pase { fab_idx: 0 },
Some(dec_key),
Some(enc_key),
Some(att_challenge),
)?;
session.complete();
exchange.with_state(|state| {
if !state.failsafe.is_armed() {
let session_mode = SessionMode::Pase { fab_idx: 0 };
state.failsafe.arm(
crate::failsafe::DEFAULT_FAILSAFE_EXPIRY_SECS,
0,
&session_mode,
&mut state.pase,
)?;
}
Ok(())
})?;
SCStatusCodes::SessionEstablishmentSuccess
}
Err(status) => status,
};
complete_with_status(exchange, status, &[]).await?;
Ok(success)
}
async fn update_session_timeout(
&mut self,
exchange: &mut Exchange<'_>,
new: bool,
) -> Result<bool, Error> {
let status = exchange.with_state(|state| {
if state
.pase
.session_timeout
.as_ref()
.map(|sd| sd.is_sess_expired())
.unwrap_or(false)
{
state.pase.session_timeout = None;
}
if let Some(sd) = state.pase.session_timeout.as_mut() {
if sd.exch_id != exchange.id() {
debug!("Another PAKE session in progress");
Ok(Some(SCStatusCodes::Busy))
} else {
state.pase.session_timeout = Some(super::SessionEstTimeout::new(exchange));
Ok(None)
}
} else if new {
state.pase.session_timeout = Some(super::SessionEstTimeout::new(exchange));
Ok(None)
} else {
error!("PAKE session not found or expired");
Ok(Some(SCStatusCodes::SessionNotFound))
}
})?;
if let Some(status) = status {
complete_with_status(exchange, status, &[]).await?;
Ok(false)
} else {
Ok(true)
}
}
fn clear_session_timeout(&mut self, exchange: &Exchange) -> Result<(), Error> {
exchange.with_state(|state| {
state.pase.session_timeout = None;
Ok(())
})
}
}