1use std::collections::{BTreeMap, HashMap};
4use std::sync::{Arc, RwLock};
5
6use dcap_qvl::collateral::get_collateral;
7use dcap_qvl::quote::Quote;
8use dcap_qvl::verify::{verify, VerifiedReport};
9use dcap_qvl::QuoteCollateralV3;
10use dstack_sdk_types::dstack::{EventLog, GetQuoteResponse};
11use log::{debug, warn};
12use sha2::{Digest, Sha256, Sha512};
13
14use crate::dstack::compose_hash::get_compose_hash;
15use crate::dstack::config::DstackTDXVerifierConfig;
16use crate::error::AtlsVerificationError;
17use crate::verifier::{AsyncByteStream, AsyncReadExt, AsyncWriteExt, AtlsVerifier, Report};
18
19pub use crate::dstack::config::DstackTDXVerifierBuilder;
20
21type CollateralCacheKey = (String, String, &'static str);
23
24#[derive(Clone)]
26struct CachedCollateral {
27 collateral: QuoteCollateralV3,
28 cached_at_secs: u64,
29}
30
31const COLLATERAL_CACHE_TTL_SECS: u64 = 8 * 3600;
33
34#[derive(Debug, serde::Deserialize)]
36struct QuoteEndpointResponse {
37 quote: GetQuoteResponse,
38}
39
40pub struct DstackTDXVerifier {
51 config: DstackTDXVerifierConfig,
52 cached_collateral: Arc<RwLock<HashMap<CollateralCacheKey, CachedCollateral>>>,
54}
55
56impl DstackTDXVerifier {
57 pub fn new(config: DstackTDXVerifierConfig) -> Result<Self, AtlsVerificationError> {
59 if !config.disable_runtime_verification {
61 if config.expected_bootchain.is_none() || config.os_image_hash.is_none() {
62 return Err(AtlsVerificationError::Configuration(
63 "expected_bootchain and os_image_hash must be provided together".into(),
64 ));
65 }
66 if config.app_compose.is_none() {
67 return Err(AtlsVerificationError::Configuration(
68 "app_compose must be provided".into(),
69 ));
70 }
71 }
72 Ok(Self {
73 config,
74 cached_collateral: Arc::new(RwLock::new(HashMap::new())),
75 })
76 }
77
78 pub fn builder() -> DstackTDXVerifierBuilder {
80 DstackTDXVerifierBuilder::new()
81 }
82
83 async fn verify_quote(&self, quote: &[u8]) -> Result<VerifiedReport, AtlsVerificationError> {
85 let pccs_url = self.config.pccs_url.as_deref().unwrap_or_default();
86 let pccs_url = if pccs_url.is_empty() {
87 "https://api.trustedservices.intel.com"
88 } else {
89 pccs_url
90 };
91
92 let parsed_quote = Quote::parse(quote)
94 .map_err(|e| AtlsVerificationError::Quote(format!("Failed to parse quote: {}", e)))?;
95 let fmspc = hex::encode_upper(
96 parsed_quote
97 .fmspc()
98 .map_err(|e| AtlsVerificationError::Quote(format!("Failed to get FMSPC: {}", e)))?,
99 );
100 let ca = parsed_quote
101 .ca()
102 .map_err(|e| AtlsVerificationError::Quote(format!("Failed to get CA: {}", e)))?;
103
104 let cache_key = (pccs_url.to_string(), fmspc.clone(), ca);
105
106 #[cfg(not(target_arch = "wasm32"))]
108 let now_secs = std::time::SystemTime::now()
109 .duration_since(std::time::UNIX_EPOCH)
110 .map_err(|e| {
111 AtlsVerificationError::Quote(format!("Failed to get current time: {}", e))
112 })?
113 .as_secs();
114
115 #[cfg(target_arch = "wasm32")]
116 let now_secs = (js_sys::Date::now() / 1000.0) as u64;
117
118 let cached = if self.config.cache_collateral {
120 match self.cached_collateral.read() {
121 Ok(guard) => guard.get(&cache_key).and_then(|entry| {
122 if now_secs.saturating_sub(entry.cached_at_secs) < COLLATERAL_CACHE_TTL_SECS {
123 Some(entry.collateral.clone())
124 } else {
125 debug!("Cached collateral expired for FMSPC={}, CA={}", fmspc, ca);
126 None
127 }
128 }),
129 Err(_) => {
130 warn!("Collateral cache lock poisoned, treating as cache miss");
131 None
132 }
133 }
134 } else {
135 None
136 };
137
138 let collateral = match cached {
139 Some(c) => {
140 debug!(
141 "Using cached collateral for PCCS={}, FMSPC={}, CA={}",
142 pccs_url, fmspc, ca
143 );
144 c
145 }
146 None => {
147 debug!("Fetching collateral from {}", pccs_url);
148 let c = get_collateral(pccs_url, quote)
149 .await
150 .map_err(|e| {
151 AtlsVerificationError::Quote(format!("Failed to get collateral: {}", e))
152 })?;
153
154 if self.config.cache_collateral {
156 match self.cached_collateral.write() {
157 Ok(mut guard) => {
158 debug!("Caching collateral for FMSPC={}, CA={}", fmspc, ca);
159 guard.insert(cache_key, CachedCollateral {
160 collateral: c.clone(),
161 cached_at_secs: now_secs,
162 });
163 }
164 Err(_) => {
165 warn!("Collateral cache lock poisoned, skipping cache write");
166 }
167 }
168 }
169 c
170 }
171 };
172
173 debug!("Collateral received, verifying DCAP quote");
174
175 let report = verify(quote, &collateral, now_secs)
177 .map_err(|e| AtlsVerificationError::Quote(format!("DCAP verification failed: {}", e)))?;
178
179 debug!("DCAP verification complete, TCB status: {}", report.status);
180
181 let tcb_allowed = self
183 .config
184 .allowed_tcb_status
185 .iter()
186 .any(|s| s == &report.status);
187
188 debug!(
189 "TCB status '{}' allowed: {}",
190 report.status, tcb_allowed
191 );
192
193 if !tcb_allowed {
194 return Err(AtlsVerificationError::TcbStatusNotAllowed {
195 status: report.status.clone(),
196 allowed: self.config.allowed_tcb_status.clone(),
197 });
198 }
199
200 Ok(report)
201 }
202
203 fn verify_bootchain(
210 &self,
211 verified_report: &VerifiedReport,
212 ) -> Result<(), AtlsVerificationError> {
213 let bootchain = self.config.expected_bootchain.as_ref().ok_or_else(|| {
214 AtlsVerificationError::Configuration("expected_bootchain is required".into())
215 })?;
216
217 let td_report = verified_report.report.as_td10().ok_or_else(|| {
219 AtlsVerificationError::TeeTypeMismatch(
220 "expected TDX report but got SGX enclave report".into(),
221 )
222 })?;
223
224 debug!("Verifying bootchain measurements against verified report");
225
226 let actual_mrtd = hex::encode(td_report.mr_td);
228 debug!("MRTD expected: {}", bootchain.mrtd);
229 debug!("MRTD actual: {}", actual_mrtd);
230 let mrtd_match = actual_mrtd == bootchain.mrtd;
231 debug!("MRTD match: {}", mrtd_match);
232
233 if !mrtd_match {
234 return Err(AtlsVerificationError::BootchainMismatch {
235 field: "mrtd".into(),
236 expected: bootchain.mrtd.clone(),
237 actual: actual_mrtd,
238 });
239 }
240
241 let actual_rtmrs = [
243 hex::encode(td_report.rt_mr0),
244 hex::encode(td_report.rt_mr1),
245 hex::encode(td_report.rt_mr2),
246 ];
247 let expected_rtmrs = [&bootchain.rtmr0, &bootchain.rtmr1, &bootchain.rtmr2];
248
249 for idx in 0..3usize {
250 debug!("RTMR{} expected: {}", idx, expected_rtmrs[idx]);
251 debug!("RTMR{} actual: {}", idx, actual_rtmrs[idx]);
252 let rtmr_match = &actual_rtmrs[idx] == expected_rtmrs[idx];
253 debug!("RTMR{} match: {}", idx, rtmr_match);
254
255 if !rtmr_match {
256 return Err(AtlsVerificationError::BootchainMismatch {
257 field: format!("rtmr{}", idx),
258 expected: expected_rtmrs[idx].clone(),
259 actual: actual_rtmrs[idx].clone(),
260 });
261 }
262 }
263
264 debug!("Bootchain verification successful");
265 Ok(())
266 }
267
268 fn verify_cert_in_eventlog(
273 &self,
274 cert_der: &[u8],
275 events: &[EventLog],
276 ) -> Result<bool, AtlsVerificationError> {
277 let cert_hash = hex::encode(Sha256::digest(cert_der));
278 debug!("Certificate hash: {}", cert_hash);
279
280 let cert_event = events
282 .iter()
283 .rfind(|e| e.event == "New TLS Certificate");
284
285 match cert_event {
286 Some(event) => {
287 let decoded = hex::decode(&event.event_payload).map_err(|e| {
289 AtlsVerificationError::EventLogParse(format!(
290 "failed to hex-decode certificate event payload: {}",
291 e
292 ))
293 })?;
294
295 let eventlog_cert_hash = String::from_utf8(decoded).map_err(|e| {
296 AtlsVerificationError::EventLogParse(format!(
297 "certificate event payload is not valid UTF-8: {}",
298 e
299 ))
300 })?;
301
302 debug!("Certificate hash from event log: {}", eventlog_cert_hash);
303 let cert_match = eventlog_cert_hash == cert_hash;
304 debug!("Certificate hash match: {}", cert_match);
305 Ok(cert_match)
306 }
307 None => {
308 debug!("No 'New TLS Certificate' event found in event log");
309 Ok(false)
310 }
311 }
312 }
313
314 fn verify_app_compose(&self, events: &[EventLog]) -> Result<(), AtlsVerificationError> {
321 let app_compose = self.config.app_compose.as_ref().ok_or_else(|| {
322 AtlsVerificationError::Configuration("app_compose is required".into())
323 })?;
324 let expected = get_compose_hash(app_compose).map_err(|e| {
325 AtlsVerificationError::Configuration(format!(
326 "Failed to serialize app_compose for hashing: {}",
327 e
328 ))
329 })?;
330
331 debug!("Verifying app compose hash against trusted event log");
332 debug!("App compose hash expected: {}", expected);
333
334 let event = events
336 .iter()
337 .find(|e| e.event == "compose-hash")
338 .ok_or_else(|| {
339 AtlsVerificationError::AppComposeHashMismatch {
340 expected: expected.clone(),
341 actual: "<not found in event log>".to_string(),
342 }
343 })?;
344
345 debug!("App compose hash from event log: {}", event.event_payload);
346 let eventlog_match = event.event_payload == expected;
347 debug!("App compose hash match: {}", eventlog_match);
348
349 if !eventlog_match {
350 return Err(AtlsVerificationError::AppComposeHashMismatch {
351 expected,
352 actual: event.event_payload.clone(),
353 });
354 }
355
356 debug!("App compose verification successful");
357 Ok(())
358 }
359
360 fn verify_os_image_hash(&self, events: &[EventLog]) -> Result<(), AtlsVerificationError> {
367 let expected = self.config.os_image_hash.as_ref().ok_or_else(|| {
368 AtlsVerificationError::Configuration("os_image_hash is required".into())
369 })?;
370
371 debug!("Verifying OS image hash against trusted event log");
372 debug!("OS image hash expected: {}", expected);
373
374 let event = events
376 .iter()
377 .find(|e| e.event == "os-image-hash")
378 .ok_or_else(|| AtlsVerificationError::OsImageHashMismatch {
379 expected: expected.clone(),
380 actual: Some("<not found in event log>".to_string()),
381 })?;
382
383 debug!("OS image hash from event log: {}", event.event_payload);
384 let eventlog_match = &event.event_payload == expected;
385 debug!("OS image hash match: {}", eventlog_match);
386
387 if !eventlog_match {
388 return Err(AtlsVerificationError::OsImageHashMismatch {
389 expected: expected.clone(),
390 actual: Some(event.event_payload.clone()),
391 });
392 }
393
394 debug!("OS image hash verification successful");
395 Ok(())
396 }
397
398 fn verify_rtmr_replay(
403 &self,
404 quote_response: &GetQuoteResponse,
405 verified_report: &VerifiedReport,
406 ) -> Result<(), AtlsVerificationError> {
407 debug!("Verifying RTMR replay against verified report");
408
409 let td_report = verified_report.report.as_td10().ok_or_else(|| {
411 AtlsVerificationError::TeeTypeMismatch(
412 "expected TDX report but got SGX enclave report".into(),
413 )
414 })?;
415
416 let replayed: BTreeMap<u8, String> = quote_response
418 .replay_rtmrs()
419 .map_err(AtlsVerificationError::Other)?;
420
421 let trusted_rtmrs = [
423 hex::encode(td_report.rt_mr0),
424 hex::encode(td_report.rt_mr1),
425 hex::encode(td_report.rt_mr2),
426 hex::encode(td_report.rt_mr3),
427 ];
428
429 for i in 0..4u8 {
430 let replayed_rtmr = replayed.get(&i).cloned().ok_or_else(|| {
431 AtlsVerificationError::Quote(format!(
432 "RTMR{} missing from event log replay - malformed event log",
433 i
434 ))
435 })?;
436 debug!(
437 "RTMR{} from verified report: {}",
438 i, trusted_rtmrs[i as usize]
439 );
440 debug!("RTMR{} replayed: {}", i, replayed_rtmr);
441 let rtmr_match = replayed_rtmr == trusted_rtmrs[i as usize];
442 debug!("RTMR{} replay match: {}", i, rtmr_match);
443
444 if !rtmr_match {
445 return Err(AtlsVerificationError::RtmrMismatch {
446 index: i,
447 expected: trusted_rtmrs[i as usize].clone(),
448 actual: replayed_rtmr,
449 });
450 }
451 }
452
453 debug!("RTMR replay verification successful");
454 Ok(())
455 }
456
457 fn verify_report_data(
462 &self,
463 nonce: &[u8; 32],
464 session_ekm: &[u8; 32],
465 verified_report: &VerifiedReport,
466 ) -> Result<(), AtlsVerificationError> {
467 debug!("Verifying report data against verified report");
468
469 let mut hasher = Sha512::new();
471 hasher.update(nonce);
472 hasher.update(session_ekm);
473 let report_data: [u8; 64] = hasher.finalize().into();
474
475 let td_report = verified_report.report.as_td10().ok_or_else(|| {
477 AtlsVerificationError::TeeTypeMismatch(
478 "expected TDX report but got SGX enclave report".into(),
479 )
480 })?;
481
482 let expected = hex::encode(report_data);
483 let actual = hex::encode(td_report.report_data);
484
485 debug!("Report data expected: {}", expected);
486 debug!("Report data actual: {}", actual);
487
488
489 if expected != actual {
490 return Err(AtlsVerificationError::ReportDataMismatch { expected, actual });
491 }
492
493 debug!("Report data verification successful");
494 Ok(())
495 }
496}
497
498impl AtlsVerifier for DstackTDXVerifier {
499 async fn verify<S>(
500 &self,
501 stream: &mut S,
502 peer_cert: &[u8],
503 session_ekm: &[u8],
504 hostname: &str,
505 ) -> Result<Report, AtlsVerificationError>
506 where
507 S: AsyncByteStream,
508 {
509 debug!("Starting DStack TDX verification for {}", hostname);
510
511 let mut nonce = [0u8; 32];
513 rand::Rng::fill(&mut rand::thread_rng(), &mut nonce);
514
515 let quote_response = get_quote_over_http(stream, &nonce, hostname).await?;
517
518 debug!("Parsing event log");
520 let events = quote_response
521 .decode_event_log()
522 .map_err(|e| AtlsVerificationError::Other(e.into()))?;
523 debug!("Event log parsed, {} events found", events.len());
524
525 debug!("Verifying certificate in event log");
527 let cert_in_eventlog = self.verify_cert_in_eventlog(peer_cert, &events)?;
528 if !cert_in_eventlog {
529 return Err(AtlsVerificationError::CertificateNotInEventLog);
530 }
531
532 debug!("Decoding quote for DCAP verification");
534 let quote_bytes = quote_response
535 .decode_quote()
536 .map_err(|e| AtlsVerificationError::Other(anyhow::anyhow!("Failed to decode quote: {}", e)))?;
537 debug!("Quote decoded ({} bytes)", quote_bytes.len());
538
539 let verified_report = self.verify_quote("e_bytes).await?;
541
542 let session_ekm: &[u8; 32] = session_ekm.try_into().map_err(|_| {
544 AtlsVerificationError::Configuration(
545 "session_ekm must be exactly 32 bytes".into(),
546 )
547 })?;
548 self.verify_report_data(&nonce, session_ekm, &verified_report)?;
549
550 self.verify_rtmr_replay("e_response, &verified_report)?;
552
553 if self.config.disable_runtime_verification {
555 debug!("Runtime verification disabled, skipping bootchain/app-compose/os-image checks");
556 return Ok(Report::Tdx(verified_report));
557 }
558
559 self.verify_bootchain(&verified_report)?;
561
562 self.verify_app_compose(&events)?;
564
565 self.verify_os_image_hash(&events)?;
567
568 debug!("DStack TDX verification complete");
569 Ok(Report::Tdx(verified_report))
570 }
571}
572
573async fn get_quote_over_http<S>(
575 stream: &mut S,
576 nonce: &[u8; 32],
577 hostname: &str,
578) -> Result<GetQuoteResponse, AtlsVerificationError>
579where
580 S: AsyncByteStream,
581{
582 debug!("Sending POST /tdx_quote request to {}", hostname);
583
584 let body = serde_json::json!({
586 "nonce_hex": hex::encode(nonce)
587 });
588 let body_str = body.to_string();
589
590 let request = format!(
591 "POST /tdx_quote HTTP/1.1\r\n\
592 Host: {}\r\n\
593 Content-Type: application/json\r\n\
594 Content-Length: {}\r\n\
595 Connection: keep-alive\r\n\
596 \r\n\
597 {}",
598 hostname,
599 body_str.len(),
600 body_str
601 );
602
603 stream
604 .write_all(request.as_bytes())
605 .await
606 .map_err(|e| AtlsVerificationError::Io(e.to_string()))?;
607 stream
608 .flush()
609 .await
610 .map_err(|e| AtlsVerificationError::Io(e.to_string()))?;
611
612 let mut response_buf = Vec::new();
614 let mut chunk = [0u8; 4096];
615
616 loop {
618 let n = stream
619 .read(&mut chunk)
620 .await
621 .map_err(|e| AtlsVerificationError::Io(e.to_string()))?;
622 if n == 0 {
623 break;
624 }
625 response_buf.extend_from_slice(&chunk[..n]);
626
627 if let Some(body_start) = find_http_body_start(&response_buf) {
629 if let Some(content_length) = parse_content_length(&response_buf[..body_start]) {
631 if response_buf.len() >= body_start + content_length {
632 break;
633 }
634 }
635 }
636 }
637
638 debug!("Received quote response ({} bytes)", response_buf.len());
639
640 let body_start = find_http_body_start(&response_buf)
642 .ok_or_else(|| AtlsVerificationError::Io("Invalid HTTP response".into()))?;
643 let response_body = &response_buf[body_start..];
644
645 let response: QuoteEndpointResponse = serde_json::from_slice(response_body)
646 .map_err(|e| {
647 AtlsVerificationError::Quote(format!(
648 "Failed to parse /tdx_quote response: {}",
649 e
650 ))
651 })?;
652
653 Ok(response.quote)
654}
655
656fn find_http_body_start(data: &[u8]) -> Option<usize> {
658 for i in 0..data.len().saturating_sub(3) {
659 if &data[i..i + 4] == b"\r\n\r\n" {
660 return Some(i + 4);
661 }
662 }
663 None
664}
665
666fn parse_content_length(headers: &[u8]) -> Option<usize> {
668 let headers_str = std::str::from_utf8(headers).ok()?;
669 for line in headers_str.lines() {
670 if line.to_lowercase().starts_with("content-length:") {
671 let value = line.split(':').nth(1)?.trim();
672 return value.parse().ok();
673 }
674 }
675 None
676}