use std::{borrow::Cow, error::Error as StdError, fmt, io, mem, str::FromStr};
use lexe_byte_array::ByteArray;
use lexe_hex::hex;
use lexe_serde::hexstr_or_bytes;
use lexe_sha256::sha256;
#[cfg(any(test, feature = "test-utils"))]
use proptest_derive::Arbitrary;
use ref_cast::RefCast;
use serde::{Deserialize, Serialize};
#[cfg_attr(any(test, feature = "test-utils"), derive(Arbitrary))]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
#[derive(RefCast, Serialize, Deserialize)]
#[repr(transparent)]
pub struct Measurement(#[serde(with = "hexstr_or_bytes")] [u8; 32]);
#[cfg_attr(any(test, feature = "test-utils"), derive(Arbitrary))]
#[derive(Copy, Clone, Hash, Eq, PartialEq, RefCast, Serialize, Deserialize)]
#[repr(transparent)]
pub struct MrShort(#[serde(with = "hexstr_or_bytes")] [u8; 4]);
pub enum Error {
SgxError(sgx_isa::ErrorCode),
SealInputTooLarge,
UnsealInputTooSmall,
InvalidKeyRequestLength,
UnsealDecryptionError,
DeserializationError,
}
#[cfg_attr(any(test, feature = "test-utils"), derive(Arbitrary))]
#[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
#[derive(RefCast, Serialize, Deserialize)]
#[repr(transparent)]
pub struct MachineId(#[serde(with = "hexstr_or_bytes")] [u8; 16]);
#[cfg_attr(any(test, feature = "test-utils"), derive(Arbitrary))]
#[derive(Copy, Clone, Hash, Eq, PartialEq, RefCast, Serialize, Deserialize)]
#[repr(transparent)]
pub struct MinCpusvn(#[serde(with = "hexstr_or_bytes")] [u8; 16]);
#[derive(Eq, PartialEq)]
pub struct Sealed<'a> {
pub keyrequest: Cow<'a, [u8]>,
pub ciphertext: Cow<'a, [u8]>,
}
impl StdError for Error {}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::SgxError(_) => "",
Self::SealInputTooLarge => "sealing: input data is too large",
Self::UnsealInputTooSmall => "unsealing: ciphertext is too small",
Self::InvalidKeyRequestLength => "keyrequest is not a valid length",
Self::UnsealDecryptionError =>
"unseal error: ciphertext or metadata may be corrupted",
Self::DeserializationError => "deserialize: input is malformed",
};
match self {
Self::SgxError(err) => write!(f, "SGX error: {err:?}"),
_ => f.write_str(s),
}
}
}
impl fmt::Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "enclave::Error({self})")
}
}
impl From<sgx_isa::ErrorCode> for Error {
fn from(err: sgx_isa::ErrorCode) -> Self {
Self::SgxError(err)
}
}
impl Measurement {
pub const MOCK_ENCLAVE: Self =
Self::new(*b"~~~~~~~ LEXE MOCK ENCLAVE ~~~~~~");
pub const MOCK_SIGNER: Self =
Self::new(*b"======= LEXE MOCK SIGNER =======");
pub const DEV_SIGNER: Self = Self::new(hex::decode_const(
b"9affcfae47b848ec2caf1c49b4b283531e1cc425f93582b36806e52a43d78d1a",
));
pub const PROD_SIGNER: Self = Self::new(hex::decode_const(
b"02d07f56b7f4a71d32211d6821beaeb316fbf577d02bab0dfe1f18a73de08a8e",
));
pub const fn expected_signer(use_sgx: bool, is_dev: bool) -> Self {
if use_sgx {
if is_dev {
Self::DEV_SIGNER
} else {
Self::PROD_SIGNER
}
} else {
Self::MOCK_SIGNER
}
}
pub fn compute_from_sgxs(
mut sgxs_reader: impl io::Read,
) -> io::Result<Self> {
let mut buf = [0u8; 4096];
let mut digest = sha256::Context::new();
loop {
let n = sgxs_reader.read(&mut buf)?;
if n == 0 {
let hash = digest.finish();
return Ok(Self::new(hash.to_array()));
} else {
digest.update(&buf[0..n]);
}
}
}
pub const fn new(bytes: [u8; 32]) -> Self {
Self(bytes)
}
pub fn short(&self) -> MrShort {
MrShort::from(self)
}
}
lexe_byte_array::impl_byte_array!(Measurement, 32);
lexe_byte_array::impl_fromstr_fromhex!(Measurement, 32);
lexe_byte_array::impl_debug_display_as_hex!(Measurement);
impl MrShort {
pub const fn new(bytes: [u8; 4]) -> Self {
Self(bytes)
}
pub fn is_prefix_of(&self, long: &Measurement) -> bool {
self.0 == long.0[..4]
}
}
lexe_byte_array::impl_byte_array!(MrShort, 4);
lexe_byte_array::impl_fromstr_fromhex!(MrShort, 4);
lexe_byte_array::impl_debug_display_as_hex!(MrShort);
impl From<&Measurement> for MrShort {
fn from(long: &Measurement) -> Self {
(long.0)[..4].try_into().map(Self).unwrap()
}
}
impl MachineId {
pub const MOCK: Self =
MachineId::new(hex::decode_const(b"52bc575eb9618084083ca7b3a45a2a76"));
pub const fn new(bytes: [u8; 16]) -> Self {
Self(bytes)
}
}
lexe_byte_array::impl_byte_array!(MachineId, 16);
lexe_byte_array::impl_fromstr_fromhex!(MachineId, 16);
lexe_byte_array::impl_debug_display_as_hex!(MachineId);
impl MinCpusvn {
pub const CURRENT: Self =
Self::new(hex::decode_const(b"0e0e100fffff01000000000000000000"));
pub const fn new(bytes: [u8; 16]) -> Self {
Self(bytes)
}
pub const fn inner(self) -> [u8; 16] {
self.0
}
}
impl ByteArray<16> for MinCpusvn {
fn from_array(array: [u8; 16]) -> Self {
Self(array)
}
fn to_array(&self) -> [u8; 16] {
self.0
}
fn as_array(&self) -> &[u8; 16] {
&self.0
}
}
impl FromStr for MinCpusvn {
type Err = hex::DecodeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::from_hex(s)
}
}
impl fmt::Display for MinCpusvn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Self::fmt_as_hex(self, f)
}
}
impl fmt::Debug for MinCpusvn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("MinCpusvn")
.field(&self.as_hex_display())
.finish()
}
}
impl<'a> Sealed<'a> {
pub const TAG_LEN: usize = 16;
pub fn serialize(&self) -> Vec<u8> {
let out_len = mem::size_of::<u32>()
+ self.keyrequest.len()
+ mem::size_of::<u32>()
+ self.ciphertext.len();
let mut out = Vec::with_capacity(out_len);
out.extend_from_slice(&(self.keyrequest.len() as u32).to_le_bytes());
out.extend_from_slice(self.keyrequest.as_ref());
out.extend_from_slice(&(self.ciphertext.len() as u32).to_le_bytes());
out.extend_from_slice(self.ciphertext.as_ref());
out
}
pub fn deserialize(bytes: &'a [u8]) -> Result<Self, Error> {
let (keyrequest, bytes) = Self::read_bytes(bytes)?;
let (ciphertext, bytes) = Self::read_bytes(bytes)?;
if bytes.is_empty() {
Ok(Self {
keyrequest: Cow::Borrowed(keyrequest),
ciphertext: Cow::Borrowed(ciphertext),
})
} else {
Err(Error::DeserializationError)
}
}
fn read_bytes(bytes: &[u8]) -> Result<(&[u8], &[u8]), Error> {
let (len, bytes) = Self::read_u32_le(bytes)?;
let len = len as usize;
if bytes.len() >= len {
Ok(bytes.split_at(len))
} else {
Err(Error::DeserializationError)
}
}
fn read_u32_le(bytes: &[u8]) -> Result<(u32, &[u8]), Error> {
match bytes.split_first_chunk::<4>() {
Some((val, rest)) => Ok((u32::from_le_bytes(*val), rest)),
None => Err(Error::DeserializationError),
}
}
}
impl fmt::Debug for Sealed<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Sealed")
.field("keyrequest", &hex::display(&self.keyrequest))
.field("ciphertext", &hex::display(&self.ciphertext))
.finish()
}
}
pub mod attributes {
use sgx_isa::AttributesFlags;
pub const LEXE_FLAGS_PROD: AttributesFlags = AttributesFlags::MODE64BIT;
pub const LEXE_FLAGS_DEBUG: AttributesFlags =
LEXE_FLAGS_PROD.union(AttributesFlags::DEBUG);
pub const LEXE_MASK: AttributesFlags = AttributesFlags::INIT
.union(AttributesFlags::DEBUG)
.union(AttributesFlags::MODE64BIT);
}
pub mod xfrm {
pub const LEGACY: u64 = 0x0000000000000003;
pub const AVX: u64 = 0x0000000000000006;
pub const AVX512: u64 = 0x00000000000000e6;
pub const MPX: u64 = 0x0000000000000018;
pub const PKRU: u64 = 0x0000000000000200;
pub const AMX: u64 = 0x0000000000060000;
pub const LEXE_FLAGS: u64 = AVX512 | LEGACY;
pub const LEXE_MASK: u64 = LEXE_FLAGS;
}
pub mod miscselect {
use sgx_isa::Miscselect;
pub const LEXE_FLAGS: Miscselect = Miscselect::empty();
pub const LEXE_MASK: Miscselect = Miscselect::empty();
}
#[cfg(test)]
mod test {
use proptest::{arbitrary::any, proptest, strategy::Strategy};
use super::*;
#[test]
fn test_mr_short() {
proptest!(|(
long1 in any::<Measurement>(),
long2 in any::<Measurement>(),
)| {
let short1 = long1.short();
let short2 = long2.short();
assert!(short1.is_prefix_of(&long1));
assert!(short2.is_prefix_of(&long2));
if short1 != short2 {
assert_ne!(long1, long2);
assert!(!short1.is_prefix_of(&long2));
assert!(!short2.is_prefix_of(&long1));
}
});
}
#[test]
fn test_sealed_serialization() {
let arb_keyrequest = any::<Vec<u8>>();
let arb_ciphertext = any::<Vec<u8>>();
let arb_sealed = (arb_keyrequest, arb_ciphertext).prop_map(
|(keyrequest, ciphertext)| Sealed {
keyrequest: keyrequest.into(),
ciphertext: ciphertext.into(),
},
);
proptest!(|(sealed in arb_sealed)| {
let bytes = sealed.serialize();
let sealed2 = Sealed::deserialize(&bytes).unwrap();
assert_eq!(sealed, sealed2);
});
}
}