mosaik 0.3.13

A Rust runtime for building self-organizing, leaderless distributed systems.
Documentation
use {
	crate::{
		NetworkId,
		PeerId,
		UniqueId,
		primitives::{Pretty, encoding},
		tee::tdx::Measurements,
		tickets::{Expiration, Ticket},
		unique_id,
	},
	chrono::{DateTime, Utc},
	serde::{
		Deserialize,
		Deserializer,
		Serialize,
		Serializer,
		ser::SerializeStruct,
	},
	std::{sync::OnceLock, time::Duration},
	tdx_quote::Quote,
};

#[derive(Debug, Clone, thiserror::Error)]
pub enum TdxTicketError {
	#[error("Invalid ticket format")]
	InvalidFormat,

	#[error("TDX Quote verification failed")]
	QuoteVerificationFailed,

	#[error("Ticket has expired")]
	TicketExpired,

	#[error("TDX Quote's report_data does not match expected hash of ExtraData")]
	InvalidReportData,

	#[error("Ticket is not a mosaik TDX ticket")]
	InvalidTicketType,
}

/// Mosaik ticket data structure for TDX tickets.
///
/// This is the data that is serialized into a [`Ticket`](crate::Ticket) when a
/// TDX ticket is created, and deserialized from a `Ticket` when a TDX ticket is
/// validated.
///
/// It carries the raw bytes of the TDX Quote generated by the TDX hardware with
/// a `report_data` field containing a hash of the `ExtraData` struct, which
/// includes additional mosaik-specific data such as the peer's identity and
/// network.
#[derive(PartialEq, Eq)]
pub struct TdxTicket {
	quote_bytes: Vec<u8>,
	extra: ExtraData,
	quote: OnceLock<Quote>,
}

// Public API
impl TdxTicket {
	/// Returns the raw bytes of the TDX Quote contained in this ticket as emitted
	/// by the TDX hardware.
	pub const fn quote_bytes(&self) -> &[u8] {
		self.quote_bytes.as_slice()
	}

	/// Returns the raw TDX Quote contained in this ticket after parsing it from
	/// the raw bytes.
	///
	/// # Panics
	/// Should never panic because there is not way to construct a `TdxTicket`
	/// with invalid quote bytes, and all code paths that create a `TdxTicket`
	/// set the `quote` field after verifying the quote bytes. If this panics, it
	/// indicates a bug in the code.
	pub fn quote(&self) -> &Quote {
		self.quote.get().expect("quote should be initialized")
	}

	/// Returns the TDX measurements (`MR_TD` and `RTMR`s) contained in this
	/// ticket's TDX Quote.
	pub fn measurements(&self) -> Measurements {
		Measurements::from(self.quote())
	}

	/// Returns the global Peer ID of the peer process carrying this ticket.
	pub const fn peer_id(&self) -> &PeerId {
		&self.extra.peer_id
	}

	/// Returns the network ID that the peer process carrying this ticket is
	/// connected to.
	pub const fn network_id(&self) -> &NetworkId {
		&self.extra.network_id
	}

	/// Returns the time at which the peer process carrying this ticket was
	/// started. This is based on the peer's machine local clock and should not be
	/// blindly trusted.
	pub const fn started_at(&self) -> &DateTime<Utc> {
		&self.extra.started_at
	}

	/// Returns the time at which the TDX Quote contained in this ticket was
	/// generated. This is based on the peer's machine local clock and should not
	/// be blindly trusted.
	pub const fn quoted_at(&self) -> &DateTime<Utc> {
		&self.extra.quoted_at
	}

	/// Returns the uptime of the peer process carrying this ticket, calculated
	/// as the difference between the current time and the `started_at` time.
	/// This is based on the peer's machine local clock and should not be blindly
	/// trusted.
	pub fn uptime(&self) -> Duration {
		(Utc::now() - self.started_at())
			.to_std()
			.unwrap_or_default()
	}

	/// Returns the expiration time of this ticket.
	pub const fn expiration(&self) -> &Expiration {
		&self.extra.expiration
	}
}

impl core::fmt::Debug for TdxTicket {
	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
		f.debug_struct("TdxTicket")
			.field("quote", self.quote())
			.field("extra", &self.extra)
			.finish_non_exhaustive()
	}
}

/// Additional data included in the TDX ticket that is signed by the TDX
/// hardware but is not part of the TDX Quote itself. This includes
/// mosaik-specific data such as the peer's identity and network, as well as the
/// ticket's expiration time.
///
/// The TDX Quote contains a `report_data` field which is a 64-byte value that
/// that has its first 32 bytes set to the hash of this `ExtraData` struct
/// computed using the `signature` method, and the remaining 32 bytes set to
/// zero. This allows the TDX hardware to cryptographically bind the TDX Quote
/// to this additional data, so that when a ticket is validated, the validator
/// can check that the TDX Quote's `report_data` matches the expected hash of
/// the `ExtraData`, ensuring that the ticket is indeed associated with the
/// claimed peer identity and network.
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub(super) struct ExtraData {
	pub peer_id: PeerId,
	pub network_id: NetworkId,
	pub started_at: DateTime<Utc>,
	pub quoted_at: DateTime<Utc>,
	pub expiration: Expiration,
}

impl ExtraData {
	pub fn signature(&self) -> UniqueId {
		unique_id!("mosaik.tdx.ticket.extra")
			.derive(self.peer_id.as_bytes())
			.derive(self.network_id.as_bytes())
			.derive(self.started_at.timestamp().to_le_bytes())
			.derive(self.quoted_at.timestamp().to_le_bytes())
			.derive(match self.expiration {
				Expiration::Never => [0u8; 8],
				Expiration::At(dt) => dt.timestamp().to_le_bytes(),
			})
	}
}

