use core::borrow::Borrow;
use core::future::Future;
use core::mem::MaybeUninit;
use num_derive::FromPrimitive;
use crate::crypto::Crypto;
use crate::dm::AttrChangeNotifier;
use crate::error::{Error, ErrorCode};
use crate::respond::ExchangeHandler;
use crate::tlv::{FromTLV, ToTLV};
use crate::transport::exchange::{Exchange, MessageMeta};
use crate::utils::init::InitMaybeUninit;
use crate::utils::storage::{ReadBuf, WriteBuf};
use case::CaseResponder;
use pase::PaseResponder;
pub mod busy;
pub mod case;
pub mod pase;
pub const PROTO_ID_SECURE_CHANNEL: u16 = 0x00;
#[derive(FromPrimitive, Debug, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum OpCode {
MsgCounterSyncReq = 0x00,
MsgCounterSyncResp = 0x01,
MRPStandAloneAck = 0x10,
PBKDFParamRequest = 0x20,
PBKDFParamResponse = 0x21,
PASEPake1 = 0x22,
PASEPake2 = 0x23,
PASEPake3 = 0x24,
CASESigma1 = 0x30,
CASESigma2 = 0x31,
CASESigma3 = 0x32,
CASESigma2Resume = 0x33,
StatusReport = 0x40,
}
impl OpCode {
pub fn meta(&self) -> MessageMeta {
MessageMeta {
proto_id: PROTO_ID_SECURE_CHANNEL,
proto_opcode: *self as u8,
reliable: !matches!(self, Self::MRPStandAloneAck),
}
}
pub fn is_tlv(&self) -> bool {
!matches!(
self,
Self::MRPStandAloneAck
| Self::StatusReport
| Self::MsgCounterSyncReq
| Self::MsgCounterSyncResp
)
}
}
impl From<OpCode> for MessageMeta {
fn from(op: OpCode) -> Self {
op.meta()
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum SCStatusCodes {
SessionEstablishmentSuccess = 0,
NoSharedTrustRoots = 1,
InvalidParameter = 2,
CloseSession = 3,
Busy = 4,
SessionNotFound = 5,
}
impl SCStatusCodes {
pub fn reliable(&self) -> bool {
!matches!(
self,
SCStatusCodes::CloseSession | SCStatusCodes::Busy | SCStatusCodes::SessionNotFound
)
}
pub fn as_report<'a>(&self, payload: &'a [u8]) -> StatusReport<'a> {
let general_code = match self {
SCStatusCodes::SessionEstablishmentSuccess => GeneralCode::Success,
SCStatusCodes::CloseSession => GeneralCode::Success,
SCStatusCodes::Busy => GeneralCode::Busy,
SCStatusCodes::InvalidParameter
| SCStatusCodes::NoSharedTrustRoots
| SCStatusCodes::SessionNotFound => GeneralCode::Failure,
};
StatusReport {
general_code,
proto_id: PROTO_ID_SECURE_CHANNEL as u32,
proto_code: *self as u16,
proto_data: payload,
}
}
}
pub async fn complete_with_status(
exchange: &mut Exchange<'_>,
status_code: SCStatusCodes,
payload: &[u8],
) -> Result<(), Error> {
exchange
.send_with(|_, wb| sc_write(wb, status_code, payload))
.await
}
pub fn sc_write(
wb: &mut WriteBuf,
status_code: SCStatusCodes,
payload: &[u8],
) -> Result<Option<MessageMeta>, Error> {
status_code.as_report(payload).write(wb)?;
Ok(Some(
OpCode::StatusReport.meta().reliable(status_code.reliable()),
))
}
#[allow(dead_code)]
#[derive(FromPrimitive, PartialEq, Eq, Debug, Copy, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum GeneralCode {
Success = 0,
Failure = 1,
BadPrecondition = 2,
OutOfRange = 3,
BadRequest = 4,
Unsupported = 5,
Unexpected = 6,
ResourceExhausted = 7,
Busy = 8,
Timeout = 9,
Continue = 10,
Aborted = 11,
InvalidArgument = 12,
NotFound = 13,
AlreadyExists = 14,
PermissionDenied = 15,
DataLoss = 16,
}
#[derive(Default, FromTLV, ToTLV, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[tlvargs(start = 1)]
pub(crate) struct SessionParameters {
pub(crate) sii: Option<u32>,
pub(crate) sai: Option<u32>,
pub(crate) sat: Option<u16>,
pub(crate) dm_revision: Option<u16>,
pub(crate) im_revision: Option<u16>,
pub(crate) spec_version: Option<u32>,
pub(crate) max_paths_per_invoke: Option<u16>,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct StatusReport<'a> {
pub general_code: GeneralCode,
pub proto_id: u32,
pub proto_code: u16,
pub proto_data: &'a [u8],
}
impl<'a> StatusReport<'a> {
pub fn read<T>(pb: &'a mut ReadBuf<T>) -> Result<Self, Error>
where
T: Borrow<[u8]>,
{
Ok(Self {
general_code: num::FromPrimitive::from_u16(pb.le_u16()?)
.ok_or(ErrorCode::InvalidOpcode)?,
proto_id: pb.le_u32()?,
proto_code: pb.le_u16()?,
proto_data: pb.as_slice(),
})
}
pub fn write(&self, wb: &mut WriteBuf) -> Result<(), Error> {
wb.le_u16(self.general_code as u16)?;
wb.le_u32(self.proto_id)?;
wb.le_u16(self.proto_code)?;
wb.copy_from_slice(self.proto_data)?;
Ok(())
}
}
pub struct SecureChannel<'a, C> {
crypto: C,
notify: &'a dyn AttrChangeNotifier,
}
impl<'a, C: Crypto> SecureChannel<'a, C> {
#[inline(always)]
pub const fn new(crypto: C, notify: &'a dyn AttrChangeNotifier) -> Self {
Self { crypto, notify }
}
pub async fn handle(&self, mut exchange: Exchange<'_>) -> Result<(), Error> {
if exchange.rx().is_err() {
exchange.recv_fetch().await?;
}
let meta = exchange.rx()?.meta();
if meta.proto_id != PROTO_ID_SECURE_CHANNEL {
Err(ErrorCode::InvalidProto)?;
}
match meta.opcode()? {
OpCode::PBKDFParamRequest => {
let mut pase = MaybeUninit::uninit(); pase.init_with(PaseResponder::init(&self.crypto, self.notify))
.handle(&mut exchange)
.await
}
OpCode::CASESigma1 => {
let mut case = MaybeUninit::uninit(); case.init_with(CaseResponder::init(&self.crypto))
.handle(&mut exchange)
.await
}
opcode => {
error!("Invalid opcode: {:?}", opcode);
Err(ErrorCode::InvalidOpcode.into())
}
}
}
}
impl<C: Crypto> ExchangeHandler for SecureChannel<'_, C> {
fn handle(&self, exchange: Exchange<'_>) -> impl Future<Output = Result<(), Error>> {
SecureChannel::handle(self, exchange)
}
}
fn check_opcode(exchange: &Exchange<'_>, opcode: OpCode) -> Result<(), Error> {
let meta = exchange.rx()?.meta();
let their_opcode = meta.opcode::<OpCode>()?;
if their_opcode == opcode {
Ok(())
} else {
error!("Invalid opcode: {:?}, expected: {:?}", their_opcode, opcode);
if matches!(their_opcode, OpCode::StatusReport) {
let mut rb = ReadBuf::new(exchange.rx()?.payload());
match StatusReport::read(&mut rb) {
Ok(status_report) => error!("Status Report: {:?}", status_report),
Err(e) => error!("Failed to parse Status Report: {:?}", e),
}
}
Err(ErrorCode::Invalid.into())
}
}