Skip to main content

atlas_rs/dstack/
verifier.rs

1//! DstackTDXVerifier implementation.
2
3use std::collections::{BTreeMap, HashMap};
4use std::sync::{Arc, RwLock};
5
6use dcap_qvl::collateral::get_collateral;
7use dcap_qvl::quote::Quote;
8use dcap_qvl::verify::{verify, VerifiedReport};
9use dcap_qvl::QuoteCollateralV3;
10use dstack_sdk_types::dstack::{EventLog, GetQuoteResponse};
11use log::{debug, warn};
12use sha2::{Digest, Sha256, Sha512};
13
14use crate::dstack::compose_hash::get_compose_hash;
15use crate::dstack::config::DstackTDXVerifierConfig;
16use crate::error::AtlsVerificationError;
17use crate::verifier::{AsyncByteStream, AsyncReadExt, AsyncWriteExt, AtlsVerifier, Report};
18
19pub use crate::dstack::config::DstackTDXVerifierBuilder;
20
21/// Cache key for collateral: (pccs_url, fmspc, ca)
22type CollateralCacheKey = (String, String, &'static str);
23
24/// Cached collateral with timestamp for TTL expiration.
25#[derive(Clone)]
26struct CachedCollateral {
27    collateral: QuoteCollateralV3,
28    cached_at_secs: u64,
29}
30
31/// Default collateral cache TTL: 8 hours (in seconds).
32const COLLATERAL_CACHE_TTL_SECS: u64 = 8 * 3600;
33
34/// Response from the /tdx_quote endpoint.
35#[derive(Debug, serde::Deserialize)]
36struct QuoteEndpointResponse {
37    quote: GetQuoteResponse,
38}
39
40/// DstackTDXVerifier performs TDX attestation verification for dstack deployments.
41///
42/// This verifier implements the full verification flow:
43/// 1. Fetch quote from remote server
44/// 2. Verify DCAP quote using Intel PCS
45/// 3. Verify certificate binding to event log
46/// 4. Verify RTMR replay
47/// 5. Verify bootchain measurements (MRTD, RTMR0-2)
48/// 6. Verify app compose hash
49/// 7. Verify OS image hash
50pub struct DstackTDXVerifier {
51    config: DstackTDXVerifierConfig,
52    /// Cached collateral keyed by (pccs_url, fmspc, ca) with TTL expiration.
53    cached_collateral: Arc<RwLock<HashMap<CollateralCacheKey, CachedCollateral>>>,
54}
55
56impl DstackTDXVerifier {
57    /// Create a new DstackTDXVerifier with the given configuration.
58    pub fn new(config: DstackTDXVerifierConfig) -> Result<Self, AtlsVerificationError> {
59        // Validation: bootchain and os_image_hash must be provided together
60        if !config.disable_runtime_verification {
61            if config.expected_bootchain.is_none() || config.os_image_hash.is_none() {
62                return Err(AtlsVerificationError::Configuration(
63                    "expected_bootchain and os_image_hash must be provided together".into(),
64                ));
65            }
66            if config.app_compose.is_none() {
67                return Err(AtlsVerificationError::Configuration(
68                    "app_compose must be provided".into(),
69                ));
70            }
71        }
72        Ok(Self {
73            config,
74            cached_collateral: Arc::new(RwLock::new(HashMap::new())),
75        })
76    }
77
78    /// Create a new builder for DstackTDXVerifier.
79    pub fn builder() -> DstackTDXVerifierBuilder {
80        DstackTDXVerifierBuilder::new()
81    }
82
83    /// Verify quote using dcap-qvl directly.
84    async fn verify_quote(&self, quote: &[u8]) -> Result<VerifiedReport, AtlsVerificationError> {
85        let pccs_url = self.config.pccs_url.as_deref().unwrap_or_default();
86        let pccs_url = if pccs_url.is_empty() {
87            "https://api.trustedservices.intel.com"
88        } else {
89            pccs_url
90        };
91
92        // Parse quote to get cache key components (FMSPC and CA)
93        let parsed_quote = Quote::parse(quote)
94            .map_err(|e| AtlsVerificationError::Quote(format!("Failed to parse quote: {}", e)))?;
95        let fmspc = hex::encode_upper(
96            parsed_quote
97                .fmspc()
98                .map_err(|e| AtlsVerificationError::Quote(format!("Failed to get FMSPC: {}", e)))?,
99        );
100        let ca = parsed_quote
101            .ca()
102            .map_err(|e| AtlsVerificationError::Quote(format!("Failed to get CA: {}", e)))?;
103
104        let cache_key = (pccs_url.to_string(), fmspc.clone(), ca);
105
106        // Get current time - platform specific (needed for cache TTL and verification)
107        #[cfg(not(target_arch = "wasm32"))]
108        let now_secs = std::time::SystemTime::now()
109            .duration_since(std::time::UNIX_EPOCH)
110            .map_err(|e| {
111                AtlsVerificationError::Quote(format!("Failed to get current time: {}", e))
112            })?
113            .as_secs();
114
115        #[cfg(target_arch = "wasm32")]
116        let now_secs = (js_sys::Date::now() / 1000.0) as u64;
117
118        // Try to get collateral from cache (with TTL check)
119        let cached = if self.config.cache_collateral {
120            match self.cached_collateral.read() {
121                Ok(guard) => guard.get(&cache_key).and_then(|entry| {
122                    if now_secs.saturating_sub(entry.cached_at_secs) < COLLATERAL_CACHE_TTL_SECS {
123                        Some(entry.collateral.clone())
124                    } else {
125                        debug!("Cached collateral expired for FMSPC={}, CA={}", fmspc, ca);
126                        None
127                    }
128                }),
129                Err(_) => {
130                    warn!("Collateral cache lock poisoned, treating as cache miss");
131                    None
132                }
133            }
134        } else {
135            None
136        };
137
138        let collateral = match cached {
139            Some(c) => {
140                debug!(
141                    "Using cached collateral for PCCS={}, FMSPC={}, CA={}",
142                    pccs_url, fmspc, ca
143                );
144                c
145            }
146            None => {
147                debug!("Fetching collateral from {}", pccs_url);
148                let c = get_collateral(pccs_url, quote)
149                    .await
150                    .map_err(|e| {
151                        AtlsVerificationError::Quote(format!("Failed to get collateral: {}", e))
152                    })?;
153
154                // Cache if enabled
155                if self.config.cache_collateral {
156                    match self.cached_collateral.write() {
157                        Ok(mut guard) => {
158                            debug!("Caching collateral for FMSPC={}, CA={}", fmspc, ca);
159                            guard.insert(cache_key, CachedCollateral {
160                                collateral: c.clone(),
161                                cached_at_secs: now_secs,
162                            });
163                        }
164                        Err(_) => {
165                            warn!("Collateral cache lock poisoned, skipping cache write");
166                        }
167                    }
168                }
169                c
170            }
171        };
172
173        debug!("Collateral received, verifying DCAP quote");
174
175        // Verify the quote
176        let report = verify(quote, &collateral, now_secs)
177            .map_err(|e| AtlsVerificationError::Quote(format!("DCAP verification failed: {}", e)))?;
178
179        debug!("DCAP verification complete, TCB status: {}", report.status);
180
181        // Check TCB status
182        let tcb_allowed = self
183            .config
184            .allowed_tcb_status
185            .iter()
186            .any(|s| s == &report.status);
187
188        debug!(
189            "TCB status '{}' allowed: {}",
190            report.status, tcb_allowed
191        );
192
193        if !tcb_allowed {
194            return Err(AtlsVerificationError::TcbStatusNotAllowed {
195                status: report.status.clone(),
196                allowed: self.config.allowed_tcb_status.clone(),
197            });
198        }
199
200        Ok(report)
201    }
202
203    /// Verify bootchain measurements (MRTD, RTMR0-2) using the trusted verified report.
204    ///
205    /// Compares the cryptographically verified measurements from the report
206    /// against the expected bootchain configuration.
207    ///
208    /// Fails if `expected_bootchain` is not configured.
209    fn verify_bootchain(
210        &self,
211        verified_report: &VerifiedReport,
212    ) -> Result<(), AtlsVerificationError> {
213        let bootchain = self.config.expected_bootchain.as_ref().ok_or_else(|| {
214            AtlsVerificationError::Configuration("expected_bootchain is required".into())
215        })?;
216
217        // Get the trusted TD report from DCAP verification
218        let td_report = verified_report.report.as_td10().ok_or_else(|| {
219            AtlsVerificationError::TeeTypeMismatch(
220                "expected TDX report but got SGX enclave report".into(),
221            )
222        })?;
223
224        debug!("Verifying bootchain measurements against verified report");
225
226        // Check MRTD (convert from bytes to hex string)
227        let actual_mrtd = hex::encode(td_report.mr_td);
228        debug!("MRTD expected: {}", bootchain.mrtd);
229        debug!("MRTD actual:   {}", actual_mrtd);
230        let mrtd_match = actual_mrtd == bootchain.mrtd;
231        debug!("MRTD match: {}", mrtd_match);
232
233        if !mrtd_match {
234            return Err(AtlsVerificationError::BootchainMismatch {
235                field: "mrtd".into(),
236                expected: bootchain.mrtd.clone(),
237                actual: actual_mrtd,
238            });
239        }
240
241        // Check RTMR0-2 (convert from bytes to hex strings)
242        let actual_rtmrs = [
243            hex::encode(td_report.rt_mr0),
244            hex::encode(td_report.rt_mr1),
245            hex::encode(td_report.rt_mr2),
246        ];
247        let expected_rtmrs = [&bootchain.rtmr0, &bootchain.rtmr1, &bootchain.rtmr2];
248
249        for idx in 0..3usize {
250            debug!("RTMR{} expected: {}", idx, expected_rtmrs[idx]);
251            debug!("RTMR{} actual:   {}", idx, actual_rtmrs[idx]);
252            let rtmr_match = &actual_rtmrs[idx] == expected_rtmrs[idx];
253            debug!("RTMR{} match: {}", idx, rtmr_match);
254
255            if !rtmr_match {
256                return Err(AtlsVerificationError::BootchainMismatch {
257                    field: format!("rtmr{}", idx),
258                    expected: expected_rtmrs[idx].clone(),
259                    actual: actual_rtmrs[idx].clone(),
260                });
261            }
262        }
263
264        debug!("Bootchain verification successful");
265        Ok(())
266    }
267
268    /// Verify certificate is in event log (using dstack-sdk EventLog type).
269    ///
270    /// Returns Ok(true) if cert matches, Ok(false) if cert not found,
271    /// or Err if parsing fails.
272    fn verify_cert_in_eventlog(
273        &self,
274        cert_der: &[u8],
275        events: &[EventLog],
276    ) -> Result<bool, AtlsVerificationError> {
277        let cert_hash = hex::encode(Sha256::digest(cert_der));
278        debug!("Certificate hash: {}", cert_hash);
279
280        // Find last "New TLS Certificate" event
281        let cert_event = events
282            .iter()
283            .rfind(|e| e.event == "New TLS Certificate");
284
285        match cert_event {
286            Some(event) => {
287                // event_payload is hex-encoded, decode it to get the cert hash string
288                let decoded = hex::decode(&event.event_payload).map_err(|e| {
289                    AtlsVerificationError::EventLogParse(format!(
290                        "failed to hex-decode certificate event payload: {}",
291                        e
292                    ))
293                })?;
294
295                let eventlog_cert_hash = String::from_utf8(decoded).map_err(|e| {
296                    AtlsVerificationError::EventLogParse(format!(
297                        "certificate event payload is not valid UTF-8: {}",
298                        e
299                    ))
300                })?;
301
302                debug!("Certificate hash from event log: {}", eventlog_cert_hash);
303                let cert_match = eventlog_cert_hash == cert_hash;
304                debug!("Certificate hash match: {}", cert_match);
305                Ok(cert_match)
306            }
307            None => {
308                debug!("No 'New TLS Certificate' event found in event log");
309                Ok(false)
310            }
311        }
312    }
313
314    /// Verify app compose hash using the trusted event log.
315    ///
316    /// The event log integrity is guaranteed by RTMR replay verification against
317    /// the cryptographically verified report.
318    ///
319    /// Fails if `app_compose` is not configured.
320    fn verify_app_compose(&self, events: &[EventLog]) -> Result<(), AtlsVerificationError> {
321        let app_compose = self.config.app_compose.as_ref().ok_or_else(|| {
322            AtlsVerificationError::Configuration("app_compose is required".into())
323        })?;
324        let expected = get_compose_hash(app_compose).map_err(|e| {
325            AtlsVerificationError::Configuration(format!(
326                "Failed to serialize app_compose for hashing: {}",
327                e
328            ))
329        })?;
330
331        debug!("Verifying app compose hash against trusted event log");
332        debug!("App compose hash expected: {}", expected);
333
334        // Verify against event log (trusted after RTMR replay verification)
335        let event = events
336            .iter()
337            .find(|e| e.event == "compose-hash")
338            .ok_or_else(|| {
339                AtlsVerificationError::AppComposeHashMismatch {
340                    expected: expected.clone(),
341                    actual: "<not found in event log>".to_string(),
342                }
343            })?;
344
345        debug!("App compose hash from event log: {}", event.event_payload);
346        let eventlog_match = event.event_payload == expected;
347        debug!("App compose hash match: {}", eventlog_match);
348
349        if !eventlog_match {
350            return Err(AtlsVerificationError::AppComposeHashMismatch {
351                expected,
352                actual: event.event_payload.clone(),
353            });
354        }
355
356        debug!("App compose verification successful");
357        Ok(())
358    }
359
360    /// Verify OS image hash using the trusted event log.
361    ///
362    /// The event log integrity is guaranteed by RTMR replay verification against
363    /// the cryptographically verified report.
364    ///
365    /// Fails if `os_image_hash` is not configured.
366    fn verify_os_image_hash(&self, events: &[EventLog]) -> Result<(), AtlsVerificationError> {
367        let expected = self.config.os_image_hash.as_ref().ok_or_else(|| {
368            AtlsVerificationError::Configuration("os_image_hash is required".into())
369        })?;
370
371        debug!("Verifying OS image hash against trusted event log");
372        debug!("OS image hash expected: {}", expected);
373
374        // Verify against event log (trusted after RTMR replay verification)
375        let event = events
376            .iter()
377            .find(|e| e.event == "os-image-hash")
378            .ok_or_else(|| AtlsVerificationError::OsImageHashMismatch {
379                expected: expected.clone(),
380                actual: Some("<not found in event log>".to_string()),
381            })?;
382
383        debug!("OS image hash from event log: {}", event.event_payload);
384        let eventlog_match = &event.event_payload == expected;
385        debug!("OS image hash match: {}", eventlog_match);
386
387        if !eventlog_match {
388            return Err(AtlsVerificationError::OsImageHashMismatch {
389                expected: expected.clone(),
390                actual: Some(event.event_payload.clone()),
391            });
392        }
393
394        debug!("OS image hash verification successful");
395        Ok(())
396    }
397
398    /// Verify RTMR replay using dstack-sdk's built-in replay_rtmrs().
399    ///
400    /// Compares replayed RTMRs from the event log against the trusted values
401    /// from the cryptographically verified report.
402    fn verify_rtmr_replay(
403        &self,
404        quote_response: &GetQuoteResponse,
405        verified_report: &VerifiedReport,
406    ) -> Result<(), AtlsVerificationError> {
407        debug!("Verifying RTMR replay against verified report");
408
409        // Get the trusted TD report from DCAP verification
410        let td_report = verified_report.report.as_td10().ok_or_else(|| {
411            AtlsVerificationError::TeeTypeMismatch(
412                "expected TDX report but got SGX enclave report".into(),
413            )
414        })?;
415
416        // Use dstack-sdk-types' built-in replay_rtmrs()
417        let replayed: BTreeMap<u8, String> = quote_response
418            .replay_rtmrs()
419            .map_err(AtlsVerificationError::Other)?;
420
421        // Get trusted RTMRs from verified report (as hex strings)
422        let trusted_rtmrs = [
423            hex::encode(td_report.rt_mr0),
424            hex::encode(td_report.rt_mr1),
425            hex::encode(td_report.rt_mr2),
426            hex::encode(td_report.rt_mr3),
427        ];
428
429        for i in 0..4u8 {
430            let replayed_rtmr = replayed.get(&i).cloned().ok_or_else(|| {
431                AtlsVerificationError::Quote(format!(
432                    "RTMR{} missing from event log replay - malformed event log",
433                    i
434                ))
435            })?;
436            debug!(
437                "RTMR{} from verified report: {}",
438                i, trusted_rtmrs[i as usize]
439            );
440            debug!("RTMR{} replayed:             {}", i, replayed_rtmr);
441            let rtmr_match = replayed_rtmr == trusted_rtmrs[i as usize];
442            debug!("RTMR{} replay match: {}", i, rtmr_match);
443
444            if !rtmr_match {
445                return Err(AtlsVerificationError::RtmrMismatch {
446                    index: i,
447                    expected: trusted_rtmrs[i as usize].clone(),
448                    actual: replayed_rtmr,
449                });
450            }
451        }
452
453        debug!("RTMR replay verification successful");
454        Ok(())
455    }
456
457    /// Verify report data (nonce + session EKM) against the verified report.
458    ///
459    /// This prevents replay and relay attacks by ensuring the quote was generated specifically
460    /// for this verification request, within the current TLS session (identified by EKM).
461    fn verify_report_data(
462        &self,
463        nonce: &[u8; 32],
464        session_ekm: &[u8; 32],
465        verified_report: &VerifiedReport,
466    ) -> Result<(), AtlsVerificationError> {
467        debug!("Verifying report data against verified report");
468
469        // Compute report_data = SHA512(nonce || session_ekm)
470        let mut hasher = Sha512::new();
471        hasher.update(nonce);
472        hasher.update(session_ekm);
473        let report_data: [u8; 64] = hasher.finalize().into();
474
475        // Get the trusted TD report from DCAP verification
476        let td_report = verified_report.report.as_td10().ok_or_else(|| {
477            AtlsVerificationError::TeeTypeMismatch(
478                "expected TDX report but got SGX enclave report".into(),
479            )
480        })?;
481
482        let expected = hex::encode(report_data);
483        let actual = hex::encode(td_report.report_data);
484
485        debug!("Report data expected: {}", expected);
486        debug!("Report data actual:   {}", actual);
487
488        
489        if expected != actual {
490            return Err(AtlsVerificationError::ReportDataMismatch { expected, actual });
491        }
492
493        debug!("Report data verification successful");
494        Ok(())
495    }
496}
497
498impl AtlsVerifier for DstackTDXVerifier {
499    async fn verify<S>(
500        &self,
501        stream: &mut S,
502        peer_cert: &[u8],
503        session_ekm: &[u8],
504        hostname: &str,
505    ) -> Result<Report, AtlsVerificationError>
506    where
507        S: AsyncByteStream,
508    {
509        debug!("Starting DStack TDX verification for {}", hostname);
510
511        // 1. Generate nonce and get quote via HTTP POST to /tdx_quote
512        let mut nonce = [0u8; 32];
513        rand::Rng::fill(&mut rand::thread_rng(), &mut nonce);
514
515        // Get quote via HTTP POST to /tdx_quote
516        let quote_response = get_quote_over_http(stream, &nonce, hostname).await?;
517
518        // 2. Parse event log using dstack-sdk-types
519        debug!("Parsing event log");
520        let events = quote_response
521            .decode_event_log()
522            .map_err(|e| AtlsVerificationError::Other(e.into()))?;
523        debug!("Event log parsed, {} events found", events.len());
524
525        // 3. Verify certificate in event log
526        debug!("Verifying certificate in event log");
527        let cert_in_eventlog = self.verify_cert_in_eventlog(peer_cert, &events)?;
528        if !cert_in_eventlog {
529            return Err(AtlsVerificationError::CertificateNotInEventLog);
530        }
531
532        // 4. Verify DCAP quote using dcap-qvl directly
533        debug!("Decoding quote for DCAP verification");
534        let quote_bytes = quote_response
535            .decode_quote()
536            .map_err(|e| AtlsVerificationError::Other(anyhow::anyhow!("Failed to decode quote: {}", e)))?;
537        debug!("Quote decoded ({} bytes)", quote_bytes.len());
538
539        // Async quote verification - no blocking!
540        let verified_report = self.verify_quote(&quote_bytes).await?;
541
542        // 5. Verify report data
543        let session_ekm: &[u8; 32] = session_ekm.try_into().map_err(|_| {
544            AtlsVerificationError::Configuration(
545                "session_ekm must be exactly 32 bytes".into(),
546            )
547        })?;
548        self.verify_report_data(&nonce, session_ekm, &verified_report)?;
549
550        // 6. Verify RTMR replay against the verified report
551        self.verify_rtmr_replay(&quote_response, &verified_report)?;
552
553        // Skip remaining checks if runtime verification is disabled
554        if self.config.disable_runtime_verification {
555            debug!("Runtime verification disabled, skipping bootchain/app-compose/os-image checks");
556            return Ok(Report::Tdx(verified_report));
557        }
558
559        // 7. Verify bootchain (MRTD, RTMR0-2) against verified report
560        self.verify_bootchain(&verified_report)?;
561
562        // 8. Verify app compose hash against trusted event log
563        self.verify_app_compose(&events)?;
564
565        // 9. Verify OS image hash against trusted event log
566        self.verify_os_image_hash(&events)?;
567
568        debug!("DStack TDX verification complete");
569        Ok(Report::Tdx(verified_report))
570    }
571}
572
573/// Fetch quote over HTTP from /tdx_quote endpoint (async version).
574async fn get_quote_over_http<S>(
575    stream: &mut S,
576    nonce: &[u8; 32],
577    hostname: &str,
578) -> Result<GetQuoteResponse, AtlsVerificationError>
579where
580    S: AsyncByteStream,
581{
582    debug!("Sending POST /tdx_quote request to {}", hostname);
583
584    // Build HTTP POST request for the /tdx_quote endpoint with EKM binding
585    let body = serde_json::json!({
586        "nonce_hex": hex::encode(nonce)
587    });
588    let body_str = body.to_string();
589
590    let request = format!(
591        "POST /tdx_quote HTTP/1.1\r\n\
592         Host: {}\r\n\
593         Content-Type: application/json\r\n\
594         Content-Length: {}\r\n\
595         Connection: keep-alive\r\n\
596         \r\n\
597         {}",
598        hostname,
599        body_str.len(),
600        body_str
601    );
602
603    stream
604        .write_all(request.as_bytes())
605        .await
606        .map_err(|e| AtlsVerificationError::Io(e.to_string()))?;
607    stream
608        .flush()
609        .await
610        .map_err(|e| AtlsVerificationError::Io(e.to_string()))?;
611
612    // Read HTTP response
613    let mut response_buf = Vec::new();
614    let mut chunk = [0u8; 4096];
615
616    // Read until we have the complete response
617    loop {
618        let n = stream
619            .read(&mut chunk)
620            .await
621            .map_err(|e| AtlsVerificationError::Io(e.to_string()))?;
622        if n == 0 {
623            break;
624        }
625        response_buf.extend_from_slice(&chunk[..n]);
626
627        // Check if we have the complete response (look for end of body)
628        if let Some(body_start) = find_http_body_start(&response_buf) {
629            // Try to parse content-length header
630            if let Some(content_length) = parse_content_length(&response_buf[..body_start]) {
631                if response_buf.len() >= body_start + content_length {
632                    break;
633                }
634            }
635        }
636    }
637
638    debug!("Received quote response ({} bytes)", response_buf.len());
639
640    // Parse HTTP response
641    let body_start = find_http_body_start(&response_buf)
642        .ok_or_else(|| AtlsVerificationError::Io("Invalid HTTP response".into()))?;
643    let response_body = &response_buf[body_start..];
644
645    let response: QuoteEndpointResponse = serde_json::from_slice(response_body)
646        .map_err(|e| {
647            AtlsVerificationError::Quote(format!(
648                "Failed to parse /tdx_quote response: {}",
649                e
650            ))
651        })?;
652
653    Ok(response.quote)
654}
655
656/// Find the start of HTTP body (after \r\n\r\n).
657fn find_http_body_start(data: &[u8]) -> Option<usize> {
658    for i in 0..data.len().saturating_sub(3) {
659        if &data[i..i + 4] == b"\r\n\r\n" {
660            return Some(i + 4);
661        }
662    }
663    None
664}
665
666/// Parse Content-Length header from HTTP response.
667fn parse_content_length(headers: &[u8]) -> Option<usize> {
668    let headers_str = std::str::from_utf8(headers).ok()?;
669    for line in headers_str.lines() {
670        if line.to_lowercase().starts_with("content-length:") {
671            let value = line.split(':').nth(1)?.trim();
672            return value.parse().ok();
673        }
674    }
675    None
676}