use crate::{
crypto::{crypto_digest, CryptoError},
device::{Device, DeviceError},
key::Tpm2shAlgId,
};
use std::{convert::TryFrom, fmt};
use thiserror::Error;
use tpm2_protocol::{
constant::TPM_PCR_SELECT_MAX,
data::{
TpmAlgId, TpmCap, TpmCc, TpmlPcrSelection, TpmsPcrSelect, TpmsPcrSelection,
TpmuCapabilities,
},
message::TpmPcrReadCommand,
TpmError,
};
#[derive(Debug, Error)]
pub enum PcrError {
#[error("device: {0}")]
Device(#[from] DeviceError),
#[error("invalid algorithm: {0:?}")]
InvalidAlgorithm(TpmAlgId),
#[error("invalid PCR selection: {0}")]
InvalidPcrSelection(String),
#[error("TPM: {0}")]
Tpm(TpmError),
#[error("crypto: {0}")]
Crypto(#[from] CryptoError),
}
impl From<TpmError> for PcrError {
fn from(err: TpmError) -> Self {
Self::Tpm(err)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Pcr {
pub bank: TpmAlgId,
pub index: u32,
pub value: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PcrBank {
pub alg: TpmAlgId,
pub count: usize,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct PcrSelection {
pub alg: TpmAlgId,
pub indices: Vec<u32>,
}
impl fmt::Display for PcrSelection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let indices_str = self
.indices
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(",");
write!(f, "{}:{}", crate::key::Tpm2shAlgId(self.alg), indices_str)
}
}
pub fn pcr_get_bank_list(device: &mut Device) -> Result<Vec<PcrBank>, PcrError> {
let (_, cap_data) = device.get_capability_page(TpmCap::Pcrs, 0, 1)?;
let mut banks = Vec::new();
if let TpmuCapabilities::Pcrs(pcrs) = cap_data.data {
for bank in pcrs.iter() {
banks.push(PcrBank {
alg: bank.hash,
count: bank.pcr_select.len() * 8,
});
}
}
if banks.is_empty() {
return Err(PcrError::InvalidPcrSelection(
"TPM reported no active PCR banks.".to_string(),
));
}
banks.sort_by_key(|b| b.alg);
Ok(banks)
}
pub fn pcr_selection_vec_from_str(selection_str: &str) -> Result<Vec<PcrSelection>, PcrError> {
selection_str
.split('+')
.map(|part| {
let (alg_str, indices_str) = part
.split_once(':')
.ok_or_else(|| PcrError::InvalidPcrSelection(part.to_string()))?;
let alg = Tpm2shAlgId::try_from(alg_str)
.map_err(|e| PcrError::InvalidPcrSelection(e.to_string()))?
.0;
let indices: Vec<u32> = indices_str
.split(',')
.map(|s| {
s.parse::<u32>()
.map_err(|_| PcrError::InvalidPcrSelection(indices_str.to_string()))
})
.collect::<Result<_, _>>()?;
Ok(PcrSelection { alg, indices })
})
.collect()
}
pub fn pcr_selection_vec_to_tpml(
selections: &[PcrSelection],
banks: &[PcrBank],
) -> Result<TpmlPcrSelection, PcrError> {
let mut list = TpmlPcrSelection::new();
for selection in selections {
let bank = banks
.iter()
.find(|b| b.alg == selection.alg)
.ok_or_else(|| {
PcrError::InvalidPcrSelection(format!(
"PCR bank for algorithm {:?} not found or supported by TPM",
selection.alg
))
})?;
let pcr_select_size = bank.count.div_ceil(8);
if pcr_select_size > TPM_PCR_SELECT_MAX {
return Err(PcrError::InvalidPcrSelection(format!(
"invalid select size {pcr_select_size} (> {TPM_PCR_SELECT_MAX})"
)));
}
let mut pcr_select_bytes = vec![0u8; pcr_select_size];
for &pcr_index in &selection.indices {
let pcr_index = pcr_index as usize;
if pcr_index >= bank.count {
return Err(PcrError::InvalidPcrSelection(format!(
"invalid index {pcr_index} for {:?} bank (max is {})",
bank.alg,
bank.count - 1
)));
}
pcr_select_bytes[pcr_index / 8] |= 1 << (pcr_index % 8);
}
list.try_push(TpmsPcrSelection {
hash: selection.alg,
pcr_select: TpmsPcrSelect::try_from(pcr_select_bytes.as_slice())?,
})?;
}
Ok(list)
}
pub fn pcr_read(
device: &mut Device,
pcr_selection_in: &TpmlPcrSelection,
) -> Result<(Vec<Pcr>, u32), PcrError> {
let cmd = TpmPcrReadCommand {
pcr_selection_in: *pcr_selection_in,
};
let (resp, _) = device.execute(&cmd, &[])?;
let pcr_read_resp = resp
.PcrRead()
.map_err(|_| DeviceError::ResponseMismatch(TpmCc::PcrRead))?;
let mut pcrs = Vec::new();
let mut digest_iter = pcr_read_resp.pcr_values.iter();
for selection in pcr_read_resp.pcr_selection_out.iter() {
for (byte_idx, &byte) in selection.pcr_select.iter().enumerate() {
if byte == 0 {
continue;
}
for bit_idx in 0..8 {
if (byte >> bit_idx) & 1 == 1 {
let pcr_index = u32::try_from(byte_idx * 8 + bit_idx)
.map_err(|_| PcrError::InvalidPcrSelection("PCR index overflow".into()))?;
let value = digest_iter.next().ok_or_else(|| {
PcrError::InvalidPcrSelection("PCR selection mismatch".to_string())
})?;
pcrs.push(Pcr {
bank: selection.hash,
index: pcr_index,
value: value.to_vec(),
});
}
}
}
}
Ok((pcrs, pcr_read_resp.pcr_update_counter))
}
pub fn pcr_composite_digest(pcrs: &[Pcr], alg: TpmAlgId) -> Result<Vec<u8>, PcrError> {
let digests: Vec<&[u8]> = pcrs.iter().map(|p| p.value.as_slice()).collect();
Ok(crypto_digest(alg, &digests)?)
}