use core::mem::MaybeUninit;
use core::num::NonZeroU8;
use crate::alloc;
use crate::cert::CertRef;
use crate::crypto::{
CanonPkcPublicKeyRef, CanonPkcSignature, CanonPkcSignatureRef, Crypto, Hash, AEAD_CANON_KEY_LEN,
};
use crate::error::{Error, ErrorCode};
use crate::sc::{complete_with_status, GeneralCode, OpCode, SCStatusCodes, StatusReport};
use crate::tlv::{get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVTag, TLVWrite};
use crate::transport::exchange::Exchange;
use crate::transport::session::{NocCatIds, ReservedSession, SessionMode};
use crate::utils::init::InitMaybeUninit;
use crate::utils::storage::ReadBuf;
use super::casep::{CaseP, CaseRandom, CaseRandomRef, CaseSessionKeys, CASE_RESUMPTION_ID_ZEROED};
use super::CASE_LARGE_BUF_SIZE;
#[derive(FromTLV, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[tlvargs(start = 1, lifetime = "'a")]
struct Sigma2Resp<'a> {
responder_random: OctetStr<'a>,
responder_sessid: u16,
responder_eph_pub_key: OctetStr<'a>,
encrypted2: OctetStr<'a>,
}
#[derive(FromTLV)]
#[tlvargs(start = 1, lifetime = "'a")]
struct TBEData2Decrypt<'a> {
responder_noc: OctetStr<'a>,
responder_icac: Option<OctetStr<'a>>,
signature: OctetStr<'a>,
resumption_id: OctetStr<'a>,
}
pub struct CaseInitiator<'a, C: Crypto + 'a> {
casep: CaseP<'a, C>,
peer_node_id: u64,
secret_key: Option<C::SecretKey<'a>>,
}
impl<'a, C: Crypto + 'a> CaseInitiator<'a, C> {
const fn new(peer_node_id: u64) -> Self {
Self {
casep: CaseP::new(),
peer_node_id,
secret_key: None,
}
}
pub async fn initiate(
exchange: &mut Exchange<'_>,
crypto: &'a C,
fab_idx: NonZeroU8,
peer_node_id: u64,
) -> Result<(), Error> {
let mut session = ReservedSession::reserve(exchange.matter(), crypto).await?;
let mut initiator = Self::new(peer_node_id);
let mut random = MaybeUninit::<CaseRandom>::uninit();
let random = random.init_with(CaseRandom::init());
let mut dest_id = MaybeUninit::<Hash>::uninit();
let dest_id = dest_id.init_with(Hash::init());
let local_sessid = exchange.with_state(|state| {
let local_sessid = state.sessions.get_next_sess_id();
let fabric = state.fabrics.fabric(fab_idx)?;
let secret_key = initiator.casep.start_initiator(
crypto,
fabric,
peer_node_id,
local_sessid,
random,
dest_id,
)?;
initiator.secret_key = Some(secret_key);
Ok(local_sessid)
})?;
let mut tt_updated = false;
exchange
.send_with(|_, tw| {
tw.start_struct(&TLVTag::Anonymous)?;
tw.str(&TLVTag::Context(1), random.access())?;
tw.u16(&TLVTag::Context(2), local_sessid)?;
tw.str(&TLVTag::Context(3), dest_id.access())?;
tw.str(&TLVTag::Context(4), initiator.casep.our_pub_key().access())?;
tw.end_container()?;
if !tt_updated {
initiator.casep.update_tt(tw.as_slice())?;
tt_updated = true;
}
Ok(Some(OpCode::CASESigma1.into()))
})
.await?;
exchange.recv_fetch().await?;
{
let rx = exchange.rx()?;
let meta = rx.meta();
if meta.proto_opcode == OpCode::StatusReport as u8 {
let mut rb = ReadBuf::new(rx.payload());
let status = StatusReport::read(&mut rb)?;
error!(
"CASE Sigma1 failed: general={:?}, proto_code={}",
status.general_code, status.proto_code
);
return Err(ErrorCode::Invalid.into());
}
if meta.proto_opcode != OpCode::CASESigma2 as u8 {
error!(
"Unexpected opcode: expected CASESigma2, got {}",
meta.proto_opcode
);
return Err(ErrorCode::InvalidOpcode.into());
}
}
let (peer_catids, _resumption_id) = {
let rx = exchange.rx()?;
let raw_sigma2_payload = rx.payload();
let sigma2 = Sigma2Resp::from_tlv(&get_root_node_struct(raw_sigma2_payload)?)?;
let result = exchange.with_state(|state| {
let mut encrypted2_buf = alloc!([0u8; CASE_LARGE_BUF_SIZE]);
if sigma2.encrypted2.0.len() > encrypted2_buf.len() {
error!("Sigma2 encrypted data too large");
return Err(ErrorCode::BufferTooSmall.into());
}
let encrypted2 = &mut encrypted2_buf[..sigma2.encrypted2.0.len()];
encrypted2.copy_from_slice(sigma2.encrypted2.0);
let peer_random = CaseRandomRef::try_new(sigma2.responder_random.0)?;
let peer_sessid = sigma2.responder_sessid;
let peer_eph_pub_key =
CanonPkcPublicKeyRef::try_new(sigma2.responder_eph_pub_key.0)?;
let fabric = state.fabrics.fabric(fab_idx)?;
let secret_key = initiator
.secret_key
.as_ref()
.ok_or(ErrorCode::InvalidState)?;
let len = initiator
.casep
.sigma2_decrypt(
crypto,
fabric,
secret_key,
raw_sigma2_payload,
peer_random,
peer_sessid,
peer_eph_pub_key,
encrypted2,
)
.inspect_err(|e| {
error!("Failed to decrypt Sigma2 TBE: {}", e);
})?;
initiator.secret_key = None;
let decrypted = &encrypted2[..len];
let decrypted_data = TBEData2Decrypt::from_tlv(&get_root_node_struct(decrypted)?)?;
let responder_noc = CertRef::new(TLVElement::new(decrypted_data.responder_noc.0));
let icac_cert = decrypted_data
.responder_icac
.as_ref()
.map(|icac| CertRef::new(TLVElement::new(icac.0)));
let mut tmp_buf = alloc!([0u8; CASE_LARGE_BUF_SIZE]); initiator
.casep
.validate_certs(
crypto,
state.rtc.utc_time(),
fabric,
&responder_noc,
icac_cert.as_ref(),
&mut tmp_buf[..],
)
.inspect_err(|e| {
error!("Certificate chain doesn't match: {}", e);
})?;
if responder_noc.get_node_id()? != initiator.peer_node_id {
error!(
"Responder node ID doesn't match expected peer: expected {}, got {}",
initiator.peer_node_id,
responder_noc.get_node_id()?
);
Err(ErrorCode::Invalid)?;
}
initiator
.casep
.validate_peer_tbs_signature(
crypto,
decrypted_data.responder_noc.0,
decrypted_data.responder_icac.map(|a| a.0),
&responder_noc,
CanonPkcSignatureRef::try_new(decrypted_data.signature.0)?,
&mut tmp_buf[..],
)
.inspect_err(|e| {
error!("Sigma2 signature doesn't match: {}", e);
})?;
let mut peer_catids: NocCatIds = Default::default();
responder_noc.get_cat_ids(&mut peer_catids)?;
let mut resumption_id = CASE_RESUMPTION_ID_ZEROED;
resumption_id
.access_mut()
.copy_from_slice(decrypted_data.resumption_id.0);
Ok((peer_catids, resumption_id))
});
if result.is_err() {
complete_with_status(exchange, SCStatusCodes::InvalidParameter, &[]).await?;
}
result
}?;
let mut signature = MaybeUninit::<CanonPkcSignature>::uninit();
let signature = signature.init_with(CanonPkcSignature::init());
exchange.with_state(|state| {
let fabric = state.fabrics.fabric(fab_idx)?;
let mut tmp_buf = alloc!([0u8; CASE_LARGE_BUF_SIZE]);
initiator
.casep
.compute_sigma3_signature(crypto, fabric, &mut tmp_buf[..], signature)
})?;
let mut tt_updated = false;
exchange
.send_with(|exchange_ref, tw| {
exchange_ref.with_state(|state| {
let fabric = state.fabrics.fabric(fab_idx)?;
tw.start_struct(&TLVTag::Anonymous)?;
tw.str_cb(&TLVTag::Context(1), |buf| {
initiator
.casep
.sigma3_encrypt(crypto, fabric, signature.reference(), buf)
})?;
tw.end_container()?;
if !tt_updated {
initiator.casep.update_tt(tw.as_slice())?;
tt_updated = true;
}
Ok(Some(OpCode::CASESigma3.into()))
})
})
.await?;
exchange.recv_fetch().await?;
{
let rx = exchange.rx()?;
let meta = rx.meta();
if meta.proto_opcode != OpCode::StatusReport as u8 {
error!(
"Unexpected opcode: expected StatusReport, got {}",
meta.proto_opcode
);
return Err(ErrorCode::InvalidOpcode.into());
}
let mut rb = ReadBuf::new(rx.payload());
let status = StatusReport::read(&mut rb)?;
if status.general_code != GeneralCode::Success
|| status.proto_code != SCStatusCodes::SessionEstablishmentSuccess as u16
{
error!(
"CASE failed: general={:?}, proto_code={}",
status.general_code, status.proto_code
);
return Err(ErrorCode::Invalid.into());
}
}
{
let mut session_keys = MaybeUninit::<CaseSessionKeys>::uninit();
let session_keys = session_keys.init_with(CaseSessionKeys::init());
let (peer_addr, local_node_id) = exchange.with_state(|state| {
let sess = exchange.id().session(&mut state.sessions);
let fabric = state.fabrics.fabric(fab_idx)?;
initiator.casep.compute_session_keys(
crypto,
fabric.ipk().op_key(),
session_keys,
)?;
Ok((sess.get_peer_addr(), fabric.node_id()))
})?;
let (enc_key, remaining) = session_keys
.reference()
.split::<AEAD_CANON_KEY_LEN, { AEAD_CANON_KEY_LEN * 2 }>();
let (dec_key, att_challenge) =
remaining.split::<AEAD_CANON_KEY_LEN, AEAD_CANON_KEY_LEN>();
session.update(
local_node_id,
peer_node_id,
initiator.casep.peer_sessid(),
initiator.casep.local_sessid(),
peer_addr,
SessionMode::Case {
fab_idx,
cat_ids: peer_catids,
},
Some(dec_key),
Some(enc_key),
Some(att_challenge),
)?;
}
session.complete();
exchange.acknowledge().await?;
info!(
"CASE session established: local_sessid={}, peer_sessid={}",
initiator.casep.local_sessid(),
initiator.casep.peer_sessid()
);
Ok(())
}
}