#![recursion_limit = "256"]
mod autonomy;
mod color;
mod config;
mod curator;
mod db;
mod embeddings;
mod errors;
mod federation;
mod handlers;
mod hnsw;
mod identity;
mod llm;
mod mcp;
mod metrics;
#[cfg(feature = "sal")]
mod migrate;
mod mine;
mod models;
mod replication;
mod reranker;
#[cfg(feature = "sal")]
mod store;
mod subscriptions;
mod toon;
mod validate;
use anyhow::{Context, Result};
use axum::{
Router,
extract::DefaultBodyLimit,
routing::{delete, get, post, put},
};
use chrono::{Duration, Utc};
use clap::{Args, CommandFactory, Parser, Subcommand};
use clap_complete::{Shell, generate};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::Mutex;
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
use tracing_subscriber::EnvFilter;
use crate::models::Tier;
const DEFAULT_DB: &str = "ai-memory.db";
const DEFAULT_PORT: u16 = 9077;
const GC_INTERVAL_SECS: u64 = 1800;
const WAL_CHECKPOINT_INTERVAL_SECS: u64 = 600;
fn id_short(id: &str) -> &str {
let end = id.len().min(8);
let mut end = end;
while end > 0 && !id.is_char_boundary(end) {
end -= 1;
}
&id[..end]
}
#[derive(Parser)]
#[command(
name = "ai-memory",
version,
about = "AI-agnostic persistent memory — MCP server, HTTP API, and CLI for any AI platform"
)]
struct Cli {
#[command(subcommand)]
command: Command,
#[arg(long, env = "AI_MEMORY_DB", default_value = DEFAULT_DB, global = true)]
db: PathBuf,
#[arg(long, global = true, default_value_t = false)]
json: bool,
#[arg(long, env = "AI_MEMORY_AGENT_ID", global = true)]
agent_id: Option<String>,
#[arg(long, global = true, value_name = "PATH")]
db_passphrase_file: Option<PathBuf>,
}
#[derive(Subcommand)]
enum Command {
Serve(ServeArgs),
Mcp {
#[arg(long, default_value = "semantic")]
tier: String,
},
Store(StoreArgs),
Update(UpdateArgs),
Recall(RecallArgs),
Search(SearchArgs),
Get(GetArgs),
List(ListArgs),
Delete(DeleteArgs),
Promote(PromoteArgs),
Forget(ForgetArgs),
Link(LinkArgs),
Consolidate(ConsolidateArgs),
Gc,
Stats,
Namespaces,
Export,
Import(ImportArgs),
Resolve(ResolveArgs),
Shell,
Sync(SyncArgs),
SyncDaemon(SyncDaemonArgs),
AutoConsolidate(AutoConsolidateArgs),
Completions(CompletionsArgs),
Man,
Mine(MineArgs),
Archive(ArchiveArgs),
Agents(AgentsArgs),
Pending(PendingArgs),
Backup(BackupArgs),
Restore(RestoreArgs),
Curator(CuratorArgs),
#[cfg(feature = "sal")]
Migrate(MigrateArgs),
}
#[derive(Args)]
#[allow(clippy::struct_excessive_bools)]
struct CuratorArgs {
#[arg(long, conflicts_with = "daemon")]
once: bool,
#[arg(long)]
daemon: bool,
#[arg(long, default_value_t = 3600)]
interval_secs: u64,
#[arg(long, default_value_t = 100)]
max_ops: usize,
#[arg(long)]
dry_run: bool,
#[arg(long = "include-namespace")]
include_namespaces: Vec<String>,
#[arg(long = "exclude-namespace")]
exclude_namespaces: Vec<String>,
#[arg(long)]
json: bool,
#[arg(long, conflicts_with_all = ["once", "daemon"])]
rollback: Option<String>,
#[arg(long)]
rollback_last: Option<usize>,
}
#[cfg(feature = "sal")]
#[derive(Args)]
struct MigrateArgs {
#[arg(long)]
from: String,
#[arg(long)]
to: String,
#[arg(long, default_value_t = 1000)]
batch: usize,
#[arg(long)]
namespace: Option<String>,
#[arg(long)]
dry_run: bool,
#[arg(long)]
json: bool,
}
#[derive(Args)]
struct BackupArgs {
#[arg(long, default_value = "./backups")]
to: PathBuf,
#[arg(long, default_value_t = 48)]
keep: usize,
}
#[derive(Args)]
struct RestoreArgs {
#[arg(long)]
from: PathBuf,
#[arg(long)]
skip_verify: bool,
}
#[derive(Args)]
struct PendingArgs {
#[command(subcommand)]
action: PendingAction,
}
#[derive(Subcommand)]
enum PendingAction {
List {
#[arg(long)]
status: Option<String>,
#[arg(long, default_value_t = 100)]
limit: usize,
},
Approve { id: String },
Reject { id: String },
}
#[derive(Args)]
struct AgentsArgs {
#[command(subcommand)]
action: Option<AgentsAction>,
}
#[derive(Subcommand)]
enum AgentsAction {
List,
Register {
#[arg(long)]
agent_id: String,
#[arg(long)]
agent_type: String,
#[arg(long, default_value = "")]
capabilities: String,
},
}
#[derive(Args)]
struct ArchiveArgs {
#[command(subcommand)]
action: ArchiveAction,
}
#[derive(Subcommand)]
enum ArchiveAction {
List {
#[arg(long, short)]
namespace: Option<String>,
#[arg(long, default_value_t = 50)]
limit: usize,
#[arg(long, default_value_t = 0)]
offset: usize,
},
Restore { id: String },
Purge {
#[arg(long)]
older_than_days: Option<i64>,
},
Stats,
}
#[derive(Args)]
struct MineArgs {
path: PathBuf,
#[arg(long, short)]
format: String,
#[arg(long, short)]
namespace: Option<String>,
#[arg(long, short, default_value = "mid")]
tier: String,
#[arg(long, default_value_t = 3)]
min_messages: usize,
#[arg(long, default_value_t = false)]
dry_run: bool,
}
#[derive(Args)]
struct ServeArgs {
#[arg(long, default_value = "127.0.0.1")]
host: String,
#[arg(long, default_value_t = DEFAULT_PORT)]
port: u16,
#[arg(long, requires = "tls_key")]
tls_cert: Option<PathBuf>,
#[arg(long, requires = "tls_cert")]
tls_key: Option<PathBuf>,
#[arg(long, requires = "tls_cert")]
mtls_allowlist: Option<PathBuf>,
#[arg(long, default_value_t = 30)]
shutdown_grace_secs: u64,
#[arg(long, default_value_t = 0)]
quorum_writes: usize,
#[arg(long, value_delimiter = ',')]
quorum_peers: Vec<String>,
#[arg(long, default_value_t = 2000)]
quorum_timeout_ms: u64,
#[arg(long)]
quorum_client_cert: Option<PathBuf>,
#[arg(long)]
quorum_client_key: Option<PathBuf>,
#[arg(long)]
quorum_ca_cert: Option<PathBuf>,
#[arg(long, default_value_t = 30)]
catchup_interval_secs: u64,
}
#[derive(Args)]
struct StoreArgs {
#[arg(long, short, default_value = "mid")]
tier: String,
#[arg(long, short)]
namespace: Option<String>,
#[arg(long, short = 'T', allow_hyphen_values = true)]
title: String,
#[arg(long, short, allow_hyphen_values = true)]
content: String,
#[arg(long, default_value = "")]
tags: String,
#[arg(long, short, default_value_t = 5)]
priority: i32,
#[arg(long, default_value_t = 1.0)]
confidence: f64,
#[arg(long, short = 'S', default_value = "cli")]
source: String,
#[arg(long)]
expires_at: Option<String>,
#[arg(long)]
ttl_secs: Option<i64>,
#[arg(long)]
scope: Option<String>,
}
#[derive(Args)]
struct UpdateArgs {
id: String,
#[arg(long, short = 'T', allow_hyphen_values = true)]
title: Option<String>,
#[arg(long, short, allow_hyphen_values = true)]
content: Option<String>,
#[arg(long, short)]
tier: Option<String>,
#[arg(long, short)]
namespace: Option<String>,
#[arg(long)]
tags: Option<String>,
#[arg(long, short)]
priority: Option<i32>,
#[arg(long)]
confidence: Option<f64>,
#[arg(long)]
expires_at: Option<String>,
}
#[derive(Args)]
struct RecallArgs {
#[arg(allow_hyphen_values = true)]
context: String,
#[arg(long, short)]
namespace: Option<String>,
#[arg(long, default_value_t = 10)]
limit: usize,
#[arg(long)]
tags: Option<String>,
#[arg(long)]
since: Option<String>,
#[arg(long)]
until: Option<String>,
#[arg(long, short = 'T')]
tier: Option<String>,
#[arg(long)]
as_agent: Option<String>,
#[arg(long)]
budget_tokens: Option<usize>,
#[arg(long, value_delimiter = ',')]
context_tokens: Option<Vec<String>>,
}
#[derive(Args)]
struct SearchArgs {
#[arg(allow_hyphen_values = true)]
query: String,
#[arg(long, short)]
namespace: Option<String>,
#[arg(long, short)]
tier: Option<String>,
#[arg(long, default_value_t = 20)]
limit: usize,
#[arg(long)]
since: Option<String>,
#[arg(long)]
until: Option<String>,
#[arg(long)]
tags: Option<String>,
#[arg(long)]
agent_id: Option<String>,
#[arg(long)]
as_agent: Option<String>,
}
#[derive(Args)]
struct GetArgs {
id: String,
}
#[derive(Args)]
struct ListArgs {
#[arg(long, short)]
namespace: Option<String>,
#[arg(long, short)]
tier: Option<String>,
#[arg(long, default_value_t = 20)]
limit: usize,
#[arg(long)]
since: Option<String>,
#[arg(long)]
until: Option<String>,
#[arg(long)]
tags: Option<String>,
#[arg(long, default_value_t = 0)]
offset: usize,
#[arg(long)]
agent_id: Option<String>,
}
#[derive(Args)]
struct DeleteArgs {
id: String,
}
#[derive(Args)]
struct PromoteArgs {
id: String,
#[arg(long)]
to_namespace: Option<String>,
}
#[derive(Args)]
struct ForgetArgs {
#[arg(long, short)]
namespace: Option<String>,
#[arg(long, short)]
pattern: Option<String>,
#[arg(long, short)]
tier: Option<String>,
}
#[derive(Args)]
struct LinkArgs {
source_id: String,
target_id: String,
#[arg(long, short, default_value = "related_to")]
relation: String,
}
#[derive(Args)]
struct ConsolidateArgs {
ids: String,
#[arg(long, short = 'T', allow_hyphen_values = true)]
title: String,
#[arg(long, short = 's', allow_hyphen_values = true)]
summary: String,
#[arg(long, short)]
namespace: Option<String>,
}
#[derive(Args)]
struct ResolveArgs {
winner_id: String,
loser_id: String,
}
#[derive(Args)]
struct SyncDaemonArgs {
#[arg(long, value_delimiter = ',')]
peers: Vec<String>,
#[arg(long, default_value_t = 2)]
interval: u64,
#[arg(long)]
api_key: Option<String>,
#[arg(long, default_value_t = 500)]
batch_size: usize,
#[arg(long, requires = "client_key")]
client_cert: Option<PathBuf>,
#[arg(long, requires = "client_cert")]
client_key: Option<PathBuf>,
#[arg(long, default_value_t = false)]
insecure_skip_server_verify: bool,
}
#[derive(Args)]
struct SyncArgs {
remote_db: PathBuf,
#[arg(long, short, default_value = "merge")]
direction: String,
#[arg(long, default_value_t = false)]
trust_source: bool,
#[arg(long, default_value_t = false)]
dry_run: bool,
}
#[derive(Args)]
struct ImportArgs {
#[arg(long, default_value_t = false)]
trust_source: bool,
}
#[derive(Args)]
struct AutoConsolidateArgs {
#[arg(long, short)]
namespace: Option<String>,
#[arg(long, default_value_t = false)]
short_only: bool,
#[arg(long, default_value_t = 3)]
min_count: usize,
#[arg(long, default_value_t = false)]
dry_run: bool,
}
#[derive(Args)]
struct CompletionsArgs {
shell: Shell,
}
fn auto_namespace() -> String {
if let Ok(out) = std::process::Command::new("git")
.args(["remote", "get-url", "origin"])
.stderr(std::process::Stdio::null())
.output()
{
let url = String::from_utf8_lossy(&out.stdout).trim().to_string();
if !url.is_empty() {
if let Some(name) = url.rsplit('/').next() {
let name = name.trim_end_matches(".git");
if !name.is_empty() {
return name.to_string();
}
}
}
}
std::env::current_dir()
.ok()
.and_then(|p| p.file_name().map(|n| n.to_string_lossy().to_string()))
.unwrap_or_else(|| "global".to_string())
}
fn human_age(iso: &str) -> String {
let Ok(dt) = chrono::DateTime::parse_from_rfc3339(iso) else {
return iso.to_string();
};
let dur = Utc::now().signed_duration_since(dt);
if dur.num_seconds() < 60 {
return "just now".to_string();
}
if dur.num_minutes() < 60 {
return format!("{}m ago", dur.num_minutes());
}
if dur.num_hours() < 24 {
return format!("{}h ago", dur.num_hours());
}
if dur.num_days() < 30 {
return format!("{}d ago", dur.num_days());
}
format!("{}mo ago", dur.num_days() / 30)
}
#[tokio::main]
#[allow(clippy::too_many_lines)]
async fn main() -> Result<()> {
color::init();
let app_config = config::AppConfig::load();
config::AppConfig::write_default_if_missing();
if app_config.effective_anonymize_default() && std::env::var("AI_MEMORY_ANONYMIZE").is_err() {
unsafe { std::env::set_var("AI_MEMORY_ANONYMIZE", "1") };
}
let cli = Cli::parse();
if let Some(path) = &cli.db_passphrase_file {
let raw = std::fs::read_to_string(path)
.with_context(|| format!("reading passphrase file {}", path.display()))?;
let passphrase = raw.trim_end_matches(['\n', '\r']).to_string();
if passphrase.is_empty() {
anyhow::bail!("passphrase file {} is empty", path.display());
}
unsafe { std::env::set_var("AI_MEMORY_DB_PASSPHRASE", passphrase) };
}
let db_path = app_config.effective_db(&cli.db);
let j = cli.json;
let cli_agent_id: Option<String> = cli.agent_id.clone();
let is_write_command = matches!(
cli.command,
Command::Store(_)
| Command::Update(_)
| Command::Delete(_)
| Command::Promote(_)
| Command::Forget(_)
| Command::Link(_)
| Command::Consolidate(_)
| Command::Resolve(_)
| Command::Sync(_)
| Command::SyncDaemon(_)
| Command::Import(_)
| Command::AutoConsolidate(_)
| Command::Gc
);
let db_path_for_checkpoint = if is_write_command {
Some(db_path.clone())
} else {
None
};
let result = match cli.command {
Command::Serve(a) => serve(db_path, a, &app_config).await,
Command::Mcp { tier } => {
let feature_tier = app_config.effective_tier(Some(&tier));
mcp::run_mcp_server(&db_path, feature_tier, &app_config)?;
Ok(())
}
Command::Store(a) => cmd_store(&db_path, a, j, &app_config, cli_agent_id.as_deref()),
Command::Update(a) => cmd_update(&db_path, &a, j),
Command::Recall(a) => cmd_recall(&db_path, &a, j, &app_config),
Command::Search(a) => cmd_search(&db_path, &a, j, &app_config),
Command::Get(a) => cmd_get(&db_path, &a, j),
Command::List(a) => cmd_list(&db_path, &a, j, &app_config),
Command::Delete(a) => cmd_delete(&db_path, &a, j, cli_agent_id.as_deref()),
Command::Promote(a) => cmd_promote(&db_path, &a, j, cli_agent_id.as_deref()),
Command::Forget(a) => cmd_forget(&db_path, &a, j),
Command::Link(a) => cmd_link(&db_path, &a, j),
Command::Consolidate(a) => cmd_consolidate(&db_path, a, j, cli_agent_id.as_deref()),
Command::Resolve(a) => cmd_resolve(&db_path, &a, j),
Command::Shell => cmd_shell(&db_path),
Command::Sync(a) => cmd_sync(&db_path, &a, j, cli_agent_id.as_deref()),
Command::SyncDaemon(a) => cmd_sync_daemon(&db_path, a, cli_agent_id.as_deref()).await,
Command::AutoConsolidate(a) => {
cmd_auto_consolidate(&db_path, &a, j, cli_agent_id.as_deref())
}
Command::Gc => cmd_gc(&db_path, j, &app_config),
Command::Stats => cmd_stats(&db_path, j),
Command::Namespaces => cmd_namespaces(&db_path, j),
Command::Export => cmd_export(&db_path),
Command::Import(a) => cmd_import(&db_path, &a, j, cli_agent_id.as_deref()),
Command::Completions(a) => {
generate(
a.shell,
&mut Cli::command(),
"ai-memory",
&mut std::io::stdout(),
);
Ok(())
}
Command::Man => {
let cmd = Cli::command();
let man = clap_mangen::Man::new(cmd);
man.render(&mut std::io::stdout())?;
Ok(())
}
Command::Mine(a) => cmd_mine(&db_path, a, j, &app_config, cli_agent_id.as_deref()),
Command::Archive(a) => cmd_archive(&db_path, a, j),
Command::Agents(a) => cmd_agents(&db_path, a, j),
Command::Pending(a) => cmd_pending(&db_path, a, j, cli_agent_id.as_deref()),
Command::Backup(a) => cmd_backup(&db_path, &a, j),
Command::Restore(a) => cmd_restore(&db_path, &a, j),
Command::Curator(a) => cmd_curator(&db_path, &a, &app_config).await,
#[cfg(feature = "sal")]
Command::Migrate(a) => cmd_migrate(&a).await,
};
if result.is_ok()
&& let Some(cp_path) = db_path_for_checkpoint
&& let Ok(conn) = db::open(&cp_path)
{
let _ = db::checkpoint(&conn);
}
result
}
#[allow(clippy::too_many_lines)]
async fn serve(db_path: PathBuf, args: ServeArgs, app_config: &config::AppConfig) -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::from_default_env()
.add_directive("ai_memory=info".parse()?)
.add_directive("tower_http=info".parse()?),
)
.init();
let resolved_ttl = app_config.effective_ttl();
let archive_on_gc = app_config.effective_archive_on_gc();
let conn = db::open(&db_path)?;
let feature_tier = app_config.effective_tier(None);
let tier_config = feature_tier.config();
let embedder: Option<embeddings::Embedder> =
if let Some(emb_model) = tier_config.embedding_model {
let embed_url = app_config.effective_embed_url().to_string();
let build = tokio::task::spawn_blocking(move || {
let embed_client = llm::OllamaClient::new_with_url(&embed_url, "nomic-embed-text")
.ok()
.map(Arc::new);
embeddings::Embedder::for_model(emb_model, embed_client)
})
.await?;
match build {
Ok(emb) => {
tracing::info!(
"embedder loaded ({}) — tier={} semantic recall enabled",
emb.model_description(),
feature_tier.as_str()
);
Some(emb)
}
Err(e) => {
tracing::error!(
"⚠️ EMBEDDER LOAD FAILED — tier={} requested semantic features, \
but embedder init errored: {e}. Daemon falls back to keyword-only. \
Semantic recall, sync_push embedding refresh (#322), and HNSW index \
will be NO-OPS. Check network egress to HuggingFace Hub + available \
memory for model weights. To force keyword-only explicitly (silences \
this error), set `tier = \"keyword\"` in config.toml.",
feature_tier.as_str()
);
None
}
}
} else {
tracing::info!(
"embedder disabled — tier={} keyword-only (FTS5); semantic recall not wired",
feature_tier.as_str()
);
None
};
let vector_index = if embedder.is_some() {
match db::get_all_embeddings(&conn) {
Ok(entries) if !entries.is_empty() => Some(hnsw::VectorIndex::build(entries)),
_ => Some(hnsw::VectorIndex::empty()),
}
} else {
None
};
let db_state: handlers::Db = Arc::new(Mutex::new((
conn,
db_path.clone(),
resolved_ttl,
archive_on_gc,
)));
let federation = federation::FederationConfig::build(
args.quorum_writes,
&args.quorum_peers,
std::time::Duration::from_millis(args.quorum_timeout_ms),
args.quorum_client_cert.as_deref(),
args.quorum_client_key.as_deref(),
args.quorum_ca_cert.as_deref(),
format!("host:{}", gethostname::gethostname().to_string_lossy()),
)
.context("federation config")?;
if let Some(ref fed) = federation {
tracing::info!(
"federation enabled: W={} over {} peer(s), timeout {}ms",
fed.policy.w,
fed.peer_count(),
args.quorum_timeout_ms,
);
if args.catchup_interval_secs > 0 {
let interval = std::time::Duration::from_secs(args.catchup_interval_secs);
tracing::info!(
"catchup loop enabled: polling {} peer(s) every {}s",
fed.peer_count(),
args.catchup_interval_secs,
);
federation::spawn_catchup_loop(fed.clone(), db_state.clone(), interval);
} else {
tracing::info!("catchup loop disabled (--catchup-interval-secs=0)");
}
}
let app_state = handlers::AppState {
db: db_state.clone(),
embedder: Arc::new(embedder),
vector_index: Arc::new(Mutex::new(vector_index)),
federation: Arc::new(federation),
tier_config: Arc::new(tier_config.clone()),
scoring: Arc::new(app_config.effective_scoring()),
};
let state = db_state;
let gc_state = state.clone();
let archive_max_days = app_config.archive_max_days;
tokio::spawn(async move {
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(GC_INTERVAL_SECS)).await;
let lock = gc_state.lock().await;
match db::gc(&lock.0, lock.3) {
Ok(n) if n > 0 => tracing::info!("gc: expired {n} memories"),
_ => {}
}
match db::auto_purge_archive(&lock.0, archive_max_days) {
Ok(n) if n > 0 => tracing::info!("gc: purged {n} old archived memories"),
_ => {}
}
}
});
let checkpoint_state = state.clone();
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_secs(
WAL_CHECKPOINT_INTERVAL_SECS / 2,
))
.await;
loop {
{
let lock = checkpoint_state.lock().await;
match db::checkpoint(&lock.0) {
Ok(()) => tracing::debug!("wal checkpoint: ok"),
Err(e) => tracing::warn!("wal checkpoint failed: {e}"),
}
}
tokio::time::sleep(tokio::time::Duration::from_secs(
WAL_CHECKPOINT_INTERVAL_SECS,
))
.await;
}
});
let shutdown_state = state.clone();
let shutdown = async move {
let _ = tokio::signal::ctrl_c().await;
tracing::info!("shutting down — checkpointing WAL");
let lock = shutdown_state.lock().await;
let _ = db::checkpoint(&lock.0);
};
let api_key_state = handlers::ApiKeyState {
key: app_config.api_key.clone(),
};
if api_key_state.key.is_some() {
tracing::info!("API key authentication enabled");
}
let app = Router::new()
.route("/api/v1/health", get(handlers::health))
.route("/metrics", get(handlers::prometheus_metrics))
.route("/api/v1/metrics", get(handlers::prometheus_metrics))
.route("/api/v1/memories", get(handlers::list_memories))
.route("/api/v1/memories", post(handlers::create_memory))
.route("/api/v1/memories/bulk", post(handlers::bulk_create))
.route("/api/v1/memories/{id}", get(handlers::get_memory))
.route("/api/v1/memories/{id}", put(handlers::update_memory))
.route("/api/v1/memories/{id}", delete(handlers::delete_memory))
.route(
"/api/v1/memories/{id}/promote",
post(handlers::promote_memory),
)
.route("/api/v1/search", get(handlers::search_memories))
.route("/api/v1/recall", get(handlers::recall_memories_get))
.route("/api/v1/recall", post(handlers::recall_memories_post))
.route("/api/v1/forget", post(handlers::forget_memories))
.route("/api/v1/consolidate", post(handlers::consolidate_memories))
.route(
"/api/v1/contradictions",
get(handlers::detect_contradictions),
)
.route("/api/v1/links", post(handlers::create_link))
.route("/api/v1/links", delete(handlers::delete_link))
.route("/api/v1/links/{id}", get(handlers::get_links))
.route(
"/api/v1/namespaces",
get(handlers::get_namespace_standard_qs),
)
.route(
"/api/v1/namespaces",
post(handlers::set_namespace_standard_qs),
)
.route(
"/api/v1/namespaces",
delete(handlers::clear_namespace_standard_qs),
)
.route(
"/api/v1/namespaces/{ns}/standard",
post(handlers::set_namespace_standard),
)
.route(
"/api/v1/namespaces/{ns}/standard",
get(handlers::get_namespace_standard),
)
.route(
"/api/v1/namespaces/{ns}/standard",
delete(handlers::clear_namespace_standard),
)
.route("/api/v1/stats", get(handlers::get_stats))
.route("/api/v1/gc", post(handlers::run_gc))
.route("/api/v1/export", get(handlers::export_memories))
.route("/api/v1/import", post(handlers::import_memories))
.route("/api/v1/archive", get(handlers::list_archive))
.route("/api/v1/archive", post(handlers::archive_by_ids))
.route("/api/v1/archive", delete(handlers::purge_archive))
.route(
"/api/v1/archive/{id}/restore",
post(handlers::restore_archive),
)
.route("/api/v1/archive/stats", get(handlers::archive_stats))
.route("/api/v1/agents", get(handlers::list_agents))
.route("/api/v1/agents", post(handlers::register_agent))
.route("/api/v1/pending", get(handlers::list_pending))
.route(
"/api/v1/pending/{id}/approve",
post(handlers::approve_pending),
)
.route(
"/api/v1/pending/{id}/reject",
post(handlers::reject_pending),
)
.route("/api/v1/sync/push", post(handlers::sync_push))
.route("/api/v1/sync/since", get(handlers::sync_since))
.route("/api/v1/capabilities", get(handlers::get_capabilities))
.route("/api/v1/notify", post(handlers::notify))
.route("/api/v1/inbox", get(handlers::get_inbox))
.route("/api/v1/subscriptions", post(handlers::subscribe))
.route("/api/v1/subscriptions", delete(handlers::unsubscribe))
.route("/api/v1/subscriptions", get(handlers::list_subscriptions))
.route("/api/v1/session/start", post(handlers::session_start))
.layer(axum::middleware::from_fn_with_state(
api_key_state,
handlers::api_key_auth,
))
.layer(TraceLayer::new_for_http())
.layer(DefaultBodyLimit::max(2 * 1024 * 1024)) .layer(CorsLayer::new())
.with_state(app_state);
let addr = format!("{}:{}", args.host, args.port);
tracing::info!("database: {}", db_path.display());
if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = if let Some(allowlist_path) = &args.mtls_allowlist {
tracing::info!(
"mTLS enabled — client certs required. Allowlist: {}",
allowlist_path.display()
);
load_mtls_rustls_config(cert, key, allowlist_path).await?
} else {
tracing::warn!(
"TLS enabled but mTLS NOT configured — sync endpoints \
(/api/v1/sync/push, /api/v1/sync/since) accept any client. \
Set --mtls-allowlist for production peer-mesh deployments \
(red-team #231)."
);
load_rustls_config(cert, key).await?
};
tracing::info!("ai-memory listening on https://{addr}");
let socket_addr: std::net::SocketAddr = addr.parse()?;
let grace = std::time::Duration::from_secs(args.shutdown_grace_secs);
let handle = axum_server::Handle::new();
let handle_clone = handle.clone();
tokio::spawn(async move {
shutdown.await;
handle_clone.graceful_shutdown(Some(grace));
});
axum_server::bind_rustls(socket_addr, tls_config)
.handle(handle)
.serve(app.into_make_service())
.await?;
} else {
tracing::warn!(
"TLS NOT enabled — sync endpoints (/api/v1/sync/push, \
/api/v1/sync/since) accept any caller over plain HTTP. \
Set --tls-cert + --tls-key + --mtls-allowlist for production \
peer-mesh deployments (red-team #231)."
);
tracing::info!("ai-memory listening on http://{addr}");
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(listener, app)
.with_graceful_shutdown(shutdown)
.await?;
}
Ok(())
}
async fn load_rustls_config(
cert_path: &Path,
key_path: &Path,
) -> Result<axum_server::tls_rustls::RustlsConfig> {
let cert = tokio::fs::read(cert_path)
.await
.with_context(|| format!("failed to read TLS cert from {}", cert_path.display()))?;
let key = tokio::fs::read(key_path)
.await
.with_context(|| format!("failed to read TLS key from {}", key_path.display()))?;
let config = axum_server::tls_rustls::RustlsConfig::from_pem(cert, key)
.await
.context(
"failed to parse TLS cert/key — ensure PEM-encoded (cert may be fullchain; \
key must be PKCS#8 or RSA)",
)?;
Ok(config)
}
async fn load_mtls_rustls_config(
cert_path: &Path,
key_path: &Path,
allowlist_path: &Path,
) -> Result<axum_server::tls_rustls::RustlsConfig> {
let allowlist = load_fingerprint_allowlist(allowlist_path).await?;
if allowlist.is_empty() {
anyhow::bail!(
"mTLS allowlist at {} is empty — refuse to start rather than silently accept all peers",
allowlist_path.display()
);
}
let cert_pem = tokio::fs::read(cert_path)
.await
.with_context(|| format!("failed to read TLS cert from {}", cert_path.display()))?;
let key_pem = tokio::fs::read(key_path)
.await
.with_context(|| format!("failed to read TLS key from {}", key_path.display()))?;
let certs: Vec<rustls::pki_types::CertificateDer<'static>> =
rustls_pki_pem_iter_certs(&cert_pem)?;
let key = rustls_pki_pem_parse_private_key(&key_pem)?;
let verifier = Arc::new(FingerprintAllowlistVerifier { allowlist });
let server_config = rustls::ServerConfig::builder()
.with_client_cert_verifier(verifier)
.with_single_cert(certs, key)
.context("failed to build rustls ServerConfig for mTLS")?;
Ok(axum_server::tls_rustls::RustlsConfig::from_config(
Arc::new(server_config),
))
}
async fn load_fingerprint_allowlist(path: &Path) -> Result<std::collections::HashSet<[u8; 32]>> {
let text = tokio::fs::read_to_string(path)
.await
.with_context(|| format!("failed to read mTLS allowlist from {}", path.display()))?;
let mut set = std::collections::HashSet::new();
for (lineno, raw) in text.lines().enumerate() {
let line = raw.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let hex_part = line.strip_prefix("sha256:").unwrap_or(line);
if let Some(bad) = hex_part
.chars()
.find(|c| !c.is_ascii_hexdigit() && *c != ':')
{
anyhow::bail!(
"mTLS allowlist line {}: unexpected character {:?} — \
entries must be 64 hex chars with optional `:` separators",
lineno + 1,
bad
);
}
let hex_clean: String = hex_part.chars().filter(|c| *c != ':').collect();
if hex_clean.len() != 64 {
anyhow::bail!(
"mTLS allowlist line {}: expected 64 hex chars (optionally with `:` separators), got {}",
lineno + 1,
hex_clean.len()
);
}
let mut bytes = [0u8; 32];
for i in 0..32 {
bytes[i] = u8::from_str_radix(&hex_clean[i * 2..i * 2 + 2], 16)
.with_context(|| format!("mTLS allowlist line {}: invalid hex", lineno + 1))?;
}
set.insert(bytes);
}
Ok(set)
}
fn rustls_pki_pem_iter_certs(
pem: &[u8],
) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
use rustls::pki_types::pem::PemObject as _;
let mut cursor = std::io::Cursor::new(pem);
let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_reader_iter(&mut cursor)
.collect::<std::result::Result<Vec<_>, _>>()
.context("failed to parse TLS cert PEM")?;
if certs.is_empty() {
anyhow::bail!("TLS cert PEM contained no certificates");
}
Ok(certs)
}
fn rustls_pki_pem_parse_private_key(
pem: &[u8],
) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
use rustls::pki_types::pem::PemObject as _;
let mut cursor = std::io::Cursor::new(pem);
let key = rustls::pki_types::PrivateKeyDer::from_pem_reader(&mut cursor)
.context("failed to parse TLS key PEM — expected PKCS#8, RSA, or SEC1")?;
Ok(key)
}
#[derive(Debug)]
struct FingerprintAllowlistVerifier {
allowlist: std::collections::HashSet<[u8; 32]>,
}
impl rustls::server::danger::ClientCertVerifier for FingerprintAllowlistVerifier {
fn offer_client_auth(&self) -> bool {
true
}
fn client_auth_mandatory(&self) -> bool {
true
}
fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] {
&[]
}
fn verify_client_cert(
&self,
end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_now: rustls::pki_types::UnixTime,
) -> std::result::Result<rustls::server::danger::ClientCertVerified, rustls::Error> {
use sha2::{Digest, Sha256};
let fp: [u8; 32] = Sha256::digest(end_entity.as_ref()).into();
if self.allowlist.contains(&fp) {
Ok(rustls::server::danger::ClientCertVerified::assertion())
} else {
Err(rustls::Error::General(format!(
"client cert fingerprint {} not in mTLS allowlist",
hex_short(&fp)
)))
}
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
fn hex_short(fp: &[u8; 32]) -> String {
use std::fmt::Write as _;
let mut s = String::with_capacity(12);
for b in &fp[..6] {
let _ = write!(s, "{b:02x}");
}
s.push('…');
s
}
async fn build_rustls_client_config(
cert_path: &Path,
key_path: &Path,
) -> Result<rustls::ClientConfig> {
let cert_pem = tokio::fs::read(cert_path)
.await
.with_context(|| format!("failed to read client cert from {}", cert_path.display()))?;
let key_pem = tokio::fs::read(key_path)
.await
.with_context(|| format!("failed to read client key from {}", key_path.display()))?;
let certs = rustls_pki_pem_iter_certs(&cert_pem)?;
let key = rustls_pki_pem_parse_private_key(&key_pem)?;
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(DangerousAnyServerVerifier))
.with_client_auth_cert(certs, key)
.context("failed to build rustls ClientConfig with client cert")?;
Ok(config)
}
#[derive(Debug)]
struct DangerousAnyServerVerifier;
impl rustls::client::danger::ServerCertVerifier for DangerousAnyServerVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
#[allow(clippy::too_many_lines)]
fn cmd_store(
db_path: &Path,
args: StoreArgs,
json_out: bool,
app_config: &config::AppConfig,
cli_agent_id: Option<&str>,
) -> Result<()> {
let conn = db::open(db_path)?;
let resolved_ttl = app_config.effective_ttl();
let _ = db::gc_if_needed(&conn, app_config.effective_archive_on_gc());
let tier = Tier::from_str(&args.tier)
.ok_or_else(|| anyhow::anyhow!("invalid tier: {} (use short, mid, long)", args.tier))?;
let namespace = args.namespace.unwrap_or_else(auto_namespace);
let content = if args.content == "-" {
use std::io::Read;
let mut buf = String::new();
std::io::stdin().read_to_string(&mut buf)?;
buf
} else {
args.content
};
let tags: Vec<String> = args
.tags
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
validate::validate_title(&args.title)?;
validate::validate_content(&content)?;
validate::validate_namespace(&namespace)?;
validate::validate_source(&args.source)?;
validate::validate_tags(&tags)?;
validate::validate_priority(args.priority)?;
validate::validate_confidence(args.confidence)?;
validate::validate_expires_at(args.expires_at.as_deref())?;
validate::validate_ttl_secs(args.ttl_secs)?;
let now = Utc::now();
let expires_at = args.expires_at.or_else(|| {
args.ttl_secs
.or(resolved_ttl.ttl_for_tier(&tier))
.map(|s| (now + Duration::seconds(s)).to_rfc3339())
});
let agent_id = identity::resolve_agent_id(cli_agent_id, None)?;
let mut metadata = models::default_metadata();
if let Some(obj) = metadata.as_object_mut() {
obj.insert(
"agent_id".to_string(),
serde_json::Value::String(agent_id.clone()),
);
}
if let Some(ref s) = args.scope {
validate::validate_scope(s)?;
if let Some(obj) = metadata.as_object_mut() {
obj.insert("scope".to_string(), serde_json::Value::String(s.clone()));
}
}
let mem = models::Memory {
id: uuid::Uuid::new_v4().to_string(),
tier,
namespace,
title: args.title,
content,
tags,
priority: args.priority.clamp(1, 10),
confidence: args.confidence.clamp(0.0, 1.0),
source: args.source,
access_count: 0,
created_at: now.to_rfc3339(),
updated_at: now.to_rfc3339(),
last_accessed_at: None,
expires_at,
metadata,
};
{
use models::{GovernanceDecision, GovernedAction};
let payload = serde_json::to_value(&mem).unwrap_or_default();
match db::enforce_governance(
&conn,
GovernedAction::Store,
&mem.namespace,
&agent_id,
None,
None,
&payload,
)? {
GovernanceDecision::Allow => {}
GovernanceDecision::Deny(reason) => {
eprintln!("store denied by governance: {reason}");
std::process::exit(1);
}
GovernanceDecision::Pending(pending_id) => {
if json_out {
println!(
"{}",
serde_json::json!({
"status": "pending",
"pending_id": pending_id,
"reason": "governance requires approval",
"action": "store",
"namespace": &mem.namespace,
})
);
} else {
println!(
"store queued for approval: pending_id={pending_id} ns={}",
&mem.namespace
);
}
return Ok(());
}
}
}
let contradictions =
db::find_contradictions(&conn, &mem.title, &mem.namespace).unwrap_or_default();
let actual_id = db::insert(&conn, &mem)?;
let filtered: Vec<&String> = contradictions
.iter()
.filter(|c| c.id != mem.id && c.id != actual_id)
.map(|c| &c.id)
.collect();
if json_out {
let mut j = serde_json::to_value(&mem)?;
j["id"] = serde_json::json!(actual_id);
let filtered: Vec<&String> = contradictions
.iter()
.filter(|c| c.id != actual_id)
.map(|c| &c.id)
.collect();
if !filtered.is_empty() {
j["potential_contradictions"] = serde_json::json!(filtered);
}
println!("{}", serde_json::to_string(&j)?);
} else {
println!(
"stored: {} [{}] (ns={})",
actual_id, mem.tier, mem.namespace
);
if !filtered.is_empty() {
eprintln!(
"warning: {} similar memories found in same namespace (potential contradictions)",
filtered.len()
);
}
}
Ok(())
}
fn cmd_update(db_path: &Path, args: &UpdateArgs, json_out: bool) -> Result<()> {
validate::validate_id(&args.id)?;
let conn = db::open(db_path)?;
let resolved_id = if db::get(&conn, &args.id)?.is_some() {
args.id.clone()
} else if let Some(mem) = db::get_by_prefix(&conn, &args.id)? {
mem.id
} else {
eprintln!("not found: {}", args.id);
std::process::exit(1);
};
let tier = args.tier.as_deref().and_then(Tier::from_str);
let tags: Option<Vec<String>> = args.tags.as_ref().map(|t| {
t.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect()
});
if let Some(ref t) = args.title {
validate::validate_title(t)?;
}
if let Some(ref c) = args.content {
validate::validate_content(c)?;
}
if let Some(ref ns) = args.namespace {
validate::validate_namespace(ns)?;
}
if let Some(ref tags) = tags {
validate::validate_tags(tags)?;
}
if let Some(p) = args.priority {
validate::validate_priority(p)?;
}
if let Some(c) = args.confidence {
validate::validate_confidence(c)?;
}
if let Some(ref ts) = args.expires_at
&& !ts.is_empty()
{
validate::validate_expires_at_format(ts)?;
}
let (found, _content_changed) = db::update(
&conn,
&resolved_id,
args.title.as_deref(),
args.content.as_deref(),
tier.as_ref(),
args.namespace.as_deref(),
tags.as_ref(),
args.priority,
args.confidence,
args.expires_at.as_deref(),
None,
)?;
if !found {
eprintln!("not found: {}", args.id);
std::process::exit(1);
}
if let Some(mem) = db::get(&conn, &resolved_id)? {
if json_out {
println!("{}", serde_json::to_string(&mem)?);
} else {
println!("updated: {} [{}]", mem.id, mem.title);
}
}
Ok(())
}
#[allow(clippy::too_many_lines)]
fn cmd_recall(
db_path: &Path,
args: &RecallArgs,
json_out: bool,
app_config: &config::AppConfig,
) -> Result<()> {
if let Some(ref a) = args.as_agent {
validate::validate_namespace(a)?;
}
let conn = db::open(db_path)?;
let _ = db::gc_if_needed(&conn, app_config.effective_archive_on_gc());
let feature_tier = app_config.effective_tier(args.tier.as_deref());
let tier_config = feature_tier.config();
let embedder = if let Some(ref emb_model) = tier_config.embedding_model {
let ollama_client = if tier_config.llm_model.is_some() {
let ollama_url = app_config.effective_ollama_url();
llm::OllamaClient::new_with_url(ollama_url, "nomic-embed-text")
.ok()
.map(Arc::new)
} else {
None
};
let embed_client = {
let embed_url = app_config.effective_embed_url();
let ollama_url = app_config.effective_ollama_url();
if embed_url == ollama_url {
ollama_client.clone()
} else {
llm::OllamaClient::new_with_url(embed_url, "nomic-embed-text")
.ok()
.map(Arc::new)
.or(ollama_client.clone())
}
};
match embeddings::Embedder::for_model(*emb_model, embed_client) {
Ok(emb) => {
eprintln!("ai-memory: embedder loaded ({})", emb.model_description());
if let Ok(unembedded) = db::get_unembedded_ids(&conn)
&& !unembedded.is_empty()
{
eprintln!("ai-memory: backfilling {} memories...", unembedded.len());
let mut ok = 0usize;
for (id, title, content) in &unembedded {
let text = format!("{title} {content}");
if let Ok(embedding) = emb.embed(&text)
&& db::set_embedding(&conn, id, &embedding).is_ok()
{
ok += 1;
}
}
eprintln!("ai-memory: backfilled {}/{}", ok, unembedded.len());
}
Some(emb)
}
Err(e) => {
eprintln!("ai-memory: embedder failed: {e}, falling back to keyword");
None
}
}
} else {
None
};
let vector_index = if embedder.is_some() {
match db::get_all_embeddings(&conn) {
Ok(entries) if !entries.is_empty() => Some(hnsw::VectorIndex::build(entries)),
_ => Some(hnsw::VectorIndex::empty()),
}
} else {
None
};
let reranker = if tier_config.cross_encoder {
Some(reranker::CrossEncoder::new_neural())
} else {
None
};
let resolved_ttl = app_config.effective_ttl();
let resolved_scoring = app_config.effective_scoring();
let (results, tokens_used, mode) = if let Some(ref emb) = embedder {
match emb.embed(&args.context) {
Ok(primary_emb) => {
let query_emb = match args.context_tokens.as_deref() {
Some(tokens) if !tokens.is_empty() => {
let joined = tokens.join(" ");
match emb.embed(&joined) {
Ok(ctx_emb) => embeddings::Embedder::fuse(&primary_emb, &ctx_emb, 0.7),
Err(e) => {
eprintln!(
"ai-memory: context_tokens embed failed: {e}, using primary only"
);
primary_emb
}
}
}
_ => primary_emb,
};
let (results, tokens_used) = db::recall_hybrid(
&conn,
&args.context,
&query_emb,
args.namespace.as_deref(),
args.limit.min(50),
args.tags.as_deref(),
args.since.as_deref(),
args.until.as_deref(),
vector_index.as_ref(),
resolved_ttl.short_extend_secs,
resolved_ttl.mid_extend_secs,
args.as_agent.as_deref(),
args.budget_tokens,
&resolved_scoring,
)?;
if let Some(ref ce) = reranker {
(
ce.rerank(&args.context, results),
tokens_used,
"hybrid+rerank",
)
} else {
(results, tokens_used, "hybrid")
}
}
Err(e) => {
eprintln!("ai-memory: embedding query failed: {e}, falling back to keyword");
let (results, tokens_used) = db::recall(
&conn,
&args.context,
args.namespace.as_deref(),
args.limit,
args.tags.as_deref(),
args.since.as_deref(),
args.until.as_deref(),
resolved_ttl.short_extend_secs,
resolved_ttl.mid_extend_secs,
args.as_agent.as_deref(),
args.budget_tokens,
)?;
(results, tokens_used, "keyword")
}
}
} else {
let (results, tokens_used) = db::recall(
&conn,
&args.context,
args.namespace.as_deref(),
args.limit,
args.tags.as_deref(),
args.since.as_deref(),
args.until.as_deref(),
resolved_ttl.short_extend_secs,
resolved_ttl.mid_extend_secs,
args.as_agent.as_deref(),
args.budget_tokens,
)?;
(results, tokens_used, "keyword")
};
if json_out {
let scored: Vec<serde_json::Value> = results
.iter()
.map(|(m, s)| {
let mut v = serde_json::to_value(m).unwrap_or_default();
if let Some(obj) = v.as_object_mut() {
obj.insert(
"score".to_string(),
serde_json::json!((s * 1000.0).round() / 1000.0),
);
}
v
})
.collect();
let mut body = serde_json::json!({
"memories": scored,
"count": results.len(),
"mode": mode,
"tokens_used": tokens_used,
});
if let Some(b) = args.budget_tokens {
body["budget_tokens"] = serde_json::json!(b);
}
println!("{}", serde_json::to_string(&body)?);
return Ok(());
}
if results.is_empty() {
eprintln!("no memories found for: {}", args.context);
return Ok(());
}
for (mem, score) in &results {
let age = human_age(&mem.updated_at);
let config = if mem.confidence < 1.0 {
format!(" conf={:.0}%", mem.confidence * 100.0)
} else {
String::new()
};
println!(
"[{}] {} {} score={:.2} (ns={}, {}x, {}{})",
color::tier_color(
mem.tier.as_str(),
&format!("{}/{}", mem.tier, id_short(&mem.id))
),
color::bold(&mem.title),
color::priority_bar(mem.priority),
score,
color::cyan(&mem.namespace),
mem.access_count,
color::dim(&age),
config
);
let preview: String = mem.content.chars().take(200).collect();
println!(" {}\n", color::dim(&preview));
}
println!("{} memory(ies) recalled [{}]", results.len(), mode);
Ok(())
}
fn cmd_search(
db_path: &Path,
args: &SearchArgs,
json_out: bool,
app_config: &config::AppConfig,
) -> Result<()> {
if let Some(ref aid) = args.agent_id {
validate::validate_agent_id(aid)?;
}
if let Some(ref a) = args.as_agent {
validate::validate_namespace(a)?;
}
let conn = db::open(db_path)?;
let _ = db::gc_if_needed(&conn, app_config.effective_archive_on_gc());
let tier = args.tier.as_deref().and_then(Tier::from_str);
let results = db::search(
&conn,
&args.query,
args.namespace.as_deref(),
tier.as_ref(),
args.limit,
None,
args.since.as_deref(),
args.until.as_deref(),
args.tags.as_deref(),
args.agent_id.as_deref(),
args.as_agent.as_deref(),
)?;
if json_out {
println!(
"{}",
serde_json::to_string(
&serde_json::json!({"results": results, "count": results.len()})
)?
);
return Ok(());
}
if results.is_empty() {
eprintln!("no results for: {}", args.query);
return Ok(());
}
for mem in &results {
let age = human_age(&mem.updated_at);
println!(
"[{}/{}] {} (p={}, ns={}, {})",
mem.tier,
id_short(&mem.id),
mem.title,
mem.priority,
mem.namespace,
age
);
}
println!("\n{} result(s)", results.len());
Ok(())
}
fn cmd_get(db_path: &Path, args: &GetArgs, json_out: bool) -> Result<()> {
validate::validate_id(&args.id)?;
let conn = db::open(db_path)?;
if let Some(mem) = db::resolve_id(&conn, &args.id)? {
let links = db::get_links(&conn, &mem.id).unwrap_or_default();
if json_out {
println!(
"{}",
serde_json::to_string(&serde_json::json!({"memory": mem, "links": links}))?
);
} else {
println!("{}", serde_json::to_string_pretty(&mem)?);
if !links.is_empty() {
println!("\nlinks:");
for l in &links {
println!(" {} --[{}]--> {}", l.source_id, l.relation, l.target_id);
}
}
}
} else {
eprintln!("not found: {}", args.id);
std::process::exit(1);
}
Ok(())
}
fn cmd_list(
db_path: &Path,
args: &ListArgs,
json_out: bool,
app_config: &config::AppConfig,
) -> Result<()> {
if let Some(ref aid) = args.agent_id {
validate::validate_agent_id(aid)?;
}
let conn = db::open(db_path)?;
let _ = db::gc_if_needed(&conn, app_config.effective_archive_on_gc());
let tier = args.tier.as_deref().and_then(Tier::from_str);
let results = db::list(
&conn,
args.namespace.as_deref(),
tier.as_ref(),
args.limit,
args.offset,
None,
args.since.as_deref(),
args.until.as_deref(),
args.tags.as_deref(),
args.agent_id.as_deref(),
)?;
if json_out {
println!(
"{}",
serde_json::to_string(
&serde_json::json!({"memories": results, "count": results.len()})
)?
);
return Ok(());
}
if results.is_empty() {
eprintln!("no memories stored");
return Ok(());
}
for mem in &results {
let age = human_age(&mem.updated_at);
println!(
"[{}/{}] {} (p={}, ns={}, {})",
mem.tier,
id_short(&mem.id),
mem.title,
mem.priority,
mem.namespace,
age
);
}
println!("\n{} memory(ies)", results.len());
Ok(())
}
fn cmd_delete(
db_path: &Path,
args: &DeleteArgs,
json_out: bool,
cli_agent_id: Option<&str>,
) -> Result<()> {
validate::validate_id(&args.id)?;
let conn = db::open(db_path)?;
let target = db::resolve_id(&conn, &args.id)?;
let Some(target) = target else {
eprintln!("not found: {}", args.id);
std::process::exit(1);
};
{
use models::{GovernanceDecision, GovernedAction};
let caller_agent_id = identity::resolve_agent_id(cli_agent_id, None)?;
let mem_owner = target
.metadata
.get("agent_id")
.and_then(|v| v.as_str())
.map(str::to_string);
let payload = serde_json::json!({"id": target.id, "title": target.title});
match db::enforce_governance(
&conn,
GovernedAction::Delete,
&target.namespace,
&caller_agent_id,
Some(&target.id),
mem_owner.as_deref(),
&payload,
)? {
GovernanceDecision::Allow => {}
GovernanceDecision::Deny(reason) => {
eprintln!("delete denied by governance: {reason}");
std::process::exit(1);
}
GovernanceDecision::Pending(pending_id) => {
if json_out {
println!(
"{}",
serde_json::json!({
"status": "pending",
"pending_id": pending_id,
"reason": "governance requires approval",
"action": "delete",
"memory_id": target.id,
})
);
} else {
println!(
"delete queued for approval: pending_id={pending_id} id={}",
target.id
);
}
return Ok(());
}
}
}
if db::delete(&conn, &target.id)? {
if json_out {
println!("{}", serde_json::json!({"deleted": true, "id": target.id}));
} else {
println!("deleted: {}", target.id);
}
} else {
eprintln!("not found: {}", args.id);
std::process::exit(1);
}
Ok(())
}
#[allow(clippy::too_many_lines)]
fn cmd_promote(
db_path: &Path,
args: &PromoteArgs,
json_out: bool,
cli_agent_id: Option<&str>,
) -> Result<()> {
validate::validate_id(&args.id)?;
if let Some(ref to_ns) = args.to_namespace {
validate::validate_namespace(to_ns)?;
}
let conn = db::open(db_path)?;
let target = if let Some(m) = db::get(&conn, &args.id)? {
m
} else if let Some(m) = db::get_by_prefix(&conn, &args.id)? {
m
} else {
eprintln!("not found: {}", args.id);
std::process::exit(1);
};
let resolved_id = target.id.clone();
{
use models::{GovernanceDecision, GovernedAction};
let caller_agent_id = identity::resolve_agent_id(cli_agent_id, None)?;
let mem_owner = target
.metadata
.get("agent_id")
.and_then(|v| v.as_str())
.map(str::to_string);
let payload = serde_json::json!({
"id": resolved_id,
"to_namespace": args.to_namespace,
});
match db::enforce_governance(
&conn,
GovernedAction::Promote,
&target.namespace,
&caller_agent_id,
Some(&resolved_id),
mem_owner.as_deref(),
&payload,
)? {
GovernanceDecision::Allow => {}
GovernanceDecision::Deny(reason) => {
eprintln!("promote denied by governance: {reason}");
std::process::exit(1);
}
GovernanceDecision::Pending(pending_id) => {
if json_out {
println!(
"{}",
serde_json::json!({
"status": "pending",
"pending_id": pending_id,
"reason": "governance requires approval",
"action": "promote",
"memory_id": resolved_id,
})
);
} else {
println!(
"promote queued for approval: pending_id={pending_id} id={resolved_id}"
);
}
return Ok(());
}
}
}
if let Some(ref to_ns) = args.to_namespace {
let clone_id = db::promote_to_namespace(&conn, &resolved_id, to_ns)?;
if json_out {
println!(
"{}",
serde_json::to_string(&serde_json::json!({
"promoted": true,
"mode": "vertical",
"source_id": resolved_id,
"clone_id": clone_id,
"to_namespace": to_ns,
}))?
);
} else {
println!(
"promoted (vertical): {} → {} (clone: {})",
id_short(&resolved_id),
to_ns,
id_short(&clone_id),
);
}
return Ok(());
}
let (found, _) = db::update(
&conn,
&resolved_id,
None,
None,
Some(&Tier::Long),
None,
None,
None,
None,
Some(""),
None,
)?;
if !found {
eprintln!("not found: {}", args.id);
std::process::exit(1);
}
if json_out {
println!(
"{}",
serde_json::json!({"promoted": true, "id": resolved_id, "tier": "long"})
);
} else {
println!("promoted to long-term: {resolved_id}");
}
Ok(())
}
fn cmd_forget(db_path: &Path, args: &ForgetArgs, json_out: bool) -> Result<()> {
let tier = args.tier.as_deref().and_then(Tier::from_str);
let conn = db::open(db_path)?;
match db::forget(
&conn,
args.namespace.as_deref(),
args.pattern.as_deref(),
tier.as_ref(),
true, ) {
Ok(n) => {
if json_out {
println!("{}", serde_json::json!({"deleted": n}));
} else {
println!("forgot {n} memories");
}
}
Err(e) => {
eprintln!("error: {e}");
std::process::exit(1);
}
}
Ok(())
}
fn cmd_link(db_path: &Path, args: &LinkArgs, json_out: bool) -> Result<()> {
validate::validate_link(&args.source_id, &args.target_id, &args.relation)?;
let conn = db::open(db_path)?;
db::create_link(&conn, &args.source_id, &args.target_id, &args.relation)?;
if json_out {
println!("{}", serde_json::json!({"linked": true}));
} else {
println!(
"linked: {} --[{}]--> {}",
args.source_id, args.relation, args.target_id
);
}
Ok(())
}
fn cmd_consolidate(
db_path: &Path,
args: ConsolidateArgs,
json_out: bool,
cli_agent_id: Option<&str>,
) -> Result<()> {
let ids: Vec<String> = args
.ids
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
let namespace = args.namespace.unwrap_or_else(auto_namespace);
validate::validate_consolidate(&ids, &args.title, &args.summary, &namespace)?;
let conn = db::open(db_path)?;
let consolidator_agent_id = identity::resolve_agent_id(cli_agent_id, None)?;
let new_id = db::consolidate(
&conn,
&ids,
&args.title,
&args.summary,
&namespace,
&Tier::Long,
"cli",
&consolidator_agent_id,
)?;
if json_out {
println!(
"{}",
serde_json::json!({"id": new_id, "consolidated": ids.len()})
);
} else {
println!("consolidated {} memories into: {}", ids.len(), new_id);
}
Ok(())
}
fn cmd_gc(db_path: &Path, json_out: bool, app_config: &config::AppConfig) -> Result<()> {
let conn = db::open(db_path)?;
let count = db::gc(&conn, app_config.effective_archive_on_gc())?;
if json_out {
println!("{}", serde_json::json!({"expired_deleted": count}));
} else {
println!("expired memories deleted: {count}");
}
Ok(())
}
fn cmd_stats(db_path: &Path, json_out: bool) -> Result<()> {
let conn = db::open(db_path)?;
let stats = db::stats(&conn, db_path)?;
if json_out {
println!("{}", serde_json::to_string(&stats)?);
return Ok(());
}
println!("total memories: {}", stats.total);
println!("expiring within 1h: {}", stats.expiring_soon);
println!("links: {}", stats.links_count);
println!("database size: {} bytes", stats.db_size_bytes);
println!("\nby tier:");
for t in &stats.by_tier {
println!(" {}: {}", t.tier, t.count);
}
println!("\nby namespace:");
for ns in &stats.by_namespace {
println!(" {}: {}", ns.namespace, ns.count);
}
Ok(())
}
fn cmd_namespaces(db_path: &Path, json_out: bool) -> Result<()> {
let conn = db::open(db_path)?;
let ns = db::list_namespaces(&conn)?;
if json_out {
println!(
"{}",
serde_json::to_string(&serde_json::json!({"namespaces": ns}))?
);
return Ok(());
}
if ns.is_empty() {
eprintln!("no namespaces");
} else {
for n in &ns {
println!(" {}: {} memories", n.namespace, n.count);
}
}
Ok(())
}
fn cmd_export(db_path: &Path) -> Result<()> {
let conn = db::open(db_path)?;
let memories = db::export_all(&conn)?;
let links = db::export_links(&conn)?;
println!(
"{}",
serde_json::to_string_pretty(&serde_json::json!({
"memories": memories, "links": links, "count": memories.len(),
"exported_at": Utc::now().to_rfc3339(),
}))?
);
Ok(())
}
fn cmd_import(
db_path: &Path,
args: &ImportArgs,
json_out: bool,
cli_agent_id: Option<&str>,
) -> Result<()> {
use std::io::Read;
let mut buf = String::new();
std::io::stdin().read_to_string(&mut buf)?;
let data: serde_json::Value = serde_json::from_str(&buf)?;
let memories: Vec<models::Memory> =
serde_json::from_value(data.get("memories").cloned().unwrap_or_default())?;
let links: Vec<models::MemoryLink> =
serde_json::from_value(data.get("links").cloned().unwrap_or_default()).unwrap_or_default();
let caller_id = identity::resolve_agent_id(cli_agent_id, None)?;
let conn = db::open(db_path)?;
let mut imported = 0usize;
let mut restamped = 0usize;
let mut errors = Vec::new();
for mut mem in memories {
if !args.trust_source {
let original = mem
.metadata
.get("agent_id")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
if let Some(obj) = mem.metadata.as_object_mut() {
obj.insert(
"agent_id".to_string(),
serde_json::Value::String(caller_id.clone()),
);
if let Some(orig) = original.as_ref()
&& orig.as_str() != caller_id
{
obj.insert(
"imported_from_agent_id".to_string(),
serde_json::Value::String(orig.clone()),
);
restamped += 1;
}
}
}
if let Err(e) = validate::validate_memory(&mem) {
errors.push(format!("{}: {}", mem.id, e));
continue;
}
match db::insert(&conn, &mem) {
Ok(_) => imported += 1,
Err(e) => errors.push(format!("{}: {}", mem.id, e)),
}
}
for link in links {
if validate::validate_link(&link.source_id, &link.target_id, &link.relation).is_err() {
continue;
}
let _ = db::create_link(&conn, &link.source_id, &link.target_id, &link.relation);
}
if json_out {
println!(
"{}",
serde_json::json!({
"imported": imported,
"restamped": restamped,
"trusted_source": args.trust_source,
"errors": errors
})
);
} else {
println!("imported: {imported} (restamped agent_id on {restamped})");
if args.trust_source {
eprintln!("warning: --trust-source: agent_id from imported JSON was preserved as-is");
}
if !errors.is_empty() {
for e in &errors {
eprintln!(" {e}");
}
}
}
Ok(())
}
fn cmd_resolve(db_path: &Path, args: &ResolveArgs, json_out: bool) -> Result<()> {
let conn = db::open(db_path)?;
validate::validate_link(&args.winner_id, &args.loser_id, "supersedes")?;
db::create_link(&conn, &args.winner_id, &args.loser_id, "supersedes")?;
let _ = db::update(
&conn,
&args.loser_id,
None,
None,
None,
None,
None,
Some(1),
Some(0.1),
None,
None,
)?;
db::touch(
&conn,
&args.winner_id,
models::SHORT_TTL_EXTEND_SECS,
models::MID_TTL_EXTEND_SECS,
)?;
if json_out {
println!(
"{}",
serde_json::json!({"resolved": true, "winner": args.winner_id, "loser": args.loser_id})
);
} else {
println!(
"resolved: {} supersedes {}",
color::long(&args.winner_id),
color::dim(&args.loser_id)
);
}
Ok(())
}
#[allow(clippy::too_many_lines)]
fn cmd_shell(db_path: &Path) -> Result<()> {
let conn = db::open(db_path)?;
println!(
"{}",
color::bold("ai-memory shell — type 'help' for commands, 'quit' to exit")
);
let stdin = std::io::stdin();
loop {
eprint!("{} ", color::cyan("memory>"));
let mut line = String::new();
if stdin.read_line(&mut line)? == 0 {
break;
}
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.is_empty() {
continue;
}
match parts[0] {
"quit" | "exit" | "q" => break,
"help" | "h" => {
println!(" recall <context> — fuzzy recall");
println!(" search <query> — keyword search");
println!(" list [namespace] — list memories");
println!(" get <id> — show memory details");
println!(" stats — show statistics");
println!(" namespaces — list namespaces");
println!(" delete <id> — delete a memory");
println!(" quit — exit shell");
}
"recall" | "r" => {
let ctx = parts[1..].join(" ");
if ctx.is_empty() {
eprintln!("usage: recall <context>");
continue;
}
match db::recall(
&conn,
&ctx,
None,
10,
None,
None,
None,
models::SHORT_TTL_EXTEND_SECS,
models::MID_TTL_EXTEND_SECS,
None,
None,
) {
Ok((results, _tokens_used)) => {
for (mem, score) in &results {
println!(
" [{}] {} {} score={:.2}",
color::tier_color(mem.tier.as_str(), mem.tier.as_str()),
color::bold(&mem.title),
color::priority_bar(mem.priority),
score
);
let preview: String = mem.content.chars().take(100).collect();
println!(" {}", color::dim(&preview));
}
println!(" {} result(s)", results.len());
}
Err(e) => eprintln!("error: {e}"),
}
}
"search" | "s" => {
let q = parts[1..].join(" ");
if q.is_empty() {
eprintln!("usage: search <query>");
continue;
}
match db::search(
&conn, &q, None, None, 20, None, None, None, None, None, None,
) {
Ok(results) => {
for mem in &results {
println!(
" [{}] {} (p={})",
color::tier_color(mem.tier.as_str(), mem.tier.as_str()),
mem.title,
mem.priority
);
}
println!(" {} result(s)", results.len());
}
Err(e) => eprintln!("error: {e}"),
}
}
"list" | "ls" => {
let ns = parts.get(1).copied();
match db::list(&conn, ns, None, 20, 0, None, None, None, None, None) {
Ok(results) => {
for mem in &results {
let age = human_age(&mem.updated_at);
println!(
" [{}] {} (ns={}, {})",
color::tier_color(mem.tier.as_str(), mem.tier.as_str()),
mem.title,
mem.namespace,
color::dim(&age)
);
}
println!(" {} memory(ies)", results.len());
}
Err(e) => eprintln!("error: {e}"),
}
}
"get" => {
let id = parts.get(1).unwrap_or(&"");
if id.is_empty() {
eprintln!("usage: get <id>");
continue;
}
if let Err(e) = validate::validate_id(id) {
eprintln!("invalid id: {e}");
continue;
}
match db::get(&conn, id) {
Ok(Some(mem)) => {
println!("{}", serde_json::to_string_pretty(&mem).unwrap_or_default());
}
Ok(None) => eprintln!("not found"),
Err(e) => eprintln!("error: {e}"),
}
}
"stats" => match db::stats(&conn, db_path) {
Ok(s) => {
println!(" total: {}, links: {}", s.total, s.links_count);
for t in &s.by_tier {
println!(" {}: {}", color::tier_color(&t.tier, &t.tier), t.count);
}
}
Err(e) => eprintln!("error: {e}"),
},
"namespaces" | "ns" => match db::list_namespaces(&conn) {
Ok(ns) => {
for n in &ns {
println!(" {}: {}", color::cyan(&n.namespace), n.count);
}
}
Err(e) => eprintln!("error: {e}"),
},
"delete" | "del" | "rm" => {
let id = parts.get(1).unwrap_or(&"");
if id.is_empty() {
eprintln!("usage: delete <id>");
continue;
}
if let Err(e) = validate::validate_id(id) {
eprintln!("invalid id: {e}");
continue;
}
match db::delete(&conn, id) {
Ok(true) => println!(" deleted"),
Ok(false) => eprintln!(" not found"),
Err(e) => eprintln!("error: {e}"),
}
}
_ => eprintln!("unknown command: {}. Type 'help' for commands.", parts[0]),
}
}
println!("goodbye");
Ok(())
}
fn restamp_agent_id(mem: &mut models::Memory, caller_id: &str) {
let original = mem
.metadata
.get("agent_id")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
if let Some(obj) = mem.metadata.as_object_mut() {
obj.insert(
"agent_id".to_string(),
serde_json::Value::String(caller_id.to_string()),
);
if let Some(orig) = original
&& orig != caller_id
{
obj.insert(
"imported_from_agent_id".to_string(),
serde_json::Value::String(orig),
);
}
}
}
#[allow(clippy::too_many_lines)]
#[allow(clippy::struct_field_names)] #[derive(Default)]
struct SyncPreview {
would_pull_new: usize,
would_pull_update: usize,
would_pull_noop: usize,
would_push_new: usize,
would_push_update: usize,
would_push_noop: usize,
would_pull_links: usize,
would_push_links: usize,
}
impl SyncPreview {
fn classify(local: Option<&models::Memory>, remote: &models::Memory) -> MergeOutcome {
match local {
None => MergeOutcome::New,
Some(existing) => {
if remote.updated_at > existing.updated_at {
MergeOutcome::Update
} else {
MergeOutcome::Noop
}
}
}
}
}
enum MergeOutcome {
New,
Update,
Noop,
}
#[allow(clippy::too_many_lines)] fn cmd_sync(
db_path: &Path,
args: &SyncArgs,
json_out: bool,
cli_agent_id: Option<&str>,
) -> Result<()> {
let local_conn = db::open(db_path)?;
let remote_conn = db::open(&args.remote_db)?;
let caller_id = identity::resolve_agent_id(cli_agent_id, None)?;
if args.dry_run {
return cmd_sync_dry_run(&local_conn, &remote_conn, &args.direction, json_out);
}
match args.direction.as_str() {
"pull" => {
let mems = db::export_all(&remote_conn)?;
let links = db::export_links(&remote_conn)?;
let mut n = 0;
for mem in &mems {
let mut owned = mem.clone();
if !args.trust_source {
restamp_agent_id(&mut owned, &caller_id);
}
if let Err(e) = validate::validate_memory(&owned) {
tracing::warn!("sync: skipping invalid memory {}: {}", owned.id, e);
continue;
}
if db::insert(&local_conn, &owned).is_ok() {
n += 1;
}
}
for link in &links {
if validate::validate_link(&link.source_id, &link.target_id, &link.relation)
.is_err()
{
continue;
}
let _ = db::create_link(
&local_conn,
&link.source_id,
&link.target_id,
&link.relation,
);
}
if json_out {
println!(
"{}",
serde_json::json!({"direction": "pull", "imported": n})
);
} else {
println!("pulled {n} memories from remote");
}
}
"push" => {
let mems = db::export_all(&local_conn)?;
let links = db::export_links(&local_conn)?;
let mut n = 0;
for mem in &mems {
if let Err(e) = validate::validate_memory(mem) {
tracing::warn!("sync: skipping invalid memory {}: {}", mem.id, e);
continue;
}
if db::insert(&remote_conn, mem).is_ok() {
n += 1;
}
}
for link in &links {
if validate::validate_link(&link.source_id, &link.target_id, &link.relation)
.is_err()
{
continue;
}
let _ = db::create_link(
&remote_conn,
&link.source_id,
&link.target_id,
&link.relation,
);
}
if json_out {
println!(
"{}",
serde_json::json!({"direction": "push", "exported": n})
);
} else {
println!("pushed {n} memories to remote");
}
}
"merge" => {
let r_mems = db::export_all(&remote_conn)?;
let r_links = db::export_links(&remote_conn)?;
let l_mems = db::export_all(&local_conn)?;
let l_links = db::export_links(&local_conn)?;
let (mut pulled, mut pushed) = (0, 0);
for mem in &r_mems {
let mut owned = mem.clone();
if !args.trust_source {
restamp_agent_id(&mut owned, &caller_id);
}
if validate::validate_memory(&owned).is_err() {
continue;
}
if db::insert_if_newer(&local_conn, &owned).is_ok() {
pulled += 1;
}
}
for link in &r_links {
if validate::validate_link(&link.source_id, &link.target_id, &link.relation)
.is_err()
{
continue;
}
let _ = db::create_link(
&local_conn,
&link.source_id,
&link.target_id,
&link.relation,
);
}
for mem in &l_mems {
if validate::validate_memory(mem).is_err() {
continue;
}
if db::insert_if_newer(&remote_conn, mem).is_ok() {
pushed += 1;
}
}
for link in &l_links {
if validate::validate_link(&link.source_id, &link.target_id, &link.relation)
.is_err()
{
continue;
}
let _ = db::create_link(
&remote_conn,
&link.source_id,
&link.target_id,
&link.relation,
);
}
if json_out {
println!(
"{}",
serde_json::json!({"direction": "merge", "pulled": pulled, "pushed": pushed})
);
} else {
println!("merged: pulled {pulled}, pushed {pushed}");
}
}
_ => anyhow::bail!(
"invalid direction: {} (use pull, push, merge)",
args.direction
),
}
Ok(())
}
fn cmd_sync_dry_run(
local_conn: &rusqlite::Connection,
remote_conn: &rusqlite::Connection,
direction: &str,
json_out: bool,
) -> Result<()> {
let l_mems = db::export_all(local_conn)?;
let r_mems = db::export_all(remote_conn)?;
let l_links = db::export_links(local_conn)?;
let r_links = db::export_links(remote_conn)?;
let local_by_id: std::collections::HashMap<&str, &models::Memory> =
l_mems.iter().map(|m| (m.id.as_str(), m)).collect();
let remote_by_id: std::collections::HashMap<&str, &models::Memory> =
r_mems.iter().map(|m| (m.id.as_str(), m)).collect();
let mut preview = SyncPreview::default();
let classify_pull = direction != "push";
let classify_push = direction != "pull";
if classify_pull {
for mem in &r_mems {
match SyncPreview::classify(local_by_id.get(mem.id.as_str()).copied(), mem) {
MergeOutcome::New => preview.would_pull_new += 1,
MergeOutcome::Update => preview.would_pull_update += 1,
MergeOutcome::Noop => preview.would_pull_noop += 1,
}
}
preview.would_pull_links = r_links.len();
}
if classify_push {
for mem in &l_mems {
match SyncPreview::classify(remote_by_id.get(mem.id.as_str()).copied(), mem) {
MergeOutcome::New => preview.would_push_new += 1,
MergeOutcome::Update => preview.would_push_update += 1,
MergeOutcome::Noop => preview.would_push_noop += 1,
}
}
preview.would_push_links = l_links.len();
}
if json_out {
println!(
"{}",
serde_json::json!({
"dry_run": true,
"direction": direction,
"pull": {
"new": preview.would_pull_new,
"update": preview.would_pull_update,
"noop": preview.would_pull_noop,
"links": preview.would_pull_links,
},
"push": {
"new": preview.would_push_new,
"update": preview.would_push_update,
"noop": preview.would_push_noop,
"links": preview.would_push_links,
}
})
);
} else {
println!("DRY RUN — no changes written. Direction: {direction}");
if classify_pull {
println!(
" pull: {} new, {} update, {} noop, {} links",
preview.would_pull_new,
preview.would_pull_update,
preview.would_pull_noop,
preview.would_pull_links
);
}
if classify_push {
println!(
" push: {} new, {} update, {} noop, {} links",
preview.would_push_new,
preview.would_push_update,
preview.would_push_noop,
preview.would_push_links
);
}
}
Ok(())
}
async fn cmd_sync_daemon(
db_path: &Path,
args: SyncDaemonArgs,
cli_agent_id: Option<&str>,
) -> Result<()> {
if args.peers.is_empty() {
anyhow::bail!("at least one --peers URL is required");
}
let interval = args.interval.max(1);
let batch_size = args.batch_size.max(1);
let local_agent_id = identity::resolve_agent_id(cli_agent_id, None)?;
let _ = tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::from_default_env()
.add_directive("ai_memory=info".parse()?)
.add_directive("tower_http=info".parse()?),
)
.try_init();
let _ = rustls::crypto::ring::default_provider().install_default();
if args.insecure_skip_server_verify && (args.client_cert.is_none() || args.client_key.is_none())
{
anyhow::bail!(
"sync-daemon: --insecure-skip-server-verify requires both --client-cert \
and --client-key as a compensating mTLS control. Running with neither side \
of the TLS handshake verified is an open MITM surface and is refused."
);
}
let client = if let (Some(cert_path), Some(key_path)) = (&args.client_cert, &args.client_key) {
let rustls_config = build_rustls_client_config(cert_path, key_path).await?;
let mut builder = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.use_preconfigured_tls(rustls_config);
if args.insecure_skip_server_verify {
tracing::warn!(
"sync-daemon: --insecure-skip-server-verify set with --client-cert — \
peer server certificates will NOT be validated; peer authenticates us \
via mTLS allowlist (compensating control). Do NOT use in production."
);
builder = builder.danger_accept_invalid_certs(true);
}
builder.build()?
} else {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()?
};
tracing::info!(
"sync-daemon: local_agent_id={local_agent_id} peers={peers:?} interval={interval}s",
peers = args.peers
);
let mut shutdown = Box::pin(tokio::signal::ctrl_c());
let db_path_owned: Arc<Path> = Arc::from(db_path);
let local_agent_id_arc: Arc<str> = Arc::from(local_agent_id.as_str());
let api_key_arc: Option<Arc<str>> = args.api_key.as_deref().map(Arc::from);
let peers_arc: Vec<Arc<str>> = args.peers.iter().map(|s| Arc::from(s.as_str())).collect();
loop {
let mut set: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
for peer_url in &peers_arc {
let client = client.clone();
let db_path = db_path_owned.clone();
let local_agent_id = local_agent_id_arc.clone();
let peer_url = peer_url.clone();
let api_key = api_key_arc.clone();
set.spawn(async move {
if let Err(e) = sync_cycle_once(
&client,
&db_path,
&local_agent_id,
&peer_url,
api_key.as_deref(),
batch_size,
)
.await
{
tracing::warn!("sync-daemon: peer {peer_url} cycle failed: {e}");
}
});
}
while set.join_next().await.is_some() {}
tokio::select! {
() = tokio::time::sleep(std::time::Duration::from_secs(interval)) => {}
_ = &mut shutdown => {
tracing::info!("sync-daemon: shutdown signal received");
return Ok(());
}
}
}
}
async fn sync_cycle_once(
client: &reqwest::Client,
db_path: &Path,
local_agent_id: &str,
peer_url: &str,
api_key: Option<&str>,
batch_size: usize,
) -> Result<()> {
let peer_url = peer_url.trim_end_matches('/');
let since = {
let conn = db::open(db_path)?;
db::sync_state_load(&conn, local_agent_id)?
.entries
.get(peer_url)
.cloned()
};
let mut pull_url = format!(
"{peer_url}/api/v1/sync/since?limit={batch_size}&peer={}",
urlencoding_minimal(local_agent_id)
);
if let Some(ref s) = since {
pull_url.push_str("&since=");
pull_url.push_str(&urlencoding_minimal(s));
}
let mut req = client.get(&pull_url).header("x-agent-id", local_agent_id);
if let Some(key) = api_key {
req = req.header("x-api-key", key);
}
let resp = req.send().await?;
if !resp.status().is_success() {
anyhow::bail!("sync-daemon: pull status {}", resp.status());
}
let pulled: SyncSinceResponse = resp.json().await?;
let pull_count = pulled.memories.len();
let latest_pulled = pulled.memories.last().map(|m| m.updated_at.clone());
{
let conn = db::open(db_path)?;
for mem in &pulled.memories {
if validate::validate_memory(mem).is_ok() {
let _ = db::insert_if_newer(&conn, mem);
}
}
if let Some(ref at) = latest_pulled {
db::sync_state_observe(&conn, local_agent_id, peer_url, at)?;
}
}
let last_pushed = {
let conn = db::open(db_path)?;
db::sync_state_last_pushed(&conn, local_agent_id, peer_url)
};
let outgoing = {
let conn = db::open(db_path)?;
db::memories_updated_since(&conn, last_pushed.as_deref(), batch_size)?
};
let push_count = outgoing.len();
let latest_pushed = outgoing.last().map(|m| m.updated_at.clone());
if !outgoing.is_empty() {
let body = serde_json::json!({
"sender_agent_id": local_agent_id,
"sender_clock": { "entries": {} },
"memories": outgoing,
"dry_run": false,
});
let mut req = client
.post(format!("{peer_url}/api/v1/sync/push"))
.header("x-agent-id", local_agent_id)
.header("content-type", "application/json")
.json(&body);
if let Some(key) = api_key {
req = req.header("x-api-key", key);
}
let resp = req.send().await?;
if !resp.status().is_success() {
anyhow::bail!("sync-daemon: push status {}", resp.status());
}
if let Some(at) = latest_pushed {
let conn = db::open(db_path)?;
db::sync_state_record_push(&conn, local_agent_id, peer_url, &at)?;
}
}
tracing::info!("sync-daemon: peer={peer_url} pulled={pull_count} pushed={push_count}");
Ok(())
}
fn urlencoding_minimal(s: &str) -> String {
use std::fmt::Write as _;
let mut out = String::with_capacity(s.len());
for b in s.bytes() {
match b {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
out.push(b as char);
}
_ => {
let _ = write!(out, "%{b:02X}");
}
}
}
out
}
#[derive(serde::Deserialize)]
struct SyncSinceResponse {
#[allow(dead_code)]
count: usize,
#[allow(dead_code)]
limit: usize,
memories: Vec<models::Memory>,
}
#[allow(clippy::too_many_lines)]
fn cmd_auto_consolidate(
db_path: &Path,
args: &AutoConsolidateArgs,
json_out: bool,
cli_agent_id: Option<&str>,
) -> Result<()> {
let conn = db::open(db_path)?;
let consolidator_agent_id = identity::resolve_agent_id(cli_agent_id, None)?;
let tier_filter = if args.short_only {
Some(Tier::Short)
} else {
None
};
let namespaces = if let Some(ref ns) = args.namespace {
vec![models::NamespaceCount {
namespace: ns.clone(),
count: 0,
}]
} else {
db::list_namespaces(&conn)?
};
let mut total = 0;
let mut groups = Vec::new();
for ns in &namespaces {
let memories = db::list(
&conn,
Some(&ns.namespace),
tier_filter.as_ref(),
200,
0,
None,
None,
None,
None,
None,
)?;
if memories.len() < args.min_count {
continue;
}
let mut tag_groups: std::collections::HashMap<String, Vec<&models::Memory>> =
std::collections::HashMap::new();
for mem in &memories {
if mem.tags.is_empty() {
tag_groups
.entry("_untagged".to_string())
.or_default()
.push(mem);
} else {
for tag in &mem.tags {
tag_groups.entry(tag.clone()).or_default().push(mem);
}
}
}
let mut consolidated_ids: std::collections::HashSet<String> =
std::collections::HashSet::new();
for (tag, group) in &tag_groups {
let group: Vec<&&models::Memory> = group
.iter()
.filter(|m| !consolidated_ids.contains(&m.id))
.collect();
if group.len() < args.min_count {
continue;
}
let ids: Vec<String> = group.iter().map(|m| m.id.clone()).collect();
if args.dry_run {
let titles: Vec<&str> = group.iter().map(|m| m.title.as_str()).collect();
groups.push(serde_json::json!({"namespace": ns.namespace, "tag": tag, "count": group.len(), "titles": titles}));
} else {
let title = format!(
"Consolidated: {} ({} memories)",
if tag == "_untagged" {
&ns.namespace
} else {
tag
},
group.len()
);
let content: String = group
.iter()
.map(|m| format!("- {}: {}", m.title, &m.content[..m.content.len().min(200)]))
.collect::<Vec<_>>()
.join("\n");
db::consolidate(
&conn,
&ids,
&title,
&content,
&ns.namespace,
&Tier::Long,
"auto-consolidate",
&consolidator_agent_id,
)?;
consolidated_ids.extend(ids);
total += group.len();
}
}
}
if json_out {
if args.dry_run {
println!("{}", serde_json::json!({"dry_run": true, "groups": groups}));
} else {
println!("{}", serde_json::json!({"consolidated": total}));
}
} else if args.dry_run {
println!("dry run — would consolidate:");
for g in &groups {
println!(
" {} [{}]: {} memories",
g["namespace"], g["tag"], g["count"]
);
}
} else {
println!("auto-consolidated {total} memories");
}
Ok(())
}
fn cmd_archive(db_path: &Path, args: ArchiveArgs, json_out: bool) -> Result<()> {
let conn = db::open(db_path)?;
match args.action {
ArchiveAction::List {
namespace,
limit,
offset,
} => {
let items = db::list_archived(&conn, namespace.as_deref(), limit, offset)?;
if json_out {
println!(
"{}",
serde_json::json!({"archived": items, "count": items.len()})
);
} else if items.is_empty() {
println!("no archived memories");
} else {
for item in &items {
println!(
"[{}] {} (archived: {})",
id_short(item["id"].as_str().unwrap_or("")),
item["title"].as_str().unwrap_or(""),
item["archived_at"].as_str().unwrap_or("")
);
}
println!("{} archived memories", items.len());
}
}
ArchiveAction::Restore { id } => {
validate::validate_id(&id)?;
let restored = db::restore_archived(&conn, &id)?;
if json_out {
println!("{}", serde_json::json!({"restored": restored, "id": id}));
} else if restored {
println!("restored: {}", id_short(&id));
} else {
eprintln!("not found in archive: {id}");
std::process::exit(1);
}
}
ArchiveAction::Purge { older_than_days } => {
let purged = db::purge_archive(&conn, older_than_days)?;
if json_out {
println!("{}", serde_json::json!({"purged": purged}));
} else {
println!("purged {purged} archived memories");
}
}
ArchiveAction::Stats => {
let stats = db::archive_stats(&conn)?;
if json_out {
println!("{stats}");
} else {
println!("archived: {} total", stats["archived_total"]);
if let Some(by_ns) = stats["by_namespace"].as_array() {
for ns in by_ns {
println!(
" {}: {}",
ns["namespace"].as_str().unwrap_or(""),
ns["count"]
);
}
}
}
}
}
Ok(())
}
fn cmd_agents(db_path: &Path, args: AgentsArgs, json_out: bool) -> Result<()> {
let conn = db::open(db_path)?;
match args.action.unwrap_or(AgentsAction::List) {
AgentsAction::List => {
let agents = db::list_agents(&conn)?;
if json_out {
println!(
"{}",
serde_json::json!({"count": agents.len(), "agents": agents})
);
} else if agents.is_empty() {
println!("no registered agents");
} else {
for a in &agents {
let caps = if a.capabilities.is_empty() {
String::new()
} else {
format!(" [{}]", a.capabilities.join(","))
};
println!(
"{} type={} registered={} last_seen={}{}",
a.agent_id, a.agent_type, a.registered_at, a.last_seen_at, caps
);
}
println!("{} registered agents", agents.len());
}
}
AgentsAction::Register {
agent_id,
agent_type,
capabilities,
} => {
validate::validate_agent_id(&agent_id)?;
validate::validate_agent_type(&agent_type)?;
let caps: Vec<String> = capabilities
.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.map(String::from)
.collect();
validate::validate_capabilities(&caps)?;
let id = db::register_agent(&conn, &agent_id, &agent_type, &caps)?;
if json_out {
println!(
"{}",
serde_json::json!({
"registered": true,
"id": id,
"agent_id": agent_id,
"agent_type": agent_type,
"capabilities": caps,
})
);
} else {
println!(
"registered {agent_id} (type={agent_type}, capabilities={})",
if caps.is_empty() {
"-".to_string()
} else {
caps.join(",")
}
);
}
}
}
Ok(())
}
fn cmd_pending(
db_path: &Path,
args: PendingArgs,
json_out: bool,
cli_agent_id: Option<&str>,
) -> Result<()> {
let conn = db::open(db_path)?;
match args.action {
PendingAction::List { status, limit } => {
let items = db::list_pending_actions(&conn, status.as_deref(), limit)?;
if json_out {
println!(
"{}",
serde_json::json!({"count": items.len(), "pending": items})
);
} else if items.is_empty() {
println!("no pending actions");
} else {
for item in &items {
println!(
"[{}] {} ns={} action={} by={} ({})",
id_short(&item.id),
item.status,
item.namespace,
item.action_type,
item.requested_by,
item.requested_at
);
}
println!("{} pending action(s)", items.len());
}
}
PendingAction::Approve { id } => {
use db::ApproveOutcome;
validate::validate_id(&id)?;
let agent = identity::resolve_agent_id(cli_agent_id, None)?;
match db::approve_with_approver_type(&conn, &id, &agent)? {
ApproveOutcome::Approved => {
let executed = db::execute_pending_action(&conn, &id)?;
if json_out {
println!(
"{}",
serde_json::json!({
"approved": true,
"id": id,
"decided_by": agent,
"executed": true,
"memory_id": executed,
})
);
} else {
println!("approved + executed: {id} (by {agent})");
}
}
ApproveOutcome::Pending { votes, quorum } => {
if json_out {
println!(
"{}",
serde_json::json!({
"approved": false,
"status": "pending",
"id": id,
"votes": votes,
"quorum": quorum,
"reason": "consensus threshold not yet reached",
})
);
} else {
println!(
"approval recorded: {id} ({votes}/{quorum} consensus, not yet met)"
);
}
}
ApproveOutcome::Rejected(reason) => {
eprintln!("approve rejected: {reason}");
std::process::exit(1);
}
}
}
PendingAction::Reject { id } => {
validate::validate_id(&id)?;
let agent = identity::resolve_agent_id(cli_agent_id, None)?;
let ok = db::decide_pending_action(&conn, &id, false, &agent)?;
if !ok {
eprintln!("pending action not found or already decided: {id}");
std::process::exit(1);
}
if json_out {
println!(
"{}",
serde_json::json!({"rejected": true, "id": id, "decided_by": agent})
);
} else {
println!("rejected: {id} (by {agent})");
}
}
}
Ok(())
}
#[allow(clippy::too_many_lines)]
fn cmd_mine(
db_path: &Path,
args: MineArgs,
json_out: bool,
app_config: &config::AppConfig,
cli_agent_id: Option<&str>,
) -> Result<()> {
let miner_agent_id = identity::resolve_agent_id(cli_agent_id, None)?;
let format = mine::Format::from_str(&args.format).ok_or_else(|| {
anyhow::anyhow!(
"invalid format: {} (use claude, chatgpt, slack)",
args.format
)
})?;
let tier = Tier::from_str(&args.tier)
.ok_or_else(|| anyhow::anyhow!("invalid tier: {} (use short, mid, long)", args.tier))?;
let namespace = args.namespace.unwrap_or_else(|| match format {
mine::Format::Claude => "claude-export".to_string(),
mine::Format::ChatGpt => "chatgpt-export".to_string(),
mine::Format::Slack => "slack-export".to_string(),
});
let path = std::path::Path::new(&args.path);
let conversations = match format {
mine::Format::Claude => mine::parse_claude(path)?,
mine::Format::ChatGpt => mine::parse_chatgpt(path)?,
mine::Format::Slack => mine::parse_slack(path)?,
};
let filtered: Vec<_> = conversations
.iter()
.filter(|c| c.messages.len() >= args.min_messages)
.collect();
if args.dry_run {
if json_out {
let items: Vec<serde_json::Value> = filtered
.iter()
.filter_map(|c| {
mine::conversation_to_memory(c, format).map(|m| {
serde_json::json!({
"title": m.title,
"content_length": m.content.len(),
"messages": c.messages.len(),
"source": m.source_format,
})
})
})
.collect();
println!(
"{}",
serde_json::to_string_pretty(&serde_json::json!({
"dry_run": true,
"total_conversations": conversations.len(),
"filtered": filtered.len(),
"would_import": items.len(),
"namespace": namespace,
"tier": tier.as_str(),
"memories": items,
}))?
);
} else {
println!("Dry run — no memories will be stored\n");
println!("Total conversations found: {}", conversations.len());
println!(
"After filter (>={} messages): {}",
args.min_messages,
filtered.len()
);
println!("Namespace: {namespace}");
println!("Tier: {tier}\n");
for c in &filtered {
if let Some(m) = mine::conversation_to_memory(c, format) {
println!(
" {} ({} msgs, {} bytes)",
m.title,
c.messages.len(),
m.content.len()
);
}
}
}
return Ok(());
}
let conn = db::open(db_path)?;
let _ = db::gc_if_needed(&conn, app_config.effective_archive_on_gc());
let now = Utc::now();
let mut imported = 0usize;
let mut skipped = 0usize;
let mut errors = 0usize;
conn.execute_batch("BEGIN")?;
for conv in &filtered {
let Some(mined) = mine::conversation_to_memory(conv, format) else {
skipped += 1;
continue;
};
let expires_at = app_config
.effective_ttl()
.ttl_for_tier(&tier)
.map(|s| (now + Duration::seconds(s)).to_rfc3339());
let mut metadata = models::default_metadata();
if let Some(obj) = metadata.as_object_mut() {
obj.insert(
"agent_id".to_string(),
serde_json::Value::String(miner_agent_id.clone()),
);
obj.insert(
"mined_from".to_string(),
serde_json::Value::String(format.source_tag().to_string()),
);
}
let mem = models::Memory {
id: uuid::Uuid::new_v4().to_string(),
tier: tier.clone(),
namespace: namespace.clone(),
title: mined.title,
content: mined.content,
tags: vec![format.source_tag().to_string()],
priority: 5,
confidence: 0.8,
source: mined.source_format,
access_count: 0,
created_at: mined.created_at.unwrap_or_else(|| now.to_rfc3339()),
updated_at: now.to_rfc3339(),
last_accessed_at: None,
expires_at,
metadata,
};
match db::insert(&conn, &mem) {
Ok(_) => imported += 1,
Err(e) => {
errors += 1;
eprintln!("warning: failed to store '{}': {}", mem.title, e);
}
}
if imported.is_multiple_of(100) && imported > 0 {
conn.execute_batch("COMMIT")?;
conn.execute_batch("BEGIN")?;
}
}
conn.execute_batch("COMMIT")?;
if json_out {
println!(
"{}",
serde_json::to_string(&serde_json::json!({
"imported": imported,
"skipped": skipped,
"errors": errors,
"total_conversations": conversations.len(),
"namespace": namespace,
"tier": tier.as_str(),
}))?
);
} else {
println!(
"Imported {} memories from {} conversations (skipped: {}, errors: {})",
imported,
conversations.len(),
skipped,
errors
);
println!("Namespace: {namespace}, Tier: {tier}");
}
Ok(())
}
const BACKUP_TS_FMT: &str = "%Y-%m-%dT%H%M%SZ";
#[derive(serde::Serialize, serde::Deserialize)]
struct BackupManifest {
snapshot: String,
sha256: String,
bytes: u64,
source_db: String,
version: String,
created_at: String,
}
fn cmd_backup(db_path: &Path, args: &BackupArgs, json_out: bool) -> Result<()> {
use std::io::Read;
std::fs::create_dir_all(&args.to)
.with_context(|| format!("creating backup dir {}", args.to.display()))?;
let conn = db::open(db_path).context("opening source DB for backup")?;
let ts = chrono::Utc::now().format(BACKUP_TS_FMT).to_string();
let snapshot_name = format!("ai-memory-{ts}.db");
let snapshot_path = args.to.join(&snapshot_name);
if snapshot_path.exists() {
anyhow::bail!(
"refusing to overwrite existing snapshot {}",
snapshot_path.display()
);
}
conn.execute(
"VACUUM INTO ?1",
rusqlite::params![snapshot_path.to_string_lossy()],
)
.context("VACUUM INTO failed")?;
drop(conn);
let bytes = std::fs::metadata(&snapshot_path)?.len();
let sha = {
use sha2::Digest;
let mut hasher = sha2::Sha256::new();
let mut f = std::fs::File::open(&snapshot_path)?;
let mut buf = vec![0u8; 64 * 1024];
loop {
let n = f.read(&mut buf)?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
format!("{:x}", hasher.finalize())
};
let manifest = BackupManifest {
snapshot: snapshot_name.clone(),
sha256: sha.clone(),
bytes,
source_db: db_path.to_string_lossy().into_owned(),
version: env!("CARGO_PKG_VERSION").to_string(),
created_at: chrono::Utc::now().to_rfc3339(),
};
let manifest_path = args.to.join(format!("ai-memory-{ts}.manifest.json"));
let manifest_text = serde_json::to_string_pretty(&manifest)?;
std::fs::write(&manifest_path, manifest_text.as_bytes())?;
if args.keep > 0 {
prune_old_snapshots(&args.to, args.keep)?;
}
if json_out {
println!("{}", serde_json::to_string(&manifest)?);
} else {
println!("Snapshot: {}", snapshot_path.display());
println!("Manifest: {}", manifest_path.display());
println!("SHA-256 : {sha}");
println!("Bytes : {bytes}");
}
Ok(())
}
fn prune_old_snapshots(dir: &Path, keep: usize) -> Result<()> {
let mut snaps: Vec<(std::time::SystemTime, PathBuf)> = std::fs::read_dir(dir)?
.filter_map(std::result::Result::ok)
.filter_map(|entry| {
let path = entry.path();
let name = path.file_name()?.to_str()?.to_owned();
let is_snapshot = name.starts_with("ai-memory-")
&& path
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("db"));
if is_snapshot {
let mtime = entry.metadata().ok()?.modified().ok()?;
Some((mtime, path))
} else {
None
}
})
.collect();
snaps.sort_by_key(|b| std::cmp::Reverse(b.0));
for (_, path) in snaps.into_iter().skip(keep) {
let _ = std::fs::remove_file(&path);
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
let manifest = dir.join(format!("{stem}.manifest.json"));
let _ = std::fs::remove_file(manifest);
}
}
Ok(())
}
fn cmd_restore(db_path: &Path, args: &RestoreArgs, json_out: bool) -> Result<()> {
use std::io::Read;
let (snapshot_path, manifest_path) = if args.from.is_dir() {
let mut snaps: Vec<(std::time::SystemTime, PathBuf)> = std::fs::read_dir(&args.from)?
.filter_map(std::result::Result::ok)
.filter_map(|entry| {
let path = entry.path();
let name = path.file_name()?.to_str()?.to_owned();
let is_snapshot = name.starts_with("ai-memory-")
&& path
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("db"));
if is_snapshot {
let mtime = entry.metadata().ok()?.modified().ok()?;
Some((mtime, path))
} else {
None
}
})
.collect();
snaps.sort_by_key(|b| std::cmp::Reverse(b.0));
let snap = snaps
.into_iter()
.next()
.map(|(_, p)| p)
.ok_or_else(|| anyhow::anyhow!("no snapshots found in {}", args.from.display()))?;
let stem = snap.file_stem().and_then(|s| s.to_str()).unwrap_or("");
let manifest = args.from.join(format!("{stem}.manifest.json"));
(snap, manifest)
} else {
let snap = args.from.clone();
let stem = snap.file_stem().and_then(|s| s.to_str()).unwrap_or("");
let parent = snap.parent().unwrap_or_else(|| Path::new("."));
let manifest = parent.join(format!("{stem}.manifest.json"));
(snap, manifest)
};
if !snapshot_path.exists() {
anyhow::bail!("snapshot {} does not exist", snapshot_path.display());
}
if !args.skip_verify {
if !manifest_path.exists() {
anyhow::bail!(
"manifest {} not found; pass --skip-verify to restore anyway",
manifest_path.display()
);
}
let manifest_text = std::fs::read_to_string(&manifest_path)?;
let manifest: BackupManifest = serde_json::from_str(&manifest_text)
.with_context(|| format!("parsing manifest {}", manifest_path.display()))?;
let observed = {
use sha2::Digest;
let mut hasher = sha2::Sha256::new();
let mut f = std::fs::File::open(&snapshot_path)?;
let mut buf = vec![0u8; 64 * 1024];
loop {
let n = f.read(&mut buf)?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
format!("{:x}", hasher.finalize())
};
if observed != manifest.sha256 {
anyhow::bail!(
"sha256 mismatch — manifest says {}, snapshot is {}",
manifest.sha256,
observed
);
}
}
if db_path.exists() {
let ts = chrono::Utc::now().format(BACKUP_TS_FMT).to_string();
let aside = db_path.with_extension(format!("pre-restore-{ts}.db"));
std::fs::rename(db_path, &aside)
.with_context(|| format!("moving current DB aside to {}", aside.display()))?;
if !json_out {
println!("Previous DB moved to {}", aside.display());
}
}
std::fs::copy(&snapshot_path, db_path)
.with_context(|| format!("copying snapshot to {}", db_path.display()))?;
if json_out {
println!(
"{}",
serde_json::json!({
"status": "restored",
"from": snapshot_path.to_string_lossy(),
"to": db_path.to_string_lossy(),
})
);
} else {
println!(
"Restored {} → {}",
snapshot_path.display(),
db_path.display()
);
}
Ok(())
}
async fn cmd_curator(
db_path: &Path,
args: &CuratorArgs,
app_config: &config::AppConfig,
) -> Result<()> {
if args.rollback.is_some() || args.rollback_last.is_some() {
return cmd_curator_rollback(db_path, args);
}
if !args.once && !args.daemon {
anyhow::bail!("curator requires --once, --daemon, --rollback <id>, or --rollback-last N");
}
let cfg = curator::CuratorConfig {
interval_secs: args.interval_secs,
max_ops_per_cycle: args.max_ops,
dry_run: args.dry_run,
include_namespaces: args.include_namespaces.clone(),
exclude_namespaces: args.exclude_namespaces.clone(),
};
let feature_tier = app_config.effective_tier(None);
let llm = build_curator_llm(feature_tier);
if args.once {
let conn = db::open(db_path)?;
let report = curator::run_once(&conn, llm.as_ref(), &cfg)?;
if args.json {
println!("{}", serde_json::to_string_pretty(&report)?);
} else {
print_curator_report(&report);
}
return Ok(());
}
let shutdown = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let shutdown_for_signal = shutdown.clone();
tokio::spawn(async move {
let _ = tokio::signal::ctrl_c().await;
shutdown_for_signal.store(true, std::sync::atomic::Ordering::Relaxed);
});
let db_owned = db_path.to_path_buf();
let llm_arc = llm.map(std::sync::Arc::new);
tokio::task::spawn_blocking(move || {
curator::run_daemon(db_owned, llm_arc, cfg, shutdown);
})
.await
.map_err(|e| anyhow::anyhow!("curator daemon join: {e}"))?;
Ok(())
}
fn cmd_curator_rollback(db_path: &Path, args: &CuratorArgs) -> Result<()> {
let conn = db::open(db_path)?;
if let Some(id) = &args.rollback {
let Some(mem) = db::get(&conn, id)? else {
anyhow::bail!("rollback entry {id} not found");
};
let entry: autonomy::RollbackEntry = serde_json::from_str(&mem.content)
.context("rollback entry content is not a valid RollbackEntry JSON")?;
let applied = autonomy::reverse_rollback_entry(&conn, &entry)?;
let mut tags = mem.tags.clone();
if !tags.iter().any(|t| t == "_reversed") {
tags.push("_reversed".to_string());
db::update(
&conn,
&mem.id,
None,
None,
None,
None,
Some(&tags),
None,
None,
None,
None,
)?;
}
println!(
"rollback {id}: {}",
if applied { "applied" } else { "no-op" }
);
return Ok(());
}
if let Some(n) = args.rollback_last {
let log = db::list(
&conn,
Some("_curator/rollback"),
None,
n.max(1),
0,
None,
None,
None,
None,
None,
)?;
let mut reversed = 0usize;
for mem in &log {
if mem.tags.iter().any(|t| t == "_reversed") {
continue;
}
let Ok(entry) = serde_json::from_str::<autonomy::RollbackEntry>(&mem.content) else {
continue;
};
let applied = autonomy::reverse_rollback_entry(&conn, &entry)?;
if applied {
reversed += 1;
let mut tags = mem.tags.clone();
tags.push("_reversed".to_string());
db::update(
&conn,
&mem.id,
None,
None,
None,
None,
Some(&tags),
None,
None,
None,
None,
)?;
}
}
println!("reversed {reversed} rollback entries");
return Ok(());
}
unreachable!("cmd_curator_rollback entered without --rollback or --rollback-last");
}
fn build_curator_llm(tier: config::FeatureTier) -> Option<llm::OllamaClient> {
let llm_model = tier.config().llm_model?;
let model = llm_model.ollama_model_id().to_string();
llm::OllamaClient::new(&model).ok()
}
fn print_curator_report(r: &curator::CuratorReport) {
println!("curator cycle report");
println!(" started_at: {}", r.started_at);
println!(" completed_at: {}", r.completed_at);
println!(" duration_ms: {}", r.cycle_duration_ms);
println!(" memories_scanned: {}", r.memories_scanned);
println!(" memories_eligible: {}", r.memories_eligible);
println!(" operations: {}", r.operations_attempted);
println!(" auto_tagged: {}", r.auto_tagged);
println!(" contradictions: {}", r.contradictions_found);
println!(" skipped (cap): {}", r.operations_skipped_cap);
println!(" errors: {}", r.errors.len());
println!(" dry_run: {}", r.dry_run);
for e in &r.errors {
println!(" - {e}");
}
}
#[cfg(feature = "sal")]
async fn cmd_migrate(args: &MigrateArgs) -> Result<()> {
let src = migrate::open_store(&args.from)
.await
.context("open source store")?;
let dst = migrate::open_store(&args.to)
.await
.context("open destination store")?;
let report = migrate::migrate(
src.as_ref(),
dst.as_ref(),
args.batch,
args.namespace.clone(),
args.dry_run,
)
.await;
if args.json {
let value = serde_json::json!({
"from_url": args.from,
"to_url": args.to,
"memories_read": report.memories_read,
"memories_written": report.memories_written,
"batches": report.batches,
"errors": report.errors,
"dry_run": report.dry_run,
});
println!("{}", serde_json::to_string_pretty(&value)?);
} else {
println!("migration report");
println!(" from: {}", args.from);
println!(" to: {}", args.to);
println!(" memories_read: {}", report.memories_read);
println!(" memories_written: {}", report.memories_written);
println!(" batches: {}", report.batches);
println!(" dry_run: {}", report.dry_run);
println!(" errors: {}", report.errors.len());
for e in &report.errors {
println!(" - {e}");
}
}
if !report.errors.is_empty() {
anyhow::bail!("migration completed with {} error(s)", report.errors.len());
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn id_short_truncates() {
assert_eq!(id_short("abcdefghijklmnop"), "abcdefgh");
}
#[test]
fn id_short_short_input() {
assert_eq!(id_short("abc"), "abc");
}
#[test]
fn id_short_empty() {
assert_eq!(id_short(""), "");
}
#[test]
fn human_age_just_now() {
let now = chrono::Utc::now().to_rfc3339();
assert_eq!(human_age(&now), "just now");
}
#[test]
fn human_age_minutes() {
let past = (chrono::Utc::now() - chrono::Duration::minutes(5)).to_rfc3339();
let age = human_age(&past);
assert!(age.contains("m ago"), "got: {age}");
}
#[test]
fn human_age_hours() {
let past = (chrono::Utc::now() - chrono::Duration::hours(3)).to_rfc3339();
let age = human_age(&past);
assert!(age.contains("h ago"), "got: {age}");
}
#[test]
fn human_age_days() {
let past = (chrono::Utc::now() - chrono::Duration::days(5)).to_rfc3339();
let age = human_age(&past);
assert!(age.contains("d ago"), "got: {age}");
}
#[test]
fn human_age_invalid_returns_input() {
assert_eq!(human_age("not-a-date"), "not-a-date");
}
#[test]
fn auto_namespace_returns_nonempty() {
let ns = auto_namespace();
assert!(!ns.is_empty());
}
}