use std::{collections::BTreeMap, path::PathBuf};
use anyhow::{Context, Result};
use base64::{engine::general_purpose::STANDARD, Engine as _};
use chrono::{DateTime, Utc};
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use rand::{rngs::OsRng, RngCore};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use zeroize::Zeroize;
use crate::{fs_util, logs};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct AgentState {
#[serde(default)]
pub projects: BTreeMap<String, Vec<AgentIdentity>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct AgentIdentity {
pub agent_name: String,
pub agent_key_id: String,
pub public_key: String,
pub private_seed: String,
pub created_at: DateTime<Utc>,
pub last_used: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AgentProof {
pub agent_name: String,
pub agent_key_id: String,
pub payload: String,
pub signature: String,
}
pub fn agents_path() -> PathBuf {
logs::ward_home().join("agents.json")
}
pub fn ensure_agent(project: &str, agent_name: &str) -> Result<AgentIdentity> {
let mut state = load_agents()?;
let agents = state.projects.entry(project.to_string()).or_default();
if let Some(agent) = agents
.iter_mut()
.find(|agent| agent.agent_name == agent_name)
{
agent.last_used = Utc::now();
let agent = agent.clone();
save_agents(&state)?;
return Ok(agent);
}
let mut seed = [0_u8; 32];
OsRng.fill_bytes(&mut seed);
let signing_key = SigningKey::from_bytes(&seed);
let public_key = signing_key.verifying_key().to_bytes();
let agent = AgentIdentity {
agent_name: agent_name.to_string(),
agent_key_id: key_id(&public_key),
public_key: STANDARD.encode(public_key),
private_seed: STANDARD.encode(seed),
created_at: Utc::now(),
last_used: Utc::now(),
};
seed.zeroize();
agents.push(agent.clone());
save_agents(&state)?;
Ok(agent)
}
pub fn sign_payload(project: &str, agent_name: &str, payload: &str) -> Result<AgentProof> {
let agent = ensure_agent(project, agent_name)?;
let mut seed = STANDARD
.decode(&agent.private_seed)
.context("agent private seed is not valid base64")?;
let seed_array = <[u8; 32]>::try_from(seed.as_slice())
.map_err(|_| anyhow::anyhow!("agent private seed has invalid length"))?;
let signing_key = SigningKey::from_bytes(&seed_array);
let signature = signing_key.sign(payload.as_bytes());
seed.zeroize();
Ok(AgentProof {
agent_name: agent.agent_name,
agent_key_id: agent.agent_key_id,
payload: payload.to_string(),
signature: STANDARD.encode(signature.to_bytes()),
})
}
pub fn verify_proof(project: &str, proof: &AgentProof) -> Result<bool> {
let state = load_agents()?;
let Some(agent) = state.projects.get(project).and_then(|agents| {
agents.iter().find(|agent| {
agent.agent_name == proof.agent_name && agent.agent_key_id == proof.agent_key_id
})
}) else {
return Ok(false);
};
let public_key = STANDARD
.decode(&agent.public_key)
.context("agent public key is not valid base64")?;
let public_key = <[u8; 32]>::try_from(public_key.as_slice())
.map_err(|_| anyhow::anyhow!("agent public key has invalid length"))?;
let signature = STANDARD
.decode(&proof.signature)
.context("agent signature is not valid base64")?;
let verifying_key = VerifyingKey::from_bytes(&public_key)?;
let signature = Signature::try_from(signature.as_slice())?;
Ok(verifying_key
.verify(proof.payload.as_bytes(), &signature)
.is_ok())
}
pub fn load_agents() -> Result<AgentState> {
let path = agents_path();
if !path.exists() {
return Ok(AgentState::default());
}
let contents =
std::fs::read_to_string(&path).context(format!("failed to read {}", path.display()))?;
serde_json::from_str(&contents).context(format!("failed to parse {}", path.display()))
}
pub fn save_agents(state: &AgentState) -> Result<()> {
fs_util::ensure_private_dir(&logs::ward_home())?;
let contents = serde_json::to_string_pretty(state).expect("agent state should serialize");
fs_util::write_private_file(&agents_path(), format!("{contents}\n").as_bytes())
}
fn key_id(public_key: &[u8; 32]) -> String {
let mut hasher = Sha256::new();
hasher.update(public_key);
format!("agent:{}", hex::encode(hasher.finalize()))
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use std::sync::{Mutex, OnceLock};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(())).lock().unwrap()
}
#[test]
#[serial]
fn agent_identity_signs_and_verifies_payloads() {
let _guard = env_lock();
let home = tempfile::tempdir().unwrap();
std::env::set_var("WARD_HOME", home.path());
let first = ensure_agent("demo", "codex").unwrap();
let second = ensure_agent("demo", "codex").unwrap();
assert_eq!(first.agent_key_id, second.agent_key_id);
let proof = sign_payload("demo", "codex", "payload").unwrap();
assert!(verify_proof("demo", &proof).unwrap());
let mut bad = proof.clone();
bad.payload = "changed".to_string();
assert!(!verify_proof("demo", &bad).unwrap());
assert!(!verify_proof("other", &proof).unwrap());
let mut missing_agent = proof.clone();
missing_agent.agent_name = "claude".to_string();
assert!(!verify_proof("demo", &missing_agent).unwrap());
std::env::remove_var("WARD_HOME");
}
#[test]
#[serial]
fn agent_state_reports_invalid_files() {
let _guard = env_lock();
let home = tempfile::tempdir().unwrap();
std::env::set_var("WARD_HOME", home.path());
std::fs::create_dir_all(home.path()).unwrap();
std::fs::write(agents_path(), "{bad-json}").unwrap();
assert!(load_agents().is_err());
std::env::remove_var("WARD_HOME");
}
#[test]
#[serial]
fn agent_signing_reports_invalid_key_lengths() {
let _guard = env_lock();
let home = tempfile::tempdir().unwrap();
std::env::set_var("WARD_HOME", home.path());
let proof = sign_payload("demo", "codex", "payload").unwrap();
let mut state = load_agents().unwrap();
let agent = state
.projects
.get_mut("demo")
.unwrap()
.iter_mut()
.find(|agent| agent.agent_name == "codex")
.unwrap();
agent.private_seed = STANDARD.encode([1_u8, 2, 3]);
save_agents(&state).unwrap();
assert!(sign_payload("demo", "codex", "payload").is_err());
let mut state = load_agents().unwrap();
let agent = state
.projects
.get_mut("demo")
.unwrap()
.iter_mut()
.find(|agent| agent.agent_name == "codex")
.unwrap();
agent.private_seed = STANDARD.encode([7_u8; 32]);
agent.public_key = STANDARD.encode([1_u8, 2, 3]);
save_agents(&state).unwrap();
assert!(verify_proof("demo", &proof).is_err());
std::env::remove_var("WARD_HOME");
}
}