Skip to main content

atlas_rs/dstack/
policy.rs

1//! DStack-specific policy types.
2
3use crate::dstack::{DstackTDXVerifier, DstackTDXVerifierBuilder};
4use crate::tdx::{ExpectedBootchain, TCB_STATUS_LIST};
5use crate::verifier::IntoVerifier;
6use crate::AtlsVerificationError;
7use serde::{Deserialize, Serialize};
8
9/// Default PCCS URL for TDX collateral fetching.
10pub const DEFAULT_PCCS_URL: &str = "https://pccs.phala.network/tdx/certification/v4";
11
12fn default_pccs_url() -> Option<String> {
13    Some(DEFAULT_PCCS_URL.to_string())
14}
15
16fn default_allowed_tcb_status() -> Vec<String> {
17    vec!["UpToDate".to_string()]
18}
19
20/// Policy configuration for dstack TDX verification.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct DstackTdxPolicy {
23    /// Expected bootchain measurements (MRTD, RTMR0-2).
24    #[serde(default, skip_serializing_if = "Option::is_none")]
25    pub expected_bootchain: Option<ExpectedBootchain>,
26
27    /// Expected app compose configuration.
28    #[serde(default, skip_serializing_if = "Option::is_none")]
29    pub app_compose: Option<serde_json::Value>,
30
31    /// Expected OS image hash (SHA256).
32    #[serde(default, skip_serializing_if = "Option::is_none")]
33    pub os_image_hash: Option<String>,
34
35    /// Allowed TCB status values.
36    #[serde(default = "default_allowed_tcb_status")]
37    pub allowed_tcb_status: Vec<String>,
38
39    /// PCCS URL for collateral fetching.
40    /// Defaults to `https://pccs.phala.network/tdx/certification/v4`.
41    #[serde(default = "default_pccs_url", skip_serializing_if = "Option::is_none")]
42    pub pccs_url: Option<String>,
43
44    /// Cache collateral to avoid repeated fetches.
45    #[serde(default)]
46    pub cache_collateral: bool,
47
48    /// Disable runtime verification (NOT RECOMMENDED for production).
49    ///
50    /// When false (default), all runtime fields (expected_bootchain, app_compose,
51    /// os_image_hash) must be provided or verification will fail.
52    /// Set to true only for development/testing.
53    #[serde(default)]
54    pub disable_runtime_verification: bool,
55}
56
57impl Default for DstackTdxPolicy {
58    fn default() -> Self {
59        Self {
60            expected_bootchain: None,
61            app_compose: None,
62            os_image_hash: None,
63            allowed_tcb_status: default_allowed_tcb_status(),
64            pccs_url: default_pccs_url(),
65            cache_collateral: false,
66            disable_runtime_verification: false,
67        }
68    }
69}
70
71/// Check if a string is a valid lowercase hex string.
72fn is_valid_hex(s: &str) -> bool {
73    !s.is_empty() && s.chars().all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase())
74}
75
76impl DstackTdxPolicy {
77    /// Relaxed policy for development.
78    ///
79    /// Accepts common TCB statuses and disables runtime verification
80    /// (bootchain, app_compose, os_image_hash checks are skipped).
81    pub fn dev() -> Self {
82        Self {
83            disable_runtime_verification: true,
84            allowed_tcb_status: vec![
85                "UpToDate".into(),
86                "SWHardeningNeeded".into(),
87                "OutOfDate".into(),
88            ],
89            ..Default::default()
90        }
91    }
92
93    /// Validate the policy configuration.
94    ///
95    /// Checks that:
96    /// - `allowed_tcb_status` values are valid TCB status strings
97    /// - `os_image_hash` is a valid hex string (if provided)
98    /// - `expected_bootchain` fields are valid hex strings (if provided)
99    pub fn validate(&self) -> Result<(), AtlsVerificationError> {
100        // Validate TCB status values
101        for status in &self.allowed_tcb_status {
102            if !TCB_STATUS_LIST.contains(&status.as_str()) {
103                return Err(AtlsVerificationError::Configuration(format!(
104                    "invalid TCB status '{}', valid values are: {:?}",
105                    status, TCB_STATUS_LIST
106                )));
107            }
108        }
109
110        // Validate os_image_hash is hex
111        if let Some(ref hash) = self.os_image_hash {
112            if !is_valid_hex(hash) {
113                return Err(AtlsVerificationError::Configuration(
114                    "os_image_hash must be a lowercase hex string".into(),
115                ));
116            }
117        }
118
119        // Validate bootchain fields are hex
120        if let Some(ref bootchain) = self.expected_bootchain {
121            if !is_valid_hex(&bootchain.mrtd) {
122                return Err(AtlsVerificationError::Configuration(
123                    "expected_bootchain.mrtd must be a lowercase hex string".into(),
124                ));
125            }
126            if !is_valid_hex(&bootchain.rtmr0) {
127                return Err(AtlsVerificationError::Configuration(
128                    "expected_bootchain.rtmr0 must be a lowercase hex string".into(),
129                ));
130            }
131            if !is_valid_hex(&bootchain.rtmr1) {
132                return Err(AtlsVerificationError::Configuration(
133                    "expected_bootchain.rtmr1 must be a lowercase hex string".into(),
134                ));
135            }
136            if !is_valid_hex(&bootchain.rtmr2) {
137                return Err(AtlsVerificationError::Configuration(
138                    "expected_bootchain.rtmr2 must be a lowercase hex string".into(),
139                ));
140            }
141        }
142
143        Ok(())
144    }
145}
146
147impl IntoVerifier for DstackTdxPolicy {
148    type Verifier = DstackTDXVerifier;
149
150    fn into_verifier(self) -> Result<DstackTDXVerifier, AtlsVerificationError> {
151        // Validate configuration before building
152        self.validate()?;
153
154        let mut builder = DstackTDXVerifierBuilder::new();
155
156        // Only disable runtime verification if explicitly requested
157        if self.disable_runtime_verification {
158            builder = builder.disable_runtime_verification();
159        }
160
161        // Pass all fields through - validation happens in DstackTDXVerifier::new()
162        if let Some(bootchain) = self.expected_bootchain {
163            builder = builder.expected_bootchain(bootchain);
164        }
165        if let Some(app_compose) = self.app_compose {
166            builder = builder.app_compose(app_compose);
167        }
168        if let Some(os_hash) = self.os_image_hash {
169            builder = builder.os_image_hash(os_hash);
170        }
171
172        builder = builder.allowed_tcb_status(self.allowed_tcb_status);
173
174        if let Some(pccs) = self.pccs_url {
175            builder = builder.pccs_url(pccs);
176        }
177
178        builder = builder.cache_collateral(self.cache_collateral);
179
180        builder.build()
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn test_dstack_tdx_policy_default() {
190        let policy = DstackTdxPolicy::default();
191        assert_eq!(policy.allowed_tcb_status, vec!["UpToDate"]);
192        assert!(policy.expected_bootchain.is_none());
193        assert!(!policy.disable_runtime_verification);
194    }
195
196    #[test]
197    fn test_dstack_tdx_policy_dev() {
198        let policy = DstackTdxPolicy::dev();
199        assert!(policy.allowed_tcb_status.contains(&"SWHardeningNeeded".to_string()));
200        assert!(policy.disable_runtime_verification);
201    }
202
203    #[test]
204    fn test_dstack_tdx_policy_json_roundtrip() {
205        let policy = DstackTdxPolicy {
206            allowed_tcb_status: vec!["UpToDate".into(), "SWHardeningNeeded".into()],
207            ..Default::default()
208        };
209
210        let json = serde_json::to_string(&policy).unwrap();
211        let parsed: DstackTdxPolicy = serde_json::from_str(&json).unwrap();
212
213        assert_eq!(parsed.allowed_tcb_status.len(), 2);
214    }
215
216    #[test]
217    fn test_default_policy_requires_all_fields() {
218        // Default policy with no runtime fields should fail to build verifier
219        let policy = DstackTdxPolicy::default();
220        let result = policy.into_verifier();
221        assert!(result.is_err());
222    }
223
224    #[test]
225    fn test_dev_policy_builds_without_runtime_fields() {
226        // Dev policy explicitly disables runtime verification
227        let policy = DstackTdxPolicy::dev();
228        let result = policy.into_verifier();
229        assert!(result.is_ok());
230    }
231
232    #[test]
233    fn test_invalid_tcb_status_rejected() {
234        let policy = DstackTdxPolicy {
235            allowed_tcb_status: vec!["InvalidStatus".into()],
236            disable_runtime_verification: true,
237            ..Default::default()
238        };
239        let result = policy.validate();
240        assert!(result.is_err());
241        let err = result.unwrap_err().to_string();
242        assert!(err.contains("invalid TCB status"));
243    }
244
245    #[test]
246    fn test_invalid_hex_os_image_hash_rejected() {
247        let policy = DstackTdxPolicy {
248            os_image_hash: Some("not-valid-hex!".into()),
249            disable_runtime_verification: true,
250            ..Default::default()
251        };
252        let result = policy.validate();
253        assert!(result.is_err());
254        let err = result.unwrap_err().to_string();
255        assert!(err.contains("os_image_hash must be a lowercase hex string"));
256    }
257
258    #[test]
259    fn test_uppercase_hex_rejected() {
260        let policy = DstackTdxPolicy {
261            os_image_hash: Some("ABCD1234".into()),
262            disable_runtime_verification: true,
263            ..Default::default()
264        };
265        let result = policy.validate();
266        assert!(result.is_err());
267    }
268
269    #[test]
270    fn test_valid_hex_accepted() {
271        let policy = DstackTdxPolicy {
272            os_image_hash: Some("abcd1234".into()),
273            disable_runtime_verification: true,
274            ..Default::default()
275        };
276        let result = policy.validate();
277        assert!(result.is_ok());
278    }
279
280    #[test]
281    fn test_invalid_bootchain_hex_rejected() {
282        let policy = DstackTdxPolicy {
283            expected_bootchain: Some(ExpectedBootchain {
284                mrtd: "invalid_hex".into(),
285                rtmr0: "abc123".into(),
286                rtmr1: "def456".into(),
287                rtmr2: "789abc".into(),
288            }),
289            disable_runtime_verification: true,
290            ..Default::default()
291        };
292        let result = policy.validate();
293        assert!(result.is_err());
294        let err = result.unwrap_err().to_string();
295        assert!(err.contains("mrtd"));
296    }
297}