use crate::errors::{AuthError, Result};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SpiffeId {
pub trust_domain: String,
pub path: String,
}
impl SpiffeId {
pub fn parse(uri: &str) -> Result<Self> {
let stripped = uri
.strip_prefix("spiffe://")
.ok_or_else(|| AuthError::validation("SPIFFE ID must start with 'spiffe://'"))?;
if stripped.is_empty() {
return Err(AuthError::validation("SPIFFE ID trust domain is empty"));
}
let (trust_domain, path) = match stripped.find('/') {
Some(idx) => (&stripped[..idx], &stripped[idx..]),
None => (stripped, ""),
};
if trust_domain.is_empty() {
return Err(AuthError::validation("SPIFFE ID trust domain is empty"));
}
for ch in trust_domain.chars() {
if !ch.is_ascii_alphanumeric() && ch != '-' && ch != '.' && ch != '_' {
return Err(AuthError::validation(&format!(
"SPIFFE ID trust domain contains invalid character: '{ch}'"
)));
}
}
if path.contains('?') || path.contains('#') {
return Err(AuthError::validation(
"SPIFFE ID must not contain query or fragment",
));
}
if path.len() > 1 && path.ends_with('/') {
return Err(AuthError::validation(
"SPIFFE ID path must not end with '/'",
));
}
if path.contains("//") {
return Err(AuthError::validation(
"SPIFFE ID path must not contain empty segments",
));
}
for segment in path.split('/').skip(1) {
if segment == "." || segment == ".." {
return Err(AuthError::validation(
"SPIFFE ID path must not contain '.' or '..' segments",
));
}
}
Ok(Self {
trust_domain: trust_domain.to_string(),
path: path.to_string(),
})
}
pub fn to_uri(&self) -> String {
format!("spiffe://{}{}", self.trust_domain, self.path)
}
pub fn is_member_of(&self, trust_domain: &str) -> bool {
self.trust_domain == trust_domain
}
pub fn matches_path_prefix(&self, prefix: &str) -> bool {
self.path.starts_with(prefix)
}
}
impl std::fmt::Display for SpiffeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "spiffe://{}{}", self.trust_domain, self.path)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtSvidClaims {
pub sub: String,
pub aud: Vec<String>,
pub exp: u64,
#[serde(default)]
pub iat: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct ValidatedJwtSvid {
pub spiffe_id: SpiffeId,
pub claims: JwtSvidClaims,
pub header: serde_json::Value,
}
pub fn validate_jwt_svid(token: &str, expected_audience: &str) -> Result<ValidatedJwtSvid> {
let b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(AuthError::validation("JWT-SVID must have 3 parts"));
}
let header_bytes = b64
.decode(parts[0])
.map_err(|_| AuthError::validation("Invalid JWT-SVID header encoding"))?;
let header: serde_json::Value = serde_json::from_slice(&header_bytes)
.map_err(|_| AuthError::validation("Invalid JWT-SVID header JSON"))?;
let alg = header
.get("alg")
.and_then(|v| v.as_str())
.ok_or_else(|| AuthError::validation("JWT-SVID header missing 'alg'"))?;
if alg.eq_ignore_ascii_case("none") {
return Err(AuthError::validation(
"JWT-SVID must not use 'none' algorithm",
));
}
let claims_bytes = b64
.decode(parts[1])
.map_err(|_| AuthError::validation("Invalid JWT-SVID claims encoding"))?;
let claims: JwtSvidClaims = serde_json::from_slice(&claims_bytes)
.map_err(|_| AuthError::validation("Invalid JWT-SVID claims JSON"))?;
let spiffe_id = SpiffeId::parse(&claims.sub)?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if claims.exp <= now {
return Err(AuthError::validation("JWT-SVID has expired"));
}
if !claims.aud.iter().any(|a| a == expected_audience) {
return Err(AuthError::validation(
"JWT-SVID audience does not match expected audience",
));
}
Ok(ValidatedJwtSvid {
spiffe_id,
claims,
header,
})
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct X509SvidInfo {
pub spiffe_id: SpiffeId,
pub fingerprint: String,
#[serde(default)]
pub serial: Option<String>,
#[serde(default)]
pub not_before: Option<u64>,
#[serde(default)]
pub not_after: Option<u64>,
}
pub fn extract_spiffe_id_from_der(cert_der: &[u8]) -> Result<X509SvidInfo> {
let fingerprint = hex::encode(Sha256::digest(cert_der));
let cert_str = String::from_utf8_lossy(cert_der);
let spiffe_uri = find_spiffe_uri_in_bytes(cert_der)
.or_else(|| {
cert_str.find("spiffe://").map(|idx| {
let end = cert_str[idx..]
.find(|c: char| c.is_control() || c == '\0')
.unwrap_or(cert_str.len() - idx);
cert_str[idx..idx + end].to_string()
})
})
.ok_or_else(|| {
AuthError::validation("No SPIFFE ID (spiffe:// URI) found in certificate SAN")
})?;
let spiffe_id = SpiffeId::parse(&spiffe_uri)?;
Ok(X509SvidInfo {
spiffe_id,
fingerprint,
serial: None,
not_before: None,
not_after: None,
})
}
fn find_spiffe_uri_in_bytes(data: &[u8]) -> Option<String> {
let needle = b"spiffe://";
for i in 0..data.len().saturating_sub(needle.len()) {
if data[i..].starts_with(needle) {
let start = i;
let mut end = i + needle.len();
while end < data.len() {
let b = data[end];
if b < 0x20 || b == 0x7f || b == 0x00 {
break;
}
end += 1;
}
if let Ok(uri) = std::str::from_utf8(&data[start..end]) {
return Some(uri.to_string());
}
}
}
None
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpiffeAuthzPolicy {
pub source: String,
pub destination: String,
pub allowed_actions: Vec<String>,
}
pub struct SpiffeTrustManager {
trust_bundles: Arc<RwLock<HashMap<String, Vec<Vec<u8>>>>>,
policies: Arc<RwLock<Vec<SpiffeAuthzPolicy>>>,
}
impl SpiffeTrustManager {
pub fn new() -> Self {
Self {
trust_bundles: Arc::new(RwLock::new(HashMap::new())),
policies: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn add_trust_bundle(&self, trust_domain: &str, ca_certs_der: Vec<Vec<u8>>) {
self.trust_bundles
.write()
.await
.insert(trust_domain.to_string(), ca_certs_der);
}
pub async fn has_trust_bundle(&self, trust_domain: &str) -> bool {
self.trust_bundles.read().await.contains_key(trust_domain)
}
pub async fn get_trust_bundle(&self, trust_domain: &str) -> Option<Vec<Vec<u8>>> {
self.trust_bundles.read().await.get(trust_domain).cloned()
}
pub async fn remove_trust_bundle(&self, trust_domain: &str) -> bool {
self.trust_bundles
.write()
.await
.remove(trust_domain)
.is_some()
}
pub async fn add_policy(&self, policy: SpiffeAuthzPolicy) {
self.policies.write().await.push(policy);
}
pub async fn is_authorized(
&self,
source: &SpiffeId,
destination: &SpiffeId,
action: &str,
) -> bool {
let policies = self.policies.read().await;
let source_uri = source.to_uri();
let dest_uri = destination.to_uri();
policies.iter().any(|p| {
(p.source == source_uri || p.source == "*")
&& (p.destination == dest_uri || p.destination == "*")
&& (p.allowed_actions.contains(&action.to_string())
|| p.allowed_actions.contains(&"*".to_string()))
})
}
pub async fn verify_jwt_svid(
&self,
token: &str,
expected_audience: &str,
) -> Result<ValidatedJwtSvid> {
let result = validate_jwt_svid(token, expected_audience)?;
if !self.has_trust_bundle(&result.spiffe_id.trust_domain).await {
return Err(AuthError::validation(&format!(
"No trust bundle for domain '{}'",
result.spiffe_id.trust_domain
)));
}
Ok(result)
}
pub async fn trust_domains(&self) -> Vec<String> {
self.trust_bundles.read().await.keys().cloned().collect()
}
}
impl Default for SpiffeTrustManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SvidResponse {
X509 {
spiffe_id: String,
cert_chain: Vec<Vec<u8>>,
private_key: Vec<u8>,
bundle: Vec<Vec<u8>>,
expires_at: u64,
},
Jwt {
spiffe_id: String,
token: String,
expires_at: u64,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkloadApiConfig {
pub endpoint: String,
pub rotation_interval_secs: u64,
pub jwt_audiences: Vec<String>,
}
impl Default for WorkloadApiConfig {
fn default() -> Self {
Self {
endpoint: "/tmp/spire-agent/public/api.sock".to_string(),
rotation_interval_secs: 300,
jwt_audiences: Vec::new(),
}
}
}
pub struct WorkloadApiClient {
config: WorkloadApiConfig,
x509_svids: Arc<RwLock<HashMap<String, SvidResponse>>>,
jwt_svids: Arc<RwLock<HashMap<String, SvidResponse>>>,
bundles: Arc<RwLock<HashMap<String, Vec<Vec<u8>>>>>,
}
impl WorkloadApiClient {
pub fn new(config: WorkloadApiConfig) -> Self {
Self {
config,
x509_svids: Arc::new(RwLock::new(HashMap::new())),
jwt_svids: Arc::new(RwLock::new(HashMap::new())),
bundles: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn endpoint(&self) -> &str {
&self.config.endpoint
}
pub fn rotation_interval(&self) -> Duration {
Duration::from_secs(self.config.rotation_interval_secs)
}
pub async fn store_x509_svid(&self, svid: SvidResponse) {
if let SvidResponse::X509 {
ref spiffe_id,
ref bundle,
..
} = svid
{
if let Ok(id) = SpiffeId::parse(spiffe_id) {
self.bundles
.write()
.await
.insert(id.trust_domain.clone(), bundle.clone());
}
self.x509_svids
.write()
.await
.insert(spiffe_id.clone(), svid);
}
}
pub async fn store_jwt_svid(&self, svid: SvidResponse) {
if let SvidResponse::Jwt { ref spiffe_id, .. } = svid {
self.jwt_svids.write().await.insert(spiffe_id.clone(), svid);
}
}
pub async fn get_x509_svid(&self, spiffe_id: &str) -> Option<SvidResponse> {
self.x509_svids.read().await.get(spiffe_id).cloned()
}
pub async fn get_jwt_svid(&self, spiffe_id: &str) -> Option<SvidResponse> {
self.jwt_svids.read().await.get(spiffe_id).cloned()
}
pub async fn get_bundle(&self, trust_domain: &str) -> Option<Vec<Vec<u8>>> {
self.bundles.read().await.get(trust_domain).cloned()
}
pub async fn needs_rotation(&self) -> Vec<String> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let svids = self.x509_svids.read().await;
let mut needs = Vec::new();
for (id, svid) in svids.iter() {
if let SvidResponse::X509 { expires_at, .. } = svid {
let remaining = expires_at.saturating_sub(now);
let threshold = self.config.rotation_interval_secs;
if remaining < threshold {
needs.push(id.clone());
}
}
}
needs
}
pub async fn cleanup_expired(&self) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
self.x509_svids.write().await.retain(|_, svid| {
if let SvidResponse::X509 { expires_at, .. } = svid {
*expires_at > now
} else {
true
}
});
self.jwt_svids.write().await.retain(|_, svid| {
if let SvidResponse::Jwt { expires_at, .. } = svid {
*expires_at > now
} else {
true
}
});
}
pub async fn x509_count(&self) -> usize {
self.x509_svids.read().await.len()
}
pub async fn jwt_count(&self) -> usize {
self.jwt_svids.read().await.len()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttestationEvidence {
pub attestor: String,
pub payload: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttestationResult {
pub spiffe_ids: Vec<SpiffeId>,
pub selectors: Vec<WorkloadSelector>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct WorkloadSelector {
pub selector_type: String,
pub value: String,
}
impl WorkloadSelector {
pub fn unix_uid(uid: u32) -> Self {
Self {
selector_type: "unix".to_string(),
value: format!("uid:{uid}"),
}
}
pub fn unix_gid(gid: u32) -> Self {
Self {
selector_type: "unix".to_string(),
value: format!("gid:{gid}"),
}
}
pub fn k8s_sa(namespace: &str, name: &str) -> Self {
Self {
selector_type: "k8s".to_string(),
value: format!("sa:{namespace}:{name}"),
}
}
pub fn k8s_pod_label(key: &str, value: &str) -> Self {
Self {
selector_type: "k8s".to_string(),
value: format!("pod-label:{key}:{value}"),
}
}
pub fn docker_image_id(image_id: &str) -> Self {
Self {
selector_type: "docker".to_string(),
value: format!("image_id:{image_id}"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegistrationEntry {
pub spiffe_id: SpiffeId,
pub parent_id: SpiffeId,
pub selectors: Vec<WorkloadSelector>,
pub ttl: u64,
pub downstream: bool,
}
pub struct RegistrationStore {
entries: Arc<RwLock<Vec<RegistrationEntry>>>,
}
impl RegistrationStore {
pub fn new() -> Self {
Self {
entries: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn register(&self, entry: RegistrationEntry) {
self.entries.write().await.push(entry);
}
pub async fn match_selectors(&self, workload_selectors: &[WorkloadSelector]) -> Vec<SpiffeId> {
let entries = self.entries.read().await;
entries
.iter()
.filter(|entry| {
entry
.selectors
.iter()
.all(|s| workload_selectors.contains(s))
})
.map(|entry| entry.spiffe_id.clone())
.collect()
}
pub async fn attest(&self, evidence: &AttestationEvidence) -> Result<AttestationResult> {
let selectors: Vec<WorkloadSelector> = evidence
.payload
.iter()
.map(|(key, value)| WorkloadSelector {
selector_type: evidence.attestor.clone(),
value: format!("{key}:{value}"),
})
.collect();
if selectors.is_empty() {
return Err(AuthError::validation(
"Attestation evidence contains no selectors",
));
}
let spiffe_ids = self.match_selectors(&selectors).await;
if spiffe_ids.is_empty() {
return Err(AuthError::validation(
"No registration entries match the workload selectors",
));
}
Ok(AttestationResult {
spiffe_ids,
selectors,
})
}
pub async fn count(&self) -> usize {
self.entries.read().await.len()
}
pub async fn remove_by_spiffe_id(&self, id: &SpiffeId) {
self.entries.write().await.retain(|e| &e.spiffe_id != id);
}
}
impl Default for RegistrationStore {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederatedBundle {
pub trust_domain: String,
pub ca_certs: Vec<Vec<u8>>,
pub refreshed_at: u64,
pub sequence_number: u64,
}
pub struct FederatedTrustBundleManager {
local_domain: String,
bundles: Arc<RwLock<HashMap<String, FederatedBundle>>>,
endpoints: Arc<RwLock<HashMap<String, String>>>,
}
impl FederatedTrustBundleManager {
pub fn new(local_domain: impl Into<String>) -> Self {
Self {
local_domain: local_domain.into(),
bundles: Arc::new(RwLock::new(HashMap::new())),
endpoints: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn local_domain(&self) -> &str {
&self.local_domain
}
pub async fn add_federation_endpoint(&self, trust_domain: &str, endpoint_url: &str) {
self.endpoints
.write()
.await
.insert(trust_domain.to_string(), endpoint_url.to_string());
}
pub async fn store_bundle(&self, bundle: FederatedBundle) {
self.bundles
.write()
.await
.insert(bundle.trust_domain.clone(), bundle);
}
pub async fn get_bundle(&self, trust_domain: &str) -> Option<FederatedBundle> {
self.bundles.read().await.get(trust_domain).cloned()
}
pub async fn get_endpoint(&self, trust_domain: &str) -> Option<String> {
self.endpoints.read().await.get(trust_domain).cloned()
}
pub async fn federated_domains(&self) -> Vec<String> {
self.bundles.read().await.keys().cloned().collect()
}
pub async fn is_federated_id_trusted(&self, id: &SpiffeId) -> bool {
if id.trust_domain == self.local_domain {
return true; }
self.bundles.read().await.contains_key(&id.trust_domain)
}
pub async fn remove_bundle(&self, trust_domain: &str) -> bool {
self.bundles.write().await.remove(trust_domain).is_some()
}
pub async fn cleanup_stale(&self, max_age: Duration) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let max_age_secs = max_age.as_secs();
self.bundles
.write()
.await
.retain(|_, b| now.saturating_sub(b.refreshed_at) <= max_age_secs);
}
pub async fn bundle_count(&self) -> usize {
self.bundles.read().await.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use base64::Engine;
#[test]
fn test_parse_valid_spiffe_id() {
let id = SpiffeId::parse("spiffe://example.org/service/web").unwrap();
assert_eq!(id.trust_domain, "example.org");
assert_eq!(id.path, "/service/web");
assert_eq!(id.to_uri(), "spiffe://example.org/service/web");
}
#[test]
fn test_parse_spiffe_id_no_path() {
let id = SpiffeId::parse("spiffe://example.org").unwrap();
assert_eq!(id.trust_domain, "example.org");
assert_eq!(id.path, "");
}
#[test]
fn test_parse_spiffe_id_deeply_nested() {
let id = SpiffeId::parse("spiffe://prod.example.com/ns/default/sa/api-server").unwrap();
assert_eq!(id.trust_domain, "prod.example.com");
assert_eq!(id.path, "/ns/default/sa/api-server");
}
#[test]
fn test_parse_missing_scheme() {
assert!(SpiffeId::parse("https://example.org/svc").is_err());
}
#[test]
fn test_parse_empty_trust_domain() {
assert!(SpiffeId::parse("spiffe://").is_err());
}
#[test]
fn test_parse_invalid_td_char() {
assert!(SpiffeId::parse("spiffe://ex ample.org/svc").is_err());
}
#[test]
fn test_parse_query_rejected() {
assert!(SpiffeId::parse("spiffe://example.org/svc?q=1").is_err());
}
#[test]
fn test_parse_fragment_rejected() {
assert!(SpiffeId::parse("spiffe://example.org/svc#frag").is_err());
}
#[test]
fn test_parse_trailing_slash_rejected() {
assert!(SpiffeId::parse("spiffe://example.org/svc/").is_err());
}
#[test]
fn test_parse_empty_segment_rejected() {
assert!(SpiffeId::parse("spiffe://example.org//svc").is_err());
}
#[test]
fn test_parse_dot_segment_rejected() {
assert!(SpiffeId::parse("spiffe://example.org/./svc").is_err());
assert!(SpiffeId::parse("spiffe://example.org/../svc").is_err());
}
#[test]
fn test_is_member_of() {
let id = SpiffeId::parse("spiffe://example.org/svc").unwrap();
assert!(id.is_member_of("example.org"));
assert!(!id.is_member_of("other.org"));
}
#[test]
fn test_matches_path_prefix() {
let id = SpiffeId::parse("spiffe://example.org/ns/prod/svc/api").unwrap();
assert!(id.matches_path_prefix("/ns/prod"));
assert!(!id.matches_path_prefix("/ns/staging"));
}
#[test]
fn test_display() {
let id = SpiffeId::parse("spiffe://td/path").unwrap();
assert_eq!(format!("{id}"), "spiffe://td/path");
}
fn make_jwt_svid(sub: &str, aud: &[&str], exp: u64, alg: &str) -> String {
let b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD;
let header = serde_json::json!({"alg": alg, "typ": "JWT"});
let claims = serde_json::json!({
"sub": sub,
"aud": aud,
"exp": exp,
});
let h = b64.encode(header.to_string().as_bytes());
let c = b64.encode(claims.to_string().as_bytes());
format!("{h}.{c}.fake-signature")
}
#[test]
fn test_validate_jwt_svid_valid() {
let future = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600;
let token = make_jwt_svid(
"spiffe://example.org/svc/api",
&["https://service.example.org"],
future,
"ES256",
);
let result = validate_jwt_svid(&token, "https://service.example.org").unwrap();
assert_eq!(result.spiffe_id.trust_domain, "example.org");
assert_eq!(result.spiffe_id.path, "/svc/api");
}
#[test]
fn test_validate_jwt_svid_expired() {
let past = 1_000_000;
let token = make_jwt_svid("spiffe://example.org/svc", &["aud"], past, "ES256");
assert!(validate_jwt_svid(&token, "aud").is_err());
}
#[test]
fn test_validate_jwt_svid_wrong_audience() {
let future = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600;
let token = make_jwt_svid(
"spiffe://example.org/svc",
&["expected-aud"],
future,
"ES256",
);
assert!(validate_jwt_svid(&token, "wrong-aud").is_err());
}
#[test]
fn test_validate_jwt_svid_none_algorithm_rejected() {
let future = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600;
let token = make_jwt_svid("spiffe://example.org/svc", &["aud"], future, "none");
assert!(validate_jwt_svid(&token, "aud").is_err());
}
#[test]
fn test_validate_jwt_svid_invalid_sub() {
let future = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600;
let token = make_jwt_svid("https://not-spiffe.example.org", &["aud"], future, "ES256");
assert!(validate_jwt_svid(&token, "aud").is_err());
}
#[test]
fn test_validate_jwt_svid_malformed() {
assert!(validate_jwt_svid("not.a.valid.jwt.token", "aud").is_err());
assert!(validate_jwt_svid("only-one-part", "aud").is_err());
}
#[test]
fn test_extract_spiffe_id_from_synthetic_der() {
let mut data = vec![0x30, 0x82]; data.extend_from_slice(&[0x00, 0x50]); data.extend_from_slice(b"some-cert-fields-");
data.extend_from_slice(b"spiffe://example.org/workload/web");
data.push(0x00); data.extend_from_slice(&[0xFF; 20]);
let info = extract_spiffe_id_from_der(&data).unwrap();
assert_eq!(info.spiffe_id.trust_domain, "example.org");
assert_eq!(info.spiffe_id.path, "/workload/web");
assert!(!info.fingerprint.is_empty());
assert_eq!(info.fingerprint.len(), 64); }
#[test]
fn test_extract_spiffe_id_no_uri() {
let data = b"no spiffe uri here at all";
assert!(extract_spiffe_id_from_der(data).is_err());
}
#[tokio::test]
async fn test_trust_manager_bundle_operations() {
let mgr = SpiffeTrustManager::new();
assert!(!mgr.has_trust_bundle("example.org").await);
mgr.add_trust_bundle("example.org", vec![vec![1, 2, 3]])
.await;
assert!(mgr.has_trust_bundle("example.org").await);
let bundle = mgr.get_trust_bundle("example.org").await.unwrap();
assert_eq!(bundle.len(), 1);
let domains = mgr.trust_domains().await;
assert_eq!(domains, vec!["example.org"]);
assert!(mgr.remove_trust_bundle("example.org").await);
assert!(!mgr.has_trust_bundle("example.org").await);
}
#[tokio::test]
async fn test_trust_manager_verify_jwt_svid_no_bundle() {
let mgr = SpiffeTrustManager::new();
let future = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600;
let token = make_jwt_svid("spiffe://example.org/svc", &["aud"], future, "ES256");
assert!(mgr.verify_jwt_svid(&token, "aud").await.is_err());
}
#[tokio::test]
async fn test_trust_manager_verify_jwt_svid_with_bundle() {
let mgr = SpiffeTrustManager::new();
mgr.add_trust_bundle("example.org", vec![vec![0xCA]]).await;
let future = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600;
let token = make_jwt_svid("spiffe://example.org/svc", &["aud"], future, "ES256");
let result = mgr.verify_jwt_svid(&token, "aud").await.unwrap();
assert_eq!(result.spiffe_id.trust_domain, "example.org");
}
#[tokio::test]
async fn test_authz_policy_exact_match() {
let mgr = SpiffeTrustManager::new();
mgr.add_policy(SpiffeAuthzPolicy {
source: "spiffe://td/frontend".to_string(),
destination: "spiffe://td/backend".to_string(),
allowed_actions: vec!["GET".to_string(), "POST".to_string()],
})
.await;
let src = SpiffeId::parse("spiffe://td/frontend").unwrap();
let dst = SpiffeId::parse("spiffe://td/backend").unwrap();
assert!(mgr.is_authorized(&src, &dst, "GET").await);
assert!(mgr.is_authorized(&src, &dst, "POST").await);
assert!(!mgr.is_authorized(&src, &dst, "DELETE").await);
}
#[tokio::test]
async fn test_authz_policy_wildcard() {
let mgr = SpiffeTrustManager::new();
mgr.add_policy(SpiffeAuthzPolicy {
source: "*".to_string(),
destination: "spiffe://td/public-api".to_string(),
allowed_actions: vec!["*".to_string()],
})
.await;
let any_src = SpiffeId::parse("spiffe://other/svc").unwrap();
let dst = SpiffeId::parse("spiffe://td/public-api").unwrap();
assert!(mgr.is_authorized(&any_src, &dst, "GET").await);
assert!(mgr.is_authorized(&any_src, &dst, "DELETE").await);
}
#[tokio::test]
async fn test_authz_policy_no_match() {
let mgr = SpiffeTrustManager::new();
let src = SpiffeId::parse("spiffe://td/svc1").unwrap();
let dst = SpiffeId::parse("spiffe://td/svc2").unwrap();
assert!(!mgr.is_authorized(&src, &dst, "GET").await);
}
#[test]
fn test_workload_api_config_defaults() {
let cfg = WorkloadApiConfig::default();
assert!(cfg.endpoint.contains("spire-agent"));
assert_eq!(cfg.rotation_interval_secs, 300);
assert!(cfg.jwt_audiences.is_empty());
}
#[tokio::test]
async fn test_workload_api_store_x509_svid() {
let client = WorkloadApiClient::new(WorkloadApiConfig::default());
let svid = SvidResponse::X509 {
spiffe_id: "spiffe://example.org/web".to_string(),
cert_chain: vec![vec![0x30, 0x82]],
private_key: vec![0x01],
bundle: vec![vec![0xCA]],
expires_at: 9999999999,
};
client.store_x509_svid(svid).await;
assert_eq!(client.x509_count().await, 1);
assert!(
client
.get_x509_svid("spiffe://example.org/web")
.await
.is_some()
);
assert!(client.get_bundle("example.org").await.is_some());
}
#[tokio::test]
async fn test_workload_api_store_jwt_svid() {
let client = WorkloadApiClient::new(WorkloadApiConfig::default());
let svid = SvidResponse::Jwt {
spiffe_id: "spiffe://example.org/api".to_string(),
token: "eyJ...".to_string(),
expires_at: 9999999999,
};
client.store_jwt_svid(svid).await;
assert_eq!(client.jwt_count().await, 1);
assert!(
client
.get_jwt_svid("spiffe://example.org/api")
.await
.is_some()
);
}
#[tokio::test]
async fn test_workload_api_cleanup_expired() {
let client = WorkloadApiClient::new(WorkloadApiConfig::default());
let svid = SvidResponse::X509 {
spiffe_id: "spiffe://example.org/old".to_string(),
cert_chain: vec![],
private_key: vec![],
bundle: vec![],
expires_at: 1, };
client.store_x509_svid(svid).await;
assert_eq!(client.x509_count().await, 1);
client.cleanup_expired().await;
assert_eq!(client.x509_count().await, 0);
}
#[tokio::test]
async fn test_workload_api_needs_rotation() {
let client = WorkloadApiClient::new(WorkloadApiConfig::default());
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let svid = SvidResponse::X509 {
spiffe_id: "spiffe://example.org/expiring".to_string(),
cert_chain: vec![],
private_key: vec![],
bundle: vec![],
expires_at: now + 10,
};
client.store_x509_svid(svid).await;
let needs = client.needs_rotation().await;
assert_eq!(needs.len(), 1);
assert_eq!(needs[0], "spiffe://example.org/expiring");
}
#[test]
fn test_workload_api_rotation_interval() {
let cfg = WorkloadApiConfig {
rotation_interval_secs: 600,
..WorkloadApiConfig::default()
};
let client = WorkloadApiClient::new(cfg);
assert_eq!(client.rotation_interval(), Duration::from_secs(600));
}
#[test]
fn test_workload_selector_unix_uid() {
let s = WorkloadSelector::unix_uid(1000);
assert_eq!(s.selector_type, "unix");
assert_eq!(s.value, "uid:1000");
}
#[test]
fn test_workload_selector_k8s_sa() {
let s = WorkloadSelector::k8s_sa("default", "api-server");
assert_eq!(s.selector_type, "k8s");
assert_eq!(s.value, "sa:default:api-server");
}
#[test]
fn test_workload_selector_docker_image() {
let s = WorkloadSelector::docker_image_id("sha256:abc123");
assert_eq!(s.selector_type, "docker");
assert_eq!(s.value, "image_id:sha256:abc123");
}
#[tokio::test]
async fn test_registration_store_match_selectors() {
let store = RegistrationStore::new();
let entry = RegistrationEntry {
spiffe_id: SpiffeId::parse("spiffe://example.org/web").unwrap(),
parent_id: SpiffeId::parse("spiffe://example.org/node1").unwrap(),
selectors: vec![WorkloadSelector::unix_uid(1000)],
ttl: 3600,
downstream: false,
};
store.register(entry).await;
let ids = store
.match_selectors(&[WorkloadSelector::unix_uid(1000)])
.await;
assert_eq!(ids.len(), 1);
assert_eq!(ids[0].path, "/web");
let ids = store
.match_selectors(&[
WorkloadSelector::unix_uid(1000),
WorkloadSelector::unix_gid(100),
])
.await;
assert_eq!(ids.len(), 1);
let ids = store
.match_selectors(&[WorkloadSelector::unix_uid(2000)])
.await;
assert!(ids.is_empty());
}
#[tokio::test]
async fn test_registration_store_attest() {
let store = RegistrationStore::new();
let entry = RegistrationEntry {
spiffe_id: SpiffeId::parse("spiffe://example.org/api").unwrap(),
parent_id: SpiffeId::parse("spiffe://example.org/node1").unwrap(),
selectors: vec![WorkloadSelector {
selector_type: "unix".to_string(),
value: "uid:1000".to_string(),
}],
ttl: 3600,
downstream: false,
};
store.register(entry).await;
let evidence = AttestationEvidence {
attestor: "unix".to_string(),
payload: HashMap::from([("uid".to_string(), "1000".to_string())]),
};
let result = store.attest(&evidence).await.unwrap();
assert_eq!(result.spiffe_ids.len(), 1);
assert_eq!(result.spiffe_ids[0].path, "/api");
assert_eq!(result.selectors.len(), 1);
}
#[tokio::test]
async fn test_registration_store_attest_no_match() {
let store = RegistrationStore::new();
let evidence = AttestationEvidence {
attestor: "unix".to_string(),
payload: HashMap::from([("uid".to_string(), "9999".to_string())]),
};
assert!(store.attest(&evidence).await.is_err());
}
#[tokio::test]
async fn test_registration_store_remove_by_spiffe_id() {
let store = RegistrationStore::new();
let id = SpiffeId::parse("spiffe://example.org/web").unwrap();
store
.register(RegistrationEntry {
spiffe_id: id.clone(),
parent_id: SpiffeId::parse("spiffe://example.org/node").unwrap(),
selectors: vec![],
ttl: 3600,
downstream: false,
})
.await;
assert_eq!(store.count().await, 1);
store.remove_by_spiffe_id(&id).await;
assert_eq!(store.count().await, 0);
}
#[tokio::test]
async fn test_federated_bundle_manager_local_domain() {
let mgr = FederatedTrustBundleManager::new("example.org");
assert_eq!(mgr.local_domain(), "example.org");
}
#[tokio::test]
async fn test_federated_bundle_store_and_retrieve() {
let mgr = FederatedTrustBundleManager::new("local.org");
let bundle = FederatedBundle {
trust_domain: "remote.org".to_string(),
ca_certs: vec![vec![0xCA, 0xFE]],
refreshed_at: 1000000,
sequence_number: 1,
};
mgr.store_bundle(bundle).await;
assert_eq!(mgr.bundle_count().await, 1);
let b = mgr.get_bundle("remote.org").await.unwrap();
assert_eq!(b.sequence_number, 1);
assert_eq!(b.ca_certs.len(), 1);
}
#[tokio::test]
async fn test_federated_bundle_is_trusted() {
let mgr = FederatedTrustBundleManager::new("local.org");
let local_id = SpiffeId::parse("spiffe://local.org/svc").unwrap();
let remote_id = SpiffeId::parse("spiffe://remote.org/svc").unwrap();
assert!(mgr.is_federated_id_trusted(&local_id).await);
assert!(!mgr.is_federated_id_trusted(&remote_id).await);
mgr.store_bundle(FederatedBundle {
trust_domain: "remote.org".to_string(),
ca_certs: vec![vec![0x01]],
refreshed_at: 9999999999,
sequence_number: 1,
})
.await;
assert!(mgr.is_federated_id_trusted(&remote_id).await);
}
#[tokio::test]
async fn test_federated_bundle_remove() {
let mgr = FederatedTrustBundleManager::new("local.org");
mgr.store_bundle(FederatedBundle {
trust_domain: "remote.org".to_string(),
ca_certs: vec![],
refreshed_at: 0,
sequence_number: 0,
})
.await;
assert!(mgr.remove_bundle("remote.org").await);
assert!(!mgr.remove_bundle("remote.org").await);
assert_eq!(mgr.bundle_count().await, 0);
}
#[tokio::test]
async fn test_federated_bundle_cleanup_stale() {
let mgr = FederatedTrustBundleManager::new("local.org");
mgr.store_bundle(FederatedBundle {
trust_domain: "stale.org".to_string(),
ca_certs: vec![],
refreshed_at: 0,
sequence_number: 1,
})
.await;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
mgr.store_bundle(FederatedBundle {
trust_domain: "fresh.org".to_string(),
ca_certs: vec![],
refreshed_at: now,
sequence_number: 1,
})
.await;
assert_eq!(mgr.bundle_count().await, 2);
mgr.cleanup_stale(Duration::from_secs(3600)).await;
assert_eq!(mgr.bundle_count().await, 1);
assert!(mgr.get_bundle("fresh.org").await.is_some());
assert!(mgr.get_bundle("stale.org").await.is_none());
}
#[tokio::test]
async fn test_federated_bundle_endpoints() {
let mgr = FederatedTrustBundleManager::new("local.org");
mgr.add_federation_endpoint("remote.org", "https://remote.org/.well-known/spiffe-bundle")
.await;
let ep = mgr.get_endpoint("remote.org").await.unwrap();
assert!(ep.contains("spiffe-bundle"));
assert!(mgr.get_endpoint("unknown.org").await.is_none());
}
#[tokio::test]
async fn test_federated_bundle_list_domains() {
let mgr = FederatedTrustBundleManager::new("local.org");
mgr.store_bundle(FederatedBundle {
trust_domain: "a.org".to_string(),
ca_certs: vec![],
refreshed_at: 0,
sequence_number: 0,
})
.await;
mgr.store_bundle(FederatedBundle {
trust_domain: "b.org".to_string(),
ca_certs: vec![],
refreshed_at: 0,
sequence_number: 0,
})
.await;
let mut domains = mgr.federated_domains().await;
domains.sort();
assert_eq!(domains, vec!["a.org", "b.org"]);
}
}