use super::{
audience::{AudienceError, validate_audience_shape, validate_role_grants},
canonical::shard_key_hash,
};
use crate::{cdk::types::Principal, dto::auth::DelegationCert};
use thiserror::Error;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct DelegatedAuthTtlLimits {
pub max_cert_ttl_ns: u64,
pub max_token_ttl_ns: u64,
}
#[derive(Debug, Eq, Error, PartialEq)]
pub enum CertRuleError {
#[error("delegated auth cert root pid mismatch (expected {expected}, found {found})")]
RootPidMismatch {
expected: Principal,
found: Principal,
},
#[error("delegated auth cert expires_at must be greater than issued_at")]
InvalidCertWindow,
#[error("delegated auth cert ttl {ttl_ns}ns exceeds max {max_ttl_ns}ns")]
CertTtlExceeded { ttl_ns: u64, max_ttl_ns: u64 },
#[error("delegated auth max token ttl must be greater than zero")]
TokenTtlZero,
#[error("delegated auth max token ttl {ttl_ns}ns exceeds max {max_ttl_ns}ns")]
TokenTtlExceeded { ttl_ns: u64, max_ttl_ns: u64 },
#[error("delegated auth max token ttl {token_ttl_ns}ns exceeds cert ttl {cert_ttl_ns}ns")]
TokenTtlOutlivesCert { token_ttl_ns: u64, cert_ttl_ns: u64 },
#[error("delegated auth shard public key must be compressed SEC1 secp256k1 encoding")]
ShardPublicKeyEncodingInvalid,
#[error("delegated auth shard public key hash mismatch")]
ShardPublicKeyHashMismatch,
#[error(transparent)]
Audience(#[from] AudienceError),
}
pub fn validate_cert_issuance_rules(
cert: &DelegationCert,
limits: DelegatedAuthTtlLimits,
expected_root_pid: Principal,
) -> Result<(), CertRuleError> {
if cert.root_pid != expected_root_pid {
return Err(CertRuleError::RootPidMismatch {
expected: expected_root_pid,
found: cert.root_pid,
});
}
if cert.not_before_ns < cert.issued_at_ns {
return Err(CertRuleError::InvalidCertWindow);
}
let cert_ttl_ns = cert
.expires_at_ns
.checked_sub(cert.not_before_ns)
.ok_or(CertRuleError::InvalidCertWindow)?;
if cert_ttl_ns == 0 {
return Err(CertRuleError::InvalidCertWindow);
}
if cert_ttl_ns > limits.max_cert_ttl_ns {
return Err(CertRuleError::CertTtlExceeded {
ttl_ns: cert_ttl_ns,
max_ttl_ns: limits.max_cert_ttl_ns,
});
}
if cert.max_token_ttl_ns == 0 {
return Err(CertRuleError::TokenTtlZero);
}
if cert.max_token_ttl_ns > limits.max_token_ttl_ns {
return Err(CertRuleError::TokenTtlExceeded {
ttl_ns: cert.max_token_ttl_ns,
max_ttl_ns: limits.max_token_ttl_ns,
});
}
if cert.max_token_ttl_ns > cert_ttl_ns {
return Err(CertRuleError::TokenTtlOutlivesCert {
token_ttl_ns: cert.max_token_ttl_ns,
cert_ttl_ns,
});
}
validate_audience_shape(&cert.aud)?;
validate_role_grants(&cert.grants)?;
if cert.shard_public_key_sec1.len() != 33 {
return Err(CertRuleError::ShardPublicKeyEncodingInvalid);
}
if shard_key_hash(
cert.shard_sig_alg,
&cert.shard_public_key_sec1,
cert.shard_key_binding,
) != cert.shard_key_hash
{
return Err(CertRuleError::ShardPublicKeyHashMismatch);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
dto::auth::{
DelegatedRoleGrant, DelegationAudience, ShardKeyBinding, ShardSignatureAlgorithm,
},
ids::CanisterRole,
};
fn p(id: u8) -> Principal {
Principal::from_slice(&[id; 29])
}
fn limits() -> DelegatedAuthTtlLimits {
DelegatedAuthTtlLimits {
max_cert_ttl_ns: 600,
max_token_ttl_ns: 120,
}
}
fn sample_cert() -> DelegationCert {
let role = CanisterRole::new("project_instance");
let shard_public_key_sec1 = vec![2; 33];
let shard_key_binding = ShardKeyBinding::IcThresholdEcdsaSecp256k1 {
key_name_hash: [5; 32],
derivation_path_hash: [6; 32],
};
let shard_sig_alg = ShardSignatureAlgorithm::IcThresholdEcdsaSecp256k1;
let shard_key_hash =
shard_key_hash(shard_sig_alg, &shard_public_key_sec1, shard_key_binding);
DelegationCert {
root_pid: p(1),
shard_pid: p(2),
shard_key_id: "shard-key".to_string(),
shard_sig_alg,
shard_public_key_sec1,
shard_key_hash,
shard_key_binding,
issued_at_ns: 100,
not_before_ns: 100,
expires_at_ns: 500,
max_token_ttl_ns: 120,
aud: DelegationAudience::Project("test".to_string()),
grants: vec![DelegatedRoleGrant {
target: role,
scopes: vec!["read".to_string()],
}],
}
}
#[test]
fn cert_rules_accept_well_formed_cert() {
let cert = sample_cert();
validate_cert_issuance_rules(&cert, limits(), p(1)).unwrap();
}
#[test]
fn cert_rules_enforce_root_pid_binding() {
let cert = sample_cert();
assert_eq!(
validate_cert_issuance_rules(&cert, limits(), p(9)),
Err(CertRuleError::RootPidMismatch {
expected: p(9),
found: p(1),
})
);
}
#[test]
fn cert_rules_enforce_cert_ttl_bound_at_root() {
let mut cert = sample_cert();
cert.expires_at_ns = 900;
assert_eq!(
validate_cert_issuance_rules(&cert, limits(), p(1)),
Err(CertRuleError::CertTtlExceeded {
ttl_ns: 800,
max_ttl_ns: 600,
})
);
}
#[test]
fn cert_rules_enforce_token_ttl_bound_at_root() {
let mut cert = sample_cert();
cert.max_token_ttl_ns = 121;
assert_eq!(
validate_cert_issuance_rules(&cert, limits(), p(1)),
Err(CertRuleError::TokenTtlExceeded {
ttl_ns: 121,
max_ttl_ns: 120,
})
);
}
#[test]
fn cert_rules_reject_token_ttl_outliving_cert() {
let mut cert = sample_cert();
cert.expires_at_ns = 150;
assert_eq!(
validate_cert_issuance_rules(&cert, limits(), p(1)),
Err(CertRuleError::TokenTtlOutlivesCert {
token_ttl_ns: 120,
cert_ttl_ns: 50,
})
);
}
#[test]
fn cert_rules_enforce_role_grant_shape() {
let mut cert = sample_cert();
cert.grants = Vec::new();
assert_eq!(
validate_cert_issuance_rules(&cert, limits(), p(1)),
Err(CertRuleError::Audience(AudienceError::GrantsEmpty))
);
}
#[test]
fn cert_rules_enforce_shard_public_key_hash_binding() {
let mut cert = sample_cert();
cert.shard_public_key_sec1 = vec![7; 33];
assert_eq!(
validate_cert_issuance_rules(&cert, limits(), p(1)),
Err(CertRuleError::ShardPublicKeyHashMismatch)
);
}
#[test]
fn cert_rules_enforce_compressed_shard_public_key_encoding() {
let mut cert = sample_cert();
cert.shard_public_key_sec1 = vec![7; 65];
assert_eq!(
validate_cert_issuance_rules(&cert, limits(), p(1)),
Err(CertRuleError::ShardPublicKeyEncodingInvalid)
);
}
}