use crate::buffer::CryptoBuffer;
use crate::TlsError;
use crate::parse_buffer::{ParseBuffer, ParseError};
use heapless::Vec;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct PreSharedKeyClientHello<'a, const N: usize> {
pub identities: Vec<&'a [u8], N>,
pub binders: Vec<&'a [u8], N>,
pub hash_size: usize,
}
impl<'a, const N: usize> PreSharedKeyClientHello<'a, N> {
pub fn parse(buf: &mut ParseBuffer<'a>) -> Result<Self, ParseError> {
let identities_len = buf.read_u16()? as usize;
if identities_len < 7 {
return Err(ParseError::InvalidData);
}
let mut id_buf = buf.slice(identities_len)?;
let mut identities: Vec<&'a [u8], N> = Vec::new();
while !id_buf.is_empty() {
let id_len = id_buf.read_u16()? as usize;
if id_len < 1 {
return Err(ParseError::InvalidData);
}
let id_data = id_buf.slice(id_len)?;
identities
.push(id_data.as_slice())
.map_err(|_| ParseError::InsufficientSpace)?;
let _ = id_buf.read_u32()?;
}
if identities.is_empty() {
return Err(ParseError::InvalidData);
}
let binders_total_len = buf.read_u16()? as usize;
if binders_total_len < 33 {
return Err(ParseError::InvalidData);
}
let mut bind_buf = buf.slice(binders_total_len)?;
let mut binders: Vec<&'a [u8], N> = Vec::new();
let mut hash_size = 0usize;
while !bind_buf.is_empty() {
let b_len = bind_buf.read_u8()? as usize;
if b_len < 32 {
return Err(ParseError::InvalidData);
}
let b_data = bind_buf.slice(b_len)?;
binders
.push(b_data.as_slice())
.map_err(|_| ParseError::InsufficientSpace)?;
if hash_size == 0 {
hash_size = b_len;
} else if hash_size != b_len {
return Err(ParseError::InvalidData);
}
}
if binders.len() != identities.len() {
return Err(ParseError::InvalidData);
}
Ok(Self {
identities,
binders,
hash_size,
})
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), TlsError> {
buf.with_u16_length(|buf| {
for identity in &self.identities {
buf.with_u16_length(|buf| buf.extend_from_slice(identity))
.map_err(|_| TlsError::EncodeError)?;
buf.push_u32(0).map_err(|_| TlsError::EncodeError)?;
}
Ok(())
})
.map_err(|_| TlsError::EncodeError)?;
let binders_len = (1 + self.hash_size) * self.identities.len();
buf.push_u16(binders_len as u16)
.map_err(|_| TlsError::EncodeError)?;
for _ in 0..binders_len {
buf.push(0).map_err(|_| TlsError::EncodeError)?;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct PreSharedKeyServerHello {
pub selected_identity: u16,
}
impl PreSharedKeyServerHello {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
Ok(Self {
selected_identity: buf.read_u16()?,
})
}
pub fn encode(self, buf: &mut CryptoBuffer) -> Result<(), TlsError> {
buf.push_u16(self.selected_identity)
.map_err(|_| TlsError::EncodeError)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_single_identity_and_binder() {
let identity = b"vader";
let binder = [0xabu8; 32];
let mut wire = std::vec::Vec::new();
wire.extend_from_slice(&11u16.to_be_bytes());
wire.extend_from_slice(&5u16.to_be_bytes());
wire.extend_from_slice(identity);
wire.extend_from_slice(&0u32.to_be_bytes());
wire.extend_from_slice(&33u16.to_be_bytes());
wire.push(32);
wire.extend_from_slice(&binder);
let mut buf = ParseBuffer::new(&wire);
let psk: PreSharedKeyClientHello<'_, 4> =
PreSharedKeyClientHello::parse(&mut buf).expect("parse");
assert_eq!(psk.identities.len(), 1);
assert_eq!(psk.identities[0], identity);
assert_eq!(psk.binders.len(), 1);
assert_eq!(psk.binders[0], &binder);
assert_eq!(psk.hash_size, 32);
assert!(buf.is_empty());
}
#[test]
fn reject_undersized_binder() {
let mut wire = std::vec::Vec::new();
wire.extend_from_slice(&11u16.to_be_bytes());
wire.extend_from_slice(&5u16.to_be_bytes());
wire.extend_from_slice(b"vader");
wire.extend_from_slice(&0u32.to_be_bytes());
wire.extend_from_slice(&17u16.to_be_bytes());
wire.push(16);
wire.extend_from_slice(&[0u8; 16]);
let mut buf = ParseBuffer::new(&wire);
let result: Result<PreSharedKeyClientHello<'_, 4>, _> =
PreSharedKeyClientHello::parse(&mut buf);
assert!(matches!(result, Err(ParseError::InvalidData)));
}
#[test]
fn reject_binder_count_mismatch() {
let mut wire = std::vec::Vec::new();
wire.extend_from_slice(&22u16.to_be_bytes());
wire.extend_from_slice(&5u16.to_be_bytes());
wire.extend_from_slice(b"vader");
wire.extend_from_slice(&0u32.to_be_bytes());
wire.extend_from_slice(&5u16.to_be_bytes());
wire.extend_from_slice(b"luke!");
wire.extend_from_slice(&0u32.to_be_bytes());
wire.extend_from_slice(&33u16.to_be_bytes());
wire.push(32);
wire.extend_from_slice(&[0u8; 32]);
let mut buf = ParseBuffer::new(&wire);
let result: Result<PreSharedKeyClientHello<'_, 4>, _> =
PreSharedKeyClientHello::parse(&mut buf);
assert!(matches!(result, Err(ParseError::InvalidData)));
}
}