use crate::error::PqRascvError;
pub const PCR_COUNT: usize = 8;
pub const PCR_SIZE: usize = 32;
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct PcrBank(pub [[u8; PCR_SIZE]; PCR_COUNT]);
impl Default for PcrBank {
fn default() -> Self {
Self([[0u8; PCR_SIZE]; PCR_COUNT])
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct Measurements {
pub pcrs: PcrBank,
pub firmware_hash: [u8; 32],
pub ai_model_hash: [u8; 32],
pub event_counter: u64,
}
impl Measurements {
#[must_use]
pub fn zeroed() -> Self {
Self {
pcrs: PcrBank::default(),
firmware_hash: [0u8; 32],
ai_model_hash: [0u8; 32],
event_counter: 0,
}
}
}
pub trait RoT {
fn measure(&self) -> Result<Measurements, PqRascvError>;
}
pub struct SoftwareRoT<'a> {
firmware: &'a [u8],
ai_model: Option<&'a [u8]>,
event_counter: u64,
pcr_regions: &'a [&'a [u8]],
}
impl<'a> SoftwareRoT<'a> {
#[must_use]
pub fn new(firmware: &'a [u8], ai_model: Option<&'a [u8]>, event_counter: u64) -> Self {
Self {
firmware,
ai_model,
event_counter,
pcr_regions: &[],
}
}
#[must_use]
pub fn with_pcr_regions(mut self, regions: &'a [&'a [u8]]) -> Self {
self.pcr_regions = regions;
self
}
}
impl<'a> RoT for SoftwareRoT<'a> {
fn measure(&self) -> Result<Measurements, PqRascvError> {
use sha3::{Digest, Sha3_256};
let firmware_hash: [u8; 32] = {
let mut h = Sha3_256::new();
h.update(self.firmware);
h.finalize().into()
};
let ai_model_hash: [u8; 32] = match self.ai_model {
Some(model) => {
let mut h = Sha3_256::new();
h.update(model);
h.finalize().into()
}
None => [0u8; 32],
};
let mut pcrs = PcrBank::default();
for (i, region) in self.pcr_regions.iter().enumerate().take(PCR_COUNT) {
let mut h = Sha3_256::new();
h.update(region);
pcrs.0[i] = h.finalize().into();
}
Ok(Measurements {
pcrs,
firmware_hash,
ai_model_hash,
event_counter: self.event_counter,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn firmware_hash_changes_with_content() {
let rot_a = SoftwareRoT::new(b"firmware_a", None, 0);
let rot_b = SoftwareRoT::new(b"firmware_b", None, 0);
let m_a = rot_a.measure().unwrap();
let m_b = rot_b.measure().unwrap();
assert_ne!(m_a.firmware_hash, m_b.firmware_hash);
}
#[test]
fn firmware_hash_is_deterministic() {
let rot = SoftwareRoT::new(b"stable-firmware", None, 42);
assert_eq!(rot.measure().unwrap(), rot.measure().unwrap());
}
#[test]
fn ai_model_hash_zero_when_absent() {
let rot = SoftwareRoT::new(b"fw", None, 0);
assert_eq!(rot.measure().unwrap().ai_model_hash, [0u8; 32]);
}
#[test]
fn pcr_regions_are_hashed() {
let regions: &[&[u8]] = &[b"pcr0", b"pcr1"];
let rot = SoftwareRoT::new(b"fw", None, 0).with_pcr_regions(regions);
let m = rot.measure().unwrap();
assert_ne!(m.pcrs.0[0], [0u8; 32]);
assert_ne!(m.pcrs.0[1], [0u8; 32]);
assert_eq!(m.pcrs.0[2], [0u8; 32]);
}
}