atlas_rs/dstack/
policy.rs1use crate::dstack::{DstackTDXVerifier, DstackTDXVerifierBuilder};
4use crate::tdx::{ExpectedBootchain, TCB_STATUS_LIST};
5use crate::verifier::IntoVerifier;
6use crate::AtlsVerificationError;
7use serde::{Deserialize, Serialize};
8
9pub const DEFAULT_PCCS_URL: &str = "https://pccs.phala.network/tdx/certification/v4";
11
12fn default_pccs_url() -> Option<String> {
13 Some(DEFAULT_PCCS_URL.to_string())
14}
15
16fn default_allowed_tcb_status() -> Vec<String> {
17 vec!["UpToDate".to_string()]
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct DstackTdxPolicy {
23 #[serde(default, skip_serializing_if = "Option::is_none")]
25 pub expected_bootchain: Option<ExpectedBootchain>,
26
27 #[serde(default, skip_serializing_if = "Option::is_none")]
29 pub app_compose: Option<serde_json::Value>,
30
31 #[serde(default, skip_serializing_if = "Option::is_none")]
33 pub os_image_hash: Option<String>,
34
35 #[serde(default = "default_allowed_tcb_status")]
37 pub allowed_tcb_status: Vec<String>,
38
39 #[serde(default = "default_pccs_url", skip_serializing_if = "Option::is_none")]
42 pub pccs_url: Option<String>,
43
44 #[serde(default)]
46 pub cache_collateral: bool,
47
48 #[serde(default)]
54 pub disable_runtime_verification: bool,
55}
56
57impl Default for DstackTdxPolicy {
58 fn default() -> Self {
59 Self {
60 expected_bootchain: None,
61 app_compose: None,
62 os_image_hash: None,
63 allowed_tcb_status: default_allowed_tcb_status(),
64 pccs_url: default_pccs_url(),
65 cache_collateral: false,
66 disable_runtime_verification: false,
67 }
68 }
69}
70
71fn is_valid_hex(s: &str) -> bool {
73 !s.is_empty() && s.chars().all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase())
74}
75
76impl DstackTdxPolicy {
77 pub fn dev() -> Self {
82 Self {
83 disable_runtime_verification: true,
84 allowed_tcb_status: vec![
85 "UpToDate".into(),
86 "SWHardeningNeeded".into(),
87 "OutOfDate".into(),
88 ],
89 ..Default::default()
90 }
91 }
92
93 pub fn validate(&self) -> Result<(), AtlsVerificationError> {
100 for status in &self.allowed_tcb_status {
102 if !TCB_STATUS_LIST.contains(&status.as_str()) {
103 return Err(AtlsVerificationError::Configuration(format!(
104 "invalid TCB status '{}', valid values are: {:?}",
105 status, TCB_STATUS_LIST
106 )));
107 }
108 }
109
110 if let Some(ref hash) = self.os_image_hash {
112 if !is_valid_hex(hash) {
113 return Err(AtlsVerificationError::Configuration(
114 "os_image_hash must be a lowercase hex string".into(),
115 ));
116 }
117 }
118
119 if let Some(ref bootchain) = self.expected_bootchain {
121 if !is_valid_hex(&bootchain.mrtd) {
122 return Err(AtlsVerificationError::Configuration(
123 "expected_bootchain.mrtd must be a lowercase hex string".into(),
124 ));
125 }
126 if !is_valid_hex(&bootchain.rtmr0) {
127 return Err(AtlsVerificationError::Configuration(
128 "expected_bootchain.rtmr0 must be a lowercase hex string".into(),
129 ));
130 }
131 if !is_valid_hex(&bootchain.rtmr1) {
132 return Err(AtlsVerificationError::Configuration(
133 "expected_bootchain.rtmr1 must be a lowercase hex string".into(),
134 ));
135 }
136 if !is_valid_hex(&bootchain.rtmr2) {
137 return Err(AtlsVerificationError::Configuration(
138 "expected_bootchain.rtmr2 must be a lowercase hex string".into(),
139 ));
140 }
141 }
142
143 Ok(())
144 }
145}
146
147impl IntoVerifier for DstackTdxPolicy {
148 type Verifier = DstackTDXVerifier;
149
150 fn into_verifier(self) -> Result<DstackTDXVerifier, AtlsVerificationError> {
151 self.validate()?;
153
154 let mut builder = DstackTDXVerifierBuilder::new();
155
156 if self.disable_runtime_verification {
158 builder = builder.disable_runtime_verification();
159 }
160
161 if let Some(bootchain) = self.expected_bootchain {
163 builder = builder.expected_bootchain(bootchain);
164 }
165 if let Some(app_compose) = self.app_compose {
166 builder = builder.app_compose(app_compose);
167 }
168 if let Some(os_hash) = self.os_image_hash {
169 builder = builder.os_image_hash(os_hash);
170 }
171
172 builder = builder.allowed_tcb_status(self.allowed_tcb_status);
173
174 if let Some(pccs) = self.pccs_url {
175 builder = builder.pccs_url(pccs);
176 }
177
178 builder = builder.cache_collateral(self.cache_collateral);
179
180 builder.build()
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn test_dstack_tdx_policy_default() {
190 let policy = DstackTdxPolicy::default();
191 assert_eq!(policy.allowed_tcb_status, vec!["UpToDate"]);
192 assert!(policy.expected_bootchain.is_none());
193 assert!(!policy.disable_runtime_verification);
194 }
195
196 #[test]
197 fn test_dstack_tdx_policy_dev() {
198 let policy = DstackTdxPolicy::dev();
199 assert!(policy.allowed_tcb_status.contains(&"SWHardeningNeeded".to_string()));
200 assert!(policy.disable_runtime_verification);
201 }
202
203 #[test]
204 fn test_dstack_tdx_policy_json_roundtrip() {
205 let policy = DstackTdxPolicy {
206 allowed_tcb_status: vec!["UpToDate".into(), "SWHardeningNeeded".into()],
207 ..Default::default()
208 };
209
210 let json = serde_json::to_string(&policy).unwrap();
211 let parsed: DstackTdxPolicy = serde_json::from_str(&json).unwrap();
212
213 assert_eq!(parsed.allowed_tcb_status.len(), 2);
214 }
215
216 #[test]
217 fn test_default_policy_requires_all_fields() {
218 let policy = DstackTdxPolicy::default();
220 let result = policy.into_verifier();
221 assert!(result.is_err());
222 }
223
224 #[test]
225 fn test_dev_policy_builds_without_runtime_fields() {
226 let policy = DstackTdxPolicy::dev();
228 let result = policy.into_verifier();
229 assert!(result.is_ok());
230 }
231
232 #[test]
233 fn test_invalid_tcb_status_rejected() {
234 let policy = DstackTdxPolicy {
235 allowed_tcb_status: vec!["InvalidStatus".into()],
236 disable_runtime_verification: true,
237 ..Default::default()
238 };
239 let result = policy.validate();
240 assert!(result.is_err());
241 let err = result.unwrap_err().to_string();
242 assert!(err.contains("invalid TCB status"));
243 }
244
245 #[test]
246 fn test_invalid_hex_os_image_hash_rejected() {
247 let policy = DstackTdxPolicy {
248 os_image_hash: Some("not-valid-hex!".into()),
249 disable_runtime_verification: true,
250 ..Default::default()
251 };
252 let result = policy.validate();
253 assert!(result.is_err());
254 let err = result.unwrap_err().to_string();
255 assert!(err.contains("os_image_hash must be a lowercase hex string"));
256 }
257
258 #[test]
259 fn test_uppercase_hex_rejected() {
260 let policy = DstackTdxPolicy {
261 os_image_hash: Some("ABCD1234".into()),
262 disable_runtime_verification: true,
263 ..Default::default()
264 };
265 let result = policy.validate();
266 assert!(result.is_err());
267 }
268
269 #[test]
270 fn test_valid_hex_accepted() {
271 let policy = DstackTdxPolicy {
272 os_image_hash: Some("abcd1234".into()),
273 disable_runtime_verification: true,
274 ..Default::default()
275 };
276 let result = policy.validate();
277 assert!(result.is_ok());
278 }
279
280 #[test]
281 fn test_invalid_bootchain_hex_rejected() {
282 let policy = DstackTdxPolicy {
283 expected_bootchain: Some(ExpectedBootchain {
284 mrtd: "invalid_hex".into(),
285 rtmr0: "abc123".into(),
286 rtmr1: "def456".into(),
287 rtmr2: "789abc".into(),
288 }),
289 disable_runtime_verification: true,
290 ..Default::default()
291 };
292 let result = policy.validate();
293 assert!(result.is_err());
294 let err = result.unwrap_err().to_string();
295 assert!(err.contains("mrtd"));
296 }
297}