use crate::{ErrorCode, InternalError};
use std::{borrow::Cow, collections::BTreeSet, fmt::Debug, num::NonZeroU32};
pub(crate) trait Flags<T: Ord> {
fn as_string(&self) -> Cow<'static, str>;
fn ordinal(&self) -> usize;
fn validate(list: &[T]) -> bool;
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
pub enum KeyFlags {
Decrypt,
Exportable,
NoDA,
Restricted,
Sign,
System,
User,
Persistent(NonZeroU32),
}
impl Flags<Self> for KeyFlags {
fn as_string(&self) -> Cow<'static, str> {
match self {
Self::Decrypt => Cow::Borrowed("decrypt"),
Self::Exportable => Cow::Borrowed("exportable"),
Self::NoDA => Cow::Borrowed("noda"),
Self::Restricted => Cow::Borrowed("restricted"),
Self::Sign => Cow::Borrowed("sign"),
Self::System => Cow::Borrowed("system"),
Self::User => Cow::Borrowed("user"),
Self::Persistent(handle) => Cow::Owned(format!("0x{:08X}", handle)),
}
}
fn ordinal(&self) -> usize {
match self {
Self::Decrypt => 0x01usize,
Self::Exportable => 0x02usize,
Self::NoDA => 0x04usize,
Self::Restricted => 0x08usize,
Self::Sign => 0x10usize,
Self::System => 0x20usize,
Self::User => 0x40usize,
Self::Persistent(_) => 0x80usize,
}
}
fn validate(_list: &[Self]) -> bool {
true
}
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
pub enum NvFlags {
BitField,
Counter,
PCR,
NoDA,
System,
Index(NonZeroU32),
}
impl Flags<Self> for NvFlags {
fn as_string(&self) -> Cow<'static, str> {
match self {
Self::BitField => Cow::Borrowed("bitfield"),
Self::Counter => Cow::Borrowed("counter"),
Self::NoDA => Cow::Borrowed("noda"),
Self::PCR => Cow::Borrowed("pcr"),
Self::System => Cow::Borrowed("system"),
Self::Index(handle) => Cow::Owned(format!("0x{:08X}", handle)),
}
}
fn ordinal(&self) -> usize {
match self {
Self::BitField => 0x01usize,
Self::Counter => 0x02usize,
Self::NoDA => 0x04usize,
Self::PCR => 0x08usize,
Self::System => 0x10usize,
Self::Index(_) => 0x20usize,
}
}
fn validate(list: &[Self]) -> bool {
list.iter().map(|flag| matches!(flag, Self::BitField | Self::Counter | Self::PCR)).map(usize::from).sum::<usize>() < 2usize
}
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
pub enum SealFlags {
NoDA,
System,
Index(NonZeroU32),
}
impl Flags<Self> for SealFlags {
fn as_string(&self) -> Cow<'static, str> {
match self {
Self::NoDA => Cow::Borrowed("noda"),
Self::System => Cow::Borrowed("system"),
Self::Index(handle) => Cow::Owned(format!("0x{:08X}", handle)),
}
}
fn ordinal(&self) -> usize {
match self {
Self::NoDA => 0x01usize,
Self::System => 0x02usize,
Self::Index(_) => 0x04usize,
}
}
fn validate(_list: &[Self]) -> bool {
true
}
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
pub enum QuoteFlags {
TpmQuote,
}
impl Flags<Self> for QuoteFlags {
fn as_string(&self) -> Cow<'static, str> {
match self {
Self::TpmQuote => Cow::Borrowed("TPM-Quote"),
}
}
fn ordinal(&self) -> usize {
match self {
Self::TpmQuote => 0x01usize,
}
}
fn validate(_list: &[Self]) -> bool {
true
}
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
pub enum PaddingFlags {
RsaSsa,
RsaPss,
}
impl Flags<Self> for PaddingFlags {
fn as_string(&self) -> Cow<'static, str> {
match self {
Self::RsaSsa => Cow::Borrowed("RSA_SSA"),
Self::RsaPss => Cow::Borrowed("RSA_PSS"),
}
}
fn ordinal(&self) -> usize {
match self {
Self::RsaSsa => 0x01usize,
Self::RsaPss => 0x02usize,
}
}
fn validate(list: &[Self]) -> bool {
list.len() < 2usize
}
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
pub enum BlobType {
ContextLoad,
Deserialize,
}
#[derive(Clone, Copy, Debug)]
pub struct UnknownFlagError;
impl TryFrom<u8> for BlobType {
type Error = UnknownFlagError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
1u8 => Ok(BlobType::ContextLoad),
2u8 => Ok(BlobType::Deserialize),
_ => Err(UnknownFlagError),
}
}
}
pub(crate) fn flags_to_string<T: Flags<T> + Ord + Debug + Copy>(list: Option<&[T]>) -> Result<Option<String>, ErrorCode> {
match list {
Some(flags) => {
if flags.is_empty() || contains_duplicates(flags) || (!T::validate(flags)) {
Err(crate::ErrorCode::InternalError(InternalError::InvalidArguments))
} else {
Ok(Some(BTreeSet::from_iter(flags).into_iter().map(T::as_string).collect::<Vec<Cow<'static, str>>>().join(",")))
}
}
None => Ok(None),
}
}
fn contains_duplicates<T: Flags<T> + Ord + Copy>(list: &[T]) -> bool {
for i in 0..list.len() {
for j in i + 1..list.len() {
if list[i].ordinal() == list[j].ordinal() {
return true;
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::{BlobType, KeyFlags, NvFlags, PaddingFlags, QuoteFlags, SealFlags, flags_to_string};
use std::num::NonZeroU32;
#[test]
fn test_flags_to_string() {
let index = NonZeroU32::new(1u32).unwrap();
assert!(flags_to_string::<KeyFlags>(None).is_ok());
assert!(flags_to_string::<NvFlags>(None).is_ok());
assert!(flags_to_string::<SealFlags>(None).is_ok());
assert!(flags_to_string::<QuoteFlags>(None).is_ok());
assert!(flags_to_string::<PaddingFlags>(None).is_ok());
assert!(
flags_to_string(Some(&[
KeyFlags::Decrypt,
KeyFlags::Exportable,
KeyFlags::NoDA,
KeyFlags::Persistent(index),
KeyFlags::Restricted,
KeyFlags::Sign,
KeyFlags::System,
KeyFlags::User
]))
.is_ok()
);
assert!(flags_to_string(Some(&[KeyFlags::Decrypt, KeyFlags::Decrypt])).is_err());
assert!(flags_to_string(Some(&[NvFlags::BitField, NvFlags::Index(index), NvFlags::NoDA, NvFlags::System])).is_ok());
assert!(flags_to_string(Some(&[NvFlags::BitField, NvFlags::Counter])).is_err());
assert!(flags_to_string(Some(&[NvFlags::BitField, NvFlags::PCR])).is_err());
assert!(flags_to_string(Some(&[NvFlags::Counter, NvFlags::PCR])).is_err());
assert!(flags_to_string(Some(&[NvFlags::BitField, NvFlags::BitField])).is_err());
assert!(flags_to_string(Some(&[SealFlags::NoDA, SealFlags::Index(index), SealFlags::System])).is_ok());
assert!(flags_to_string(Some(&[SealFlags::NoDA, SealFlags::NoDA])).is_err());
assert!(flags_to_string(Some(&[QuoteFlags::TpmQuote])).is_ok());
assert!(flags_to_string(Some(&[QuoteFlags::TpmQuote, QuoteFlags::TpmQuote])).is_err());
assert!(flags_to_string(Some(&[PaddingFlags::RsaPss])).is_ok());
assert!(flags_to_string(Some(&[PaddingFlags::RsaSsa])).is_ok());
assert!(flags_to_string(Some(&[PaddingFlags::RsaPss, PaddingFlags::RsaPss])).is_err());
assert!(flags_to_string(Some(&[PaddingFlags::RsaPss, PaddingFlags::RsaSsa])).is_err());
}
#[test]
fn test_blob_types() {
assert!(BlobType::try_from(1u8).is_ok());
assert!(BlobType::try_from(2u8).is_ok());
assert!(BlobType::try_from(3u8).is_err());
}
}