use crate::{
collateral::Collateral,
measurements::{FullMeasurements, MeasurementsError},
quote::QuoteBytes,
report_data::ReportData,
tcb_info::{EventLog, HexBytes, TcbInfo},
};
use alloc::{
format,
string::{String, ToString},
vec::Vec,
};
use borsh::{BorshDeserialize, BorshSerialize};
use core::fmt;
use dcap_qvl::verify::VerifiedReport;
use derive_more::Constructor;
use serde::{Deserialize, Serialize};
use serde_json::json;
use sha2::{Digest as _, Sha256, Sha384};
const EXPECTED_QUOTE_STATUS: &str = "UpToDate";
const DSTACK_EVENT_TYPE: u32 = 134217729;
const COMPOSE_HASH_EVENT: &str = "compose-hash";
pub(crate) const KEY_PROVIDER_EVENT: &str = "key-provider";
const RTMR3_INDEX: u32 = 3;
#[derive(Clone, Constructor, Serialize, Deserialize, BorshDeserialize, BorshSerialize)]
pub struct DstackAttestation {
pub quote: QuoteBytes,
pub collateral: Collateral,
pub tcb_info: TcbInfo,
}
#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
pub enum VerificationError {
#[error("could not parse embedded measurements: {0}")]
EmbeddedMeasurementsParsing(MeasurementsError),
#[error("dcap verification failed: {0}")]
DcapVerification(String),
#[error("verification report is not TD10")]
ReportNotTd10,
#[error("TCB status `{0}` is not up to date")]
TcbStatusNotUpToDate(String),
#[error("ouststanding advisories reported: {0}")]
NonEmptyAdvisoryIds(String),
#[error("wrong {name} hash (found {found} expected {expected})")]
WrongHash {
name: &'static str,
found: String,
expected: String,
},
#[error("invalid event type {0}")]
InvalidEventType(u32),
#[error("failed to decode event digest `{0}`")]
EventDecoding(String),
#[error("failed to parse app compose JSON: {0}")]
AppComposeParsing(String),
#[error("no {0} event in event log")]
MissingEvent(&'static str),
#[error("duplicate {0} events in event log")]
DuplicateEvent(&'static str),
#[error("invalid app compose config: `{0}`")]
InvalidAppComposeConfig(String),
#[error("app-compose event payload had an unexpected size of {0}")]
AppComposeEventPayloadWrongSize(usize),
#[error("app-compose event payload `{0}` is not a hex string")]
AppComposeEventPayloadNotHex(String),
#[error(
"the attestation certificate with timestap {attestation_time} has expired since {expiry_time}"
)]
ExpiredCertificate {
attestation_time: u64,
expiry_time: u64,
},
#[error("PPID must be 32 bytes, got {0}")]
PpidWrongSize(usize),
#[error("the mock attestation is invalid per definition")]
InvalidMockAttestation,
#[error("custom error: `{0}`")]
Custom(String),
}
impl fmt::Debug for DstackAttestation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
const MAX_BYTES: usize = 2048;
fn truncate_debug<T: fmt::Debug>(value: &T, max_bytes: usize) -> String {
let debug_str = format!("{:?}", value);
if debug_str.len() <= max_bytes {
debug_str
} else {
format!(
"{}... (truncated {} bytes)",
&debug_str[..max_bytes],
debug_str.len() - max_bytes
)
}
}
f.debug_struct("DstackAttestation")
.field("quote", &truncate_debug(&self.quote, MAX_BYTES))
.field("collateral", &truncate_debug(&self.collateral, MAX_BYTES))
.field("tcb_info", &truncate_debug(&self.tcb_info, MAX_BYTES))
.finish()
}
}
pub fn create_mock_dstack_attestation() -> DstackAttestation {
let collateral_json = json!({
"tcb_info_issuer_chain": "",
"tcb_info": "",
"tcb_info_signature": "",
"qe_identity_issuer_chain": "",
"qe_identity": "",
"qe_identity_signature": "",
"pck_crl_issuer_chain": "",
"root_ca_crl": "",
"pck_crl": ""
});
let collateral = Collateral::try_from_json(collateral_json).expect("mock collateral is valid");
let tcb_info = TcbInfo {
mrtd: HexBytes::default(),
rtmr0: HexBytes::default(),
rtmr1: HexBytes::default(),
rtmr2: HexBytes::default(),
rtmr3: HexBytes::default(),
os_image_hash: None,
compose_hash: HexBytes::default(),
device_id: HexBytes::default(),
app_compose: String::new(),
event_log: Vec::new(),
};
DstackAttestation::new(QuoteBytes::from(Vec::new()), collateral, tcb_info)
}
impl DstackAttestation {
pub fn verify(
&self,
expected_report_data: ReportData,
timestamp_seconds: u64,
accepted_measurements: &[FullMeasurements],
accepted_ppids: &[HexBytes<16>],
) -> Result<(FullMeasurements, HexBytes<16>), VerificationError> {
let verification_result =
dcap_qvl::verify::verify(&self.quote, &self.collateral, timestamp_seconds)
.map_err(|e| VerificationError::DcapVerification(e.to_string()))?;
let report_data = verification_result
.report
.as_td10()
.ok_or(VerificationError::ReportNotTd10)?;
self.verify_tcb_status(&verification_result)?;
self.verify_report_data(&expected_report_data, report_data)?;
let ppid = self.verify_ppid(verification_result.ppid, accepted_ppids)?;
self.verify_rtmr3(report_data, &self.tcb_info)?;
self.verify_app_compose(&self.tcb_info)?;
let measurements =
self.verify_any_measurements(report_data, &self.tcb_info, accepted_measurements)?;
Ok((measurements, ppid))
}
fn verify_event_log_rtmr3(
event_log: &[EventLog],
expected_digest: [u8; 48],
) -> Result<(), VerificationError> {
let mut digest = [0u8; 48];
let filtered_events = event_log.iter().filter(|e| e.imr == RTMR3_INDEX);
for event in filtered_events {
if event.event_type != DSTACK_EVENT_TYPE {
return Err(VerificationError::InvalidEventType(event.event_type));
}
let mut hasher = Sha384::new();
hasher.update(digest);
let payload_bytes = match hex::decode(&event.event_payload) {
Ok(bytes) => bytes,
Err(e) => {
tracing::error!("Failed to decode hex string for: {:?}", e);
return Err(VerificationError::EventDecoding(hex::encode(*event.digest)));
}
};
let expected_digest =
Self::event_digest(event.event_type, &event.event, &payload_bytes);
compare_hashes("event_digest", event.digest.as_slice(), &expected_digest)?;
hasher.update(event.digest.as_slice());
digest = hasher.finalize().into();
}
compare_hashes("event_log", &digest, &expected_digest)
}
fn validate_app_compose_payload(
expected_event_payload_hex: &str,
app_compose: &str,
) -> Result<(), VerificationError> {
let expected_payload = match hex::decode(expected_event_payload_hex) {
Ok(bytes) => match <[u8; 32]>::try_from(bytes.as_slice()) {
Ok(expected_bytes) => expected_bytes,
Err(_) => {
return Err(VerificationError::AppComposeEventPayloadWrongSize(
bytes.len(),
));
}
},
Err(e) => {
tracing::error!(
"Failed to decode hex string for compose-hash event: {:?}",
e
);
return Err(VerificationError::AppComposeEventPayloadNotHex(
expected_event_payload_hex.to_string(),
));
}
};
let app_compose_hash: [u8; 32] = Sha256::digest(app_compose.as_bytes()).into();
compare_hashes("app_compose_payload", &app_compose_hash, &expected_payload)
}
fn verify_tcb_status(
&self,
verification_result: &VerifiedReport,
) -> Result<(), VerificationError> {
let status_is_up_to_date = verification_result.status == EXPECTED_QUOTE_STATUS;
let no_security_advisories = verification_result.advisory_ids.is_empty();
status_is_up_to_date.or_err(|| {
VerificationError::TcbStatusNotUpToDate(verification_result.status.clone())
})?;
no_security_advisories.or_err(|| {
VerificationError::NonEmptyAdvisoryIds(verification_result.advisory_ids.join(", "))
})?;
Ok(())
}
fn verify_report_data(
&self,
expected: &ReportData,
actual: &dcap_qvl::quote::TDReport10,
) -> Result<(), VerificationError> {
compare_hashes("report_data", &actual.report_data, &expected.to_bytes())
}
fn verify_ppid(
&self,
ppid: Vec<u8>,
accepted_ppids: &[HexBytes<16>],
) -> Result<HexBytes<16>, VerificationError> {
let ppid_array = match <[u8; 16]>::try_from(ppid.as_slice()) {
Ok(array) => array,
Err(_) => {
return Err(VerificationError::PpidWrongSize(ppid.len()));
}
};
let ppid_hex_bytes = HexBytes::from(ppid_array);
if !accepted_ppids.contains(&ppid_hex_bytes) {
return Err(VerificationError::Custom(format!(
"PPID {} is not in the allowed PPIDs list",
hex::encode(ppid_hex_bytes.as_ref())
)));
}
Ok(ppid_hex_bytes)
}
fn verify_any_measurements(
&self,
report_data: &dcap_qvl::quote::TDReport10,
tcb_info: &TcbInfo,
accepted_measurements: &[FullMeasurements],
) -> Result<FullMeasurements, VerificationError> {
for expected in accepted_measurements {
if self
.verify_static_rtmrs(report_data, tcb_info, expected)
.is_ok()
&& self
.verify_key_provider_digest(tcb_info, &expected.key_provider_event_digest)
.is_ok()
&& self
.verify_app_compose_hash(tcb_info, &expected.app_compose_hash_payload)
.is_ok()
{
return Ok(*expected); }
}
Err(VerificationError::WrongHash {
name: "expected_measurements",
expected: "one of the embedded TCB info sets (prod or dev)".into(),
found: "none matched".into(),
})
}
fn verify_static_rtmrs(
&self,
report_data: &dcap_qvl::quote::TDReport10,
tcb_info: &TcbInfo,
expected_measurements: &FullMeasurements,
) -> Result<(), VerificationError> {
compare_hashes(
"rtmr0_report_data",
&report_data.rt_mr0,
&expected_measurements.rtmrs.rtmr0,
)?;
compare_hashes(
"rtmr1_report_data",
&report_data.rt_mr1,
&expected_measurements.rtmrs.rtmr1,
)?;
compare_hashes(
"rtmr2_report_data",
&report_data.rt_mr2,
&expected_measurements.rtmrs.rtmr2,
)?;
compare_hashes(
"mrtd_report_data",
&report_data.mr_td,
&expected_measurements.rtmrs.mrtd,
)?;
compare_hashes(
"rtmr0_tcb_info",
tcb_info.rtmr0.as_slice(),
&expected_measurements.rtmrs.rtmr0,
)?;
compare_hashes(
"rtmr1_tcb_info",
tcb_info.rtmr1.as_slice(),
&expected_measurements.rtmrs.rtmr1,
)?;
compare_hashes(
"rtmr2_tcb_info",
tcb_info.rtmr2.as_slice(),
&expected_measurements.rtmrs.rtmr2,
)?;
compare_hashes(
"mtrd_tcb_info",
tcb_info.mrtd.as_slice(),
&expected_measurements.rtmrs.mrtd,
)
}
fn verify_rtmr3(
&self,
report_data: &dcap_qvl::quote::TDReport10,
tcb_info: &TcbInfo,
) -> Result<(), VerificationError> {
compare_hashes("rtmr3", tcb_info.rtmr3.as_slice(), &report_data.rt_mr3)?;
Self::verify_event_log_rtmr3(&tcb_info.event_log, report_data.rt_mr3)
}
fn verify_app_compose(&self, tcb_info: &TcbInfo) -> Result<(), VerificationError> {
let app_compose_event = tcb_info.get_single_event(COMPOSE_HASH_EVENT)?;
compare_hex_hashes(
"app_compose_event_hash",
&app_compose_event.event_payload,
&hex::encode(*tcb_info.compose_hash),
)?;
Self::validate_app_compose_payload(&app_compose_event.event_payload, &tcb_info.app_compose)
}
fn verify_key_provider_digest(
&self,
tcb_info: &TcbInfo,
expected_digest: &[u8; 48],
) -> Result<(), VerificationError> {
let key_provider_event = tcb_info.get_single_event(KEY_PROVIDER_EVENT)?;
compare_hashes(
"key_provider",
key_provider_event.digest.as_slice(),
expected_digest,
)
}
fn verify_app_compose_hash(
&self,
tcb_info: &TcbInfo,
expected_hash_payload: &[u8; 32],
) -> Result<(), VerificationError> {
compare_hashes(
"app_compose_hash",
tcb_info.compose_hash.as_slice(),
expected_hash_payload,
)
}
fn event_digest(event_type: u32, event: &str, payload: &[u8]) -> [u8; 48] {
let mut hasher = Sha384::new();
hasher.update(event_type.to_ne_bytes());
hasher.update(b":");
hasher.update(event.as_bytes());
hasher.update(b":");
hasher.update(payload);
hasher.finalize().into()
}
}
fn compare_hashes(
name: &'static str,
found: &[u8],
expected: &[u8],
) -> Result<(), VerificationError> {
(found == expected).or_err(|| VerificationError::WrongHash {
name,
found: hex::encode(found),
expected: hex::encode(expected),
})
}
fn compare_hex_hashes<S: ToString + Eq>(
name: &'static str,
found: S,
expected: S,
) -> Result<(), VerificationError> {
(found == expected).or_err(|| VerificationError::WrongHash {
name,
found: found.to_string(),
expected: expected.to_string(),
})
}
pub trait OrErr {
fn or_err<Error>(self, err: impl FnOnce() -> Error) -> Result<(), Error>;
}
impl OrErr for bool {
fn or_err<Error>(self, err: impl FnOnce() -> Error) -> Result<(), Error> {
self.then_some(()).ok_or_else(err)
}
}
pub trait GetSingleEvent {
fn get_single_event(&self, event_name: &'static str) -> Result<&EventLog, VerificationError>;
}
impl GetSingleEvent for TcbInfo {
fn get_single_event(&self, event_name: &'static str) -> Result<&EventLog, VerificationError> {
let mut events = self
.event_log
.iter()
.filter(|event| event.event == event_name && event.imr == RTMR3_INDEX);
let Some(event) = events.next() else {
return Err(VerificationError::MissingEvent(event_name));
};
if events.next().is_some() {
Err(VerificationError::DuplicateEvent(event_name))
} else {
Ok(event)
}
}
}