use std::fs;
use std::io;
use std::path::{Path, PathBuf};
use anyhow::{Context, Result, anyhow, bail};
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use ed25519_dalek::{SigningKey, VerifyingKey};
use crate::validate;
const PUB_SUFFIX: &str = ".pub";
const PRIV_SUFFIX: &str = ".priv";
const PUBLIC_KEY_LEN: usize = ed25519_dalek::PUBLIC_KEY_LENGTH;
const SECRET_KEY_LEN: usize = ed25519_dalek::SECRET_KEY_LENGTH;
#[derive(Debug, Clone)]
pub struct AgentKeypair {
pub agent_id: String,
pub public: VerifyingKey,
pub private: Option<SigningKey>,
}
impl AgentKeypair {
#[must_use]
pub fn can_sign(&self) -> bool {
self.private.is_some()
}
#[must_use]
pub fn public_base64(&self) -> String {
URL_SAFE_NO_PAD.encode(self.public.to_bytes())
}
}
#[cfg(test)]
pub(crate) fn key_dir_env_lock() -> &'static std::sync::Mutex<()> {
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
LOCK.get_or_init(|| std::sync::Mutex::new(()))
}
pub const KEY_DIR_ENV: &str = "AI_MEMORY_KEY_DIR";
#[must_use]
pub fn key_dir_env_override() -> Option<PathBuf> {
match std::env::var(KEY_DIR_ENV) {
Ok(v) if !v.is_empty() => Some(PathBuf::from(v)),
_ => None,
}
}
pub fn default_key_dir() -> Result<PathBuf> {
if let Some(p) = key_dir_env_override() {
return Ok(p);
}
let base = dirs::config_dir()
.ok_or_else(|| anyhow!("OS did not advertise a config directory for key storage"))?;
Ok(base.join("ai-memory").join("keys"))
}
pub const DAEMON_KEYPAIR_LABEL: &str = "daemon";
pub fn generate(agent_id: &str) -> Result<AgentKeypair> {
validate::validate_agent_id_shape(agent_id)?;
let mut csprng = rand_core::OsRng;
let private = SigningKey::generate(&mut csprng);
let public = private.verifying_key();
Ok(AgentKeypair {
agent_id: agent_id.to_string(),
public,
private: Some(private),
})
}
pub fn save(keypair: &AgentKeypair, dir: &Path) -> Result<()> {
let private = keypair.private.as_ref().ok_or_else(|| {
anyhow!(
"AgentKeypair for {} has no private key to save",
keypair.agent_id
)
})?;
let pub_path = dir.join(format!("{}{PUB_SUFFIX}", keypair.agent_id));
let priv_path = dir.join(format!("{}{PRIV_SUFFIX}", keypair.agent_id));
ensure_parent(&pub_path)?;
ensure_parent(&priv_path)?;
write_with_mode(&pub_path, &keypair.public.to_bytes(), 0o644)
.with_context(|| format!("writing public key {}", pub_path.display()))?;
write_with_mode(&priv_path, &private.to_bytes(), 0o600)
.with_context(|| format!("writing private key {}", priv_path.display()))?;
Ok(())
}
pub fn save_public_only(keypair: &AgentKeypair, dir: &Path) -> Result<()> {
let pub_path = dir.join(format!("{}{PUB_SUFFIX}", keypair.agent_id));
ensure_parent(&pub_path)?;
write_with_mode(&pub_path, &keypair.public.to_bytes(), 0o644)
.with_context(|| format!("writing public key {}", pub_path.display()))?;
Ok(())
}
pub fn load(agent_id: &str, dir: &Path) -> Result<AgentKeypair> {
validate::validate_agent_id_shape(agent_id)?;
let pub_path = dir.join(format!("{agent_id}{PUB_SUFFIX}"));
let priv_path = dir.join(format!("{agent_id}{PRIV_SUFFIX}"));
let pub_bytes = fs::read(&pub_path)
.with_context(|| format!("reading public key {}", pub_path.display()))?;
if pub_bytes.len() != PUBLIC_KEY_LEN {
bail!(
"public key {} has {} bytes, expected {PUBLIC_KEY_LEN}",
pub_path.display(),
pub_bytes.len()
);
}
let mut pub_arr = [0u8; PUBLIC_KEY_LEN];
pub_arr.copy_from_slice(&pub_bytes);
let public = VerifyingKey::from_bytes(&pub_arr)
.with_context(|| format!("decoding public key {}", pub_path.display()))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
match fs::metadata(&priv_path) {
Ok(meta) => {
let mode = meta.permissions().mode() & 0o777;
if mode & 0o077 != 0 {
bail!(
"private key {} has insecure mode {:o}; refusing to load. \
Restore with: chmod 0600 {}",
priv_path.display(),
mode,
priv_path.display()
);
}
}
Err(e) if e.kind() == io::ErrorKind::NotFound => {
}
Err(e) => {
return Err(anyhow!(e))
.with_context(|| format!("stat private key {}", priv_path.display()));
}
}
}
let private = match fs::read(&priv_path) {
Ok(mut priv_bytes) => {
if priv_bytes.len() != SECRET_KEY_LEN {
let actual_len = priv_bytes.len();
use zeroize::Zeroize;
priv_bytes.zeroize();
bail!(
"private key {} has {} bytes, expected {SECRET_KEY_LEN}",
priv_path.display(),
actual_len
);
}
let mut priv_arr = [0u8; SECRET_KEY_LEN];
priv_arr.copy_from_slice(&priv_bytes);
let signing = SigningKey::from_bytes(&priv_arr);
{
use zeroize::Zeroize;
priv_bytes.zeroize();
priv_arr.zeroize();
}
if signing.verifying_key().to_bytes() != public.to_bytes() {
bail!(
"private key {} does not match public key {}",
priv_path.display(),
pub_path.display()
);
}
Some(signing)
}
Err(e) if e.kind() == io::ErrorKind::NotFound => None,
Err(e) => {
return Err(anyhow!(e))
.with_context(|| format!("reading private key {}", priv_path.display()));
}
};
Ok(AgentKeypair {
agent_id: agent_id.to_string(),
public,
private,
})
}
pub fn list(dir: &Path) -> Result<Vec<AgentKeypair>> {
if !dir.exists() {
return Ok(Vec::new());
}
let mut out = Vec::new();
for entry in
fs::read_dir(dir).with_context(|| format!("reading key directory {}", dir.display()))?
{
let entry = entry?;
let name = entry.file_name();
let Some(name_str) = name.to_str() else {
continue;
};
let Some(stem) = name_str.strip_suffix(PUB_SUFFIX) else {
continue;
};
if validate::validate_agent_id_shape(stem).is_err() {
continue;
}
let path = entry.path();
let pub_bytes = match fs::read(&path) {
Ok(b) => b,
Err(_) => continue,
};
if pub_bytes.len() != PUBLIC_KEY_LEN {
continue;
}
let mut pub_arr = [0u8; PUBLIC_KEY_LEN];
pub_arr.copy_from_slice(&pub_bytes);
let Ok(public) = VerifyingKey::from_bytes(&pub_arr) else {
continue;
};
out.push(AgentKeypair {
agent_id: stem.to_string(),
public,
private: None,
});
}
out.sort_by(|a, b| a.agent_id.cmp(&b.agent_id));
Ok(out)
}
pub fn decode_public_base64(s: &str) -> Result<VerifyingKey> {
let trimmed = s.trim();
let bytes = URL_SAFE_NO_PAD
.decode(trimmed)
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(trimmed))
.with_context(|| "decoding base64 public key".to_string())?;
if bytes.len() != PUBLIC_KEY_LEN {
bail!(
"decoded public key has {} bytes, expected {PUBLIC_KEY_LEN}",
bytes.len()
);
}
let mut arr = [0u8; PUBLIC_KEY_LEN];
arr.copy_from_slice(&bytes);
VerifyingKey::from_bytes(&arr).with_context(|| "decoding public key bytes".to_string())
}
pub fn read_raw_key_file(path: &Path) -> Result<[u8; SECRET_KEY_LEN]> {
let bytes = fs::read(path).with_context(|| format!("reading key file {}", path.display()))?;
if bytes.len() != SECRET_KEY_LEN {
bail!(
"key file {} has {} bytes, expected {SECRET_KEY_LEN}",
path.display(),
bytes.len()
);
}
let mut arr = [0u8; SECRET_KEY_LEN];
arr.copy_from_slice(&bytes);
Ok(arr)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EnsureOutcome {
AlreadyExists {
pub_path: PathBuf,
},
Generated {
pub_path: PathBuf,
},
SkippedDisabled,
}
pub fn ensure_keypair(agent_id: &str, dir: &Path, disabled: bool) -> Result<EnsureOutcome> {
if disabled {
tracing::info!(
"identity: auto-gen disabled by config; link signing will be skipped at boot"
);
return Ok(EnsureOutcome::SkippedDisabled);
}
validate::validate_agent_id_shape(agent_id)?;
let pub_path = dir.join(format!("{agent_id}{PUB_SUFFIX}"));
if pub_path.exists() {
return Ok(EnsureOutcome::AlreadyExists { pub_path });
}
let kp = generate(agent_id)?;
save(&kp, dir)?;
tracing::info!(
"auto-generated identity keypair at {} — consider backing up",
pub_path.display()
);
Ok(EnsureOutcome::Generated { pub_path })
}
fn ensure_parent(path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("creating key directory {}", parent.display()))?;
}
Ok(())
}
#[cfg(unix)]
fn write_with_mode(path: &Path, bytes: &[u8], mode: u32) -> io::Result<()> {
use std::os::unix::fs::OpenOptionsExt;
let _ = fs::remove_file(path);
let mut file = fs::OpenOptions::new()
.write(true)
.create_new(true)
.mode(mode)
.open(path)?;
use std::io::Write;
file.write_all(bytes)?;
file.sync_all()?;
Ok(())
}
#[cfg(not(unix))]
fn write_with_mode(path: &Path, bytes: &[u8], _mode: u32) -> io::Result<()> {
static NON_UNIX_KEY_PERM_WARN_ONCE: std::sync::Once = std::sync::Once::new();
NON_UNIX_KEY_PERM_WARN_ONCE.call_once(|| {
tracing::warn!(
target: "identity::keypair",
"writing key material on a non-Unix platform: restrictive file-mode \
bits are not applied, so the key file inherits the parent directory \
ACL. Restrict the key directory's ACL manually, or use hardware-backed \
key storage, to protect private keys."
);
});
fs::write(path, bytes)
}
#[cfg(test)]
mod tests {
use super::*;
use ed25519_dalek::Signer;
use ed25519_dalek::Verifier;
use tempfile::TempDir;
fn tmp_dir() -> TempDir {
TempDir::new().expect("tempdir")
}
#[test]
fn generate_yields_signing_keypair() {
let kp = generate("alice").expect("generate");
assert_eq!(kp.agent_id, "alice");
assert!(
kp.can_sign(),
"freshly generated keypair must have private key"
);
let priv_pub = kp.private.as_ref().unwrap().verifying_key().to_bytes();
assert_eq!(priv_pub, kp.public.to_bytes());
}
#[test]
fn generate_rejects_invalid_agent_id() {
assert!(generate("has space").is_err());
assert!(generate("has\0null").is_err());
}
#[test]
fn round_trip_save_then_load() {
let dir = tmp_dir();
let kp = generate("alice").unwrap();
save(&kp, dir.path()).expect("save");
let loaded = load("alice", dir.path()).expect("load");
assert_eq!(loaded.agent_id, "alice");
assert_eq!(loaded.public.to_bytes(), kp.public.to_bytes());
assert!(loaded.can_sign(), "private key should round-trip");
let msg = b"hello world";
let sig = loaded.private.as_ref().unwrap().sign(msg);
assert!(kp.public.verify(msg, &sig).is_ok());
}
#[test]
fn round_trip_save_then_load_slashed_agent_id() {
let dir = tmp_dir();
let agent_id = "hive-1461/nyc3/hive-peer-nyc3-01";
let kp = generate(agent_id).expect("generate slashed id");
save(&kp, dir.path()).expect("save slashed id must create nested parents");
let pub_path = dir.path().join(format!("{agent_id}.pub"));
let priv_path = dir.path().join(format!("{agent_id}.priv"));
assert!(pub_path.exists(), "nested .pub must exist at {pub_path:?}");
assert!(
priv_path.exists(),
"nested .priv must exist at {priv_path:?}"
);
let loaded = load(agent_id, dir.path()).expect("load slashed id");
assert_eq!(loaded.agent_id, agent_id);
assert_eq!(loaded.public.to_bytes(), kp.public.to_bytes());
assert!(loaded.can_sign(), "private key should round-trip");
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let pub_mode = fs::metadata(&pub_path).unwrap().permissions().mode() & 0o777;
let priv_mode = fs::metadata(&priv_path).unwrap().permissions().mode() & 0o777;
assert_eq!(pub_mode, 0o644, "nested public key must be 0644");
assert_eq!(priv_mode, 0o600, "nested private key must be 0600");
}
}
#[test]
fn save_public_only_slashed_agent_id_creates_nested_parent() {
let dir = tmp_dir();
let agent_id = "hive-1461/sfo2/hive-peer-sfo2-01";
let kp = generate(agent_id).expect("generate");
save_public_only(&kp, dir.path()).expect("save_public_only nested");
let pub_path = dir.path().join(format!("{agent_id}.pub"));
assert!(pub_path.exists(), "nested .pub must exist at {pub_path:?}");
let loaded = load(agent_id, dir.path()).expect("load");
assert!(!loaded.can_sign(), "public-only save must yield no private");
assert_eq!(loaded.public.to_bytes(), kp.public.to_bytes());
}
#[test]
fn load_without_private_yields_public_only() {
let dir = tmp_dir();
let kp = generate("alice").unwrap();
save(&kp, dir.path()).expect("save");
let priv_path = dir.path().join("alice.priv");
fs::remove_file(&priv_path).expect("rm priv");
let loaded = load("alice", dir.path()).expect("load");
assert!(!loaded.can_sign(), "missing .priv must yield None private");
assert_eq!(loaded.public.to_bytes(), kp.public.to_bytes());
}
#[cfg(unix)]
#[test]
fn save_writes_unix_mode_0600_and_0644() {
use std::os::unix::fs::PermissionsExt;
let dir = tmp_dir();
let kp = generate("alice").unwrap();
save(&kp, dir.path()).expect("save");
let pub_meta = fs::metadata(dir.path().join("alice.pub")).unwrap();
let priv_meta = fs::metadata(dir.path().join("alice.priv")).unwrap();
let pub_mode = pub_meta.permissions().mode() & 0o777;
let priv_mode = priv_meta.permissions().mode() & 0o777;
assert_eq!(
priv_mode, 0o600,
"private key must be 0600, got {priv_mode:o}"
);
assert_eq!(pub_mode, 0o644, "public key must be 0644, got {pub_mode:o}");
}
#[test]
fn list_enumerates_saved_keypairs() {
let dir = tmp_dir();
let alice = generate("alice").unwrap();
let bob = generate("bob").unwrap();
save(&alice, dir.path()).unwrap();
save(&bob, dir.path()).unwrap();
let listed = list(dir.path()).expect("list");
assert_eq!(listed.len(), 2);
assert_eq!(listed[0].agent_id, "alice");
assert_eq!(listed[1].agent_id, "bob");
for kp in &listed {
assert!(!kp.can_sign(), "list must not load private keys");
}
assert_eq!(listed[0].public.to_bytes(), alice.public.to_bytes());
assert_eq!(listed[1].public.to_bytes(), bob.public.to_bytes());
}
#[test]
fn list_on_missing_dir_returns_empty() {
let dir = tmp_dir();
let nonexistent = dir.path().join("does-not-exist");
let listed = list(&nonexistent).expect("list");
assert!(listed.is_empty());
}
#[test]
fn list_skips_unrelated_files() {
let dir = tmp_dir();
let kp = generate("alice").unwrap();
save(&kp, dir.path()).unwrap();
fs::write(dir.path().join("README.txt"), b"ignore me").unwrap();
fs::write(dir.path().join("not-a-key.pub"), b"too short").unwrap();
let listed = list(dir.path()).expect("list");
assert_eq!(listed.len(), 1);
assert_eq!(listed[0].agent_id, "alice");
}
#[test]
fn load_rejects_truncated_public_key() {
let dir = tmp_dir();
fs::write(dir.path().join("alice.pub"), b"short").unwrap();
let err = load("alice", dir.path()).unwrap_err();
let msg = format!("{err:#}");
assert!(msg.contains("expected 32"), "got: {msg}");
}
#[test]
fn load_rejects_priv_pub_mismatch() {
let dir = tmp_dir();
let alice = generate("alice").unwrap();
let bob = generate("alice").unwrap();
save(&alice, dir.path()).unwrap();
fs::remove_file(dir.path().join("alice.priv")).unwrap();
let bob_priv = bob.private.as_ref().unwrap().to_bytes();
write_with_mode(&dir.path().join("alice.priv"), &bob_priv, 0o600).unwrap();
let err = load("alice", dir.path()).unwrap_err();
let msg = format!("{err:#}");
assert!(msg.contains("does not match"), "got: {msg}");
}
#[test]
fn export_pub_round_trips_through_base64() {
let kp = generate("alice").unwrap();
let b64 = kp.public_base64();
let decoded = decode_public_base64(&b64).expect("decode");
assert_eq!(decoded.to_bytes(), kp.public.to_bytes());
}
#[test]
fn decode_public_base64_accepts_padded_form() {
let kp = generate("alice").unwrap();
let padded = base64::engine::general_purpose::STANDARD.encode(kp.public.to_bytes());
let decoded = decode_public_base64(&padded).expect("decode padded");
assert_eq!(decoded.to_bytes(), kp.public.to_bytes());
}
#[test]
fn read_raw_key_file_validates_length() {
let dir = tmp_dir();
let p = dir.path().join("short.bin");
fs::write(&p, b"short").unwrap();
let err = read_raw_key_file(&p).unwrap_err();
let msg = format!("{err:#}");
assert!(msg.contains("expected 32"), "got: {msg}");
}
#[test]
fn save_refuses_public_only_keypair() {
let dir = tmp_dir();
let kp = AgentKeypair {
agent_id: "alice".to_string(),
public: generate("alice").unwrap().public,
private: None,
};
let err = save(&kp, dir.path()).unwrap_err();
let msg = format!("{err:#}");
assert!(msg.contains("no private key to save"), "got: {msg}");
}
#[test]
fn save_public_only_writes_pub_only() {
let dir = tmp_dir();
let kp = generate("alice").unwrap();
let pub_only = AgentKeypair {
agent_id: "alice".to_string(),
public: kp.public,
private: None,
};
save_public_only(&pub_only, dir.path()).expect("save_public_only");
assert!(dir.path().join("alice.pub").exists());
assert!(!dir.path().join("alice.priv").exists());
let loaded = load("alice", dir.path()).expect("load");
assert!(!loaded.can_sign());
}
#[test]
fn default_key_dir_ends_in_ai_memory_keys() {
let _g = key_dir_env_lock().lock().unwrap_or_else(|e| e.into_inner());
unsafe {
std::env::remove_var("AI_MEMORY_KEY_DIR");
}
let p = default_key_dir().expect("default dir");
let s = p.to_string_lossy();
assert!(s.ends_with("ai-memory/keys") || s.ends_with("ai-memory\\keys"));
}
fn key_dir_env_lock() -> &'static std::sync::Mutex<()> {
super::key_dir_env_lock()
}
#[test]
fn ensure_keypair_generates_when_missing() {
let dir = tmp_dir();
let outcome = ensure_keypair("alice", dir.path(), false).expect("ensure");
match outcome {
EnsureOutcome::Generated { pub_path } => {
assert!(pub_path.exists(), "pub key must be on disk");
let priv_path = dir.path().join("alice.priv");
assert!(priv_path.exists(), "priv key must be on disk");
}
other => panic!("expected Generated, got {other:?}"),
}
}
#[test]
fn ensure_keypair_idempotent_on_second_call() {
let dir = tmp_dir();
let first = ensure_keypair("alice", dir.path(), false).expect("first");
let pub_path = dir.path().join("alice.pub");
let priv_path = dir.path().join("alice.priv");
let pub_before = fs::read(&pub_path).unwrap();
let priv_before = fs::read(&priv_path).unwrap();
let second = ensure_keypair("alice", dir.path(), false).expect("second");
match second {
EnsureOutcome::AlreadyExists { pub_path: observed } => {
assert_eq!(observed, pub_path);
}
other => panic!("expected AlreadyExists on second call, got {other:?}"),
}
let pub_after = fs::read(&pub_path).unwrap();
let priv_after = fs::read(&priv_path).unwrap();
assert_eq!(pub_before, pub_after);
assert_eq!(priv_before, priv_after);
assert!(matches!(first, EnsureOutcome::Generated { .. }));
}
#[test]
fn ensure_keypair_respects_disabled_flag() {
let dir = tmp_dir();
let outcome = ensure_keypair("alice", dir.path(), true).expect("ensure");
assert_eq!(outcome, EnsureOutcome::SkippedDisabled);
assert!(!dir.path().join("alice.pub").exists());
assert!(!dir.path().join("alice.priv").exists());
}
#[test]
fn ensure_keypair_validates_agent_id() {
let dir = tmp_dir();
let res = ensure_keypair("has space", dir.path(), false);
assert!(res.is_err(), "must reject invalid agent_id");
}
#[test]
fn save_returns_context_when_dir_is_a_file() {
let dir = tmp_dir();
let blocker = dir.path().join("blocker");
fs::write(&blocker, b"file").unwrap();
let kp = generate("alice").unwrap();
let sub = blocker.join("sub");
let err = save(&kp, &sub).unwrap_err();
let msg = format!("{err:#}");
assert!(
msg.contains("creating key directory"),
"expected wrapped context, got: {msg}"
);
}
#[test]
fn save_public_only_returns_context_when_dir_is_a_file() {
let dir = tmp_dir();
let blocker = dir.path().join("blocker");
fs::write(&blocker, b"file").unwrap();
let kp = generate("alice").unwrap();
let sub = blocker.join("sub");
let err = save_public_only(&kp, &sub).unwrap_err();
let msg = format!("{err:#}");
assert!(
msg.contains("creating key directory"),
"expected wrapped context, got: {msg}"
);
}
#[test]
fn load_returns_context_when_pub_file_missing() {
let dir = tmp_dir();
let err = load("alice", dir.path()).unwrap_err();
let msg = format!("{err:#}");
assert!(msg.contains("reading public key"), "got: {msg}");
}
#[test]
fn load_returns_decode_context_for_corrupt_public_key() {
let dir = tmp_dir();
let bytes = [0xFFu8; PUBLIC_KEY_LEN];
fs::write(dir.path().join("alice.pub"), bytes).unwrap();
let res = load("alice", dir.path());
if let Err(err) = res {
let msg = format!("{err:#}");
assert!(
msg.contains("decoding public key") || msg.contains("expected"),
"got: {msg}"
);
} else {
}
}
#[test]
fn load_with_truncated_priv_returns_length_error() {
let dir = tmp_dir();
let kp = generate("alice").unwrap();
save(&kp, dir.path()).unwrap();
fs::write(dir.path().join("alice.priv"), b"shortie!").unwrap();
let err = load("alice", dir.path()).unwrap_err();
let msg = format!("{err:#}");
assert!(msg.contains("expected 32"), "got: {msg}");
}
#[test]
fn list_returns_context_on_unreadable_directory() {
let dir = tmp_dir();
let file = dir.path().join("not-a-dir");
fs::write(&file, b"x").unwrap();
let err = list(&file).unwrap_err();
let msg = format!("{err:#}");
assert!(msg.contains("reading key directory"), "got: {msg}");
}
#[test]
fn decode_public_base64_rejects_garbage() {
let err = decode_public_base64("not-valid-base64!!!").unwrap_err();
let msg = format!("{err:#}");
assert!(msg.contains("decoding base64"), "got: {msg}");
}
#[test]
fn decode_public_base64_rejects_wrong_length() {
let short = URL_SAFE_NO_PAD.encode([0u8; 8]);
let err = decode_public_base64(&short).unwrap_err();
let msg = format!("{err:#}");
assert!(msg.contains("expected 32"), "got: {msg}");
}
#[test]
fn read_raw_key_file_returns_context_when_path_missing() {
let dir = tmp_dir();
let missing = dir.path().join("nope.bin");
let err = read_raw_key_file(&missing).unwrap_err();
let msg = format!("{err:#}");
assert!(msg.contains("reading key file"), "got: {msg}");
}
#[test]
fn ensure_keypair_rejects_invalid_agent_id_when_enabled() {
let dir = tmp_dir();
let err = ensure_keypair("has space", dir.path(), false).unwrap_err();
let msg = format!("{err:#}");
assert!(msg.contains("invalid character"), "got: {msg}");
}
#[test]
fn list_skips_pub_file_with_invalid_agent_id_stem() {
let dir = tmp_dir();
let kp = generate("alice").unwrap();
save(&kp, dir.path()).unwrap();
fs::write(dir.path().join("has space.pub"), [0u8; PUBLIC_KEY_LEN]).unwrap();
let listed = list(dir.path()).expect("list");
assert_eq!(listed.len(), 1);
assert_eq!(listed[0].agent_id, "alice");
}
#[cfg(unix)]
#[test]
fn list_skips_unreadable_pub_file_continues_iteration() {
use std::os::unix::fs::PermissionsExt;
let dir = tmp_dir();
let alice = generate("alice").unwrap();
save(&alice, dir.path()).unwrap();
let unreadable = dir.path().join("bob.pub");
fs::write(&unreadable, [0u8; PUBLIC_KEY_LEN]).unwrap();
fs::set_permissions(&unreadable, fs::Permissions::from_mode(0o000)).unwrap();
let listed = list(dir.path()).expect("list");
fs::set_permissions(&unreadable, fs::Permissions::from_mode(0o644)).unwrap();
assert!(listed.iter().any(|k| k.agent_id == "alice"));
}
#[test]
fn list_skips_pub_file_with_invalid_curve_point() {
let dir = tmp_dir();
let alice = generate("alice").unwrap();
save(&alice, dir.path()).unwrap();
let mut bogus: Option<[u8; PUBLIC_KEY_LEN]> = None;
for seed in 0u8..=255 {
let mut bytes = [seed; PUBLIC_KEY_LEN];
bytes[31] = 0xFF;
if VerifyingKey::from_bytes(&bytes).is_err() {
bogus = Some(bytes);
break;
}
}
if let Some(b) = bogus {
fs::write(dir.path().join("bogus.pub"), b).unwrap();
let listed = list(dir.path()).expect("list");
assert!(
listed.iter().any(|k| k.agent_id == "alice"),
"alice must survive a sibling invalid-curve-point .pub file"
);
assert!(
!listed.iter().any(|k| k.agent_id == "bogus"),
"bogus.pub with invalid curve point must be filtered out"
);
}
}
#[cfg(unix)]
#[test]
fn load_propagates_non_notfound_io_error_on_private_key() {
use std::os::unix::fs::PermissionsExt;
let dir = tmp_dir();
let kp = generate("alice").unwrap();
save(&kp, dir.path()).unwrap();
let priv_path = dir.path().join("alice.priv");
fs::set_permissions(&priv_path, fs::Permissions::from_mode(0o000)).unwrap();
let res = load("alice", dir.path());
fs::set_permissions(&priv_path, fs::Permissions::from_mode(0o600)).unwrap();
if let Err(err) = res {
let msg = format!("{err:#}");
assert!(msg.contains("reading private key"), "got: {msg}");
}
}
#[cfg(unix)]
#[test]
fn ensure_keypair_save_failure_propagates_context() {
let dir = tmp_dir();
let blocker = dir.path().join("blocker");
fs::write(&blocker, b"file").unwrap();
let sub = blocker.join("sub");
let res = ensure_keypair("alice", &sub, false);
assert!(res.is_err(), "save under a file-blocked dir must fail");
}
#[test]
fn default_key_dir_honours_env_override() {
let _g = key_dir_env_lock().lock().unwrap_or_else(|e| e.into_inner());
let override_path = std::env::temp_dir().join("ai-memory-key-dir-override-probe");
unsafe {
std::env::set_var("AI_MEMORY_KEY_DIR", &override_path);
}
let p = default_key_dir().expect("default dir");
assert_eq!(p, override_path);
unsafe {
std::env::remove_var("AI_MEMORY_KEY_DIR");
}
}
#[cfg(unix)]
#[test]
fn test_keypair_load_refuses_world_readable_priv() {
use std::os::unix::fs::PermissionsExt;
let dir = tmp_dir();
let kp = generate("alice").unwrap();
save(&kp, dir.path()).unwrap();
let priv_path = dir.path().join("alice.priv");
fs::set_permissions(&priv_path, fs::Permissions::from_mode(0o777)).unwrap();
let err = load("alice", dir.path()).unwrap_err();
fs::set_permissions(&priv_path, fs::Permissions::from_mode(0o600)).unwrap();
let msg = format!("{err:#}");
assert!(
msg.contains("insecure mode"),
"error must name the failure mode, got: {msg}"
);
assert!(
msg.contains("chmod 0600"),
"error must include the fix invocation, got: {msg}"
);
}
#[cfg(unix)]
#[test]
fn test_keypair_load_refuses_group_readable_priv() {
use std::os::unix::fs::PermissionsExt;
let dir = tmp_dir();
let kp = generate("alice").unwrap();
save(&kp, dir.path()).unwrap();
let priv_path = dir.path().join("alice.priv");
fs::set_permissions(&priv_path, fs::Permissions::from_mode(0o640)).unwrap();
let err = load("alice", dir.path()).unwrap_err();
fs::set_permissions(&priv_path, fs::Permissions::from_mode(0o600)).unwrap();
let msg = format!("{err:#}");
assert!(msg.contains("insecure mode"), "got: {msg}");
}
#[cfg(unix)]
#[test]
fn test_keypair_load_accepts_0600() {
use std::os::unix::fs::PermissionsExt;
let dir = tmp_dir();
let kp = generate("alice").unwrap();
save(&kp, dir.path()).unwrap();
let priv_path = dir.path().join("alice.priv");
let mode = fs::metadata(&priv_path).unwrap().permissions().mode() & 0o777;
assert_eq!(mode, 0o600, "save must write 0600, got {mode:o}");
let loaded = load("alice", dir.path()).expect("0600 must load");
assert!(loaded.can_sign(), "0600 mode must yield a signing keypair");
}
#[cfg(unix)]
#[test]
fn test_keypair_load_missing_priv_skips_mode_check() {
let dir = tmp_dir();
let kp = generate("alice").unwrap();
save(&kp, dir.path()).unwrap();
fs::remove_file(dir.path().join("alice.priv")).unwrap();
let loaded = load("alice", dir.path()).expect("public-only load must succeed");
assert!(!loaded.can_sign());
}
}