use rand_core::RngCore;
use crate::crypto::{
CanonEcPointRef, Crypto, HmacHash, HmacHashRef, Kdf, AEAD_CANON_KEY_LEN, EC_POINT_ZEROED,
HMAC_HASH_ZEROED,
};
use crate::error::{Error, ErrorCode};
use crate::sc::pase::spake2p::{
ProverContext, Spake2P, Spake2pRandom, Spake2pSessionKeys, Spake2pVerifierPasswordRef,
SPAKE2P_VERIFIER_SALT_LEN, SPAKE2P_VERIFIER_SALT_MIN_LEN,
};
use crate::sc::{complete_with_status, GeneralCode, OpCode, SCStatusCodes, StatusReport};
use crate::tlv::{FromTLV, OctetStr, TLVElement, TagType, ToTLV};
use crate::transport::exchange::Exchange;
use crate::transport::session::{ReservedSession, SessionMode};
use crate::utils::storage::ReadBuf;
use super::{PBKDFParamReq, PBKDFParamResp, Pake1, Pake2, Pake3, SPAKE2_SESSION_KEYS_INFO};
pub struct PaseInitiator<C: Crypto> {
crypto: C,
spake2p: Spake2P,
initiator_random: Spake2pRandom,
local_sessid: u16,
peer_sessid: u16,
prover_context: Option<ProverContext>,
ca: HmacHash,
}
impl<C: Crypto> PaseInitiator<C> {
const fn new(crypto: C) -> Self {
Self {
crypto,
spake2p: Spake2P::new(),
initiator_random: Spake2pRandom::new(),
local_sessid: 0,
peer_sessid: 0,
prover_context: None,
ca: HMAC_HASH_ZEROED,
}
}
pub async fn initiate(
exchange: &mut Exchange<'_>,
crypto: C,
password: u32,
) -> Result<(), Error> {
let session = ReservedSession::reserve(exchange.matter(), &crypto).await?;
let mut initiator = Self::new(crypto);
let (salt, salt_len, iterations) = match initiator.exchange_pbkdf_params(exchange).await {
Ok(result) => result,
Err(e) => {
let _ = complete_with_status(exchange, SCStatusCodes::InvalidParameter, &[]).await;
return Err(e);
}
};
if let Err(e) = initiator
.exchange_pake1_pake2(exchange, password, &salt[..salt_len], iterations)
.await
{
let _ = complete_with_status(exchange, SCStatusCodes::InvalidParameter, &[]).await;
return Err(e);
}
initiator.exchange_pake3_status(exchange).await?;
initiator.complete_session(exchange, session).await
}
async fn exchange_pbkdf_params(
&mut self,
exchange: &mut Exchange<'_>,
) -> Result<([u8; SPAKE2P_VERIFIER_SALT_LEN], usize, u32), Error> {
let mut rand = self.crypto.rand()?;
rand.fill_bytes(self.initiator_random.access_mut());
self.local_sessid = exchange.with_state(|state| Ok(state.sessions.get_next_sess_id()))?;
let req = PBKDFParamReq {
initiator_random: OctetStr::new(self.initiator_random.access()),
initiator_ssid: self.local_sessid,
passcode_id: 0,
has_params: false,
session_parameters: None,
};
let context = {
let mut context = None;
exchange
.send_with(|_, wb| {
req.to_tlv(&TagType::Anonymous, &mut *wb)?;
context = Some(self.spake2p.start_context(
&self.crypto,
self.local_sessid,
0, wb.as_slice(),
)?);
Ok(Some(OpCode::PBKDFParamRequest.into()))
})
.await?;
context.ok_or(ErrorCode::Invalid)?
};
exchange.recv_fetch().await?;
let rx = exchange.rx()?;
let meta = rx.meta();
if meta.proto_opcode == OpCode::StatusReport as u8 {
let mut rb = ReadBuf::new(rx.payload());
let status = StatusReport::read(&mut rb)?;
error!(
"PASE failed: general={:?}, proto_code={}",
status.general_code, status.proto_code
);
return Err(ErrorCode::Invalid.into());
}
if meta.proto_opcode != OpCode::PBKDFParamResponse as u8 {
error!(
"Unexpected opcode: expected PBKDFParamResponse, got {}",
meta.proto_opcode
);
return Err(ErrorCode::InvalidOpcode.into());
}
let resp = PBKDFParamResp::from_tlv(&TLVElement::new(rx.payload()))?;
if resp.initiator_random.0 != self.initiator_random.access() {
error!("PBKDFParamResponse: initiator_random mismatch");
return Err(ErrorCode::Invalid.into());
}
self.peer_sessid = resp.responder_ssid;
let params = resp.params.ok_or_else(|| {
error!("PBKDFParamResponse: missing PBKDF params");
ErrorCode::Invalid
})?;
let mut salt = [0u8; SPAKE2P_VERIFIER_SALT_LEN];
let salt_len = params.salt.0.len();
if !(SPAKE2P_VERIFIER_SALT_MIN_LEN..=SPAKE2P_VERIFIER_SALT_LEN).contains(&salt_len) {
error!("PBKDFParamResponse: invalid salt length {}", salt_len);
return Err(ErrorCode::Invalid.into());
}
salt[..salt_len].copy_from_slice(params.salt.0);
self.spake2p.finish_context::<&C>(context, rx.payload())?;
Ok((salt, salt_len, params.iterations))
}
async fn exchange_pake1_pake2(
&mut self,
exchange: &mut Exchange<'_>,
password: u32,
salt: &[u8],
iterations: u32,
) -> Result<(), Error> {
let mut pa = EC_POINT_ZEROED;
let password_bytes = password.to_le_bytes();
let password_ref = Spake2pVerifierPasswordRef::new(&password_bytes);
let prover_ctx =
self.spake2p
.setup_prover(&self.crypto, password_ref, salt, iterations, &mut pa)?;
let pake1 = Pake1 {
pa: OctetStr::new(pa.access()),
};
exchange
.send_with(|_, wb| {
pake1.to_tlv(&TagType::Anonymous, wb)?;
Ok(Some(OpCode::PASEPake1.into()))
})
.await?;
exchange.recv_fetch().await?;
let rx = exchange.rx()?;
let meta = rx.meta();
if meta.proto_opcode == OpCode::StatusReport as u8 {
let mut rb = ReadBuf::new(rx.payload());
let status = StatusReport::read(&mut rb)?;
error!(
"PASE Pake1 failed: general={:?}, proto_code={}",
status.general_code, status.proto_code
);
return Err(ErrorCode::Invalid.into());
}
if meta.proto_opcode != OpCode::PASEPake2 as u8 {
error!(
"Unexpected opcode: expected PASEPake2, got {}",
meta.proto_opcode
);
return Err(ErrorCode::InvalidOpcode.into());
}
let pake2 = Pake2::from_tlv(&TLVElement::new(rx.payload()))?;
let pb: CanonEcPointRef<'_> = pake2.pb.0.try_into()?;
let cb: HmacHashRef<'_> = pake2.cb.0.try_into()?;
let mut ca = HMAC_HASH_ZEROED;
self.spake2p
.complete_prover(&self.crypto, &prover_ctx, pa.reference(), pb, cb, &mut ca)
.map_err(|_| {
error!("PASE: cB verification failed (wrong password?)");
ErrorCode::Invalid
})?;
self.ca = ca;
self.prover_context = Some(prover_ctx);
Ok(())
}
async fn exchange_pake3_status(&mut self, exchange: &mut Exchange<'_>) -> Result<(), Error> {
let pake3 = Pake3 {
ca: OctetStr::new(self.ca.access()),
};
exchange
.send_with(|_, wb| {
pake3.to_tlv(&TagType::Anonymous, wb)?;
Ok(Some(OpCode::PASEPake3.into()))
})
.await?;
exchange.recv_fetch().await?;
let rx = exchange.rx()?;
let meta = rx.meta();
if meta.proto_opcode != OpCode::StatusReport as u8 {
error!(
"Unexpected opcode: expected StatusReport, got {}",
meta.proto_opcode
);
return Err(ErrorCode::InvalidOpcode.into());
}
let mut rb = ReadBuf::new(rx.payload());
let status = StatusReport::read(&mut rb)?;
if status.general_code != GeneralCode::Success
|| status.proto_code != SCStatusCodes::SessionEstablishmentSuccess as u16
{
error!(
"PASE failed: general={:?}, proto_code={}",
status.general_code, status.proto_code
);
return Err(ErrorCode::Invalid.into());
}
Ok(())
}
async fn complete_session(
&mut self,
exchange: &mut Exchange<'_>,
mut session: ReservedSession<'_>,
) -> Result<(), Error> {
let ke = self.spake2p.ke();
let mut session_keys = Spake2pSessionKeys::new();
self.crypto
.kdf()?
.expand(&[], ke, SPAKE2_SESSION_KEYS_INFO, &mut session_keys)
.map_err(|_| ErrorCode::InvalidData)?;
let peer_addr = exchange.with_state(|state| {
let sess = exchange.id().session(&mut state.sessions);
Ok(sess.get_peer_addr())
})?;
let (enc_key, remaining) = session_keys
.reference()
.split::<AEAD_CANON_KEY_LEN, { AEAD_CANON_KEY_LEN * 2 }>();
let (dec_key, att_challenge) = remaining.split::<AEAD_CANON_KEY_LEN, AEAD_CANON_KEY_LEN>();
session.update(
0,
0,
self.peer_sessid,
self.local_sessid,
peer_addr,
SessionMode::Pase { fab_idx: 0 },
Some(dec_key),
Some(enc_key),
Some(att_challenge),
)?;
session.complete();
exchange.acknowledge().await?;
info!(
"PASE session established: local_sessid={}, peer_sessid={}",
self.local_sessid, self.peer_sessid
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pbkdf_param_req_encoding() {
let random = [0u8; 32];
let req = PBKDFParamReq {
initiator_random: OctetStr::new(&random),
initiator_ssid: 1234,
passcode_id: 0,
has_params: false,
session_parameters: None,
};
let mut buf = [0u8; 128];
let mut wb = crate::utils::storage::WriteBuf::new(&mut buf);
req.to_tlv(&TagType::Anonymous, &mut wb).unwrap();
assert!(!wb.as_slice().is_empty());
}
#[test]
fn test_pake1_encoding() {
let pa = [0u8; 65];
let pake1 = Pake1 {
pa: OctetStr::new(&pa),
};
let mut buf = [0u8; 128];
let mut wb = crate::utils::storage::WriteBuf::new(&mut buf);
pake1.to_tlv(&TagType::Anonymous, &mut wb).unwrap();
assert!(!wb.as_slice().is_empty());
}
#[test]
fn test_pake3_encoding() {
let ca = [0u8; 32];
let pake3 = Pake3 {
ca: OctetStr::new(&ca),
};
let mut buf = [0u8; 128];
let mut wb = crate::utils::storage::WriteBuf::new(&mut buf);
pake3.to_tlv(&TagType::Anonymous, &mut wb).unwrap();
assert!(!wb.as_slice().is_empty());
}
}