use core::{mem::MaybeUninit, num::NonZeroU8};
use super::casep::{CaseP, CaseRandom, CaseResumptionId, CaseSessionKeys};
use super::CASE_LARGE_BUF_SIZE;
use crate::alloc;
use crate::cert::CertRef;
use crate::crypto::{CanonPkcSignature, CanonPkcSignatureRef, Crypto, Hash, AEAD_CANON_KEY_LEN};
use crate::error::Error;
use crate::sc::{
check_opcode, complete_with_status, sc_write, OpCode, SCStatusCodes, SessionParameters,
};
use crate::tlv::{get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVTag, TLVWrite, ToTLV};
use crate::transport::exchange::Exchange;
use crate::transport::session::{NocCatIds, ReservedSession, SessionMode};
use crate::utils::init::{init, Init, InitMaybeUninit};
#[derive(FromTLV, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[tlvargs(start = 1, lifetime = "'a")]
struct Sigma1Req<'a> {
initiator_random: OctetStr<'a>,
initiator_sessid: u16,
dest_id: OctetStr<'a>,
peer_pub_key: OctetStr<'a>,
session_parameters: Option<SessionParameters>,
resumption_id: Option<OctetStr<'a>>,
initiator_resume_mic: Option<OctetStr<'a>>,
}
#[derive(FromTLV, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[tlvargs(start = 1, lifetime = "'a")]
struct Sigma3Decrypt<'a> {
initiator_noc: OctetStr<'a>,
initiator_icac: Option<OctetStr<'a>>,
signature: OctetStr<'a>,
}
pub struct CaseResponder<'a, C: Crypto> {
crypto: &'a C,
casep: CaseP<'a, C>,
}
impl<'a, C: Crypto> CaseResponder<'a, C> {
#[inline(always)]
pub const fn new(crypto: &'a C) -> Self {
Self {
crypto,
casep: CaseP::new(),
}
}
pub fn init(crypto: &'a C) -> impl Init<Self> {
init!(Self {
crypto,
casep <- CaseP::init(),
})
}
pub async fn handle(&mut self, exchange: &mut Exchange<'_>) -> Result<(), Error> {
let mut session = ReservedSession::reserve(exchange.matter(), self.crypto).await?;
self.handle_casesigma1(exchange, &mut session).await?;
exchange.recv_fetch().await?;
self.handle_casesigma3(exchange, session).await?;
exchange.acknowledge().await?;
Ok(())
}
async fn handle_casesigma1(
&mut self,
exchange: &mut Exchange<'_>,
session: &mut ReservedSession<'_>,
) -> Result<(), Error> {
check_opcode(exchange, OpCode::CASESigma1)?;
let req = Sigma1Req::from_tlv(&get_root_node_struct(exchange.rx()?.payload())?)?;
if req.resumption_id.is_some() != req.initiator_resume_mic.is_some() {
error!("Sigma1 has mismatched resumptionID/initiatorResumeMIC presence; rejecting");
complete_with_status(exchange, SCStatusCodes::InvalidParameter, &[]).await?;
return Ok(());
}
let local_fabric_idx = exchange.with_state(|state| {
Ok(state
.fabrics
.get_by_dest_id(self.crypto, req.initiator_random.0, req.dest_id.0)
.map(|fabric| fabric.fab_idx()))
})?;
if local_fabric_idx.is_none() {
error!("Fabric Index mismatch");
complete_with_status(exchange, SCStatusCodes::NoSharedTrustRoots, &[]).await?;
return Ok(());
}
let local_sessid = exchange.with_state(|state| Ok(state.sessions.get_next_sess_id()))?;
let mut our_random = MaybeUninit::<CaseRandom>::uninit(); let our_random = our_random.init_with(CaseRandom::init());
let mut resumption_id = MaybeUninit::<CaseResumptionId>::uninit(); let resumption_id = resumption_id.init_with(CaseResumptionId::init());
let mut tt_hash = MaybeUninit::<Hash>::uninit(); let tt_hash = tt_hash.init_with(Hash::init());
self.casep.start(
self.crypto,
req.initiator_sessid,
local_sessid,
unwrap!(local_fabric_idx).get(),
req.peer_pub_key.0.try_into()?,
exchange.rx()?.payload(),
our_random,
resumption_id,
tt_hash,
)?;
if let Some(params) = req.session_parameters.as_ref() {
exchange.with_state(|state| {
exchange
.id()
.session(&mut state.sessions)
.set_peer_session_params(params);
Ok(())
})?;
session.set_peer_session_params(params)?;
}
trace!(
"Destination ID matched to fabric index {}",
self.casep.local_fabric_idx()
);
let mut tt_updated = false;
exchange
.send_with(|exchange, tw| {
exchange.with_state(|state| {
let fabric = NonZeroU8::new(self.casep.local_fabric_idx())
.and_then(|fabric_idx| state.fabrics.get(fabric_idx));
let Some(fabric) = fabric else {
return sc_write(tw, SCStatusCodes::NoSharedTrustRoots, &[]);
};
let mut signature = MaybeUninit::<CanonPkcSignature>::uninit(); let signature = signature.init_with(CanonPkcSignature::init());
let sign_buf = tw.empty_as_mut_slice();
self.casep.compute_sigma2_signature(
self.crypto,
fabric,
sign_buf,
signature,
)?;
tw.start_struct(&TLVTag::Anonymous)?;
tw.str(&TLVTag::Context(1), our_random.access())?;
tw.u16(&TLVTag::Context(2), local_sessid)?;
tw.str(&TLVTag::Context(3), self.casep.our_pub_key().access())?;
tw.str_cb(&TLVTag::Context(4), |buf| {
self.casep.sigma2_encrypt(
self.crypto,
fabric,
our_random.reference(),
tt_hash.reference(),
signature.reference(),
resumption_id.reference(),
buf,
)
})?;
let session_params = crate::sc::SessionParameters {
max_paths_per_invoke: Some(
exchange.matter().dev_det().max_paths_per_invoke,
),
..Default::default()
};
session_params.to_tlv(&TLVTag::Context(5), &mut *tw)?;
tw.end_container()?;
if !tt_updated {
self.casep.update_tt(tw.as_slice())?;
tt_updated = true;
}
Ok(Some(OpCode::CASESigma2.into()))
})
})
.await
}
async fn handle_casesigma3(
&mut self,
exchange: &mut Exchange<'_>,
mut session: ReservedSession<'_>,
) -> Result<(), Error> {
check_opcode(exchange, OpCode::CASESigma3)?;
let status = exchange.with_state(|state| {
let sess = exchange.id().session(&mut state.sessions);
let fabric = NonZeroU8::new(self.casep.local_fabric_idx())
.and_then(|fabric_idx| state.fabrics.get(fabric_idx));
if let Some(fabric) = fabric {
let req = match get_root_node_struct(exchange.rx()?.payload()) {
Ok(req) => req,
Err(e) => {
error!("Sigma3 outer TLV parse failed: {}", e);
return Ok(SCStatusCodes::InvalidParameter);
}
};
let encrypted = match req.structure().and_then(|s| s.ctx(1)).and_then(|c| c.str()) {
Ok(s) => s,
Err(e) => {
error!("Sigma3 encrypted field parse failed: {}", e);
return Ok(SCStatusCodes::InvalidParameter);
}
};
let mut decrypted = alloc!([0; CASE_LARGE_BUF_SIZE]); if encrypted.len() > decrypted.len() {
error!(
"Encrypted Sigma3 data too large ({} bytes)",
encrypted.len()
);
return Ok(SCStatusCodes::InvalidParameter);
}
let decrypted = &mut decrypted[..encrypted.len()];
decrypted.copy_from_slice(encrypted);
let len =
match self
.casep
.sigma3_decrypt(self.crypto, fabric.ipk().op_key(), decrypted)
{
Ok(len) => len,
Err(e) => {
error!("Sigma3 AEAD decrypt failed: {}", e);
return Ok(SCStatusCodes::InvalidParameter);
}
};
let decrypted = &decrypted[..len];
let decrypted_req: Sigma3Decrypt<'_> = match get_root_node_struct(decrypted)
.and_then(|n| Sigma3Decrypt::from_tlv(&n))
{
Ok(req) => req,
Err(e) => {
error!("Sigma3 decrypted TLV parse failed: {}", e);
return Ok(SCStatusCodes::InvalidParameter);
}
};
let initiator_noc = CertRef::new(TLVElement::new(decrypted_req.initiator_noc.0));
let initiator_icac = decrypted_req
.initiator_icac
.map(|icac| CertRef::new(TLVElement::new(icac.0)));
let mut buf = alloc!([0; CASE_LARGE_BUF_SIZE]); let buf = &mut buf[..];
if let Err(e) = self.casep.validate_certs(
self.crypto,
state.rtc.utc_time(),
fabric,
&initiator_noc,
initiator_icac.as_ref(),
buf,
) {
error!("Certificate Chain doesn't match: {}", e);
Ok(SCStatusCodes::InvalidParameter)
} else if let Err(e) = self.casep.validate_peer_tbs_signature(
self.crypto,
decrypted_req.initiator_noc.0,
decrypted_req.initiator_icac.map(|a| a.0),
&initiator_noc,
CanonPkcSignatureRef::try_new(decrypted_req.signature.0)?,
buf,
) {
error!("Sigma3 Signature doesn't match: {}", e);
Ok(SCStatusCodes::InvalidParameter)
} else {
let mut peer_catids: NocCatIds = Default::default();
initiator_noc.get_cat_ids(&mut peer_catids)?;
self.casep.update_tt(exchange.rx()?.payload())?;
let mut session_keys = MaybeUninit::<CaseSessionKeys>::uninit(); let session_keys = session_keys.init_with(CaseSessionKeys::init());
self.casep.compute_session_keys(
self.crypto,
fabric.ipk().op_key(),
session_keys,
)?;
let peer_addr = sess.get_peer_addr();
let (dec_key, remaining) = session_keys
.reference()
.split::<AEAD_CANON_KEY_LEN, { AEAD_CANON_KEY_LEN * 2 }>();
let (enc_key, att_challenge) =
remaining.split::<AEAD_CANON_KEY_LEN, AEAD_CANON_KEY_LEN>();
session.update_with_state(
state,
fabric.node_id(),
initiator_noc.get_node_id()?,
self.casep.peer_sessid(),
self.casep.local_sessid(),
peer_addr,
SessionMode::Case {
fab_idx: unwrap!(NonZeroU8::new(self.casep.local_fabric_idx())),
cat_ids: peer_catids,
},
Some(dec_key),
Some(enc_key),
Some(att_challenge),
)?;
Ok(SCStatusCodes::SessionEstablishmentSuccess)
}
} else {
Ok(SCStatusCodes::NoSharedTrustRoots)
}
})?;
if matches!(status, SCStatusCodes::SessionEstablishmentSuccess) {
session.complete();
}
complete_with_status(exchange, status, &[]).await
}
}