1use std::collections::BTreeMap;
32
33use base64::Engine;
34use base64::engine::general_purpose::STANDARD as B64;
35use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
36
37pub const CRED_VERSION: u16 = 1;
41
42pub const CREDENTIAL_HEADER: &str = "x-memory-cred";
45
46pub const CREDENTIAL_PREFIX: &str = "v1=";
49
50pub const CREDENTIAL_SIG_LEN: usize = ed25519_dalek::SIGNATURE_LENGTH;
52
53pub const SUBJECT_PUBKEY_LEN: usize = ed25519_dalek::PUBLIC_KEY_LENGTH;
55
56pub const FED_CREDENTIAL_PATH_ENV: &str = "AI_MEMORY_FED_CRED_PATH";
61
62const FIELD_CRED_VERSION: &str = "cred_version";
64const FIELD_ISSUER_ID: &str = "issuer_id";
65const FIELD_NOT_AFTER: &str = "not_after";
66const FIELD_NOT_BEFORE: &str = "not_before";
67const FIELD_SUBJECT_AGENT_ID: &str = "subject_agent_id";
68const FIELD_SUBJECT_PUBKEY: &str = "subject_pubkey";
69const FIELD_TRUST_DOMAIN: &str = "trust_domain";
70
71const WIRE_CLAIMS_KEY: &str = "claims";
73const WIRE_SIG_KEY: &str = "sig";
74
75#[derive(Debug, Clone, PartialEq, Eq)]
79pub enum CredentialError {
80 Malformed,
82 BadSignature,
84 NotYetValid,
86 Expired,
88 UnsupportedVersion(u16),
90 BadSubjectKey,
92 UnknownIssuer,
95 WrongTrustDomain,
98}
99
100impl CredentialError {
101 #[must_use]
103 pub fn tag(&self) -> &'static str {
104 match self {
105 Self::Malformed => "credential_malformed",
106 Self::BadSignature => "credential_bad_signature",
107 Self::NotYetValid => "credential_not_yet_valid",
108 Self::Expired => "credential_expired",
109 Self::UnsupportedVersion(_) => "credential_unsupported_version",
110 Self::BadSubjectKey => "credential_bad_subject_key",
111 Self::UnknownIssuer => "credential_unknown_issuer",
112 Self::WrongTrustDomain => "credential_wrong_trust_domain",
113 }
114 }
115}
116
117impl std::fmt::Display for CredentialError {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 match self {
120 Self::UnsupportedVersion(v) => {
121 write!(
122 f,
123 "{} (got v{v}, this binary speaks v{CRED_VERSION})",
124 self.tag()
125 )
126 }
127 _ => f.write_str(self.tag()),
128 }
129 }
130}
131
132impl std::error::Error for CredentialError {}
133
134#[derive(Debug, Clone, PartialEq, Eq)]
136pub struct FederationCredential {
137 pub subject_agent_id: String,
141 pub subject_pubkey: [u8; SUBJECT_PUBKEY_LEN],
143 pub issuer_id: String,
145 pub trust_domain: String,
148 pub not_before: i64,
150 pub not_after: i64,
152 pub cred_version: u16,
155}
156
157impl FederationCredential {
158 pub fn canonical_claims_bytes(&self) -> Result<Vec<u8>, CredentialError> {
167 let mut map: BTreeMap<&str, ciborium::Value> = BTreeMap::new();
168 map.insert(
169 FIELD_SUBJECT_AGENT_ID,
170 ciborium::Value::Text(self.subject_agent_id.clone()),
171 );
172 map.insert(
173 FIELD_SUBJECT_PUBKEY,
174 ciborium::Value::Bytes(self.subject_pubkey.to_vec()),
175 );
176 map.insert(
177 FIELD_ISSUER_ID,
178 ciborium::Value::Text(self.issuer_id.clone()),
179 );
180 map.insert(
181 FIELD_TRUST_DOMAIN,
182 ciborium::Value::Text(self.trust_domain.clone()),
183 );
184 map.insert(FIELD_NOT_BEFORE, int_value(self.not_before));
185 map.insert(FIELD_NOT_AFTER, int_value(self.not_after));
186 map.insert(FIELD_CRED_VERSION, int_value(i64::from(self.cred_version)));
187
188 let entries: Vec<(ciborium::Value, ciborium::Value)> = map
189 .into_iter()
190 .map(|(k, v)| (ciborium::Value::Text(k.to_string()), v))
191 .collect();
192 let value = ciborium::Value::Map(entries);
193 let mut out = Vec::with_capacity(128);
194 ciborium::ser::into_writer(&value, &mut out).map_err(|_| CredentialError::Malformed)?;
195 Ok(out)
196 }
197
198 pub fn sign(&self, ca_signing_key: &SigningKey) -> Result<SignedCredential, CredentialError> {
204 let claims_bytes = self.canonical_claims_bytes()?;
205 let sig: Signature = ca_signing_key.sign(&claims_bytes);
206 Ok(SignedCredential {
207 credential: self.clone(),
208 claims_bytes,
209 signature: sig.to_bytes(),
210 })
211 }
212
213 fn from_claims_bytes(bytes: &[u8]) -> Result<Self, CredentialError> {
215 let value: ciborium::Value =
216 ciborium::de::from_reader(bytes).map_err(|_| CredentialError::Malformed)?;
217 let entries = match value {
218 ciborium::Value::Map(e) => e,
219 _ => return Err(CredentialError::Malformed),
220 };
221 let mut map: BTreeMap<String, ciborium::Value> = BTreeMap::new();
222 for (k, v) in entries {
223 if let ciborium::Value::Text(key) = k {
224 map.insert(key, v);
225 } else {
226 return Err(CredentialError::Malformed);
227 }
228 }
229 let subject_pubkey_vec = take_bytes(&mut map, FIELD_SUBJECT_PUBKEY)?;
230 if subject_pubkey_vec.len() != SUBJECT_PUBKEY_LEN {
231 return Err(CredentialError::Malformed);
232 }
233 let mut subject_pubkey = [0u8; SUBJECT_PUBKEY_LEN];
234 subject_pubkey.copy_from_slice(&subject_pubkey_vec);
235
236 let cred_version_i = take_int(&mut map, FIELD_CRED_VERSION)?;
237 let cred_version = u16::try_from(cred_version_i).map_err(|_| CredentialError::Malformed)?;
238
239 Ok(Self {
240 subject_agent_id: take_text(&mut map, FIELD_SUBJECT_AGENT_ID)?,
241 subject_pubkey,
242 issuer_id: take_text(&mut map, FIELD_ISSUER_ID)?,
243 trust_domain: take_text(&mut map, FIELD_TRUST_DOMAIN)?,
244 not_before: take_int(&mut map, FIELD_NOT_BEFORE)?,
245 not_after: take_int(&mut map, FIELD_NOT_AFTER)?,
246 cred_version,
247 })
248 }
249
250 pub fn subject_verifying_key(&self) -> Result<VerifyingKey, CredentialError> {
256 VerifyingKey::from_bytes(&self.subject_pubkey).map_err(|_| CredentialError::BadSubjectKey)
257 }
258}
259
260#[derive(Debug, Clone)]
263pub struct SignedCredential {
264 credential: FederationCredential,
265 claims_bytes: Vec<u8>,
268 signature: [u8; CREDENTIAL_SIG_LEN],
269}
270
271impl SignedCredential {
272 #[must_use]
274 pub fn credential(&self) -> &FederationCredential {
275 &self.credential
276 }
277
278 pub fn to_wire_bytes(&self) -> Result<Vec<u8>, CredentialError> {
283 let entries: Vec<(ciborium::Value, ciborium::Value)> = vec![
284 (
285 ciborium::Value::Text(WIRE_CLAIMS_KEY.to_string()),
286 ciborium::Value::Bytes(self.claims_bytes.clone()),
287 ),
288 (
289 ciborium::Value::Text(WIRE_SIG_KEY.to_string()),
290 ciborium::Value::Bytes(self.signature.to_vec()),
291 ),
292 ];
293 let value = ciborium::Value::Map(entries);
294 let mut out = Vec::with_capacity(self.claims_bytes.len() + CREDENTIAL_SIG_LEN + 16);
295 ciborium::ser::into_writer(&value, &mut out).map_err(|_| CredentialError::Malformed)?;
296 Ok(out)
297 }
298
299 pub fn from_wire_bytes(bytes: &[u8]) -> Result<Self, CredentialError> {
305 let value: ciborium::Value =
306 ciborium::de::from_reader(bytes).map_err(|_| CredentialError::Malformed)?;
307 let entries = match value {
308 ciborium::Value::Map(e) => e,
309 _ => return Err(CredentialError::Malformed),
310 };
311 let mut claims_bytes: Option<Vec<u8>> = None;
312 let mut signature_vec: Option<Vec<u8>> = None;
313 for (k, v) in entries {
314 let key = match k {
315 ciborium::Value::Text(s) => s,
316 _ => return Err(CredentialError::Malformed),
317 };
318 match (key.as_str(), v) {
319 (WIRE_CLAIMS_KEY, ciborium::Value::Bytes(b)) => claims_bytes = Some(b),
320 (WIRE_SIG_KEY, ciborium::Value::Bytes(b)) => signature_vec = Some(b),
321 _ => return Err(CredentialError::Malformed),
322 }
323 }
324 let claims_bytes = claims_bytes.ok_or(CredentialError::Malformed)?;
325 let signature_vec = signature_vec.ok_or(CredentialError::Malformed)?;
326 if signature_vec.len() != CREDENTIAL_SIG_LEN {
327 return Err(CredentialError::Malformed);
328 }
329 let mut signature = [0u8; CREDENTIAL_SIG_LEN];
330 signature.copy_from_slice(&signature_vec);
331 let credential = FederationCredential::from_claims_bytes(&claims_bytes)?;
332 Ok(Self {
333 credential,
334 claims_bytes,
335 signature,
336 })
337 }
338
339 pub fn to_header_value(&self) -> Result<String, CredentialError> {
345 let wire = self.to_wire_bytes()?;
346 Ok(format!("{CREDENTIAL_PREFIX}{}", B64.encode(wire)))
347 }
348
349 pub fn from_header_value(value: &str) -> Result<Self, CredentialError> {
356 let b64 = value
357 .strip_prefix(CREDENTIAL_PREFIX)
358 .ok_or_else(|| unsupported_or_malformed(value))?;
359 let wire = B64.decode(b64).map_err(|_| CredentialError::Malformed)?;
360 Self::from_wire_bytes(&wire)
361 }
362
363 pub fn verify_against(
377 &self,
378 issuer_pub: &VerifyingKey,
379 now_unix: i64,
380 ) -> Result<(), CredentialError> {
381 if self.credential.cred_version > CRED_VERSION {
382 return Err(CredentialError::UnsupportedVersion(
383 self.credential.cred_version,
384 ));
385 }
386 let sig = Signature::from_bytes(&self.signature);
387 issuer_pub
388 .verify(&self.claims_bytes, &sig)
389 .map_err(|_| CredentialError::BadSignature)?;
390 self.check_validity(now_unix)
391 }
392
393 fn check_validity(&self, now_unix: i64) -> Result<(), CredentialError> {
395 if now_unix < self.credential.not_before {
396 return Err(CredentialError::NotYetValid);
397 }
398 if now_unix > self.credential.not_after {
399 return Err(CredentialError::Expired);
400 }
401 Ok(())
402 }
403
404 pub fn load_from_path(path: &std::path::Path) -> std::io::Result<Option<Self>> {
413 let raw = match std::fs::read_to_string(path) {
414 Ok(s) => s,
415 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
416 Err(e) => return Err(e),
417 };
418 let cred = Self::from_header_value(raw.trim())
419 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
420 Ok(Some(cred))
421 }
422
423 pub fn load_from_env() -> std::io::Result<Option<Self>> {
429 match std::env::var(FED_CREDENTIAL_PATH_ENV) {
430 Ok(path) => Self::load_from_path(std::path::Path::new(&path)),
431 Err(_) => Ok(None),
432 }
433 }
434}
435
436fn unsupported_or_malformed(value: &str) -> CredentialError {
439 if let Some(rest) = value.strip_prefix('v') {
440 if let Some((digits, _)) = rest.split_once('=') {
441 if let Ok(v) = digits.parse::<u16>() {
442 return CredentialError::UnsupportedVersion(v);
443 }
444 }
445 }
446 CredentialError::Malformed
447}
448
449fn int_value(n: i64) -> ciborium::Value {
451 ciborium::Value::Integer(n.into())
452}
453
454fn take_text(
455 map: &mut BTreeMap<String, ciborium::Value>,
456 key: &str,
457) -> Result<String, CredentialError> {
458 match map.remove(key) {
459 Some(ciborium::Value::Text(s)) => Ok(s),
460 _ => Err(CredentialError::Malformed),
461 }
462}
463
464fn take_bytes(
465 map: &mut BTreeMap<String, ciborium::Value>,
466 key: &str,
467) -> Result<Vec<u8>, CredentialError> {
468 match map.remove(key) {
469 Some(ciborium::Value::Bytes(b)) => Ok(b),
470 _ => Err(CredentialError::Malformed),
471 }
472}
473
474fn take_int(
475 map: &mut BTreeMap<String, ciborium::Value>,
476 key: &str,
477) -> Result<i64, CredentialError> {
478 match map.remove(key) {
479 Some(ciborium::Value::Integer(i)) => {
480 i64::try_from(i128::from(i)).map_err(|_| CredentialError::Malformed)
481 }
482 _ => Err(CredentialError::Malformed),
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use ed25519_dalek::SigningKey;
490
491 fn ca_key(seed: u8) -> SigningKey {
492 SigningKey::from_bytes(&[seed; 32])
493 }
494
495 fn subject_key(seed: u8) -> SigningKey {
496 SigningKey::from_bytes(&[seed; 32])
497 }
498
499 fn sample(now: i64) -> FederationCredential {
500 let subj = subject_key(7);
501 FederationCredential {
502 subject_agent_id: "region/nyc/node-7".to_string(),
503 subject_pubkey: subj.verifying_key().to_bytes(),
504 issuer_id: "trust-domain-root".to_string(),
505 trust_domain: "fleet.example".to_string(),
506 not_before: now - 10,
507 not_after: now + 3600,
508 cred_version: CRED_VERSION,
509 }
510 }
511
512 #[test]
513 fn sign_then_verify_round_trips() {
514 let ca = ca_key(1);
515 let now = 1_900_000_000;
516 let signed = sample(now).sign(&ca).expect("sign");
517 signed
518 .verify_against(&ca.verifying_key(), now)
519 .expect("valid credential verifies");
520 }
521
522 #[test]
523 fn wire_round_trip_preserves_claims_and_verifies() {
524 let ca = ca_key(2);
525 let now = 1_900_000_000;
526 let signed = sample(now).sign(&ca).expect("sign");
527 let wire = signed.to_wire_bytes().expect("wire encode");
528 let parsed = SignedCredential::from_wire_bytes(&wire).expect("wire decode");
529 assert_eq!(parsed.credential(), signed.credential());
530 parsed
531 .verify_against(&ca.verifying_key(), now)
532 .expect("re-parsed credential still verifies");
533 }
534
535 #[test]
536 fn header_value_round_trip() {
537 let ca = ca_key(3);
538 let now = 1_900_000_000;
539 let signed = sample(now).sign(&ca).expect("sign");
540 let header = signed.to_header_value().expect("header encode");
541 assert!(header.starts_with(CREDENTIAL_PREFIX));
542 let parsed = SignedCredential::from_header_value(&header).expect("header decode");
543 parsed
544 .verify_against(&ca.verifying_key(), now)
545 .expect("verifies");
546 }
547
548 #[test]
549 fn wrong_issuer_key_is_rejected() {
550 let ca = ca_key(4);
551 let attacker = ca_key(5);
552 let now = 1_900_000_000;
553 let signed = sample(now).sign(&ca).expect("sign");
554 assert_eq!(
555 signed.verify_against(&attacker.verifying_key(), now),
556 Err(CredentialError::BadSignature)
557 );
558 }
559
560 #[test]
561 fn tampered_claims_break_signature() {
562 let ca = ca_key(6);
563 let now = 1_900_000_000;
564 let signed = sample(now).sign(&ca).expect("sign");
565 let mut wire = signed.to_wire_bytes().expect("wire");
566 wire[10] ^= 0xFF;
568 match SignedCredential::from_wire_bytes(&wire) {
570 Ok(parsed) => assert_eq!(
571 parsed.verify_against(&ca.verifying_key(), now),
572 Err(CredentialError::BadSignature)
573 ),
574 Err(e) => assert_eq!(e, CredentialError::Malformed),
575 }
576 }
577
578 #[test]
579 fn not_yet_valid_and_expired_windows() {
580 let ca = ca_key(7);
581 let now = 1_900_000_000;
582 let signed = sample(now).sign(&ca).expect("sign");
583 assert_eq!(
584 signed.verify_against(&ca.verifying_key(), now - 100),
585 Err(CredentialError::NotYetValid)
586 );
587 assert_eq!(
588 signed.verify_against(&ca.verifying_key(), now + 100_000),
589 Err(CredentialError::Expired)
590 );
591 }
592
593 #[test]
594 fn unsupported_future_version_is_refused() {
595 let ca = ca_key(8);
596 let now = 1_900_000_000;
597 let mut cred = sample(now);
598 cred.cred_version = CRED_VERSION + 1;
599 let signed = cred.sign(&ca).expect("sign");
600 assert_eq!(
601 signed.verify_against(&ca.verifying_key(), now),
602 Err(CredentialError::UnsupportedVersion(CRED_VERSION + 1))
603 );
604 }
605
606 #[test]
607 fn subject_verifying_key_matches_issued_subject() {
608 let now = 1_900_000_000;
609 let subj = subject_key(7);
610 let cred = sample(now);
611 assert_eq!(
612 cred.subject_verifying_key().expect("valid point"),
613 subj.verifying_key()
614 );
615 }
616
617 #[test]
618 fn malformed_header_prefix_is_malformed() {
619 assert_eq!(
620 SignedCredential::from_header_value("garbage").unwrap_err(),
621 CredentialError::Malformed
622 );
623 }
624
625 #[test]
626 fn future_header_version_marker_is_unsupported_version() {
627 assert_eq!(
628 SignedCredential::from_header_value("v9=AAAA").unwrap_err(),
629 CredentialError::UnsupportedVersion(9)
630 );
631 }
632
633 #[test]
634 fn truncated_wire_is_malformed() {
635 assert_eq!(
636 SignedCredential::from_wire_bytes(&[0x01, 0x02, 0x03]).unwrap_err(),
637 CredentialError::Malformed
638 );
639 }
640
641 fn loader_scratch_dir() -> std::path::PathBuf {
644 let mut dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
645 dir.push(".local-runs");
646 dir.push("test-tmp");
647 std::fs::create_dir_all(&dir).expect("create scratch dir");
648 dir
649 }
650
651 fn unique_cred_path(label: &str) -> std::path::PathBuf {
652 let nanos = std::time::SystemTime::now()
653 .duration_since(std::time::UNIX_EPOCH)
654 .map(|d| d.as_nanos())
655 .unwrap_or(0);
656 loader_scratch_dir().join(format!("cred-{label}-{nanos}.cred"))
657 }
658
659 #[test]
660 fn load_from_path_round_trips_a_written_credential() {
661 let ca = ca_key(11);
662 let now = 1_900_000_000;
663 let signed = sample(now).sign(&ca).expect("sign");
664 let header = signed.to_header_value().expect("encode");
665 let path = unique_cred_path("roundtrip");
666 std::fs::write(&path, format!("{header}\n")).expect("write cred file");
667
668 let loaded = SignedCredential::load_from_path(&path)
669 .expect("io ok")
670 .expect("present");
671 assert_eq!(loaded.credential(), signed.credential());
672 loaded
673 .verify_against(&ca.verifying_key(), now)
674 .expect("loaded credential still verifies");
675 let _ = std::fs::remove_file(&path);
676 }
677
678 #[test]
679 fn load_from_path_missing_file_is_none() {
680 let path = unique_cred_path("missing");
681 assert!(
682 SignedCredential::load_from_path(&path)
683 .expect("missing file is not an error")
684 .is_none()
685 );
686 }
687
688 #[test]
689 fn load_from_path_malformed_content_is_invalid_data() {
690 let path = unique_cred_path("garbage");
691 std::fs::write(&path, "not-a-credential").expect("write");
692 let err = SignedCredential::load_from_path(&path).expect_err("malformed must error");
693 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
694 let _ = std::fs::remove_file(&path);
695 }
696
697 #[test]
698 fn load_from_env_unset_is_none() {
699 unsafe {
701 std::env::remove_var(FED_CREDENTIAL_PATH_ENV);
702 }
703 assert!(
704 SignedCredential::load_from_env()
705 .expect("unset env is not an error")
706 .is_none()
707 );
708 }
709}