impl TdxTicket {
	pub(super) fn new(
		quote_bytes: Vec<u8>,
		extra: ExtraData,
	) -> Result<Self, TdxTicketError> {
		let ticket = Self {
			quote_bytes,
			extra,
			quote: OnceLock::new(),
		};
		let quote = ticket.verify_quote()?;
		ticket.quote.set(quote).ok();
		Ok(ticket)
	}

	/// Verifies that the first 32 bytes of the quote's `report_data` match
	/// the hash of the `ExtraData` signature.
	fn verify_report_data(&self, quote: &Quote) -> Result<(), TdxTicketError> {
		let report_data = quote.report_input_data();
		if report_data[..32] != *self.extra.signature().as_bytes() {
			return Err(TdxTicketError::InvalidReportData);
		}
		Ok(())
	}

	/// Parses and verifies the TDX Quote from the raw bytes, checks that
	/// the `report_data` matches the `ExtraData` signature, and returns the
	/// parsed quote.
	fn verify_quote(&self) -> Result<Quote, TdxTicketError> {
		let quote = Quote::from_bytes(&self.quote_bytes)
			.map_err(|_| TdxTicketError::QuoteVerificationFailed)?;
		self.verify_report_data(&quote)?;
		Ok(quote)
	}
}

impl Serialize for TdxTicket {
	fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
		let mut state = serializer.serialize_struct("TdxTicket", 2)?;
		state.serialize_field("quote_bytes", &self.quote_bytes)?;
		state.serialize_field("extra", &self.extra)?;
		state.end()
	}
}

impl<'de> Deserialize<'de> for TdxTicket {
	fn deserialize<D: Deserializer<'de>>(
		deserializer: D,
	) -> Result<Self, D::Error> {
		#[derive(Deserialize)]
		struct Wire {
			quote_bytes: Vec<u8>,
			extra: ExtraData,
		}

		let wire = Wire::deserialize(deserializer)?;
		let ticket = Self {
			quote_bytes: wire.quote_bytes,
			extra: wire.extra,
			quote: OnceLock::new(),
		};
		let quote = ticket.verify_quote().map_err(serde::de::Error::custom)?;
		ticket.quote.set(quote).ok();
		Ok(ticket)
	}
}

/// Converts a `TdxTicket` into a generic mosaik `Ticket` by serializing the
/// `TdxTicket` and using the serialized bytes as the `data` field of the
/// `Ticket` and setting its class to the TDX ticket class constant.
impl TryFrom<TdxTicket> for Ticket {
	type Error = TdxTicketError;

	fn try_from(value: TdxTicket) -> Result<Self, Self::Error> {
		let serialized = encoding::try_serialize(&value)
			.map_err(|_| TdxTicketError::QuoteVerificationFailed)?;
		let compressed =
			zstd::bulk::compress(&serialized, TICKET_COMPRESSION_LEVEL)
				.map_err(|_| TdxTicketError::InvalidFormat)?;
		Ok(Self::new(super::TICKET_CLASS, compressed.into()))
	}
}

impl TryFrom<&[u8]> for TdxTicket {
	type Error = TdxTicketError;

	fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
		let decompressed = zstd::bulk::decompress(value, MAX_TICKET_SIZE)
			.map_err(|_| TdxTicketError::InvalidFormat)?;
		encoding::deserialize(&decompressed)
			.map_err(|_| TdxTicketError::InvalidFormat)
	}
}

impl TryFrom<Ticket> for TdxTicket {
	type Error = TdxTicketError;

	fn try_from(ticket: Ticket) -> Result<Self, Self::Error> {
		if ticket.class != super::TICKET_CLASS {
			return Err(TdxTicketError::InvalidTicketType);
		}
		Self::try_from(ticket.data.as_ref())
	}
}

impl core::fmt::Debug for Pretty<'_, TdxTicket> {
	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
		let ticket = &self.0;
		let measurements = ticket.measurements();
		let aux_hash = &ticket.quote().report_input_data()[..32];

		let mut d = f.debug_struct("TdxTicket");

		d.field("peer_id", &format_args!("{}", ticket.peer_id()))
			.field("network_id", &ticket.network_id())
			.field("mrtd", &format_args!("{}", measurements.mrtd()))
			.field("rtmr0", &format_args!("{}", measurements.rtmr0()))
			.field("rtmr1", &format_args!("{}", measurements.rtmr1()))
			.field("rtmr2", &format_args!("{}", measurements.rtmr2()))
			.field("aux_hash", &format_args!("{}", hex::encode(aux_hash)))
			.field("started_at", &ticket.started_at())
			.field("quoted_at", &ticket.quoted_at())
			.field("expiration", &ticket.expiration())
			.field("tdx_version", &match ticket.quote().body.tdx_version {
				tdx_quote::TDXVersion::One => 1.0,
				tdx_quote::TDXVersion::OnePointFive => 1.5,
			})
			.field("quote_version", &ticket.quote().header.version)
			.field("tee_type", &ticket.quote().header.tee_type);

		if ticket.quote().body.mrowner != [0u8; 48] {
			d.field("mrowner", &hex::encode(ticket.quote().body.mrowner));
		}

		d.finish()
	}
}

/// Maximum decompressed size for a TDX ticket (64 KiB). This acts as a safety
/// bound to prevent decompression bombs.
const MAX_TICKET_SIZE: usize = 64 * 1024;

/// Compression level for serialized TDX tickets when converting to the generic
/// `Ticket`, uses zstd's default compression level which is a good trade-off
/// between size and decompression speed for small data sizes.
///
/// In tests this reduces the serialized ticket size by around 50% compared to
/// uncompressed, while still keeping decompression times in the low
/// milliseconds range.
const TICKET_COMPRESSION_LEVEL: i32 = 3;