bitrouter_attestation/near/
nvidia.rs1use std::collections::HashMap;
18
19use jsonwebtoken::jwk::JwkSet;
20use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
21
22use crate::VerifyError;
23
24pub const NRAS_GPU_URL: &str = "https://nras.attestation.nvidia.com/v3/attest/gpu";
26
27pub const NVIDIA_NRAS_JWKS_URL: &str = "https://nras.attestation.nvidia.com/.well-known/jwks.json";
29
30pub struct NvidiaEatKey(KeySource);
37
38enum KeySource {
39 Single(DecodingKey),
41 Jwks(HashMap<String, DecodingKey>),
43 Unconfigured,
45}
46
47impl NvidiaEatKey {
48 pub fn from_ec_pem(pem: &[u8]) -> Result<Self, VerifyError> {
52 DecodingKey::from_ec_pem(pem)
53 .map(|k| Self(KeySource::Single(k)))
54 .map_err(|e| VerifyError::Malformed {
55 what: "nvidia eat key",
56 detail: e.to_string(),
57 })
58 }
59
60 pub fn from_jwks_json(bytes: &[u8]) -> Result<Self, VerifyError> {
63 let set: JwkSet = serde_json::from_slice(bytes).map_err(|e| VerifyError::Malformed {
64 what: "nvidia jwks",
65 detail: e.to_string(),
66 })?;
67 let mut map = HashMap::new();
68 for jwk in &set.keys {
69 if let (Some(kid), Ok(key)) = (jwk.common.key_id.clone(), DecodingKey::from_jwk(jwk)) {
70 map.insert(kid, key);
71 }
72 }
73 if map.is_empty() {
74 return Err(VerifyError::Malformed {
75 what: "nvidia jwks",
76 detail: "no usable keys with a kid".to_string(),
77 });
78 }
79 Ok(Self(KeySource::Jwks(map)))
80 }
81
82 pub async fn fetch_jwks(url: &str) -> Result<Self, VerifyError> {
85 let body = reqwest::Client::new()
86 .get(url)
87 .send()
88 .await
89 .map_err(|e| VerifyError::Transport {
90 what: "nvidia jwks",
91 source: Box::new(e),
92 })?
93 .error_for_status()
94 .map_err(|e| VerifyError::Transport {
95 what: "nvidia jwks",
96 source: Box::new(e),
97 })?
98 .bytes()
99 .await
100 .map_err(|e| VerifyError::Transport {
101 what: "nvidia jwks",
102 source: Box::new(e),
103 })?;
104 Self::from_jwks_json(&body)
105 }
106
107 pub fn unconfigured() -> Self {
110 Self(KeySource::Unconfigured)
111 }
112
113 pub(crate) fn resolve(&self, kid: Option<&str>) -> Option<&DecodingKey> {
116 match &self.0 {
117 KeySource::Single(key) => Some(key),
118 KeySource::Jwks(map) => kid.and_then(|k| map.get(k)),
119 KeySource::Unconfigured => None,
120 }
121 }
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub struct NrasVerdict {
127 pub signature_verified: bool,
129 pub overall_pass: bool,
131 pub nonce_matches: bool,
133}
134
135impl NrasVerdict {
136 pub fn failed() -> Self {
139 Self {
140 signature_verified: false,
141 overall_pass: false,
142 nonce_matches: false,
143 }
144 }
145
146 pub fn passed(&self) -> bool {
148 self.signature_verified && self.overall_pass && self.nonce_matches
149 }
150}
151
152#[derive(serde::Deserialize)]
153struct EatClaims {
154 #[serde(rename = "x-nvidia-overall-att-result")]
155 overall_att_result: Option<bool>,
156 eat_nonce: Option<String>,
157}
158
159const NRAS_ALGORITHMS: &[Algorithm] = &[
161 Algorithm::ES384,
162 Algorithm::ES256,
163 Algorithm::RS256,
164 Algorithm::PS256,
165];
166
167fn platform_jwt(response_body: &[u8]) -> Result<String, VerifyError> {
171 let v: serde_json::Value =
172 serde_json::from_slice(response_body).map_err(|e| VerifyError::Malformed {
173 what: "nras response",
174 detail: e.to_string(),
175 })?;
176 let entry = v
177 .get(0)
178 .and_then(serde_json::Value::as_array)
179 .ok_or(VerifyError::Malformed {
180 what: "nras response",
181 detail: "expected a non-empty token array".to_string(),
182 })?;
183 if entry.first().and_then(serde_json::Value::as_str) != Some("JWT") {
184 return Err(VerifyError::Malformed {
185 what: "nras response",
186 detail: "platform token is not in [\"JWT\", …] form".to_string(),
187 });
188 }
189 entry
190 .get(1)
191 .and_then(serde_json::Value::as_str)
192 .map(str::to_string)
193 .ok_or(VerifyError::Malformed {
194 what: "nras response",
195 detail: "platform token missing the JWT string".to_string(),
196 })
197}
198
199pub fn check_nras_eat(response_body: &[u8], nonce: &str, key: &NvidiaEatKey) -> NrasVerdict {
207 let Ok(jwt) = platform_jwt(response_body) else {
208 return NrasVerdict::failed();
209 };
210
211 let Ok(header) = decode_header(&jwt) else {
215 return NrasVerdict::failed();
216 };
217 if !NRAS_ALGORITHMS.contains(&header.alg) {
218 return NrasVerdict::failed();
219 }
220 let Some(decoding_key) = key.resolve(header.kid.as_deref()) else {
223 return NrasVerdict::failed();
224 };
225 let mut validation = Validation::new(header.alg);
226 validation.algorithms = vec![header.alg];
227 validation.validate_exp = false;
228 validation.validate_aud = false;
229 validation.required_spec_claims.clear();
230
231 let Ok(token) = decode::<EatClaims>(&jwt, decoding_key, &validation) else {
232 return NrasVerdict::failed();
234 };
235
236 let overall_pass = token.claims.overall_att_result == Some(true);
237 let nonce_matches = token
238 .claims
239 .eat_nonce
240 .as_deref()
241 .is_some_and(|n| n.eq_ignore_ascii_case(nonce));
242
243 NrasVerdict {
244 signature_verified: true,
245 overall_pass,
246 nonce_matches,
247 }
248}
249
250pub async fn post_nras(
255 http: &reqwest::Client,
256 nras_url: &str,
257 nvidia_payload: &str,
258) -> Result<Vec<u8>, VerifyError> {
259 let resp = http
260 .post(nras_url)
261 .header("accept", "application/json")
262 .header("content-type", "application/json")
263 .body(nvidia_payload.to_string())
264 .send()
265 .await
266 .map_err(|e| VerifyError::Transport {
267 what: "nras attestation",
268 source: Box::new(e),
269 })?
270 .error_for_status()
271 .map_err(|e| VerifyError::Transport {
272 what: "nras attestation",
273 source: Box::new(e),
274 })?;
275 resp.bytes()
276 .await
277 .map(|b| b.to_vec())
278 .map_err(|e| VerifyError::Transport {
279 what: "nras attestation",
280 source: Box::new(e),
281 })
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use jsonwebtoken::{EncodingKey, Header, encode};
288
289 const TEST_EC_PRIVATE_PKCS8_PEM: &str =
293 include_str!("../../tests/fixtures/nras_test_ec_private_pkcs8.pem");
294 const TEST_EC_PUBLIC_PEM: &str = include_str!("../../tests/fixtures/nras_test_ec_public.pem");
295
296 const NONCE: &str = "9a01356cb451dc2c3c0ce9a195245a0be984a3f73617f55f87913fc2f059cba7";
297
298 const TEST_JWKS: &str = include_str!("../../tests/fixtures/nras_test_jwks.json");
299
300 fn signing_key() -> EncodingKey {
301 EncodingKey::from_ec_pem(TEST_EC_PRIVATE_PKCS8_PEM.as_bytes()).expect("test priv key")
302 }
303
304 fn pinned_key() -> NvidiaEatKey {
306 NvidiaEatKey::from_ec_pem(TEST_EC_PUBLIC_PEM.as_bytes()).expect("test pub key")
307 }
308
309 fn nras_body_kid(overall: bool, eat_nonce: &str, kid: Option<&str>) -> Vec<u8> {
312 let claims = serde_json::json!({
313 "x-nvidia-overall-att-result": overall,
314 "eat_nonce": eat_nonce,
315 });
316 let mut header = Header::new(Algorithm::ES256);
317 header.kid = kid.map(str::to_string);
318 let jwt = encode(&header, &claims, &signing_key()).unwrap();
319 serde_json::to_vec(&serde_json::json!([["JWT", jwt], {}])).unwrap()
320 }
321
322 fn nras_body(overall: bool, eat_nonce: &str) -> Vec<u8> {
323 nras_body_kid(overall, eat_nonce, None)
324 }
325
326 #[test]
327 fn accepts_a_passing_signed_eat_with_matching_nonce() {
328 let body = nras_body(true, NONCE);
329 let verdict = check_nras_eat(&body, NONCE, &pinned_key());
330 assert!(verdict.passed());
331 assert!(verdict.signature_verified && verdict.overall_pass && verdict.nonce_matches);
332 }
333
334 #[test]
335 fn rejects_a_failing_result_claim() {
336 let body = nras_body(false, NONCE);
337 let verdict = check_nras_eat(&body, NONCE, &pinned_key());
338 assert!(verdict.signature_verified);
339 assert!(!verdict.overall_pass);
340 assert!(!verdict.passed());
341 }
342
343 #[test]
344 fn rejects_a_replayed_nonce() {
345 let body = nras_body(true, "00000000000000000000000000000000");
346 let verdict = check_nras_eat(&body, NONCE, &pinned_key());
347 assert!(!verdict.nonce_matches);
348 assert!(!verdict.passed());
349 }
350
351 #[test]
352 fn unconfigured_key_fails_closed() {
353 let body = nras_body(true, NONCE);
355 let verdict = check_nras_eat(&body, NONCE, &NvidiaEatKey::unconfigured());
356 assert!(!verdict.signature_verified);
357 assert!(!verdict.passed());
358 }
359
360 #[test]
361 fn jwks_resolves_the_signing_key_by_kid() {
362 let jwks = NvidiaEatKey::from_jwks_json(TEST_JWKS.as_bytes()).expect("jwks parses");
363
364 let ok = nras_body_kid(true, NONCE, Some("test-kid-1"));
366 assert!(check_nras_eat(&ok, NONCE, &jwks).passed());
367
368 let unknown = nras_body_kid(true, NONCE, Some("rotated-away-kid"));
370 assert!(!check_nras_eat(&unknown, NONCE, &jwks).signature_verified);
371 }
372
373 #[test]
374 fn rejects_a_malformed_response_body() {
375 let verdict = check_nras_eat(b"[\"not jwt shaped\"]", NONCE, &pinned_key());
376 assert_eq!(verdict, NrasVerdict::failed());
377 }
378}