use std::sync::OnceLock;
use anyhow::Result;
use crate::validate;
pub mod keypair;
pub mod sign;
pub mod verify;
pub mod replay;
pub mod attest;
pub mod sentinels;
const ENV_AGENT_ID: &str = "AI_MEMORY_AGENT_ID";
pub const ENV_ANONYMIZE: &str = "AI_MEMORY_ANONYMIZE";
fn anonymize_default_enabled() -> bool {
let Ok(v) = std::env::var(ENV_ANONYMIZE) else {
return false;
};
matches!(
v.trim().to_ascii_lowercase().as_str(),
"1" | "true" | "yes" | "on"
)
}
pub fn process_discriminator() -> &'static str {
static DISCRIMINATOR: OnceLock<String> = OnceLock::new();
DISCRIMINATOR.get_or_init(|| {
let pid = std::process::id();
let uuid_short = short_uuid();
format!("pid-{pid}-{uuid_short}")
})
}
fn hostname_opt() -> Option<String> {
let os = gethostname::gethostname();
let s = os.to_string_lossy().to_string();
let s = s.trim().to_string();
if s.is_empty() { None } else { Some(s) }
}
fn short_uuid() -> String {
let id = uuid::Uuid::new_v4();
let simple = id.simple().to_string(); simple[..8].to_string()
}
fn sanitize_component(input: &str) -> String {
let mut out = String::with_capacity(input.len());
let mut last_dash = false;
for c in input.chars() {
if c.is_ascii_alphanumeric() || matches!(c, '_' | '-' | '.') {
out.push(c);
last_dash = false;
} else if !last_dash {
out.push('-');
last_dash = true;
}
}
out.trim_matches('-').to_string()
}
pub fn resolve_agent_id(explicit: Option<&str>, mcp_client: Option<&str>) -> Result<String> {
if let Some(id) = explicit
&& !id.is_empty()
{
validate::validate_agent_id(id)?;
return Ok(id.to_string());
}
if let Ok(v) = std::env::var(ENV_AGENT_ID)
&& !v.is_empty()
{
validate::validate_agent_id_shape(&v)?;
return Ok(v);
}
if let Some(client) = mcp_client
&& !client.is_empty()
{
let client_s = sanitize_component(client);
let host_s =
hostname_opt().map_or_else(|| "unknown".to_string(), |h| sanitize_component(&h));
let pid = std::process::id();
let id = format!("ai:{client_s}@{host_s}:pid-{pid}");
if validate::validate_agent_id(&id).is_ok() {
return Ok(id);
}
}
if !anonymize_default_enabled()
&& let Some(host) = hostname_opt()
{
let host_s = sanitize_component(&host);
if !host_s.is_empty() {
let discriminator = process_discriminator();
let id = format!("host:{host_s}:{discriminator}");
if validate::validate_agent_id(&id).is_ok() {
return Ok(id);
}
}
}
let discriminator = process_discriminator();
let id = format!("anonymous:{discriminator}");
validate::validate_agent_id(&id)?;
Ok(id)
}
#[must_use]
pub fn resolve_read_visibility_caller() -> Option<String> {
let v = std::env::var(ENV_AGENT_ID).ok()?;
if v.is_empty() {
return None;
}
validate::validate_agent_id_shape(&v).ok()?;
Some(v)
}
pub fn anonymous_request_id() -> String {
format!("{}{}", sentinels::ANONYMOUS_REQ_PREFIX, short_uuid())
}
pub fn resolve_http_agent_id(body: Option<&str>, header: Option<&str>) -> Result<String> {
let resolved = if let Some(id) = header
&& !id.is_empty()
{
validate::validate_agent_id(id)?;
id.to_string()
} else {
let anon = anonymous_request_id();
tracing::warn!(
"HTTP memory write without agent_id body field or X-Agent-Id header; assigned {anon}"
);
validate::validate_agent_id(&anon)?;
anon
};
if let Some(claim) = body
&& !claim.is_empty()
{
validate::validate_agent_id(claim)?;
if claim != resolved {
anyhow::bail!(
"agent_id_body_header_mismatch: body-supplied agent_id {claim:?} disagrees \
with authenticated header-resolved id {resolved:?}"
);
}
}
Ok(resolved)
}
pub fn preserve_agent_id(
existing: &serde_json::Value,
incoming: &serde_json::Value,
) -> serde_json::Value {
let mut merged = if incoming.is_object() {
incoming.clone()
} else {
serde_json::Value::Object(serde_json::Map::new())
};
if let (Some(existing_id), Some(obj)) =
(existing.get("agent_id").cloned(), merged.as_object_mut())
{
obj.insert("agent_id".to_string(), existing_id);
}
merged
}
#[cfg(test)]
mod tests {
use super::*;
fn env_var_lock() -> std::sync::MutexGuard<'static, ()> {
use std::sync::{Mutex, OnceLock};
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
#[test]
fn process_discriminator_is_stable() {
let a = process_discriminator();
let b = process_discriminator();
assert_eq!(
a, b,
"discriminator must be stable for the process lifetime"
);
assert!(a.starts_with("pid-"));
assert!(a.len() >= "pid-1-0000000a".len());
}
#[test]
fn short_uuid_is_8_hex_chars() {
let s = short_uuid();
assert_eq!(s.len(), 8);
assert!(
s.chars()
.all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase())
);
}
#[test]
fn sanitize_component_preserves_safe_chars() {
assert_eq!(sanitize_component("claude-code"), "claude-code");
assert_eq!(sanitize_component("host.example.com"), "host.example.com");
assert_eq!(sanitize_component("devbox_1"), "devbox_1");
}
#[test]
fn sanitize_component_replaces_unsafe_chars() {
assert_eq!(sanitize_component("my host"), "my-host");
assert_eq!(sanitize_component("a/b"), "a-b");
assert_eq!(sanitize_component("a b"), "a-b"); assert_eq!(sanitize_component("a;b|c"), "a-b-c");
assert_eq!(sanitize_component("---foo---"), "foo");
}
#[test]
fn resolve_explicit_caller_wins() {
let id = resolve_agent_id(Some("alice"), Some("claude-code")).unwrap();
assert_eq!(id, "alice");
}
#[test]
fn resolve_validates_explicit_caller() {
assert!(resolve_agent_id(Some("alice bob"), None).is_err());
assert!(resolve_agent_id(Some("a\0null"), None).is_err());
}
#[test]
fn resolve_empty_explicit_falls_through() {
let _g = env_var_lock();
unsafe {
std::env::remove_var(ENV_AGENT_ID);
}
let id = resolve_agent_id(Some(""), None).unwrap();
assert!(id.starts_with("host:") || id.starts_with("anonymous:"));
}
#[test]
fn resolve_mcp_client_synthesizes_ai_prefix() {
let _g = env_var_lock();
unsafe {
std::env::remove_var(ENV_AGENT_ID);
}
let id = resolve_agent_id(None, Some("claude-code")).unwrap();
assert!(id.starts_with("ai:claude-code@"));
assert!(id.contains(":pid-"));
}
#[test]
fn resolve_mcp_client_sanitizes_name() {
let _g = env_var_lock();
unsafe {
std::env::remove_var(ENV_AGENT_ID);
}
let id = resolve_agent_id(None, Some("weird client!")).unwrap();
assert!(id.starts_with("ai:weird-client@"));
}
#[test]
fn resolve_default_is_host_or_anonymous() {
let _g = env_var_lock();
unsafe {
std::env::remove_var(ENV_AGENT_ID);
}
let id = resolve_agent_id(None, None).unwrap();
assert!(
id.starts_with("host:") || id.starts_with("anonymous:"),
"got: {id}"
);
}
#[test]
fn read_visibility_caller_returns_env_when_set() {
let _g = env_var_lock();
unsafe {
std::env::set_var(ENV_AGENT_ID, "ai:alice");
}
let got = resolve_read_visibility_caller();
unsafe {
std::env::remove_var(ENV_AGENT_ID);
}
assert_eq!(got.as_deref(), Some("ai:alice"));
}
#[test]
fn read_visibility_caller_none_when_unset() {
let _g = env_var_lock();
unsafe {
std::env::remove_var(ENV_AGENT_ID);
}
assert_eq!(resolve_read_visibility_caller(), None);
}
#[test]
fn read_visibility_caller_none_when_empty_or_shape_invalid() {
let _g = env_var_lock();
unsafe {
std::env::set_var(ENV_AGENT_ID, "");
}
assert_eq!(resolve_read_visibility_caller(), None);
unsafe {
std::env::set_var(ENV_AGENT_ID, "has space");
}
assert_eq!(resolve_read_visibility_caller(), None);
unsafe {
std::env::remove_var(ENV_AGENT_ID);
}
}
#[test]
fn resolve_http_body_mismatch_is_err() {
let r = resolve_http_agent_id(Some("alice"), Some("bob"));
assert!(r.is_err(), "mismatch must be Err, got Ok({r:?})");
let msg = r.unwrap_err().to_string();
assert!(
msg.contains("agent_id_body_header_mismatch"),
"error must carry tag agent_id_body_header_mismatch, got: {msg}"
);
assert!(!msg.is_empty());
}
#[test]
fn resolve_http_body_matching_header_is_ok() {
let id = resolve_http_agent_id(Some("alice"), Some("alice")).unwrap();
assert_eq!(id, "alice");
}
#[test]
fn resolve_http_empty_body_is_no_claim() {
let id = resolve_http_agent_id(Some(""), Some("bob")).unwrap();
assert_eq!(id, "bob");
}
#[test]
fn resolve_http_body_without_header_uses_anonymous_and_mismatches() {
let r = resolve_http_agent_id(Some("alice"), None);
assert!(r.is_err(), "body without header must be Err, got Ok({r:?})");
let msg = r.unwrap_err().to_string();
assert!(
msg.contains("agent_id_body_header_mismatch"),
"error must carry tag agent_id_body_header_mismatch, got: {msg}"
);
}
#[test]
fn resolve_http_header_used_when_body_missing() {
let id = resolve_http_agent_id(None, Some("bob")).unwrap();
assert_eq!(id, "bob");
}
#[test]
fn resolve_http_fallback_is_anonymous_req() {
let id = resolve_http_agent_id(None, None).unwrap();
assert!(id.starts_with("anonymous:req-"), "got: {id}");
let id2 = resolve_http_agent_id(None, None).unwrap();
assert_ne!(id, id2);
}
#[test]
fn resolve_http_validates_caller_input() {
assert!(resolve_http_agent_id(Some("has space"), None).is_err());
assert!(resolve_http_agent_id(None, Some("has\0null")).is_err());
}
#[test]
fn preserve_agent_id_copies_existing() {
let existing = serde_json::json!({"agent_id": "alice", "foo": "old"});
let incoming = serde_json::json!({"agent_id": "bob", "foo": "new", "bar": 1});
let merged = preserve_agent_id(&existing, &incoming);
assert_eq!(merged["agent_id"], "alice");
assert_eq!(merged["foo"], "new");
assert_eq!(merged["bar"], 1);
}
#[test]
fn preserve_agent_id_no_op_when_existing_has_none() {
let existing = serde_json::json!({"foo": "x"});
let incoming = serde_json::json!({"agent_id": "bob"});
let merged = preserve_agent_id(&existing, &incoming);
assert_eq!(merged["agent_id"], "bob");
}
#[test]
fn preserve_agent_id_handles_non_object_incoming() {
let existing = serde_json::json!({"agent_id": "alice"});
let incoming = serde_json::json!("not-an-object");
let merged = preserve_agent_id(&existing, &incoming);
assert!(merged.is_object());
assert_eq!(merged["agent_id"], "alice");
}
#[test]
fn anonymize_default_enabled_truthy_variants() {
let _g = env_var_lock();
for v in ["1", "true", "yes", "on", "TRUE", " yes ", "On", "YES"] {
unsafe {
std::env::set_var(ENV_ANONYMIZE, v);
}
assert!(anonymize_default_enabled(), "value {v:?} must be truthy");
}
unsafe {
std::env::remove_var(ENV_ANONYMIZE);
}
}
#[test]
fn anonymize_default_enabled_falsy_variants() {
let _g = env_var_lock();
for v in ["0", "false", "no", "off", "", "garbage"] {
unsafe {
std::env::set_var(ENV_ANONYMIZE, v);
}
assert!(!anonymize_default_enabled(), "value {v:?} must be falsy");
}
unsafe {
std::env::remove_var(ENV_ANONYMIZE);
}
}
#[test]
fn anonymize_default_enabled_unset_is_falsy() {
let _g = env_var_lock();
unsafe {
std::env::remove_var(ENV_ANONYMIZE);
}
assert!(!anonymize_default_enabled());
}
#[test]
fn resolve_uses_env_agent_id_when_no_explicit_no_mcp() {
let _g = env_var_lock();
unsafe {
std::env::set_var(ENV_AGENT_ID, "env-alice");
}
let id = resolve_agent_id(None, None).unwrap();
assert_eq!(id, "env-alice");
unsafe {
std::env::remove_var(ENV_AGENT_ID);
}
}
#[test]
fn resolve_anonymize_forces_anonymous_prefix() {
let _g = env_var_lock();
unsafe {
std::env::remove_var(ENV_AGENT_ID);
std::env::set_var(ENV_ANONYMIZE, "1");
}
let id = resolve_agent_id(None, None).unwrap();
assert!(
id.starts_with("anonymous:"),
"AI_MEMORY_ANONYMIZE=1 must skip host: default, got: {id}"
);
unsafe {
std::env::remove_var(ENV_ANONYMIZE);
}
}
#[test]
fn resolve_empty_env_falls_through() {
let _g = env_var_lock();
unsafe {
std::env::set_var(ENV_AGENT_ID, "");
}
let id = resolve_agent_id(None, None).unwrap();
assert!(
id.starts_with("host:") || id.starts_with("anonymous:") || id.starts_with("ai:"),
"empty env must fall through to host/anonymous default, got: {id}"
);
unsafe {
std::env::remove_var(ENV_AGENT_ID);
}
}
}