use crate::{ErrorCode, InternalError};
use std::{borrow::Cow, collections::BTreeSet, fmt::Debug, num::NonZeroU32};
const INVALID_ARGS: ErrorCode = ErrorCode::InternalError(InternalError::InvalidArguments);
pub(crate) trait Flags: Copy + Clone + Sized + Ord {
fn stringify(self) -> Cow<'static, str>;
fn validate_set(list: &BTreeSet<Self>) -> bool;
fn as_string(list: Option<&[Self]>) -> Result<Option<String>, ErrorCode> {
match list {
Some(flags) => {
let unqiue = if !flags.is_empty() { BTreeSet::from_iter(flags.iter().copied()) } else { BTreeSet::new() };
if unqiue.is_empty() || (!Self::validate_set(&unqiue)) {
Err(INVALID_ARGS)
} else {
Ok(Some(unqiue.into_iter().map(Self::stringify).collect::<Vec<Cow<'static, str>>>().join(",")))
}
}
None => Ok(None),
}
}
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
#[non_exhaustive]
pub enum KeyFlags {
Decrypt,
Exportable,
NoDA,
Restricted,
Sign,
System,
User,
Persistent(NonZeroU32),
}
impl Flags for KeyFlags {
fn stringify(self) -> Cow<'static, str> {
match self {
Self::Decrypt => Cow::Borrowed("decrypt"),
Self::Exportable => Cow::Borrowed("exportable"),
Self::NoDA => Cow::Borrowed("noda"),
Self::Restricted => Cow::Borrowed("restricted"),
Self::Sign => Cow::Borrowed("sign"),
Self::System => Cow::Borrowed("system"),
Self::User => Cow::Borrowed("user"),
Self::Persistent(handle) => Cow::Owned(format!("0x{:08X}", handle)),
}
}
fn validate_set(_flags: &BTreeSet<Self>) -> bool {
true
}
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
#[non_exhaustive]
pub enum NvFlags {
BitField,
Counter,
PCR,
NoDA,
System,
Index(NonZeroU32),
}
impl NvFlags {
fn is_type_flag(&self) -> bool {
matches!(*self, Self::BitField | Self::Counter | Self::PCR)
}
pub fn implicit_size(list: &Option<&[Self]>) -> bool {
list.is_some_and(|flags| flags.iter().any(Self::is_type_flag))
}
}
impl Flags for NvFlags {
fn stringify(self) -> Cow<'static, str> {
match self {
Self::BitField => Cow::Borrowed("bitfield"),
Self::Counter => Cow::Borrowed("counter"),
Self::NoDA => Cow::Borrowed("noda"),
Self::PCR => Cow::Borrowed("pcr"),
Self::System => Cow::Borrowed("system"),
Self::Index(handle) => Cow::Owned(format!("0x{:08X}", handle)),
}
}
fn validate_set(flags: &BTreeSet<Self>) -> bool {
flags.iter().copied().filter(Self::is_type_flag).count() < 2usize
}
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
#[non_exhaustive]
pub enum SealFlags {
NoDA,
System,
Index(NonZeroU32),
}
impl Flags for SealFlags {
fn stringify(self) -> Cow<'static, str> {
match self {
Self::NoDA => Cow::Borrowed("noda"),
Self::System => Cow::Borrowed("system"),
Self::Index(handle) => Cow::Owned(format!("0x{:08X}", handle)),
}
}
fn validate_set(_flags: &BTreeSet<Self>) -> bool {
true
}
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
#[non_exhaustive]
pub enum QuoteFlags {
TpmQuote,
}
impl Flags for QuoteFlags {
fn stringify(self) -> Cow<'static, str> {
match self {
Self::TpmQuote => Cow::Borrowed("TPM-Quote"),
}
}
fn validate_set(_list: &BTreeSet<Self>) -> bool {
true
}
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
#[non_exhaustive]
pub enum PaddingFlags {
RsaSsa,
RsaPss,
}
impl Flags for PaddingFlags {
fn stringify(self) -> Cow<'static, str> {
match self {
Self::RsaSsa => Cow::Borrowed("RSA_SSA"),
Self::RsaPss => Cow::Borrowed("RSA_PSS"),
}
}
fn validate_set(flags: &BTreeSet<Self>) -> bool {
flags.len() < 2usize
}
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
#[non_exhaustive]
pub enum BlobType {
ContextLoad,
Deserialize,
}
impl TryFrom<u8> for BlobType {
type Error = ErrorCode;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
1u8 => Ok(BlobType::ContextLoad),
2u8 => Ok(BlobType::Deserialize),
_ => Err(INVALID_ARGS),
}
}
}
#[cfg(test)]
mod tests {
use super::{BlobType, Flags, KeyFlags, NvFlags, PaddingFlags, QuoteFlags, SealFlags};
use std::num::NonZeroU32;
#[test]
fn test_flags_to_string() {
let index = NonZeroU32::new(1u32).unwrap();
assert!(Flags::as_string(None::<&[KeyFlags]>).is_ok());
assert!(
Flags::as_string(Some(&[
KeyFlags::Decrypt,
KeyFlags::Exportable,
KeyFlags::NoDA,
KeyFlags::Persistent(index),
KeyFlags::Restricted,
KeyFlags::Sign,
KeyFlags::System,
KeyFlags::User
]))
.is_ok()
);
assert!(Flags::as_string(None::<&[NvFlags]>).is_ok());
assert!(Flags::as_string(Some(&[NvFlags::BitField, NvFlags::Index(index), NvFlags::NoDA, NvFlags::System])).is_ok());
assert!(Flags::as_string(Some(&[NvFlags::PCR, NvFlags::Index(index), NvFlags::NoDA, NvFlags::System])).is_ok());
assert!(Flags::as_string(Some(&[NvFlags::Counter, NvFlags::Index(index), NvFlags::NoDA, NvFlags::System])).is_ok());
assert!(Flags::as_string(Some(&[NvFlags::BitField, NvFlags::Counter])).is_err());
assert!(Flags::as_string(Some(&[NvFlags::BitField, NvFlags::PCR])).is_err());
assert!(Flags::as_string(Some(&[NvFlags::Counter, NvFlags::PCR])).is_err());
assert!(Flags::as_string(None::<&[SealFlags]>).is_ok());
assert!(Flags::as_string(Some(&[SealFlags::NoDA, SealFlags::Index(index), SealFlags::System])).is_ok());
assert!(Flags::as_string(None::<&[QuoteFlags]>).is_ok());
assert!(Flags::as_string(Some(&[QuoteFlags::TpmQuote])).is_ok());
assert!(Flags::as_string(None::<&[PaddingFlags]>).is_ok());
assert!(Flags::as_string(Some(&[PaddingFlags::RsaPss])).is_ok());
assert!(Flags::as_string(Some(&[PaddingFlags::RsaSsa])).is_ok());
assert!(Flags::as_string(Some(&[PaddingFlags::RsaPss, PaddingFlags::RsaSsa])).is_err());
}
#[test]
fn test_blob_types() {
assert!(BlobType::try_from(1u8).is_ok());
assert!(BlobType::try_from(2u8).is_ok());
assert!(BlobType::try_from(3u8).is_err());
}
}