use std::{borrow::Cow, error::Error as StdError, fmt, io, mem};
use lexe_byte_array::ByteArray;
use lexe_hex::hex;
use lexe_serde::impl_serde_hexstr_or_bytes;
use lexe_sha256::sha256;
#[cfg(any(test, feature = "test-utils"))]
use proptest_derive::Arbitrary;
use ref_cast::RefCast;
#[cfg_attr(any(test, feature = "test-utils"), derive(Arbitrary))]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd, RefCast)]
#[repr(transparent)]
pub struct Measurement([u8; 32]);
impl_serde_hexstr_or_bytes!(Measurement);
#[cfg_attr(any(test, feature = "test-utils"), derive(Arbitrary))]
#[derive(Copy, Clone, Hash, Eq, PartialEq, RefCast)]
#[repr(transparent)]
pub struct MrShort([u8; 4]);
impl_serde_hexstr_or_bytes!(MrShort);
pub enum Error {
SgxError(sgx_isa::ErrorCode),
SealInputTooLarge,
UnsealInputTooSmall,
InvalidKeyRequestLength,
UnsealDecryptionError,
DeserializationError,
}
#[cfg_attr(any(test, feature = "test-utils"), derive(Arbitrary))]
#[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd, RefCast)]
#[repr(transparent)]
pub struct MachineId([u8; 16]);
impl_serde_hexstr_or_bytes!(MachineId);
#[derive(Eq, PartialEq)]
pub struct Sealed<'a> {
pub keyrequest: Cow<'a, [u8]>,
pub ciphertext: Cow<'a, [u8]>,
}
#[derive(Copy, Clone)]
#[repr(transparent)]
pub(crate) struct MinCpusvn([u8; 16]);
impl StdError for Error {}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::SgxError(_) => "",
Self::SealInputTooLarge => "sealing: input data is too large",
Self::UnsealInputTooSmall => "unsealing: ciphertext is too small",
Self::InvalidKeyRequestLength => "keyrequest is not a valid length",
Self::UnsealDecryptionError =>
"unseal error: ciphertext or metadata may be corrupted",
Self::DeserializationError => "deserialize: input is malformed",
};
match self {
Self::SgxError(err) => write!(f, "SGX error: {err:?}"),
_ => f.write_str(s),
}
}
}
impl fmt::Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "enclave::Error({self})")
}
}
impl From<sgx_isa::ErrorCode> for Error {
fn from(err: sgx_isa::ErrorCode) -> Self {
Self::SgxError(err)
}
}
impl Measurement {
pub const MOCK_ENCLAVE: Self =
Self::new(*b"~~~~~~~ LEXE MOCK ENCLAVE ~~~~~~");
pub const MOCK_SIGNER: Self =
Self::new(*b"======= LEXE MOCK SIGNER =======");
pub const DEV_SIGNER: Self = Self::new(hex::decode_const(
b"9affcfae47b848ec2caf1c49b4b283531e1cc425f93582b36806e52a43d78d1a",
));
pub const PROD_SIGNER: Self = Self::new(hex::decode_const(
b"02d07f56b7f4a71d32211d6821beaeb316fbf577d02bab0dfe1f18a73de08a8e",
));
pub const fn expected_signer(use_sgx: bool, is_dev: bool) -> Self {
if use_sgx {
if is_dev {
Self::DEV_SIGNER
} else {
Self::PROD_SIGNER
}
} else {
Self::MOCK_SIGNER
}
}
pub fn compute_from_sgxs(
mut sgxs_reader: impl io::Read,
) -> io::Result<Self> {
let mut buf = [0u8; 4096];
let mut digest = sha256::Context::new();
loop {
let n = sgxs_reader.read(&mut buf)?;
if n == 0 {
let hash = digest.finish();
return Ok(Self::new(hash.to_array()));
} else {
digest.update(&buf[0..n]);
}
}
}
pub const fn new(bytes: [u8; 32]) -> Self {
Self(bytes)
}
pub fn short(&self) -> MrShort {
MrShort::from(self)
}
}
lexe_byte_array::impl_byte_array!(Measurement, 32);
lexe_byte_array::impl_fromstr_fromhex!(Measurement, 32);
lexe_byte_array::impl_debug_display_as_hex!(Measurement);
impl MrShort {
pub const fn new(bytes: [u8; 4]) -> Self {
Self(bytes)
}
pub fn is_prefix_of(&self, long: &Measurement) -> bool {
self.0 == long.0[..4]
}
}
lexe_byte_array::impl_byte_array!(MrShort, 4);
lexe_byte_array::impl_fromstr_fromhex!(MrShort, 4);
lexe_byte_array::impl_debug_display_as_hex!(MrShort);
impl From<&Measurement> for MrShort {
fn from(long: &Measurement) -> Self {
(long.0)[..4].try_into().map(Self).unwrap()
}
}
impl MachineId {
pub const MOCK: Self =
MachineId::new(hex::decode_const(b"52bc575eb9618084083ca7b3a45a2a76"));
pub const fn new(bytes: [u8; 16]) -> Self {
Self(bytes)
}
}
lexe_byte_array::impl_byte_array!(MachineId, 16);
lexe_byte_array::impl_fromstr_fromhex!(MachineId, 16);
lexe_byte_array::impl_debug_display_as_hex!(MachineId);
impl<'a> Sealed<'a> {
pub const TAG_LEN: usize = 16;
pub fn serialize(&self) -> Vec<u8> {
let out_len = mem::size_of::<u32>()
+ self.keyrequest.len()
+ mem::size_of::<u32>()
+ self.ciphertext.len();
let mut out = Vec::with_capacity(out_len);
out.extend_from_slice(&(self.keyrequest.len() as u32).to_le_bytes());
out.extend_from_slice(self.keyrequest.as_ref());
out.extend_from_slice(&(self.ciphertext.len() as u32).to_le_bytes());
out.extend_from_slice(self.ciphertext.as_ref());
out
}
pub fn deserialize(bytes: &'a [u8]) -> Result<Self, Error> {
let (keyrequest, bytes) = Self::read_bytes(bytes)?;
let (ciphertext, bytes) = Self::read_bytes(bytes)?;
if bytes.is_empty() {
Ok(Self {
keyrequest: Cow::Borrowed(keyrequest),
ciphertext: Cow::Borrowed(ciphertext),
})
} else {
Err(Error::DeserializationError)
}
}
fn read_bytes(bytes: &[u8]) -> Result<(&[u8], &[u8]), Error> {
let (len, bytes) = Self::read_u32_le(bytes)?;
let len = len as usize;
if bytes.len() >= len {
Ok(bytes.split_at(len))
} else {
Err(Error::DeserializationError)
}
}
fn read_u32_le(bytes: &[u8]) -> Result<(u32, &[u8]), Error> {
match bytes.split_first_chunk::<4>() {
Some((val, rest)) => Ok((u32::from_le_bytes(*val), rest)),
None => Err(Error::DeserializationError),
}
}
}
impl fmt::Debug for Sealed<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Sealed")
.field("keyrequest", &hex::display(&self.keyrequest))
.field("ciphertext", &hex::display(&self.ciphertext))
.finish()
}
}
impl MinCpusvn {
pub const CURRENT: Self =
Self::new(hex::decode_const(b"0e0e100fffff01000000000000000000"));
pub const fn new(bytes: [u8; 16]) -> Self {
Self(bytes)
}
pub const fn to_array(self) -> [u8; 16] {
self.0
}
}
pub mod attributes {
use sgx_isa::AttributesFlags;
pub const LEXE_FLAGS_PROD: AttributesFlags = AttributesFlags::MODE64BIT;
pub const LEXE_FLAGS_DEBUG: AttributesFlags =
LEXE_FLAGS_PROD.union(AttributesFlags::DEBUG);
pub const LEXE_MASK: AttributesFlags = AttributesFlags::INIT
.union(AttributesFlags::DEBUG)
.union(AttributesFlags::MODE64BIT);
}
pub mod xfrm {
pub const LEGACY: u64 = 0x0000000000000003;
pub const AVX: u64 = 0x0000000000000006;
pub const AVX512: u64 = 0x00000000000000e6;
pub const MPX: u64 = 0x0000000000000018;
pub const PKRU: u64 = 0x0000000000000200;
pub const AMX: u64 = 0x0000000000060000;
pub const LEXE_FLAGS: u64 = AVX512 | LEGACY;
pub const LEXE_MASK: u64 = LEXE_FLAGS;
}
pub mod miscselect {
use sgx_isa::Miscselect;
pub const LEXE_FLAGS: Miscselect = Miscselect::empty();
pub const LEXE_MASK: Miscselect = Miscselect::empty();
}
#[cfg(test)]
mod test {
use std::str::FromStr;
use proptest::{arbitrary::any, proptest, strategy::Strategy};
use serde_core::{de::DeserializeOwned, ser::Serialize};
use super::*;
#[track_caller]
fn json_string_roundtrip<
T: DeserializeOwned + Serialize + PartialEq + fmt::Debug,
>(
s1: &str,
) {
let x1: T = serde_json::from_str(s1).unwrap();
let s2 = serde_json::to_string(&x1).unwrap();
let x2: T = serde_json::from_str(&s2).unwrap();
assert_eq!(x1, x2);
assert_eq!(s1, s2);
}
#[track_caller]
fn fromstr_display_roundtrip<
T: FromStr + fmt::Display + PartialEq + fmt::Debug,
>(
s1: &str,
) {
let x1 = T::from_str(s1).map_err(|_| ()).unwrap();
let s2 = x1.to_string();
let x2 = T::from_str(&s2).map_err(|_| ()).unwrap();
assert_eq!(x1, x2);
assert_eq!(s1, s2);
}
#[test]
fn serde_roundtrips() {
json_string_roundtrip::<Measurement>(
"\"c4f249b8d3121b0e61170a93a526beda574058f782c0b3f339e74651c379f888\"",
);
json_string_roundtrip::<MachineId>(
"\"df3d290e1371112bd3da4a6cdda1f245\"",
);
}
#[test]
fn fromstr_display_roundtrips() {
fromstr_display_roundtrip::<Measurement>(
"c4f249b8d3121b0e61170a93a526beda574058f782c0b3f339e74651c379f888",
);
fromstr_display_roundtrip::<MachineId>(
"df3d290e1371112bd3da4a6cdda1f245",
);
}
#[test]
fn test_mr_short() {
proptest!(|(
long1 in any::<Measurement>(),
long2 in any::<Measurement>(),
)| {
let short1 = long1.short();
let short2 = long2.short();
assert!(short1.is_prefix_of(&long1));
assert!(short2.is_prefix_of(&long2));
if short1 != short2 {
assert_ne!(long1, long2);
assert!(!short1.is_prefix_of(&long2));
assert!(!short2.is_prefix_of(&long1));
}
});
}
#[test]
fn test_sealed_serialization() {
let arb_keyrequest = any::<Vec<u8>>();
let arb_ciphertext = any::<Vec<u8>>();
let arb_sealed = (arb_keyrequest, arb_ciphertext).prop_map(
|(keyrequest, ciphertext)| Sealed {
keyrequest: keyrequest.into(),
ciphertext: ciphertext.into(),
},
);
proptest!(|(sealed in arb_sealed)| {
let bytes = sealed.serialize();
let sealed2 = Sealed::deserialize(&bytes).unwrap();
assert_eq!(sealed, sealed2);
});
}
}