use std::{convert::TryFrom, str::FromStr};
use thiserror::Error;
use tpm2_protocol::data::TpmHt;
#[derive(Debug, Error)]
pub enum HandleError {
#[error("invalid handle")]
InvalidHandle,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HandleClass {
Tpm,
Vtpm,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Handle(pub (HandleClass, u32));
impl Handle {
#[must_use]
pub fn class(&self) -> HandleClass {
self.0 .0
}
#[must_use]
pub fn value(&self) -> u32 {
self.0 .1
}
}
impl std::fmt::Display for Handle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let value = self.value();
match self.class() {
HandleClass::Tpm => write!(f, "tpm:{value:08x}"),
HandleClass::Vtpm => write!(f, "vtpm:{value:08x}"),
}
}
}
impl FromStr for Handle {
type Err = HandleError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (scheme_str, value_str) = s.split_once(':').ok_or(HandleError::InvalidHandle)?;
let class = match scheme_str {
"tpm" => HandleClass::Tpm,
"vtpm" => HandleClass::Vtpm,
_ => return Err(HandleError::InvalidHandle),
};
let value = u32::from_str_radix(value_str, 16).map_err(|_| HandleError::InvalidHandle)?;
Ok(Handle((class, value)))
}
}
impl TryFrom<Handle> for TpmHt {
type Error = HandleError;
fn try_from(handle: Handle) -> Result<Self, Self::Error> {
let raw_handle = handle.value();
let ht_byte = (raw_handle >> 24) as u8;
TpmHt::try_from(ht_byte).map_err(|()| HandleError::InvalidHandle)
}
}
#[derive(Debug, Error, PartialEq, Eq)]
pub enum HandlePatternError {
#[error("handle pattern is not a valid hex string")]
InvalidHexString,
#[error("handle pattern has less than eight characters")]
TooFewDigits,
#[error("handle pattern has more than one '*'")]
TooManyAsterisks,
#[error("handle pattern has more than eight characters")]
TooManyDigits,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HandlePattern(u32, u32);
impl HandlePattern {
pub fn new(query: &str) -> Result<Self, HandlePatternError> {
if query == "*" {
return Ok(Self(0, 0));
}
let mut mask: u32 = 0;
let mut value: u32 = 0;
let (prefix, suffix) = if let Some((p, s)) = query.split_once('*') {
if s.contains('*') {
return Err(HandlePatternError::TooManyAsterisks);
}
if p.len() + s.len() > 8 {
return Err(HandlePatternError::TooManyDigits);
}
(p, s)
} else {
if query.len() < 8 {
return Err(HandlePatternError::TooFewDigits);
}
if query.len() > 8 {
return Err(HandlePatternError::TooManyDigits);
}
(query, "")
};
for (i, c) in prefix.chars().enumerate() {
let shift = (7 - i) * 4;
match c.to_digit(16) {
Some(v) => {
mask |= 0xF << shift;
value |= v << shift;
}
None if c == '?' => {}
None => return Err(HandlePatternError::InvalidHexString),
}
}
for (i, c) in suffix.chars().rev().enumerate() {
let shift = i * 4;
match c.to_digit(16) {
Some(v) => {
mask |= 0xF << shift;
value |= v << shift;
}
None if c == '?' => {}
None => return Err(HandlePatternError::InvalidHexString),
}
}
Ok(Self(mask, value))
}
#[must_use]
pub fn matches(&self, handle: u32) -> bool {
if self.0 == 0 && self.1 == 0 {
return true;
}
(handle & self.0) == self.1
}
}