use {
crate::{
NetworkId,
PeerId,
UniqueId,
primitives::{Pretty, encoding},
tee::tdx::Measurements,
tickets::{Expiration, Ticket},
unique_id,
},
chrono::{DateTime, Utc},
serde::{
Deserialize,
Deserializer,
Serialize,
Serializer,
ser::SerializeStruct,
},
std::{sync::OnceLock, time::Duration},
tdx_quote::Quote,
};
#[derive(Debug, Clone, thiserror::Error)]
pub enum TdxTicketError {
#[error("Invalid ticket format")]
InvalidFormat,
#[error("TDX Quote verification failed")]
QuoteVerificationFailed,
#[error("Ticket has expired")]
TicketExpired,
#[error("TDX Quote's report_data does not match expected hash of ExtraData")]
InvalidReportData,
#[error("Ticket is not a mosaik TDX ticket")]
InvalidTicketType,
}
#[derive(PartialEq, Eq)]
pub struct TdxTicket {
quote_bytes: Vec<u8>,
extra: ExtraData,
quote: OnceLock<Quote>,
}
impl TdxTicket {
pub const fn quote_bytes(&self) -> &[u8] {
self.quote_bytes.as_slice()
}
pub fn quote(&self) -> &Quote {
self.quote.get().expect("quote should be initialized")
}
pub fn measurements(&self) -> Measurements {
Measurements::from(self.quote())
}
pub const fn peer_id(&self) -> &PeerId {
&self.extra.peer_id
}
pub const fn network_id(&self) -> &NetworkId {
&self.extra.network_id
}
pub const fn started_at(&self) -> &DateTime<Utc> {
&self.extra.started_at
}
pub const fn quoted_at(&self) -> &DateTime<Utc> {
&self.extra.quoted_at
}
pub fn uptime(&self) -> Duration {
(Utc::now() - self.started_at())
.to_std()
.unwrap_or_default()
}
pub const fn expiration(&self) -> &Expiration {
&self.extra.expiration
}
}
impl core::fmt::Debug for TdxTicket {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("TdxTicket")
.field("quote", self.quote())
.field("extra", &self.extra)
.finish_non_exhaustive()
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub(super) struct ExtraData {
pub peer_id: PeerId,
pub network_id: NetworkId,
pub started_at: DateTime<Utc>,
pub quoted_at: DateTime<Utc>,
pub expiration: Expiration,
}
impl ExtraData {
pub fn signature(&self) -> UniqueId {
unique_id!("mosaik.tdx.ticket.extra")
.derive(self.peer_id.as_bytes())
.derive(self.network_id.as_bytes())
.derive(self.started_at.timestamp().to_le_bytes())
.derive(self.quoted_at.timestamp().to_le_bytes())
.derive(match self.expiration {
Expiration::Never => [0u8; 8],
Expiration::At(dt) => dt.timestamp().to_le_bytes(),
})
}
}
impl TdxTicket {
pub(super) fn new(
quote_bytes: Vec<u8>,
extra: ExtraData,
) -> Result<Self, TdxTicketError> {
let ticket = Self {
quote_bytes,
extra,
quote: OnceLock::new(),
};
let quote = ticket.verify_quote()?;
ticket.quote.set(quote).ok();
Ok(ticket)
}
fn verify_report_data(&self, quote: &Quote) -> Result<(), TdxTicketError> {
let report_data = quote.report_input_data();
if report_data[..32] != *self.extra.signature().as_bytes() {
return Err(TdxTicketError::InvalidReportData);
}
Ok(())
}
fn verify_quote(&self) -> Result<Quote, TdxTicketError> {
let quote = Quote::from_bytes(&self.quote_bytes)
.map_err(|_| TdxTicketError::QuoteVerificationFailed)?;
self.verify_report_data("e)?;
Ok(quote)
}
}
impl Serialize for TdxTicket {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut state = serializer.serialize_struct("TdxTicket", 2)?;
state.serialize_field("quote_bytes", &self.quote_bytes)?;
state.serialize_field("extra", &self.extra)?;
state.end()
}
}
impl<'de> Deserialize<'de> for TdxTicket {
fn deserialize<D: Deserializer<'de>>(
deserializer: D,
) -> Result<Self, D::Error> {
#[derive(Deserialize)]
struct Wire {
quote_bytes: Vec<u8>,
extra: ExtraData,
}
let wire = Wire::deserialize(deserializer)?;
let ticket = Self {
quote_bytes: wire.quote_bytes,
extra: wire.extra,
quote: OnceLock::new(),
};
let quote = ticket.verify_quote().map_err(serde::de::Error::custom)?;
ticket.quote.set(quote).ok();
Ok(ticket)
}
}
impl TryFrom<TdxTicket> for Ticket {
type Error = TdxTicketError;
fn try_from(value: TdxTicket) -> Result<Self, Self::Error> {
let serialized = encoding::try_serialize(&value)
.map_err(|_| TdxTicketError::QuoteVerificationFailed)?;
let compressed =
zstd::bulk::compress(&serialized, TICKET_COMPRESSION_LEVEL)
.map_err(|_| TdxTicketError::InvalidFormat)?;
Ok(Self::new(super::TICKET_CLASS, compressed.into()))
}
}
impl TryFrom<&[u8]> for TdxTicket {
type Error = TdxTicketError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
let decompressed = zstd::bulk::decompress(value, MAX_TICKET_SIZE)
.map_err(|_| TdxTicketError::InvalidFormat)?;
encoding::deserialize(&decompressed)
.map_err(|_| TdxTicketError::InvalidFormat)
}
}
impl TryFrom<Ticket> for TdxTicket {
type Error = TdxTicketError;
fn try_from(ticket: Ticket) -> Result<Self, Self::Error> {
if ticket.class != super::TICKET_CLASS {
return Err(TdxTicketError::InvalidTicketType);
}
Self::try_from(ticket.data.as_ref())
}
}
impl core::fmt::Debug for Pretty<'_, TdxTicket> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let ticket = &self.0;
let measurements = ticket.measurements();
let aux_hash = &ticket.quote().report_input_data()[..32];
let mut d = f.debug_struct("TdxTicket");
d.field("peer_id", &format_args!("{}", ticket.peer_id()))
.field("network_id", &ticket.network_id())
.field("mrtd", &format_args!("{}", measurements.mrtd()))
.field("rtmr0", &format_args!("{}", measurements.rtmr0()))
.field("rtmr1", &format_args!("{}", measurements.rtmr1()))
.field("rtmr2", &format_args!("{}", measurements.rtmr2()))
.field("aux_hash", &format_args!("{}", hex::encode(aux_hash)))
.field("started_at", &ticket.started_at())
.field("quoted_at", &ticket.quoted_at())
.field("expiration", &ticket.expiration())
.field("tdx_version", &match ticket.quote().body.tdx_version {
tdx_quote::TDXVersion::One => 1.0,
tdx_quote::TDXVersion::OnePointFive => 1.5,
})
.field("quote_version", &ticket.quote().header.version)
.field("tee_type", &ticket.quote().header.tee_type);
if ticket.quote().body.mrowner != [0u8; 48] {
d.field("mrowner", &hex::encode(ticket.quote().body.mrowner));
}
d.finish()
}
}
const MAX_TICKET_SIZE: usize = 64 * 1024;
const TICKET_COMPRESSION_LEVEL: i32 = 3;