use std::collections::HashMap;
use jsonwebtoken::jwk::JwkSet;
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
use crate::VerifyError;
pub const NRAS_GPU_URL: &str = "https://nras.attestation.nvidia.com/v3/attest/gpu";
pub const NVIDIA_NRAS_JWKS_URL: &str = "https://nras.attestation.nvidia.com/.well-known/jwks.json";
pub struct NvidiaEatKey(KeySource);
enum KeySource {
Single(DecodingKey),
Jwks(HashMap<String, DecodingKey>),
Unconfigured,
}
impl NvidiaEatKey {
pub fn from_ec_pem(pem: &[u8]) -> Result<Self, VerifyError> {
DecodingKey::from_ec_pem(pem)
.map(|k| Self(KeySource::Single(k)))
.map_err(|e| VerifyError::Malformed {
what: "nvidia eat key",
detail: e.to_string(),
})
}
pub fn from_jwks_json(bytes: &[u8]) -> Result<Self, VerifyError> {
let set: JwkSet = serde_json::from_slice(bytes).map_err(|e| VerifyError::Malformed {
what: "nvidia jwks",
detail: e.to_string(),
})?;
let mut map = HashMap::new();
for jwk in &set.keys {
if let (Some(kid), Ok(key)) = (jwk.common.key_id.clone(), DecodingKey::from_jwk(jwk)) {
map.insert(kid, key);
}
}
if map.is_empty() {
return Err(VerifyError::Malformed {
what: "nvidia jwks",
detail: "no usable keys with a kid".to_string(),
});
}
Ok(Self(KeySource::Jwks(map)))
}
pub async fn fetch_jwks(url: &str) -> Result<Self, VerifyError> {
let body = reqwest::Client::new()
.get(url)
.send()
.await
.map_err(|e| VerifyError::Transport {
what: "nvidia jwks",
source: Box::new(e),
})?
.error_for_status()
.map_err(|e| VerifyError::Transport {
what: "nvidia jwks",
source: Box::new(e),
})?
.bytes()
.await
.map_err(|e| VerifyError::Transport {
what: "nvidia jwks",
source: Box::new(e),
})?;
Self::from_jwks_json(&body)
}
pub fn unconfigured() -> Self {
Self(KeySource::Unconfigured)
}
pub(crate) fn resolve(&self, kid: Option<&str>) -> Option<&DecodingKey> {
match &self.0 {
KeySource::Single(key) => Some(key),
KeySource::Jwks(map) => kid.and_then(|k| map.get(k)),
KeySource::Unconfigured => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct NrasVerdict {
pub signature_verified: bool,
pub overall_pass: bool,
pub nonce_matches: bool,
}
impl NrasVerdict {
pub fn failed() -> Self {
Self {
signature_verified: false,
overall_pass: false,
nonce_matches: false,
}
}
pub fn passed(&self) -> bool {
self.signature_verified && self.overall_pass && self.nonce_matches
}
}
#[derive(serde::Deserialize)]
struct EatClaims {
#[serde(rename = "x-nvidia-overall-att-result")]
overall_att_result: Option<bool>,
eat_nonce: Option<String>,
}
const NRAS_ALGORITHMS: &[Algorithm] = &[
Algorithm::ES384,
Algorithm::ES256,
Algorithm::RS256,
Algorithm::PS256,
];
fn platform_jwt(response_body: &[u8]) -> Result<String, VerifyError> {
let v: serde_json::Value =
serde_json::from_slice(response_body).map_err(|e| VerifyError::Malformed {
what: "nras response",
detail: e.to_string(),
})?;
let entry = v
.get(0)
.and_then(serde_json::Value::as_array)
.ok_or(VerifyError::Malformed {
what: "nras response",
detail: "expected a non-empty token array".to_string(),
})?;
if entry.first().and_then(serde_json::Value::as_str) != Some("JWT") {
return Err(VerifyError::Malformed {
what: "nras response",
detail: "platform token is not in [\"JWT\", …] form".to_string(),
});
}
entry
.get(1)
.and_then(serde_json::Value::as_str)
.map(str::to_string)
.ok_or(VerifyError::Malformed {
what: "nras response",
detail: "platform token missing the JWT string".to_string(),
})
}
pub fn check_nras_eat(response_body: &[u8], nonce: &str, key: &NvidiaEatKey) -> NrasVerdict {
let Ok(jwt) = platform_jwt(response_body) else {
return NrasVerdict::failed();
};
let Ok(header) = decode_header(&jwt) else {
return NrasVerdict::failed();
};
if !NRAS_ALGORITHMS.contains(&header.alg) {
return NrasVerdict::failed();
}
let Some(decoding_key) = key.resolve(header.kid.as_deref()) else {
return NrasVerdict::failed();
};
let mut validation = Validation::new(header.alg);
validation.algorithms = vec![header.alg];
validation.validate_exp = false;
validation.validate_aud = false;
validation.required_spec_claims.clear();
let Ok(token) = decode::<EatClaims>(&jwt, decoding_key, &validation) else {
return NrasVerdict::failed();
};
let overall_pass = token.claims.overall_att_result == Some(true);
let nonce_matches = token
.claims
.eat_nonce
.as_deref()
.is_some_and(|n| n.eq_ignore_ascii_case(nonce));
NrasVerdict {
signature_verified: true,
overall_pass,
nonce_matches,
}
}
pub async fn post_nras(
http: &reqwest::Client,
nras_url: &str,
nvidia_payload: &str,
) -> Result<Vec<u8>, VerifyError> {
let resp = http
.post(nras_url)
.header("accept", "application/json")
.header("content-type", "application/json")
.body(nvidia_payload.to_string())
.send()
.await
.map_err(|e| VerifyError::Transport {
what: "nras attestation",
source: Box::new(e),
})?
.error_for_status()
.map_err(|e| VerifyError::Transport {
what: "nras attestation",
source: Box::new(e),
})?;
resp.bytes()
.await
.map(|b| b.to_vec())
.map_err(|e| VerifyError::Transport {
what: "nras attestation",
source: Box::new(e),
})
}
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{EncodingKey, Header, encode};
const TEST_EC_PRIVATE_PKCS8_PEM: &str =
include_str!("../../tests/fixtures/nras_test_ec_private_pkcs8.pem");
const TEST_EC_PUBLIC_PEM: &str = include_str!("../../tests/fixtures/nras_test_ec_public.pem");
const NONCE: &str = "9a01356cb451dc2c3c0ce9a195245a0be984a3f73617f55f87913fc2f059cba7";
const TEST_JWKS: &str = include_str!("../../tests/fixtures/nras_test_jwks.json");
fn signing_key() -> EncodingKey {
EncodingKey::from_ec_pem(TEST_EC_PRIVATE_PKCS8_PEM.as_bytes()).expect("test priv key")
}
fn pinned_key() -> NvidiaEatKey {
NvidiaEatKey::from_ec_pem(TEST_EC_PUBLIC_PEM.as_bytes()).expect("test pub key")
}
fn nras_body_kid(overall: bool, eat_nonce: &str, kid: Option<&str>) -> Vec<u8> {
let claims = serde_json::json!({
"x-nvidia-overall-att-result": overall,
"eat_nonce": eat_nonce,
});
let mut header = Header::new(Algorithm::ES256);
header.kid = kid.map(str::to_string);
let jwt = encode(&header, &claims, &signing_key()).unwrap();
serde_json::to_vec(&serde_json::json!([["JWT", jwt], {}])).unwrap()
}
fn nras_body(overall: bool, eat_nonce: &str) -> Vec<u8> {
nras_body_kid(overall, eat_nonce, None)
}
#[test]
fn accepts_a_passing_signed_eat_with_matching_nonce() {
let body = nras_body(true, NONCE);
let verdict = check_nras_eat(&body, NONCE, &pinned_key());
assert!(verdict.passed());
assert!(verdict.signature_verified && verdict.overall_pass && verdict.nonce_matches);
}
#[test]
fn rejects_a_failing_result_claim() {
let body = nras_body(false, NONCE);
let verdict = check_nras_eat(&body, NONCE, &pinned_key());
assert!(verdict.signature_verified);
assert!(!verdict.overall_pass);
assert!(!verdict.passed());
}
#[test]
fn rejects_a_replayed_nonce() {
let body = nras_body(true, "00000000000000000000000000000000");
let verdict = check_nras_eat(&body, NONCE, &pinned_key());
assert!(!verdict.nonce_matches);
assert!(!verdict.passed());
}
#[test]
fn unconfigured_key_fails_closed() {
let body = nras_body(true, NONCE);
let verdict = check_nras_eat(&body, NONCE, &NvidiaEatKey::unconfigured());
assert!(!verdict.signature_verified);
assert!(!verdict.passed());
}
#[test]
fn jwks_resolves_the_signing_key_by_kid() {
let jwks = NvidiaEatKey::from_jwks_json(TEST_JWKS.as_bytes()).expect("jwks parses");
let ok = nras_body_kid(true, NONCE, Some("test-kid-1"));
assert!(check_nras_eat(&ok, NONCE, &jwks).passed());
let unknown = nras_body_kid(true, NONCE, Some("rotated-away-kid"));
assert!(!check_nras_eat(&unknown, NONCE, &jwks).signature_verified);
}
#[test]
fn rejects_a_malformed_response_body() {
let verdict = check_nras_eat(b"[\"not jwt shaped\"]", NONCE, &pinned_key());
assert_eq!(verdict, NrasVerdict::failed());
}
}