1use std::collections::BTreeSet;
19
20use crate::VerifyError;
21use crate::near::report::AttestationInfo;
22
23#[derive(Debug, thiserror::Error, PartialEq, Eq)]
24pub enum PolicyError {
25 #[error("DCAP policy requires at least one accepted workload id or image digest")]
26 EmptyPolicy,
27 #[error("DCAP policy requires at least one accepted dstack KMS root public key")]
28 EmptyKmsRootPolicy,
29 #[error("invalid dstack KMS root public key: {0}")]
30 InvalidKmsRootPublicKey(String),
31 #[error("DCAP policy requires at least one accepted base-measurement bundle (issue #567)")]
32 EmptyBaseMeasurementPolicy,
33 #[error("invalid base-measurement bundle: {0}")]
34 InvalidBaseMeasurement(String),
35}
36
37#[derive(Debug, Clone)]
41pub struct AciDcapVerifierPolicy {
42 accepted_workload_ids: BTreeSet<String>,
43 accepted_image_digests: BTreeSet<String>,
44 accepted_kms_root_public_keys: BTreeSet<String>,
45 accepted_base_measurements: BTreeSet<String>,
57 allowed_tcb_advisory_ids: BTreeSet<String>,
61}
62
63impl AciDcapVerifierPolicy {
64 pub fn new(
74 accepted_workload_ids: impl IntoIterator<Item = String>,
75 accepted_image_digests: impl IntoIterator<Item = String>,
76 accepted_kms_root_public_keys: impl IntoIterator<Item = String>,
77 accepted_base_measurements: impl IntoIterator<Item = String>,
78 ) -> Result<Self, PolicyError> {
79 let accepted_workload_ids = accepted_workload_ids
80 .into_iter()
81 .filter(|s| !s.is_empty())
82 .map(|s| s.to_lowercase())
83 .collect::<BTreeSet<_>>();
84 let accepted_image_digests = accepted_image_digests
85 .into_iter()
86 .filter(|s| !s.is_empty())
87 .map(|s| s.to_lowercase())
88 .collect::<BTreeSet<_>>();
89 let accepted_kms_root_public_keys = accepted_kms_root_public_keys
90 .into_iter()
91 .filter(|s| !s.is_empty())
92 .map(|key| canonical_ec_public_key(&key))
93 .collect::<Result<BTreeSet<_>, _>>()?;
94 let accepted_base_measurements = accepted_base_measurements
95 .into_iter()
96 .filter(|s| !s.is_empty())
97 .map(|m| canonical_base_measurements(&m))
98 .collect::<Result<BTreeSet<_>, _>>()?;
99 if accepted_workload_ids.is_empty() && accepted_image_digests.is_empty() {
100 return Err(PolicyError::EmptyPolicy);
101 }
102 if accepted_kms_root_public_keys.is_empty() {
103 return Err(PolicyError::EmptyKmsRootPolicy);
104 }
105 if accepted_base_measurements.is_empty() {
106 return Err(PolicyError::EmptyBaseMeasurementPolicy);
107 }
108 Ok(Self {
109 accepted_workload_ids,
110 accepted_image_digests,
111 accepted_kms_root_public_keys,
112 accepted_base_measurements,
113 allowed_tcb_advisory_ids: BTreeSet::new(),
116 })
117 }
118
119 #[must_use]
124 pub fn with_allowed_tcb_advisory_ids(mut self, ids: impl IntoIterator<Item = String>) -> Self {
125 self.allowed_tcb_advisory_ids = ids
126 .into_iter()
127 .map(|s| s.trim().to_uppercase())
128 .filter(|s| !s.is_empty())
129 .collect();
130 self
131 }
132
133 pub fn tcb_acceptable(&self, status: Option<&str>, advisory_ids: &[String]) -> bool {
141 match status {
142 Some("UpToDate") => true,
143 Some(_) => {
144 !advisory_ids.is_empty()
145 && advisory_ids.iter().all(|id| {
146 self.allowed_tcb_advisory_ids
147 .contains(&id.trim().to_uppercase())
148 })
149 }
150 None => false,
151 }
152 }
153
154 pub fn accepts(&self, workload_id: &str, image_digests: &[String]) -> bool {
157 self.accepted_workload_ids
158 .contains(&workload_id.to_lowercase())
159 || image_digests
160 .iter()
161 .any(|d| self.accepted_image_digests.contains(&d.to_lowercase()))
162 }
163
164 pub fn accepts_kms_root(&self, kms_root_public_key: &str) -> bool {
167 match canonical_ec_public_key(kms_root_public_key) {
168 Ok(k) => self.accepted_kms_root_public_keys.contains(&k),
169 Err(_) => false,
170 }
171 }
172
173 pub fn accepts_base_measurements(
181 &self,
182 mr_td: &[u8; 48],
183 rtmr0: &[u8; 48],
184 rtmr1: &[u8; 48],
185 rtmr2: &[u8; 48],
186 ) -> bool {
187 let bundle = base_measurement_bundle(mr_td, rtmr0, rtmr1, rtmr2);
188 self.accepted_base_measurements.contains(&bundle)
189 }
190}
191
192fn base_measurement_bundle(
195 mr_td: &[u8; 48],
196 rtmr0: &[u8; 48],
197 rtmr1: &[u8; 48],
198 rtmr2: &[u8; 48],
199) -> String {
200 let mut buf = [0u8; 192];
201 buf[..48].copy_from_slice(mr_td);
202 buf[48..96].copy_from_slice(rtmr0);
203 buf[96..144].copy_from_slice(rtmr1);
204 buf[144..192].copy_from_slice(rtmr2);
205 hex::encode(buf)
206}
207
208fn canonical_base_measurements(value: &str) -> Result<String, PolicyError> {
213 let bytes = hex::decode(value.trim())
214 .map_err(|e| PolicyError::InvalidBaseMeasurement(format!("not hex: {e}")))?;
215 if bytes.len() != 192 {
216 return Err(PolicyError::InvalidBaseMeasurement(format!(
217 "expected 192 bytes (MRTD‖RTMR0‖RTMR1‖RTMR2, 4×48), got {}",
218 bytes.len()
219 )));
220 }
221 Ok(hex::encode(bytes))
222}
223
224#[derive(Debug, Clone, PartialEq, Eq)]
227pub struct ModelIdentity {
228 pub workload_id: String,
229 pub image_digests: Vec<String>,
230 pub kms_root_public_key: String,
231}
232
233#[derive(serde::Deserialize)]
234struct KeyProviderInfo {
235 id: String,
236}
237
238pub fn model_identity(info: &AttestationInfo) -> Result<ModelIdentity, VerifyError> {
244 let kpi: KeyProviderInfo =
245 serde_json::from_str(&info.key_provider_info).map_err(|e| VerifyError::Malformed {
246 what: "key_provider_info",
247 detail: e.to_string(),
248 })?;
249 Ok(ModelIdentity {
250 workload_id: info.app_id.clone(),
251 image_digests: vec![info.os_image_hash.clone(), info.compose_hash.clone()],
252 kms_root_public_key: kpi.id,
253 })
254}
255
256fn canonical_ec_public_key(public_key_hex: &str) -> Result<String, PolicyError> {
260 let bytes = hex::decode(public_key_hex.trim())
261 .map_err(|e| PolicyError::InvalidKmsRootPublicKey(format!("not hex: {e}")))?;
262 let point = sec1_point(&bytes).ok_or_else(|| {
263 PolicyError::InvalidKmsRootPublicKey(
264 "expected a SEC1 EC point or a DER SubjectPublicKeyInfo".to_string(),
265 )
266 })?;
267 Ok(hex::encode(point))
268}
269
270fn sec1_point(bytes: &[u8]) -> Option<Vec<u8>> {
277 if is_sec1_point(bytes) {
278 return Some(bytes.to_vec());
279 }
280 let point = spki_ec_point(bytes)?;
281 is_sec1_point(&point).then_some(point)
282}
283
284fn is_sec1_point(b: &[u8]) -> bool {
287 (b.len() == 65 && b[0] == 0x04) || (b.len() == 33 && matches!(b[0], 0x02 | 0x03))
288}
289
290const OID_EC_PUBLIC_KEY: &[u8] = &[0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01];
292
293fn spki_ec_point(der: &[u8]) -> Option<Vec<u8>> {
297 let (tag, spki, _) = der_tlv(der)?;
298 if tag != 0x30 {
299 return None; }
301 let (alg_tag, alg, after_alg) = der_tlv(spki)?;
302 if alg_tag != 0x30 {
303 return None; }
305 let (oid_tag, oid, _) = der_tlv(alg)?;
306 if oid_tag != 0x06 || oid != OID_EC_PUBLIC_KEY {
307 return None; }
309 let (bit_tag, bit_string, _) = der_tlv(after_alg)?;
310 if bit_tag != 0x03 {
311 return None; }
313 let (&unused_bits, point) = bit_string.split_first()?;
314 if unused_bits != 0 {
315 return None;
316 }
317 Some(point.to_vec())
318}
319
320fn der_tlv(input: &[u8]) -> Option<(u8, &[u8], &[u8])> {
323 let (&tag, rest) = input.split_first()?;
324 let (&len0, rest) = rest.split_first()?;
325 let (len, rest) = if len0 < 0x80 {
326 (len0 as usize, rest)
327 } else {
328 let n = (len0 & 0x7f) as usize;
329 if n == 0 || n > 4 || rest.len() < n {
330 return None;
331 }
332 let mut len = 0usize;
333 for &b in &rest[..n] {
334 len = (len << 8) | b as usize;
335 }
336 (len, &rest[n..])
337 };
338 if rest.len() < len {
339 return None;
340 }
341 Some((tag, &rest[..len], &rest[len..]))
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use crate::near::report::AttestationReport;
348
349 const FIXTURE: &str = include_str!("../../tests/fixtures/near_report.json");
350 const APP_ID: &str = "2c0a0c96cb6dbd659bf1446e2f3fce58172ff91b";
351 const COMPOSE_HASH: &str = "c445f29994165e94e85bdfc4824f4bcba89b0a883f45e7912f1bfd7c2634a698";
352 const OS_IMAGE_HASH: &str = "9b69bb1698bacbb6985409a2c272bcb892e09cdcea63d5399c6768b67d3ff677";
353 const KMS_ROOT_DER_SPKI: &str = "3059301306072a8648ce3d020106082a8648ce3d03010703420004228f800590a10442cba9d0e6adb2fa9f195eea9e75e23dd35990d52b59dda2415a63674c38adebde4ffd4d4b265bf818985933820c8053cee3ce29b5fb0fbcbc";
354
355 fn fixture_info() -> AttestationInfo {
356 let r: AttestationReport = serde_json::from_str(FIXTURE).unwrap();
357 r.model_attestations[0].info.clone()
358 }
359
360 #[test]
361 fn constructor_refuses_without_a_workload_or_image_pin() {
362 let err = AciDcapVerifierPolicy::new(
363 [],
364 [],
365 [KMS_ROOT_DER_SPKI.to_string()],
366 [fixture_base_mrs()],
367 )
368 .unwrap_err();
369 assert_eq!(err, PolicyError::EmptyPolicy);
370 }
371
372 #[test]
373 fn constructor_refuses_without_a_kms_root_pin() {
374 let err = AciDcapVerifierPolicy::new([APP_ID.to_string()], [], [], [fixture_base_mrs()])
375 .unwrap_err();
376 assert_eq!(err, PolicyError::EmptyKmsRootPolicy);
377 }
378
379 #[test]
380 fn constructor_rejects_an_unparseable_kms_root() {
381 let err = AciDcapVerifierPolicy::new(
382 [APP_ID.to_string()],
383 [],
384 ["nothex!!".to_string()],
385 [fixture_base_mrs()],
386 )
387 .unwrap_err();
388 assert!(matches!(err, PolicyError::InvalidKmsRootPublicKey(_)));
389 }
390
391 #[test]
392 fn model_identity_maps_the_info_block() {
393 let id = model_identity(&fixture_info()).expect("identity");
394 assert_eq!(id.workload_id, APP_ID);
395 assert!(id.image_digests.contains(&OS_IMAGE_HASH.to_string()));
396 assert!(id.image_digests.contains(&COMPOSE_HASH.to_string()));
397 assert_eq!(id.kms_root_public_key, KMS_ROOT_DER_SPKI);
398 }
399
400 #[test]
401 fn policy_accepts_the_legitimate_model_by_workload_id() {
402 let policy = AciDcapVerifierPolicy::new(
403 [APP_ID.to_string()],
404 [],
405 [KMS_ROOT_DER_SPKI.to_string()],
406 [fixture_base_mrs()],
407 )
408 .unwrap();
409 let id = model_identity(&fixture_info()).unwrap();
410 assert!(policy.accepts(&id.workload_id, &id.image_digests));
411 assert!(policy.accepts_kms_root(&id.kms_root_public_key));
412 }
413
414 #[test]
415 fn policy_accepts_by_image_digest_alone() {
416 let policy = AciDcapVerifierPolicy::new(
417 [],
418 [COMPOSE_HASH.to_string()],
419 [KMS_ROOT_DER_SPKI.to_string()],
420 [fixture_base_mrs()],
421 )
422 .unwrap();
423 let id = model_identity(&fixture_info()).unwrap();
424 assert!(policy.accepts(&id.workload_id, &id.image_digests));
425 }
426
427 #[test]
428 fn policy_rejects_a_genuine_tee_running_a_different_model() {
429 let policy = AciDcapVerifierPolicy::new(
431 ["some-other-workload".to_string()],
432 ["deadbeef".to_string()],
433 [KMS_ROOT_DER_SPKI.to_string()],
434 [fixture_base_mrs()],
435 )
436 .unwrap();
437 let id = model_identity(&fixture_info()).unwrap();
438 assert!(!policy.accepts(&id.workload_id, &id.image_digests));
439 }
440
441 #[test]
442 fn kms_root_matches_whether_pinned_as_der_spki_or_raw_point() {
443 let raw_point = &KMS_ROOT_DER_SPKI[KMS_ROOT_DER_SPKI.len() - 130..];
445 let policy = AciDcapVerifierPolicy::new(
446 [APP_ID.to_string()],
447 [],
448 [raw_point.to_string()],
449 [fixture_base_mrs()],
450 )
451 .unwrap();
452 assert!(policy.accepts_kms_root(KMS_ROOT_DER_SPKI));
454 }
455
456 #[test]
457 fn rejects_a_crafted_der_blob_whose_tail_spoofs_a_pinned_point() {
458 let raw_point = &KMS_ROOT_DER_SPKI[KMS_ROOT_DER_SPKI.len() - 130..];
463 let crafted = format!("30430441{raw_point}");
464 let policy = AciDcapVerifierPolicy::new(
465 [APP_ID.to_string()],
466 [],
467 [KMS_ROOT_DER_SPKI.to_string()],
468 [fixture_base_mrs()],
469 )
470 .unwrap();
471 assert!(!policy.accepts_kms_root(&crafted));
472 }
473
474 #[test]
475 fn policy_rejects_an_unpinned_kms_root() {
476 let policy = AciDcapVerifierPolicy::new(
477 [APP_ID.to_string()],
478 [],
479 [KMS_ROOT_DER_SPKI.to_string()],
480 [fixture_base_mrs()],
481 )
482 .unwrap();
483 let mut other = KMS_ROOT_DER_SPKI.to_string();
485 other.replace_range(other.len() - 2.., "ff");
486 assert!(!policy.accepts_kms_root(&other));
487 }
488
489 fn tcb_policy(allowed: &[&str]) -> AciDcapVerifierPolicy {
490 AciDcapVerifierPolicy::new(
491 [APP_ID.to_string()],
492 [],
493 [KMS_ROOT_DER_SPKI.to_string()],
494 [fixture_base_mrs()],
495 )
496 .unwrap()
497 .with_allowed_tcb_advisory_ids(allowed.iter().map(|s| s.to_string()))
498 }
499
500 #[test]
501 fn tcb_floor_accepts_up_to_date_only_by_default() {
502 let p = tcb_policy(&[]);
503 assert!(p.tcb_acceptable(Some("UpToDate"), &[]));
504 assert!(!p.tcb_acceptable(Some("OutOfDate"), &["INTEL-SA-00615".to_string()]));
505 assert!(!p.tcb_acceptable(Some("ConfigurationNeeded"), &[]));
506 assert!(!p.tcb_acceptable(Some("SWHardeningNeeded"), &[]));
507 assert!(!p.tcb_acceptable(None, &[]));
509 }
510
511 #[test]
512 fn tcb_floor_allows_a_fully_allowlisted_non_current_status() {
513 let p = tcb_policy(&["INTEL-SA-00615"]);
514 assert!(p.tcb_acceptable(Some("OutOfDate"), &["INTEL-SA-00615".to_string()]));
515 assert!(p.tcb_acceptable(Some("OutOfDate"), &["intel-sa-00615".to_string()]));
517 }
518
519 #[test]
520 fn tcb_floor_rejects_when_any_advisory_is_unlisted() {
521 let p = tcb_policy(&["INTEL-SA-00615"]);
522 assert!(!p.tcb_acceptable(
523 Some("OutOfDate"),
524 &["INTEL-SA-00615".to_string(), "INTEL-SA-00999".to_string()]
525 ));
526 }
527
528 #[test]
529 fn tcb_floor_never_accepts_a_non_current_status_with_no_named_advisory() {
530 let p = tcb_policy(&["INTEL-SA-00615"]);
533 assert!(!p.tcb_acceptable(Some("ConfigurationNeeded"), &[]));
534 }
535
536 #[test]
537 fn tcb_floor_trims_advisory_ids_on_both_sides() {
538 let p = tcb_policy(&[" "]);
541 assert!(!p.tcb_acceptable(Some("OutOfDate"), &["".to_string()]));
542 let p2 = tcb_policy(&["INTEL-SA-00615"]);
544 assert!(p2.tcb_acceptable(Some("OutOfDate"), &[" INTEL-SA-00615 ".to_string()]));
545 }
546
547 #[test]
548 fn tcb_floor_treats_revoked_as_any_non_current_status() {
549 let p = tcb_policy(&["INTEL-SA-00615"]);
557 assert!(!p.tcb_acceptable(Some("Revoked"), &[]));
558 }
559
560 use crate::near::tdx::parse_tdx_quote;
563
564 fn fixture_quote() -> Vec<u8> {
565 let r: AttestationReport = serde_json::from_str(FIXTURE).unwrap();
566 hex::decode(&r.model_attestations[0].intel_quote).unwrap()
567 }
568
569 fn fixture_base_mrs() -> String {
572 let m = parse_tdx_quote(&fixture_quote()).unwrap();
573 format!(
574 "{}{}{}{}",
575 hex::encode(m.mr_td),
576 hex::encode(m.rtmr0),
577 hex::encode(m.rtmr1),
578 hex::encode(m.rtmr2),
579 )
580 }
581
582 fn base_policy() -> AciDcapVerifierPolicy {
583 AciDcapVerifierPolicy::new(
584 [APP_ID.to_string()],
585 [],
586 [KMS_ROOT_DER_SPKI.to_string()],
587 [fixture_base_mrs()],
588 )
589 .unwrap()
590 }
591
592 #[test]
593 fn base_pin_accepts_the_genuine_bundle_and_rejects_a_forged_base() {
594 let policy = base_policy();
600 let m = parse_tdx_quote(&fixture_quote()).unwrap();
601 assert!(policy.accepts_base_measurements(&m.mr_td, &m.rtmr0, &m.rtmr1, &m.rtmr2));
603 let mut forged_mr_td = m.mr_td;
606 forged_mr_td[0] ^= 0xff;
607 assert!(!policy.accepts_base_measurements(&forged_mr_td, &m.rtmr0, &m.rtmr1, &m.rtmr2));
608 let mut forged_rtmr1 = m.rtmr1;
610 forged_rtmr1[47] ^= 0x01;
611 assert!(!policy.accepts_base_measurements(&m.mr_td, &m.rtmr0, &forged_rtmr1, &m.rtmr2));
612 }
613
614 #[test]
615 fn base_pin_normalizes_hex_casing() {
616 let policy = AciDcapVerifierPolicy::new(
619 [APP_ID.to_string()],
620 [],
621 [KMS_ROOT_DER_SPKI.to_string()],
622 [fixture_base_mrs().to_uppercase()],
623 )
624 .unwrap();
625 let m = parse_tdx_quote(&fixture_quote()).unwrap();
626 assert!(policy.accepts_base_measurements(&m.mr_td, &m.rtmr0, &m.rtmr1, &m.rtmr2));
627 }
628
629 #[test]
630 fn constructor_refuses_without_a_base_measurement_pin() {
631 let err = AciDcapVerifierPolicy::new(
632 [APP_ID.to_string()],
633 [],
634 [KMS_ROOT_DER_SPKI.to_string()],
635 [],
636 )
637 .unwrap_err();
638 assert_eq!(err, PolicyError::EmptyBaseMeasurementPolicy);
639 }
640
641 #[test]
642 fn constructor_rejects_an_unparseable_base_measurement() {
643 let err = AciDcapVerifierPolicy::new(
645 [APP_ID.to_string()],
646 [],
647 [KMS_ROOT_DER_SPKI.to_string()],
648 ["nothex!!".to_string()],
649 )
650 .unwrap_err();
651 assert!(matches!(err, PolicyError::InvalidBaseMeasurement(_)));
652 }
653
654 #[test]
655 fn constructor_rejects_a_base_measurement_of_the_wrong_length() {
656 let err = AciDcapVerifierPolicy::new(
658 [APP_ID.to_string()],
659 [],
660 [KMS_ROOT_DER_SPKI.to_string()],
661 ["abcd".to_string()],
662 )
663 .unwrap_err();
664 assert!(matches!(err, PolicyError::InvalidBaseMeasurement(_)));
665 }
666}