use core::ptr::write_volatile;
use core::sync::atomic;
use serde::{de::Visitor, ser::SerializeMap, Deserialize, Serialize};
use trussed_core::{types::Bytes, Error};
use zeroize::Zeroize;
use crate::config::{MAX_KEY_MATERIAL_LENGTH, MAX_SERIALIZED_KEY_LENGTH};
pub type Material = Bytes<MAX_KEY_MATERIAL_LENGTH>;
pub type SerializedKeyBytes = Bytes<MAX_SERIALIZED_KEY_LENGTH>;
#[derive(Clone, Debug, /*DeserializeIndexed,*/ Eq, PartialEq, /*SerializeIndexed,*/ Zeroize)]
pub struct Key {
pub flags: Flags,
pub kind: Kind,
pub material: Material,
}
#[derive(Clone, Debug, /*DeserializeIndexed,*/ Eq, PartialEq, /*SerializeIndexed,*/ Zeroize)]
pub struct Info {
pub flags: Flags,
pub kind: Kind,
}
impl Info {
pub fn with_local_flag(mut self) -> Self {
self.flags |= Flags::LOCAL;
self
}
}
impl From<Kind> for Info {
fn from(kind: Kind) -> Self {
Self {
flags: Default::default(),
kind,
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Zeroize)]
#[repr(u16)]
pub enum Kind {
Shared(usize),
Symmetric(usize),
Symmetric32Nonce(usize),
Rsa2048,
Rsa3072,
Rsa4096,
Ed255,
P256,
P384,
P521,
BrainpoolP256R1,
BrainpoolP384R1,
BrainpoolP512R1,
X255,
Secp256k1,
}
bitflags::bitflags! {
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
pub struct Flags: u16 {
const LOCAL = 1 << 0;
const SENSITIVE = 1 << 1;
const SERIALIZABLE = 1 << 4;
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum Secrecy {
Public,
Secret,
}
impl Key {
pub fn serialize(&self) -> SerializedKeyBytes {
let mut buffer = SerializedKeyBytes::new();
buffer
.extend_from_slice(&self.flags.bits().to_be_bytes())
.unwrap();
buffer
.extend_from_slice(&(self.kind.code()).to_be_bytes())
.unwrap();
buffer.extend_from_slice(&self.material).unwrap();
buffer
}
pub fn try_deserialize(bytes: &[u8]) -> Result<Self, Error> {
if bytes.len() < 4 {
return Err(Error::InvalidSerializedKey);
}
let (info, material) = bytes.split_at(4);
let flags_bits = u16::from_be_bytes([info[0], info[1]]);
let flags = Flags::from_bits(flags_bits).ok_or(Error::InvalidSerializedKey)?;
let kind_bits = u16::from_be_bytes([info[2], info[3]]);
let kind =
Kind::try_from(kind_bits, material.len()).map_err(|_| Error::InvalidSerializedKey)?;
Ok(Key {
flags,
kind,
material: Material::try_from(material).map_err(|_| Error::InvalidSerializedKey)?,
})
}
}
impl Default for Flags {
fn default() -> Self {
Flags::SENSITIVE
}
}
impl Zeroize for Flags {
fn zeroize(&mut self) {
unsafe {
write_volatile(self, Flags::empty());
}
atomic::compiler_fence(atomic::Ordering::SeqCst);
}
}
impl Serialize for Flags {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = serializer.serialize_map(Some(1))?;
map.serialize_key(&0usize)?;
map.serialize_value(&self.bits())?;
map.end()
}
}
impl<'de> Deserialize<'de> for Flags {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct FlagsVisitor;
impl<'vis_de> Visitor<'vis_de> for FlagsVisitor {
type Value = Flags;
fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(formatter, "A flag structure")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'vis_de>,
{
if !matches!(map.next_key()?, Some(0usize)) {
return Err(serde::de::Error::missing_field("bits"));
}
let bits = map.next_value()?;
let flags = Flags::from_bits(bits)
.ok_or_else(|| serde::de::Error::custom("Wrong bit layout"))?;
Ok(flags)
}
}
deserializer.deserialize_map(FlagsVisitor)
}
}
impl Kind {
pub fn code(self) -> u16 {
match self {
Kind::Shared(_) => 1,
Kind::Symmetric(_) => 2,
Kind::Symmetric32Nonce(_) => 3,
Kind::Ed255 => 4,
Kind::P256 => 5,
Kind::X255 => 6,
Kind::Rsa2048 => 7,
Kind::Rsa3072 => 8,
Kind::Rsa4096 => 9,
Kind::P384 => 10,
Kind::P521 => 11,
Kind::BrainpoolP256R1 => 12,
Kind::BrainpoolP384R1 => 13,
Kind::BrainpoolP512R1 => 14,
Kind::Secp256k1 => 15,
}
}
pub fn try_from(code: u16, length: usize) -> Result<Self, Error> {
Ok(match code {
1 => Self::Shared(length),
2 => Self::Symmetric(length),
3 => Self::Symmetric32Nonce(length - 32),
4 => Self::Ed255,
5 => Self::P256,
6 => Self::X255,
7 => Kind::Rsa2048,
8 => Kind::Rsa3072,
9 => Kind::Rsa4096,
10 => Kind::P384,
11 => Kind::P521,
12 => Kind::BrainpoolP256R1,
13 => Kind::BrainpoolP384R1,
14 => Kind::BrainpoolP512R1,
15 => Kind::Secp256k1,
_ => return Err(Error::InvalidSerializedKey),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_test::{assert_tokens, Token};
#[test]
fn keyflags_format() {
assert_tokens(
&Flags::empty(),
&[
Token::Map { len: Some(1) },
Token::U64(0),
Token::U16(0),
Token::MapEnd,
],
);
assert_tokens(
&Flags::LOCAL,
&[
Token::Map { len: Some(1) },
Token::U64(0),
Token::U16(0b1),
Token::MapEnd,
],
);
assert_tokens(
&(Flags::LOCAL | Flags::SENSITIVE),
&[
Token::Map { len: Some(1) },
Token::U64(0),
Token::U16(0b11),
Token::MapEnd,
],
);
assert_tokens(
&(Flags::LOCAL | Flags::SENSITIVE | Flags::SERIALIZABLE),
&[
Token::Map { len: Some(1) },
Token::U64(0),
Token::U16(0b10011),
Token::MapEnd,
],
);
assert_tokens(
&Flags::SENSITIVE,
&[
Token::Map { len: Some(1) },
Token::U64(0),
Token::U16(0b10),
Token::MapEnd,
],
);
assert_tokens(
&(Flags::SENSITIVE | Flags::SERIALIZABLE),
&[
Token::Map { len: Some(1) },
Token::U64(0),
Token::U16(0b10010),
Token::MapEnd,
],
);
assert_tokens(
&Flags::SERIALIZABLE,
&[
Token::Map { len: Some(1) },
Token::U64(0),
Token::U16(0b10000),
Token::MapEnd,
],
);
}
}