1use crate::errors::{AuthError, Result};
20use serde::{Deserialize, Serialize};
21use sha2::{Digest, Sha256};
22use std::collections::HashMap;
23use std::sync::Arc;
24use std::time::{Duration, SystemTime, UNIX_EPOCH};
25use tokio::sync::RwLock;
26
27#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
33pub struct SpiffeId {
34 pub trust_domain: String,
36 pub path: String,
38}
39
40impl SpiffeId {
41 pub fn parse(uri: &str) -> Result<Self> {
49 let stripped = uri
50 .strip_prefix("spiffe://")
51 .ok_or_else(|| AuthError::validation("SPIFFE ID must start with 'spiffe://'"))?;
52
53 if stripped.is_empty() {
54 return Err(AuthError::validation("SPIFFE ID trust domain is empty"));
55 }
56
57 let (trust_domain, path) = match stripped.find('/') {
59 Some(idx) => (&stripped[..idx], &stripped[idx..]),
60 None => (stripped, ""),
61 };
62
63 if trust_domain.is_empty() {
65 return Err(AuthError::validation("SPIFFE ID trust domain is empty"));
66 }
67
68 for ch in trust_domain.chars() {
69 if !ch.is_ascii_alphanumeric() && ch != '-' && ch != '.' && ch != '_' {
70 return Err(AuthError::validation(&format!(
71 "SPIFFE ID trust domain contains invalid character: '{ch}'"
72 )));
73 }
74 }
75
76 if path.contains('?') || path.contains('#') {
78 return Err(AuthError::validation(
79 "SPIFFE ID must not contain query or fragment",
80 ));
81 }
82
83 if path.len() > 1 && path.ends_with('/') {
85 return Err(AuthError::validation(
86 "SPIFFE ID path must not end with '/'",
87 ));
88 }
89
90 if path.contains("//") {
92 return Err(AuthError::validation(
93 "SPIFFE ID path must not contain empty segments",
94 ));
95 }
96
97 for segment in path.split('/').skip(1) {
99 if segment == "." || segment == ".." {
100 return Err(AuthError::validation(
101 "SPIFFE ID path must not contain '.' or '..' segments",
102 ));
103 }
104 }
105
106 Ok(Self {
107 trust_domain: trust_domain.to_string(),
108 path: path.to_string(),
109 })
110 }
111
112 pub fn to_uri(&self) -> String {
114 format!("spiffe://{}{}", self.trust_domain, self.path)
115 }
116
117 pub fn is_member_of(&self, trust_domain: &str) -> bool {
119 self.trust_domain == trust_domain
120 }
121
122 pub fn matches_path_prefix(&self, prefix: &str) -> bool {
124 self.path.starts_with(prefix)
125 }
126}
127
128impl std::fmt::Display for SpiffeId {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 write!(f, "spiffe://{}{}", self.trust_domain, self.path)
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct JwtSvidClaims {
139 pub sub: String,
141 pub aud: Vec<String>,
143 pub exp: u64,
145 #[serde(default)]
147 pub iat: Option<u64>,
148}
149
150#[derive(Debug, Clone)]
152pub struct ValidatedJwtSvid {
153 pub spiffe_id: SpiffeId,
155 pub claims: JwtSvidClaims,
157 pub header: serde_json::Value,
159}
160
161pub fn validate_jwt_svid(token: &str, expected_audience: &str) -> Result<ValidatedJwtSvid> {
172 let b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD;
173 use base64::Engine;
174
175 let parts: Vec<&str> = token.split('.').collect();
176 if parts.len() != 3 {
177 return Err(AuthError::validation("JWT-SVID must have 3 parts"));
178 }
179
180 let header_bytes = b64
182 .decode(parts[0])
183 .map_err(|_| AuthError::validation("Invalid JWT-SVID header encoding"))?;
184 let header: serde_json::Value = serde_json::from_slice(&header_bytes)
185 .map_err(|_| AuthError::validation("Invalid JWT-SVID header JSON"))?;
186
187 let alg = header
189 .get("alg")
190 .and_then(|v| v.as_str())
191 .ok_or_else(|| AuthError::validation("JWT-SVID header missing 'alg'"))?;
192 if alg.eq_ignore_ascii_case("none") {
193 return Err(AuthError::validation(
194 "JWT-SVID must not use 'none' algorithm",
195 ));
196 }
197
198 let claims_bytes = b64
200 .decode(parts[1])
201 .map_err(|_| AuthError::validation("Invalid JWT-SVID claims encoding"))?;
202 let claims: JwtSvidClaims = serde_json::from_slice(&claims_bytes)
203 .map_err(|_| AuthError::validation("Invalid JWT-SVID claims JSON"))?;
204
205 let spiffe_id = SpiffeId::parse(&claims.sub)?;
207
208 let now = SystemTime::now()
210 .duration_since(UNIX_EPOCH)
211 .unwrap_or_default()
212 .as_secs();
213 if claims.exp <= now {
214 return Err(AuthError::validation("JWT-SVID has expired"));
215 }
216
217 if !claims.aud.iter().any(|a| a == expected_audience) {
219 return Err(AuthError::validation(
220 "JWT-SVID audience does not match expected audience",
221 ));
222 }
223
224 Ok(ValidatedJwtSvid {
225 spiffe_id,
226 claims,
227 header,
228 })
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct X509SvidInfo {
236 pub spiffe_id: SpiffeId,
238 pub fingerprint: String,
240 #[serde(default)]
242 pub serial: Option<String>,
243 #[serde(default)]
245 pub not_before: Option<u64>,
246 #[serde(default)]
248 pub not_after: Option<u64>,
249}
250
251pub fn extract_spiffe_id_from_der(cert_der: &[u8]) -> Result<X509SvidInfo> {
256 let fingerprint = hex::encode(Sha256::digest(cert_der));
258
259 let cert_str = String::from_utf8_lossy(cert_der);
261 let spiffe_uri = find_spiffe_uri_in_bytes(cert_der)
262 .or_else(|| {
263 cert_str.find("spiffe://").map(|idx| {
265 let end = cert_str[idx..]
266 .find(|c: char| c.is_control() || c == '\0')
267 .unwrap_or(cert_str.len() - idx);
268 cert_str[idx..idx + end].to_string()
269 })
270 })
271 .ok_or_else(|| {
272 AuthError::validation("No SPIFFE ID (spiffe:// URI) found in certificate SAN")
273 })?;
274
275 let spiffe_id = SpiffeId::parse(&spiffe_uri)?;
276
277 Ok(X509SvidInfo {
278 spiffe_id,
279 fingerprint,
280 serial: None,
281 not_before: None,
282 not_after: None,
283 })
284}
285
286fn find_spiffe_uri_in_bytes(data: &[u8]) -> Option<String> {
288 let needle = b"spiffe://";
289 for i in 0..data.len().saturating_sub(needle.len()) {
290 if data[i..].starts_with(needle) {
291 let start = i;
293 let mut end = i + needle.len();
294 while end < data.len() {
295 let b = data[end];
296 if b < 0x20 || b == 0x7f || b == 0x00 {
297 break;
298 }
299 end += 1;
300 }
301 if let Ok(uri) = std::str::from_utf8(&data[start..end]) {
302 return Some(uri.to_string());
303 }
304 }
305 }
306 None
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
313pub struct SpiffeAuthzPolicy {
314 pub source: String,
316 pub destination: String,
318 pub allowed_actions: Vec<String>,
320}
321
322pub struct SpiffeTrustManager {
324 trust_bundles: Arc<RwLock<HashMap<String, Vec<Vec<u8>>>>>,
326 policies: Arc<RwLock<Vec<SpiffeAuthzPolicy>>>,
328}
329
330impl SpiffeTrustManager {
331 pub fn new() -> Self {
332 Self {
333 trust_bundles: Arc::new(RwLock::new(HashMap::new())),
334 policies: Arc::new(RwLock::new(Vec::new())),
335 }
336 }
337
338 pub async fn add_trust_bundle(&self, trust_domain: &str, ca_certs_der: Vec<Vec<u8>>) {
340 self.trust_bundles
341 .write()
342 .await
343 .insert(trust_domain.to_string(), ca_certs_der);
344 }
345
346 pub async fn has_trust_bundle(&self, trust_domain: &str) -> bool {
348 self.trust_bundles.read().await.contains_key(trust_domain)
349 }
350
351 pub async fn get_trust_bundle(&self, trust_domain: &str) -> Option<Vec<Vec<u8>>> {
353 self.trust_bundles.read().await.get(trust_domain).cloned()
354 }
355
356 pub async fn remove_trust_bundle(&self, trust_domain: &str) -> bool {
358 self.trust_bundles
359 .write()
360 .await
361 .remove(trust_domain)
362 .is_some()
363 }
364
365 pub async fn add_policy(&self, policy: SpiffeAuthzPolicy) {
367 self.policies.write().await.push(policy);
368 }
369
370 pub async fn is_authorized(
372 &self,
373 source: &SpiffeId,
374 destination: &SpiffeId,
375 action: &str,
376 ) -> bool {
377 let policies = self.policies.read().await;
378 let source_uri = source.to_uri();
379 let dest_uri = destination.to_uri();
380
381 policies.iter().any(|p| {
382 (p.source == source_uri || p.source == "*")
383 && (p.destination == dest_uri || p.destination == "*")
384 && (p.allowed_actions.contains(&action.to_string())
385 || p.allowed_actions.contains(&"*".to_string()))
386 })
387 }
388
389 pub async fn verify_jwt_svid(
391 &self,
392 token: &str,
393 expected_audience: &str,
394 ) -> Result<ValidatedJwtSvid> {
395 let result = validate_jwt_svid(token, expected_audience)?;
396
397 if !self.has_trust_bundle(&result.spiffe_id.trust_domain).await {
399 return Err(AuthError::validation(&format!(
400 "No trust bundle for domain '{}'",
401 result.spiffe_id.trust_domain
402 )));
403 }
404
405 Ok(result)
406 }
407
408 pub async fn trust_domains(&self) -> Vec<String> {
410 self.trust_bundles.read().await.keys().cloned().collect()
411 }
412}
413
414impl Default for SpiffeTrustManager {
415 fn default() -> Self {
416 Self::new()
417 }
418}
419
420#[derive(Debug, Clone, Serialize, Deserialize)]
424pub enum SvidResponse {
425 X509 {
427 spiffe_id: String,
428 cert_chain: Vec<Vec<u8>>,
429 private_key: Vec<u8>,
430 bundle: Vec<Vec<u8>>,
432 expires_at: u64,
434 },
435 Jwt {
437 spiffe_id: String,
438 token: String,
439 expires_at: u64,
440 },
441}
442
443#[derive(Debug, Clone, Serialize, Deserialize)]
445pub struct WorkloadApiConfig {
446 pub endpoint: String,
449 pub rotation_interval_secs: u64,
451 pub jwt_audiences: Vec<String>,
453}
454
455impl Default for WorkloadApiConfig {
456 fn default() -> Self {
457 Self {
458 endpoint: "/tmp/spire-agent/public/api.sock".to_string(),
459 rotation_interval_secs: 300,
460 jwt_audiences: Vec::new(),
461 }
462 }
463}
464
465pub struct WorkloadApiClient {
471 config: WorkloadApiConfig,
472 x509_svids: Arc<RwLock<HashMap<String, SvidResponse>>>,
474 jwt_svids: Arc<RwLock<HashMap<String, SvidResponse>>>,
476 bundles: Arc<RwLock<HashMap<String, Vec<Vec<u8>>>>>,
478}
479
480impl WorkloadApiClient {
481 pub fn new(config: WorkloadApiConfig) -> Self {
483 Self {
484 config,
485 x509_svids: Arc::new(RwLock::new(HashMap::new())),
486 jwt_svids: Arc::new(RwLock::new(HashMap::new())),
487 bundles: Arc::new(RwLock::new(HashMap::new())),
488 }
489 }
490
491 pub fn endpoint(&self) -> &str {
493 &self.config.endpoint
494 }
495
496 pub fn rotation_interval(&self) -> Duration {
498 Duration::from_secs(self.config.rotation_interval_secs)
499 }
500
501 pub async fn store_x509_svid(&self, svid: SvidResponse) {
503 if let SvidResponse::X509 {
504 ref spiffe_id,
505 ref bundle,
506 ..
507 } = svid
508 {
509 if let Ok(id) = SpiffeId::parse(spiffe_id) {
511 self.bundles
512 .write()
513 .await
514 .insert(id.trust_domain.clone(), bundle.clone());
515 }
516 self.x509_svids
517 .write()
518 .await
519 .insert(spiffe_id.clone(), svid);
520 }
521 }
522
523 pub async fn store_jwt_svid(&self, svid: SvidResponse) {
525 if let SvidResponse::Jwt { ref spiffe_id, .. } = svid {
526 self.jwt_svids.write().await.insert(spiffe_id.clone(), svid);
527 }
528 }
529
530 pub async fn get_x509_svid(&self, spiffe_id: &str) -> Option<SvidResponse> {
532 self.x509_svids.read().await.get(spiffe_id).cloned()
533 }
534
535 pub async fn get_jwt_svid(&self, spiffe_id: &str) -> Option<SvidResponse> {
537 self.jwt_svids.read().await.get(spiffe_id).cloned()
538 }
539
540 pub async fn get_bundle(&self, trust_domain: &str) -> Option<Vec<Vec<u8>>> {
542 self.bundles.read().await.get(trust_domain).cloned()
543 }
544
545 pub async fn needs_rotation(&self) -> Vec<String> {
547 let now = SystemTime::now()
548 .duration_since(UNIX_EPOCH)
549 .unwrap_or_default()
550 .as_secs();
551
552 let svids = self.x509_svids.read().await;
553 let mut needs = Vec::new();
554 for (id, svid) in svids.iter() {
555 if let SvidResponse::X509 { expires_at, .. } = svid {
556 let remaining = expires_at.saturating_sub(now);
558 let threshold = self.config.rotation_interval_secs;
559 if remaining < threshold {
560 needs.push(id.clone());
561 }
562 }
563 }
564 needs
565 }
566
567 pub async fn cleanup_expired(&self) {
569 let now = SystemTime::now()
570 .duration_since(UNIX_EPOCH)
571 .unwrap_or_default()
572 .as_secs();
573
574 self.x509_svids.write().await.retain(|_, svid| {
575 if let SvidResponse::X509 { expires_at, .. } = svid {
576 *expires_at > now
577 } else {
578 true
579 }
580 });
581
582 self.jwt_svids.write().await.retain(|_, svid| {
583 if let SvidResponse::Jwt { expires_at, .. } = svid {
584 *expires_at > now
585 } else {
586 true
587 }
588 });
589 }
590
591 pub async fn x509_count(&self) -> usize {
593 self.x509_svids.read().await.len()
594 }
595
596 pub async fn jwt_count(&self) -> usize {
598 self.jwt_svids.read().await.len()
599 }
600}
601
602#[derive(Debug, Clone, Serialize, Deserialize)]
606pub struct AttestationEvidence {
607 pub attestor: String,
609 pub payload: HashMap<String, String>,
611}
612
613#[derive(Debug, Clone, Serialize, Deserialize)]
615pub struct AttestationResult {
616 pub spiffe_ids: Vec<SpiffeId>,
618 pub selectors: Vec<WorkloadSelector>,
620}
621
622#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
624pub struct WorkloadSelector {
625 pub selector_type: String,
627 pub value: String,
629}
630
631impl WorkloadSelector {
632 pub fn unix_uid(uid: u32) -> Self {
634 Self {
635 selector_type: "unix".to_string(),
636 value: format!("uid:{uid}"),
637 }
638 }
639
640 pub fn unix_gid(gid: u32) -> Self {
642 Self {
643 selector_type: "unix".to_string(),
644 value: format!("gid:{gid}"),
645 }
646 }
647
648 pub fn k8s_sa(namespace: &str, name: &str) -> Self {
650 Self {
651 selector_type: "k8s".to_string(),
652 value: format!("sa:{namespace}:{name}"),
653 }
654 }
655
656 pub fn k8s_pod_label(key: &str, value: &str) -> Self {
658 Self {
659 selector_type: "k8s".to_string(),
660 value: format!("pod-label:{key}:{value}"),
661 }
662 }
663
664 pub fn docker_image_id(image_id: &str) -> Self {
666 Self {
667 selector_type: "docker".to_string(),
668 value: format!("image_id:{image_id}"),
669 }
670 }
671}
672
673#[derive(Debug, Clone, Serialize, Deserialize)]
675pub struct RegistrationEntry {
676 pub spiffe_id: SpiffeId,
678 pub parent_id: SpiffeId,
680 pub selectors: Vec<WorkloadSelector>,
682 pub ttl: u64,
684 pub downstream: bool,
686}
687
688pub struct RegistrationStore {
690 entries: Arc<RwLock<Vec<RegistrationEntry>>>,
691}
692
693impl RegistrationStore {
694 pub fn new() -> Self {
696 Self {
697 entries: Arc::new(RwLock::new(Vec::new())),
698 }
699 }
700
701 pub async fn register(&self, entry: RegistrationEntry) {
703 self.entries.write().await.push(entry);
704 }
705
706 pub async fn match_selectors(&self, workload_selectors: &[WorkloadSelector]) -> Vec<SpiffeId> {
711 let entries = self.entries.read().await;
712 entries
713 .iter()
714 .filter(|entry| {
715 entry
716 .selectors
717 .iter()
718 .all(|s| workload_selectors.contains(s))
719 })
720 .map(|entry| entry.spiffe_id.clone())
721 .collect()
722 }
723
724 pub async fn attest(&self, evidence: &AttestationEvidence) -> Result<AttestationResult> {
726 let selectors: Vec<WorkloadSelector> = evidence
728 .payload
729 .iter()
730 .map(|(key, value)| WorkloadSelector {
731 selector_type: evidence.attestor.clone(),
732 value: format!("{key}:{value}"),
733 })
734 .collect();
735
736 if selectors.is_empty() {
737 return Err(AuthError::validation(
738 "Attestation evidence contains no selectors",
739 ));
740 }
741
742 let spiffe_ids = self.match_selectors(&selectors).await;
743 if spiffe_ids.is_empty() {
744 return Err(AuthError::validation(
745 "No registration entries match the workload selectors",
746 ));
747 }
748
749 Ok(AttestationResult {
750 spiffe_ids,
751 selectors,
752 })
753 }
754
755 pub async fn count(&self) -> usize {
757 self.entries.read().await.len()
758 }
759
760 pub async fn remove_by_spiffe_id(&self, id: &SpiffeId) {
762 self.entries.write().await.retain(|e| &e.spiffe_id != id);
763 }
764}
765
766impl Default for RegistrationStore {
767 fn default() -> Self {
768 Self::new()
769 }
770}
771
772#[derive(Debug, Clone, Serialize, Deserialize)]
776pub struct FederatedBundle {
777 pub trust_domain: String,
779 pub ca_certs: Vec<Vec<u8>>,
781 pub refreshed_at: u64,
783 pub sequence_number: u64,
785}
786
787pub struct FederatedTrustBundleManager {
791 local_domain: String,
793 bundles: Arc<RwLock<HashMap<String, FederatedBundle>>>,
795 endpoints: Arc<RwLock<HashMap<String, String>>>,
797}
798
799impl FederatedTrustBundleManager {
800 pub fn new(local_domain: impl Into<String>) -> Self {
802 Self {
803 local_domain: local_domain.into(),
804 bundles: Arc::new(RwLock::new(HashMap::new())),
805 endpoints: Arc::new(RwLock::new(HashMap::new())),
806 }
807 }
808
809 pub fn local_domain(&self) -> &str {
811 &self.local_domain
812 }
813
814 pub async fn add_federation_endpoint(&self, trust_domain: &str, endpoint_url: &str) {
816 self.endpoints
817 .write()
818 .await
819 .insert(trust_domain.to_string(), endpoint_url.to_string());
820 }
821
822 pub async fn store_bundle(&self, bundle: FederatedBundle) {
824 self.bundles
825 .write()
826 .await
827 .insert(bundle.trust_domain.clone(), bundle);
828 }
829
830 pub async fn get_bundle(&self, trust_domain: &str) -> Option<FederatedBundle> {
832 self.bundles.read().await.get(trust_domain).cloned()
833 }
834
835 pub async fn get_endpoint(&self, trust_domain: &str) -> Option<String> {
837 self.endpoints.read().await.get(trust_domain).cloned()
838 }
839
840 pub async fn federated_domains(&self) -> Vec<String> {
842 self.bundles.read().await.keys().cloned().collect()
843 }
844
845 pub async fn is_federated_id_trusted(&self, id: &SpiffeId) -> bool {
847 if id.trust_domain == self.local_domain {
848 return true; }
850 self.bundles.read().await.contains_key(&id.trust_domain)
851 }
852
853 pub async fn remove_bundle(&self, trust_domain: &str) -> bool {
855 self.bundles.write().await.remove(trust_domain).is_some()
856 }
857
858 pub async fn cleanup_stale(&self, max_age: Duration) {
860 let now = SystemTime::now()
861 .duration_since(UNIX_EPOCH)
862 .unwrap_or_default()
863 .as_secs();
864 let max_age_secs = max_age.as_secs();
865 self.bundles
866 .write()
867 .await
868 .retain(|_, b| now.saturating_sub(b.refreshed_at) <= max_age_secs);
869 }
870
871 pub async fn bundle_count(&self) -> usize {
873 self.bundles.read().await.len()
874 }
875}
876
877#[cfg(test)]
878mod tests {
879 use super::*;
880 use base64::Engine;
881
882 #[test]
885 fn test_parse_valid_spiffe_id() {
886 let id = SpiffeId::parse("spiffe://example.org/service/web").unwrap();
887 assert_eq!(id.trust_domain, "example.org");
888 assert_eq!(id.path, "/service/web");
889 assert_eq!(id.to_uri(), "spiffe://example.org/service/web");
890 }
891
892 #[test]
893 fn test_parse_spiffe_id_no_path() {
894 let id = SpiffeId::parse("spiffe://example.org").unwrap();
895 assert_eq!(id.trust_domain, "example.org");
896 assert_eq!(id.path, "");
897 }
898
899 #[test]
900 fn test_parse_spiffe_id_deeply_nested() {
901 let id = SpiffeId::parse("spiffe://prod.example.com/ns/default/sa/api-server").unwrap();
902 assert_eq!(id.trust_domain, "prod.example.com");
903 assert_eq!(id.path, "/ns/default/sa/api-server");
904 }
905
906 #[test]
907 fn test_parse_missing_scheme() {
908 assert!(SpiffeId::parse("https://example.org/svc").is_err());
909 }
910
911 #[test]
912 fn test_parse_empty_trust_domain() {
913 assert!(SpiffeId::parse("spiffe://").is_err());
914 }
915
916 #[test]
917 fn test_parse_invalid_td_char() {
918 assert!(SpiffeId::parse("spiffe://ex ample.org/svc").is_err());
919 }
920
921 #[test]
922 fn test_parse_query_rejected() {
923 assert!(SpiffeId::parse("spiffe://example.org/svc?q=1").is_err());
924 }
925
926 #[test]
927 fn test_parse_fragment_rejected() {
928 assert!(SpiffeId::parse("spiffe://example.org/svc#frag").is_err());
929 }
930
931 #[test]
932 fn test_parse_trailing_slash_rejected() {
933 assert!(SpiffeId::parse("spiffe://example.org/svc/").is_err());
934 }
935
936 #[test]
937 fn test_parse_empty_segment_rejected() {
938 assert!(SpiffeId::parse("spiffe://example.org//svc").is_err());
939 }
940
941 #[test]
942 fn test_parse_dot_segment_rejected() {
943 assert!(SpiffeId::parse("spiffe://example.org/./svc").is_err());
944 assert!(SpiffeId::parse("spiffe://example.org/../svc").is_err());
945 }
946
947 #[test]
948 fn test_is_member_of() {
949 let id = SpiffeId::parse("spiffe://example.org/svc").unwrap();
950 assert!(id.is_member_of("example.org"));
951 assert!(!id.is_member_of("other.org"));
952 }
953
954 #[test]
955 fn test_matches_path_prefix() {
956 let id = SpiffeId::parse("spiffe://example.org/ns/prod/svc/api").unwrap();
957 assert!(id.matches_path_prefix("/ns/prod"));
958 assert!(!id.matches_path_prefix("/ns/staging"));
959 }
960
961 #[test]
962 fn test_display() {
963 let id = SpiffeId::parse("spiffe://td/path").unwrap();
964 assert_eq!(format!("{id}"), "spiffe://td/path");
965 }
966
967 fn make_jwt_svid(sub: &str, aud: &[&str], exp: u64, alg: &str) -> String {
970 let b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD;
971 let header = serde_json::json!({"alg": alg, "typ": "JWT"});
972 let claims = serde_json::json!({
973 "sub": sub,
974 "aud": aud,
975 "exp": exp,
976 });
977 let h = b64.encode(header.to_string().as_bytes());
978 let c = b64.encode(claims.to_string().as_bytes());
979 format!("{h}.{c}.fake-signature")
980 }
981
982 #[test]
983 fn test_validate_jwt_svid_valid() {
984 let future = SystemTime::now()
985 .duration_since(UNIX_EPOCH)
986 .unwrap()
987 .as_secs()
988 + 3600;
989 let token = make_jwt_svid(
990 "spiffe://example.org/svc/api",
991 &["https://service.example.org"],
992 future,
993 "ES256",
994 );
995 let result = validate_jwt_svid(&token, "https://service.example.org").unwrap();
996 assert_eq!(result.spiffe_id.trust_domain, "example.org");
997 assert_eq!(result.spiffe_id.path, "/svc/api");
998 }
999
1000 #[test]
1001 fn test_validate_jwt_svid_expired() {
1002 let past = 1_000_000;
1003 let token = make_jwt_svid("spiffe://example.org/svc", &["aud"], past, "ES256");
1004 assert!(validate_jwt_svid(&token, "aud").is_err());
1005 }
1006
1007 #[test]
1008 fn test_validate_jwt_svid_wrong_audience() {
1009 let future = SystemTime::now()
1010 .duration_since(UNIX_EPOCH)
1011 .unwrap()
1012 .as_secs()
1013 + 3600;
1014 let token = make_jwt_svid(
1015 "spiffe://example.org/svc",
1016 &["expected-aud"],
1017 future,
1018 "ES256",
1019 );
1020 assert!(validate_jwt_svid(&token, "wrong-aud").is_err());
1021 }
1022
1023 #[test]
1024 fn test_validate_jwt_svid_none_algorithm_rejected() {
1025 let future = SystemTime::now()
1026 .duration_since(UNIX_EPOCH)
1027 .unwrap()
1028 .as_secs()
1029 + 3600;
1030 let token = make_jwt_svid("spiffe://example.org/svc", &["aud"], future, "none");
1031 assert!(validate_jwt_svid(&token, "aud").is_err());
1032 }
1033
1034 #[test]
1035 fn test_validate_jwt_svid_invalid_sub() {
1036 let future = SystemTime::now()
1037 .duration_since(UNIX_EPOCH)
1038 .unwrap()
1039 .as_secs()
1040 + 3600;
1041 let token = make_jwt_svid("https://not-spiffe.example.org", &["aud"], future, "ES256");
1042 assert!(validate_jwt_svid(&token, "aud").is_err());
1043 }
1044
1045 #[test]
1046 fn test_validate_jwt_svid_malformed() {
1047 assert!(validate_jwt_svid("not.a.valid.jwt.token", "aud").is_err());
1048 assert!(validate_jwt_svid("only-one-part", "aud").is_err());
1049 }
1050
1051 #[test]
1054 fn test_extract_spiffe_id_from_synthetic_der() {
1055 let mut data = vec![0x30, 0x82]; data.extend_from_slice(&[0x00, 0x50]); data.extend_from_slice(b"some-cert-fields-");
1059 data.extend_from_slice(b"spiffe://example.org/workload/web");
1060 data.push(0x00); data.extend_from_slice(&[0xFF; 20]); let info = extract_spiffe_id_from_der(&data).unwrap();
1064 assert_eq!(info.spiffe_id.trust_domain, "example.org");
1065 assert_eq!(info.spiffe_id.path, "/workload/web");
1066 assert!(!info.fingerprint.is_empty());
1067 assert_eq!(info.fingerprint.len(), 64); }
1069
1070 #[test]
1071 fn test_extract_spiffe_id_no_uri() {
1072 let data = b"no spiffe uri here at all";
1073 assert!(extract_spiffe_id_from_der(data).is_err());
1074 }
1075
1076 #[tokio::test]
1079 async fn test_trust_manager_bundle_operations() {
1080 let mgr = SpiffeTrustManager::new();
1081 assert!(!mgr.has_trust_bundle("example.org").await);
1082
1083 mgr.add_trust_bundle("example.org", vec![vec![1, 2, 3]])
1084 .await;
1085 assert!(mgr.has_trust_bundle("example.org").await);
1086
1087 let bundle = mgr.get_trust_bundle("example.org").await.unwrap();
1088 assert_eq!(bundle.len(), 1);
1089
1090 let domains = mgr.trust_domains().await;
1091 assert_eq!(domains, vec!["example.org"]);
1092
1093 assert!(mgr.remove_trust_bundle("example.org").await);
1094 assert!(!mgr.has_trust_bundle("example.org").await);
1095 }
1096
1097 #[tokio::test]
1098 async fn test_trust_manager_verify_jwt_svid_no_bundle() {
1099 let mgr = SpiffeTrustManager::new();
1100 let future = SystemTime::now()
1101 .duration_since(UNIX_EPOCH)
1102 .unwrap()
1103 .as_secs()
1104 + 3600;
1105 let token = make_jwt_svid("spiffe://example.org/svc", &["aud"], future, "ES256");
1106 assert!(mgr.verify_jwt_svid(&token, "aud").await.is_err());
1108 }
1109
1110 #[tokio::test]
1111 async fn test_trust_manager_verify_jwt_svid_with_bundle() {
1112 let mgr = SpiffeTrustManager::new();
1113 mgr.add_trust_bundle("example.org", vec![vec![0xCA]]).await;
1114
1115 let future = SystemTime::now()
1116 .duration_since(UNIX_EPOCH)
1117 .unwrap()
1118 .as_secs()
1119 + 3600;
1120 let token = make_jwt_svid("spiffe://example.org/svc", &["aud"], future, "ES256");
1121 let result = mgr.verify_jwt_svid(&token, "aud").await.unwrap();
1122 assert_eq!(result.spiffe_id.trust_domain, "example.org");
1123 }
1124
1125 #[tokio::test]
1128 async fn test_authz_policy_exact_match() {
1129 let mgr = SpiffeTrustManager::new();
1130 mgr.add_policy(SpiffeAuthzPolicy {
1131 source: "spiffe://td/frontend".to_string(),
1132 destination: "spiffe://td/backend".to_string(),
1133 allowed_actions: vec!["GET".to_string(), "POST".to_string()],
1134 })
1135 .await;
1136
1137 let src = SpiffeId::parse("spiffe://td/frontend").unwrap();
1138 let dst = SpiffeId::parse("spiffe://td/backend").unwrap();
1139
1140 assert!(mgr.is_authorized(&src, &dst, "GET").await);
1141 assert!(mgr.is_authorized(&src, &dst, "POST").await);
1142 assert!(!mgr.is_authorized(&src, &dst, "DELETE").await);
1143 }
1144
1145 #[tokio::test]
1146 async fn test_authz_policy_wildcard() {
1147 let mgr = SpiffeTrustManager::new();
1148 mgr.add_policy(SpiffeAuthzPolicy {
1149 source: "*".to_string(),
1150 destination: "spiffe://td/public-api".to_string(),
1151 allowed_actions: vec!["*".to_string()],
1152 })
1153 .await;
1154
1155 let any_src = SpiffeId::parse("spiffe://other/svc").unwrap();
1156 let dst = SpiffeId::parse("spiffe://td/public-api").unwrap();
1157
1158 assert!(mgr.is_authorized(&any_src, &dst, "GET").await);
1159 assert!(mgr.is_authorized(&any_src, &dst, "DELETE").await);
1160 }
1161
1162 #[tokio::test]
1163 async fn test_authz_policy_no_match() {
1164 let mgr = SpiffeTrustManager::new();
1165 let src = SpiffeId::parse("spiffe://td/svc1").unwrap();
1166 let dst = SpiffeId::parse("spiffe://td/svc2").unwrap();
1167 assert!(!mgr.is_authorized(&src, &dst, "GET").await);
1168 }
1169
1170 #[test]
1173 fn test_workload_api_config_defaults() {
1174 let cfg = WorkloadApiConfig::default();
1175 assert!(cfg.endpoint.contains("spire-agent"));
1176 assert_eq!(cfg.rotation_interval_secs, 300);
1177 assert!(cfg.jwt_audiences.is_empty());
1178 }
1179
1180 #[tokio::test]
1181 async fn test_workload_api_store_x509_svid() {
1182 let client = WorkloadApiClient::new(WorkloadApiConfig::default());
1183 let svid = SvidResponse::X509 {
1184 spiffe_id: "spiffe://example.org/web".to_string(),
1185 cert_chain: vec![vec![0x30, 0x82]],
1186 private_key: vec![0x01],
1187 bundle: vec![vec![0xCA]],
1188 expires_at: 9999999999,
1189 };
1190 client.store_x509_svid(svid).await;
1191 assert_eq!(client.x509_count().await, 1);
1192 assert!(
1193 client
1194 .get_x509_svid("spiffe://example.org/web")
1195 .await
1196 .is_some()
1197 );
1198 assert!(client.get_bundle("example.org").await.is_some());
1200 }
1201
1202 #[tokio::test]
1203 async fn test_workload_api_store_jwt_svid() {
1204 let client = WorkloadApiClient::new(WorkloadApiConfig::default());
1205 let svid = SvidResponse::Jwt {
1206 spiffe_id: "spiffe://example.org/api".to_string(),
1207 token: "eyJ...".to_string(),
1208 expires_at: 9999999999,
1209 };
1210 client.store_jwt_svid(svid).await;
1211 assert_eq!(client.jwt_count().await, 1);
1212 assert!(
1213 client
1214 .get_jwt_svid("spiffe://example.org/api")
1215 .await
1216 .is_some()
1217 );
1218 }
1219
1220 #[tokio::test]
1221 async fn test_workload_api_cleanup_expired() {
1222 let client = WorkloadApiClient::new(WorkloadApiConfig::default());
1223 let svid = SvidResponse::X509 {
1225 spiffe_id: "spiffe://example.org/old".to_string(),
1226 cert_chain: vec![],
1227 private_key: vec![],
1228 bundle: vec![],
1229 expires_at: 1, };
1231 client.store_x509_svid(svid).await;
1232 assert_eq!(client.x509_count().await, 1);
1233 client.cleanup_expired().await;
1234 assert_eq!(client.x509_count().await, 0);
1235 }
1236
1237 #[tokio::test]
1238 async fn test_workload_api_needs_rotation() {
1239 let client = WorkloadApiClient::new(WorkloadApiConfig::default());
1240 let now = SystemTime::now()
1242 .duration_since(UNIX_EPOCH)
1243 .unwrap()
1244 .as_secs();
1245 let svid = SvidResponse::X509 {
1246 spiffe_id: "spiffe://example.org/expiring".to_string(),
1247 cert_chain: vec![],
1248 private_key: vec![],
1249 bundle: vec![],
1250 expires_at: now + 10,
1251 };
1252 client.store_x509_svid(svid).await;
1253 let needs = client.needs_rotation().await;
1254 assert_eq!(needs.len(), 1);
1255 assert_eq!(needs[0], "spiffe://example.org/expiring");
1256 }
1257
1258 #[test]
1259 fn test_workload_api_rotation_interval() {
1260 let cfg = WorkloadApiConfig {
1261 rotation_interval_secs: 600,
1262 ..WorkloadApiConfig::default()
1263 };
1264 let client = WorkloadApiClient::new(cfg);
1265 assert_eq!(client.rotation_interval(), Duration::from_secs(600));
1266 }
1267
1268 #[test]
1271 fn test_workload_selector_unix_uid() {
1272 let s = WorkloadSelector::unix_uid(1000);
1273 assert_eq!(s.selector_type, "unix");
1274 assert_eq!(s.value, "uid:1000");
1275 }
1276
1277 #[test]
1278 fn test_workload_selector_k8s_sa() {
1279 let s = WorkloadSelector::k8s_sa("default", "api-server");
1280 assert_eq!(s.selector_type, "k8s");
1281 assert_eq!(s.value, "sa:default:api-server");
1282 }
1283
1284 #[test]
1285 fn test_workload_selector_docker_image() {
1286 let s = WorkloadSelector::docker_image_id("sha256:abc123");
1287 assert_eq!(s.selector_type, "docker");
1288 assert_eq!(s.value, "image_id:sha256:abc123");
1289 }
1290
1291 #[tokio::test]
1294 async fn test_registration_store_match_selectors() {
1295 let store = RegistrationStore::new();
1296 let entry = RegistrationEntry {
1297 spiffe_id: SpiffeId::parse("spiffe://example.org/web").unwrap(),
1298 parent_id: SpiffeId::parse("spiffe://example.org/node1").unwrap(),
1299 selectors: vec![WorkloadSelector::unix_uid(1000)],
1300 ttl: 3600,
1301 downstream: false,
1302 };
1303 store.register(entry).await;
1304
1305 let ids = store
1307 .match_selectors(&[WorkloadSelector::unix_uid(1000)])
1308 .await;
1309 assert_eq!(ids.len(), 1);
1310 assert_eq!(ids[0].path, "/web");
1311
1312 let ids = store
1314 .match_selectors(&[
1315 WorkloadSelector::unix_uid(1000),
1316 WorkloadSelector::unix_gid(100),
1317 ])
1318 .await;
1319 assert_eq!(ids.len(), 1);
1320
1321 let ids = store
1323 .match_selectors(&[WorkloadSelector::unix_uid(2000)])
1324 .await;
1325 assert!(ids.is_empty());
1326 }
1327
1328 #[tokio::test]
1329 async fn test_registration_store_attest() {
1330 let store = RegistrationStore::new();
1331 let entry = RegistrationEntry {
1332 spiffe_id: SpiffeId::parse("spiffe://example.org/api").unwrap(),
1333 parent_id: SpiffeId::parse("spiffe://example.org/node1").unwrap(),
1334 selectors: vec![WorkloadSelector {
1335 selector_type: "unix".to_string(),
1336 value: "uid:1000".to_string(),
1337 }],
1338 ttl: 3600,
1339 downstream: false,
1340 };
1341 store.register(entry).await;
1342
1343 let evidence = AttestationEvidence {
1344 attestor: "unix".to_string(),
1345 payload: HashMap::from([("uid".to_string(), "1000".to_string())]),
1346 };
1347 let result = store.attest(&evidence).await.unwrap();
1348 assert_eq!(result.spiffe_ids.len(), 1);
1349 assert_eq!(result.spiffe_ids[0].path, "/api");
1350 assert_eq!(result.selectors.len(), 1);
1351 }
1352
1353 #[tokio::test]
1354 async fn test_registration_store_attest_no_match() {
1355 let store = RegistrationStore::new();
1356 let evidence = AttestationEvidence {
1357 attestor: "unix".to_string(),
1358 payload: HashMap::from([("uid".to_string(), "9999".to_string())]),
1359 };
1360 assert!(store.attest(&evidence).await.is_err());
1361 }
1362
1363 #[tokio::test]
1364 async fn test_registration_store_remove_by_spiffe_id() {
1365 let store = RegistrationStore::new();
1366 let id = SpiffeId::parse("spiffe://example.org/web").unwrap();
1367 store
1368 .register(RegistrationEntry {
1369 spiffe_id: id.clone(),
1370 parent_id: SpiffeId::parse("spiffe://example.org/node").unwrap(),
1371 selectors: vec![],
1372 ttl: 3600,
1373 downstream: false,
1374 })
1375 .await;
1376 assert_eq!(store.count().await, 1);
1377 store.remove_by_spiffe_id(&id).await;
1378 assert_eq!(store.count().await, 0);
1379 }
1380
1381 #[tokio::test]
1384 async fn test_federated_bundle_manager_local_domain() {
1385 let mgr = FederatedTrustBundleManager::new("example.org");
1386 assert_eq!(mgr.local_domain(), "example.org");
1387 }
1388
1389 #[tokio::test]
1390 async fn test_federated_bundle_store_and_retrieve() {
1391 let mgr = FederatedTrustBundleManager::new("local.org");
1392 let bundle = FederatedBundle {
1393 trust_domain: "remote.org".to_string(),
1394 ca_certs: vec![vec![0xCA, 0xFE]],
1395 refreshed_at: 1000000,
1396 sequence_number: 1,
1397 };
1398 mgr.store_bundle(bundle).await;
1399 assert_eq!(mgr.bundle_count().await, 1);
1400
1401 let b = mgr.get_bundle("remote.org").await.unwrap();
1402 assert_eq!(b.sequence_number, 1);
1403 assert_eq!(b.ca_certs.len(), 1);
1404 }
1405
1406 #[tokio::test]
1407 async fn test_federated_bundle_is_trusted() {
1408 let mgr = FederatedTrustBundleManager::new("local.org");
1409 let local_id = SpiffeId::parse("spiffe://local.org/svc").unwrap();
1410 let remote_id = SpiffeId::parse("spiffe://remote.org/svc").unwrap();
1411
1412 assert!(mgr.is_federated_id_trusted(&local_id).await);
1414 assert!(!mgr.is_federated_id_trusted(&remote_id).await);
1416
1417 mgr.store_bundle(FederatedBundle {
1419 trust_domain: "remote.org".to_string(),
1420 ca_certs: vec![vec![0x01]],
1421 refreshed_at: 9999999999,
1422 sequence_number: 1,
1423 })
1424 .await;
1425 assert!(mgr.is_federated_id_trusted(&remote_id).await);
1426 }
1427
1428 #[tokio::test]
1429 async fn test_federated_bundle_remove() {
1430 let mgr = FederatedTrustBundleManager::new("local.org");
1431 mgr.store_bundle(FederatedBundle {
1432 trust_domain: "remote.org".to_string(),
1433 ca_certs: vec![],
1434 refreshed_at: 0,
1435 sequence_number: 0,
1436 })
1437 .await;
1438 assert!(mgr.remove_bundle("remote.org").await);
1439 assert!(!mgr.remove_bundle("remote.org").await);
1440 assert_eq!(mgr.bundle_count().await, 0);
1441 }
1442
1443 #[tokio::test]
1444 async fn test_federated_bundle_cleanup_stale() {
1445 let mgr = FederatedTrustBundleManager::new("local.org");
1446 mgr.store_bundle(FederatedBundle {
1448 trust_domain: "stale.org".to_string(),
1449 ca_certs: vec![],
1450 refreshed_at: 0,
1451 sequence_number: 1,
1452 })
1453 .await;
1454 let now = SystemTime::now()
1456 .duration_since(UNIX_EPOCH)
1457 .unwrap()
1458 .as_secs();
1459 mgr.store_bundle(FederatedBundle {
1460 trust_domain: "fresh.org".to_string(),
1461 ca_certs: vec![],
1462 refreshed_at: now,
1463 sequence_number: 1,
1464 })
1465 .await;
1466 assert_eq!(mgr.bundle_count().await, 2);
1467
1468 mgr.cleanup_stale(Duration::from_secs(3600)).await;
1469 assert_eq!(mgr.bundle_count().await, 1);
1470 assert!(mgr.get_bundle("fresh.org").await.is_some());
1471 assert!(mgr.get_bundle("stale.org").await.is_none());
1472 }
1473
1474 #[tokio::test]
1475 async fn test_federated_bundle_endpoints() {
1476 let mgr = FederatedTrustBundleManager::new("local.org");
1477 mgr.add_federation_endpoint("remote.org", "https://remote.org/.well-known/spiffe-bundle")
1478 .await;
1479 let ep = mgr.get_endpoint("remote.org").await.unwrap();
1480 assert!(ep.contains("spiffe-bundle"));
1481 assert!(mgr.get_endpoint("unknown.org").await.is_none());
1482 }
1483
1484 #[tokio::test]
1485 async fn test_federated_bundle_list_domains() {
1486 let mgr = FederatedTrustBundleManager::new("local.org");
1487 mgr.store_bundle(FederatedBundle {
1488 trust_domain: "a.org".to_string(),
1489 ca_certs: vec![],
1490 refreshed_at: 0,
1491 sequence_number: 0,
1492 })
1493 .await;
1494 mgr.store_bundle(FederatedBundle {
1495 trust_domain: "b.org".to_string(),
1496 ca_certs: vec![],
1497 refreshed_at: 0,
1498 sequence_number: 0,
1499 })
1500 .await;
1501 let mut domains = mgr.federated_domains().await;
1502 domains.sort();
1503 assert_eq!(domains, vec!["a.org", "b.org"]);
1504 }
1505}