use crate::{
TpmCast, TpmCastMut, TpmMarshal, TpmProtocolError, TpmResult, TpmSized, TpmUnmarshal,
basic::TpmUint32,
};
use core::{
convert::TryFrom,
fmt::Debug,
mem::{MaybeUninit, size_of},
ops::Deref,
slice,
};
const TPML_COUNT_LEN: usize = size_of::<TpmUint32>();
#[repr(transparent)]
pub struct Tpml<const CAPACITY: usize>([u8]);
impl<const CAPACITY: usize> Tpml<CAPACITY> {
pub fn cast(buf: &[u8]) -> TpmResult<&Self> {
Self::validate(buf)?;
Ok(unsafe { Self::cast_unchecked(buf) })
}
#[must_use]
pub unsafe fn cast_unchecked(buf: &[u8]) -> &Self {
let ptr = core::ptr::from_ref(buf) as *const Self;
unsafe { &*ptr }
}
pub fn cast_mut(buf: &mut [u8]) -> TpmResult<&mut Self> {
Self::validate(buf)?;
Ok(unsafe { Self::cast_mut_unchecked(buf) })
}
#[must_use]
pub unsafe fn cast_mut_unchecked(buf: &mut [u8]) -> &mut Self {
let ptr = core::ptr::from_mut(buf) as *mut Self;
unsafe { &mut *ptr }
}
#[must_use]
pub const fn as_bytes(&self) -> &[u8] {
&self.0
}
#[must_use]
pub fn as_bytes_mut(&mut self) -> &mut [u8] {
&mut self.0
}
#[must_use]
pub fn count(&self) -> usize {
Self::read_count(&self.0)
}
#[must_use]
pub fn items_bytes(&self) -> &[u8] {
&self.0[TPML_COUNT_LEN..]
}
#[must_use]
pub fn items_bytes_mut(&mut self) -> &mut [u8] {
&mut self.0[TPML_COUNT_LEN..]
}
#[must_use]
pub const fn len(&self) -> usize {
self.0.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.count() == 0
}
fn validate(buf: &[u8]) -> TpmResult<()> {
if buf.len() < TPML_COUNT_LEN {
return Err(TpmProtocolError::UnexpectedEnd);
}
let item_count = Self::read_count(buf);
if item_count > CAPACITY {
return Err(TpmProtocolError::TooManyItems);
}
Ok(())
}
fn read_count(buf: &[u8]) -> usize {
let raw = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
raw as usize
}
}
impl<const CAPACITY: usize> TpmCast for Tpml<CAPACITY> {
fn cast(buf: &[u8]) -> TpmResult<&Self> {
Self::cast(buf)
}
unsafe fn cast_unchecked(buf: &[u8]) -> &Self {
unsafe { Self::cast_unchecked(buf) }
}
}
impl<const CAPACITY: usize> TpmCastMut for Tpml<CAPACITY> {
fn cast_mut(buf: &mut [u8]) -> TpmResult<&mut Self> {
Self::cast_mut(buf)
}
unsafe fn cast_mut_unchecked(buf: &mut [u8]) -> &mut Self {
unsafe { Self::cast_mut_unchecked(buf) }
}
}
impl<const CAPACITY: usize> AsRef<[u8]> for Tpml<CAPACITY> {
fn as_ref(&self) -> &[u8] {
self.as_bytes()
}
}
impl<const CAPACITY: usize> AsMut<[u8]> for Tpml<CAPACITY> {
fn as_mut(&mut self) -> &mut [u8] {
self.as_bytes_mut()
}
}
#[derive(Clone, Copy)]
pub struct TpmList<T: Copy, const CAPACITY: usize> {
items: [MaybeUninit<T>; CAPACITY],
len: usize,
}
impl<T: Copy, const CAPACITY: usize> TpmList<T, CAPACITY> {
#[must_use]
pub fn new() -> Self {
Self {
items: [const { MaybeUninit::uninit() }; CAPACITY],
len: 0,
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn try_push(&mut self, item: T) -> Result<(), TpmProtocolError> {
if self.len >= CAPACITY {
return Err(TpmProtocolError::TooManyItems);
}
self.items[self.len].write(item);
self.len += 1;
Ok(())
}
pub fn try_extend_from_slice(&mut self, slice: &[T]) -> Result<(), TpmProtocolError> {
let new_len = self
.len
.checked_add(slice.len())
.ok_or(TpmProtocolError::TooManyItems)?;
if new_len > CAPACITY {
return Err(TpmProtocolError::TooManyItems);
}
for (dest, src) in self.items[self.len..new_len].iter_mut().zip(slice) {
dest.write(*src);
}
self.len = new_len;
Ok(())
}
}
impl<T: Copy, const CAPACITY: usize> Deref for TpmList<T, CAPACITY> {
type Target = [T];
fn deref(&self) -> &Self::Target {
unsafe { slice::from_raw_parts(self.items.as_ptr().cast::<T>(), self.len) }
}
}
impl<T: Copy, const CAPACITY: usize> Default for TpmList<T, CAPACITY> {
fn default() -> Self {
Self::new()
}
}
impl<T: Copy + Debug, const CAPACITY: usize> Debug for TpmList<T, CAPACITY> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_list().entries(self.iter()).finish()
}
}
impl<T: Copy + PartialEq, const CAPACITY: usize> PartialEq for TpmList<T, CAPACITY> {
fn eq(&self, other: &Self) -> bool {
**self == **other
}
}
impl<T: Copy + Eq, const CAPACITY: usize> Eq for TpmList<T, CAPACITY> {}
impl<T: TpmSized + Copy, const CAPACITY: usize> TpmSized for TpmList<T, CAPACITY> {
const SIZE: usize = size_of::<TpmUint32>() + (T::SIZE * CAPACITY);
fn len(&self) -> usize {
size_of::<TpmUint32>() + self.iter().map(TpmSized::len).sum::<usize>()
}
}
impl<T: TpmMarshal + Copy, const CAPACITY: usize> TpmMarshal for TpmList<T, CAPACITY> {
fn marshal(&self, writer: &mut crate::TpmWriter) -> TpmResult<()> {
let len = TpmUint32::try_from(self.len).map_err(|_| TpmProtocolError::IntegerTooLarge)?;
TpmMarshal::marshal(&len, writer)?;
for item in &**self {
TpmMarshal::marshal(item, writer)?;
}
Ok(())
}
}
impl<T: TpmUnmarshal + Copy, const CAPACITY: usize> TpmUnmarshal for TpmList<T, CAPACITY> {
fn unmarshal(buf: &[u8]) -> TpmResult<(Self, &[u8])> {
let (count_u32, mut buf) = TpmUint32::unmarshal(buf)?;
let count = u32::from(count_u32) as usize;
if count > CAPACITY {
return Err(TpmProtocolError::TooManyItems);
}
let mut list = Self::new();
for _ in 0..count {
let (item, rest) = T::unmarshal(buf)?;
list.try_push(item)?;
buf = rest;
}
Ok((list, buf))
}
}