use {
super::{
IntoMeasurement,
Measurements,
MeasurementsCriteria,
local::Error,
ticket::TdxTicket,
},
crate::{
UniqueId,
discovery::PeerEntry,
tickets::{Expiration, InvalidTicket, TicketValidator},
},
chrono::{DateTime, Utc},
std::time::Duration,
};
#[derive(Clone)]
pub struct Tdx {
baseline: MeasurementsCriteria,
any: Vec<MeasurementsCriteria>,
max_age: Option<Duration>,
not_before: Option<DateTime<Utc>>,
max_lifetime: Option<Duration>,
}
impl Tdx {
pub const CLASS: UniqueId = super::TICKET_CLASS;
#[must_use]
pub const fn baseline(baseline: MeasurementsCriteria) -> Self {
Self {
baseline,
any: Vec::new(),
max_age: None,
not_before: None,
max_lifetime: None,
}
}
#[must_use]
pub const fn empty() -> Self {
Self::baseline(MeasurementsCriteria::new())
}
#[must_use]
pub const fn new() -> Self {
Self::empty()
}
#[must_use]
pub fn require_mrtd(mut self, mrtd: impl IntoMeasurement) -> Self {
self.baseline = self.baseline.require_mrtd(mrtd.into_measurement());
self
}
#[must_use]
pub fn require_rtmr0(mut self, rtmr0: impl IntoMeasurement) -> Self {
self.baseline = self.baseline.require_rtmr0(rtmr0.into_measurement());
self
}
#[must_use]
pub fn require_rtmr1(mut self, rtmr1: impl IntoMeasurement) -> Self {
self.baseline = self.baseline.require_rtmr1(rtmr1.into_measurement());
self
}
#[must_use]
pub fn require_rtmr2(mut self, rtmr2: impl IntoMeasurement) -> Self {
self.baseline = self.baseline.require_rtmr2(rtmr2.into_measurement());
self
}
#[must_use]
pub fn require_rtmr3(mut self, rtmr3: impl IntoMeasurement) -> Self {
self.baseline = self.baseline.require_rtmr3(rtmr3.into_measurement());
self
}
pub fn from_local() -> Result<Self, Error> {
let local = Measurements::local()?;
Ok(Self::baseline(MeasurementsCriteria::from(&local)))
}
pub fn require_own_mrtd(self) -> Result<Self, Error> {
let local = Measurements::local()?;
Ok(self.require_mrtd(local.mrtd()))
}
pub fn require_own_rtmr0(self) -> Result<Self, Error> {
let local = Measurements::local()?;
Ok(self.require_rtmr0(local.rtmr0()))
}
pub fn require_own_rtmr1(self) -> Result<Self, Error> {
let local = Measurements::local()?;
Ok(self.require_rtmr1(local.rtmr1()))
}
pub fn require_own_rtmr2(self) -> Result<Self, Error> {
let local = Measurements::local()?;
Ok(self.require_rtmr2(local.rtmr2()))
}
pub fn require_own_rtmr3(self) -> Result<Self, Error> {
let local = Measurements::local()?;
Ok(self.require_rtmr3(local.rtmr3()))
}
#[must_use]
pub fn allow_variant(mut self, criteria: MeasurementsCriteria) -> Self {
self.any.push(criteria);
self
}
#[must_use]
pub const fn max_age(mut self, max_age: Duration) -> Self {
self.max_age = Some(max_age);
self
}
#[must_use]
pub const fn not_before(mut self, not_before: DateTime<Utc>) -> Self {
self.not_before = Some(not_before);
self
}
#[must_use]
pub const fn max_lifetime(mut self, max_lifetime: Duration) -> Self {
self.max_lifetime = Some(max_lifetime);
self
}
}
impl Default for Tdx {
fn default() -> Self {
Self::new()
}
}
impl TicketValidator for Tdx {
fn class(&self) -> UniqueId {
Self::CLASS
}
fn signature(&self) -> UniqueId {
let sig = self
.any
.iter()
.fold(self.class().derive(self.baseline.signature()), |s, c| {
s.derive(c.signature())
});
let sig = self.max_age.map_or_else(
|| sig.derive([0u8; 8]),
|d| sig.derive(d.as_secs().to_le_bytes()),
);
let sig = self.not_before.map_or_else(
|| sig.derive([0u8; 8]),
|dt| sig.derive(dt.timestamp().to_le_bytes()),
);
self.max_lifetime.map_or_else(
|| sig.derive([0u8; 8]),
|d| sig.derive(d.as_secs().to_le_bytes()),
)
}
fn validate(
&self,
bytes: &[u8],
peer: &PeerEntry,
) -> Result<Expiration, InvalidTicket> {
let Ok(ticket) = TdxTicket::try_from(bytes) else {
return Err(InvalidTicket);
};
let measurements = ticket.measurements();
if !self.baseline.matches(&measurements) {
return Err(InvalidTicket);
}
if !self.any.is_empty()
&& !self.any.iter().any(|c| c.matches(&measurements))
{
return Err(InvalidTicket);
}
if ticket.peer_id() != peer.id() {
return Err(InvalidTicket);
}
if ticket.network_id() != peer.network_id() {
return Err(InvalidTicket);
}
if let Some(max_age) = self.max_age {
let age = (Utc::now() - ticket.quoted_at())
.to_std()
.unwrap_or(Duration::MAX);
if age > max_age {
return Err(InvalidTicket);
}
}
if let Some(not_before) = self.not_before
&& *ticket.quoted_at() < not_before
{
return Err(InvalidTicket);
}
if let Some(max_lifetime) = self.max_lifetime {
match ticket.expiration() {
Expiration::Never => return Err(InvalidTicket),
Expiration::At(expires_at) => {
let lifetime = (*expires_at - *ticket.quoted_at())
.to_std()
.unwrap_or(Duration::MAX);
if lifetime > max_lifetime {
return Err(InvalidTicket);
}
}
}
}
if ticket.expiration().is_expired() {
return Err(InvalidTicket);
}
Ok(*ticket.expiration())
}
}