use std::{num::ParseIntError, str::FromStr};
use thiserror::Error;
use tpm2_protocol::{data::TpmHt, TpmBuild, TpmError, TpmHandle, TpmParse, TpmSized, TpmWriter};
const MAX_AUTH_SIZE: usize = 64;
#[derive(Debug, Error)]
pub enum AuthError {
#[error("invalid auth string")]
InvalidAuth,
#[error("malformed auth value")]
MalformedAuth,
#[error("auth value too large")]
ValueTooLarge,
#[error("hex decode: {0}")]
HexDecode(#[from] hex::FromHexError),
#[error("handle decode: {0}")]
IntDecode(#[from] ParseIntError),
}
impl From<TpmError> for AuthError {
fn from(_: TpmError) -> Self {
Self::MalformedAuth
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthClass {
Password,
Policy,
Session,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Auth {
pub class: AuthClass,
pub data: [u8; MAX_AUTH_SIZE],
pub len: usize,
}
fn build_handle_array(handle_val: u32) -> Result<([u8; MAX_AUTH_SIZE], usize), TpmError> {
let handle = TpmHandle(handle_val);
let mut array = [0u8; MAX_AUTH_SIZE];
let mut writer = TpmWriter::new(&mut array[0..TpmHandle::SIZE]);
handle.build(&mut writer)?;
Ok((array, TpmHandle::SIZE))
}
fn parse_auth_hex(s: &str) -> Result<([u8; MAX_AUTH_SIZE], usize), AuthError> {
let bytes = hex::decode(s)?;
if bytes.len() > MAX_AUTH_SIZE {
return Err(AuthError::ValueTooLarge);
}
let mut array = [0u8; MAX_AUTH_SIZE];
array[..bytes.len()].copy_from_slice(&bytes);
Ok((array, bytes.len()))
}
impl Auth {
#[must_use]
pub fn class(&self) -> AuthClass {
self.class
}
#[must_use]
pub fn value(&self) -> &[u8] {
&self.data[0..self.len]
}
pub fn session(&self) -> Result<u32, AuthError> {
if self.class() != AuthClass::Session {
return Err(AuthError::InvalidAuth);
}
let (handle, remainder) = TpmHandle::parse(&self.data[0..TpmHandle::SIZE])?;
if !remainder.is_empty() {
return Err(AuthError::MalformedAuth);
}
Ok(handle.0)
}
}
impl Default for Auth {
fn default() -> Self {
Self {
class: AuthClass::Password,
data: [0u8; MAX_AUTH_SIZE],
len: 0,
}
}
}
impl std::fmt::Display for Auth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if *self == Self::default() {
return write!(f, "empty");
}
match self.class {
AuthClass::Password => write!(f, "password:<sensitive>"),
AuthClass::Policy => write!(f, "policy:{}", hex::encode(self.value())),
AuthClass::Session => match self.session() {
Ok(handle_val) => write!(f, "vtpm:{handle_val:08x}"),
Err(_) => Err(std::fmt::Error),
},
}
}
}
impl FromStr for Auth {
type Err = AuthError;
fn from_str(auth_str: &str) -> Result<Self, Self::Err> {
if auth_str == "empty" {
return Ok(Self::default());
}
let (prefix, value) = auth_str.split_once(':').ok_or(AuthError::InvalidAuth)?;
match prefix {
"password" => {
let (data, len) = parse_auth_hex(value)?;
Ok(Auth {
class: AuthClass::Password,
data,
len,
})
}
"policy" => {
let (data, len) = parse_auth_hex(value)?;
Ok(Auth {
class: AuthClass::Policy,
data,
len,
})
}
"vtpm" => {
let handle_val = u32::from_str_radix(value, 16)?;
let ht = (handle_val >> 24) as u8;
if ht == TpmHt::PolicySession as u8 || ht == TpmHt::HmacSession as u8 {
let (data, len) =
build_handle_array(handle_val).map_err(|_| AuthError::MalformedAuth)?;
Ok(Auth {
class: AuthClass::Session,
data,
len,
})
} else {
Err(AuthError::InvalidAuth)
}
}
_ => Err(AuthError::InvalidAuth),
}
}
}