use std::io::Write as _;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
use anyhow::{Context, Result};
use axum::Router;
use clap::{Args, CommandFactory, Parser, Subcommand};
use clap_complete::{Shell, generate};
use rusqlite::Connection;
use tokio::sync::{Mutex, Notify};
use tokio::task::JoinHandle;
use tracing_subscriber::EnvFilter;
use crate::cli::agents::{AgentsArgs, PendingArgs};
use crate::cli::archive::ArchiveArgs;
use crate::cli::audit::AuditArgs;
use crate::cli::backup::{BackupArgs, RestoreArgs};
use crate::cli::boot::BootArgs;
use crate::cli::consolidate::{AutoConsolidateArgs, ConsolidateArgs};
use crate::cli::crud::{DeleteArgs, GetArgs, ListArgs};
use crate::cli::curator::CuratorArgs;
use crate::cli::forget::ForgetArgs;
use crate::cli::install::InstallArgs;
use crate::cli::io::{ImportArgs, MineArgs};
use crate::cli::link::{LinkArgs, ResolveArgs};
use crate::cli::logs::LogsArgs;
use crate::cli::promote::PromoteArgs;
use crate::cli::recall::RecallArgs;
use crate::cli::search::SearchArgs;
use crate::cli::store::StoreArgs;
use crate::cli::sync::{SyncArgs, SyncDaemonArgs};
use crate::cli::update::UpdateArgs;
use crate::cli::wrap::WrapArgs;
use crate::config::{AppConfig, FeatureTier};
use crate::embeddings::Embedder;
use crate::handlers::{ApiKeyState, AppState, Db};
use crate::hnsw::VectorIndex;
use crate::{bench, cli, db, embeddings, federation, hnsw, llm, mcp, tls};
#[cfg(feature = "sal")]
use crate::migrate;
const DEFAULT_DB: &str = "ai-memory.db";
const DEFAULT_PORT: u16 = 9077;
const GC_INTERVAL_SECS: u64 = 1800;
const WAL_CHECKPOINT_INTERVAL_SECS: u64 = 600;
#[derive(Parser)]
#[command(
name = "ai-memory",
version,
about = "AI-agnostic persistent memory — MCP server, HTTP API, and CLI for any AI platform"
)]
pub struct Cli {
#[command(subcommand)]
pub command: Command,
#[arg(long, env = "AI_MEMORY_DB", default_value = DEFAULT_DB, global = true)]
pub db: PathBuf,
#[arg(long, global = true, default_value_t = false)]
pub json: bool,
#[arg(long, env = "AI_MEMORY_AGENT_ID", global = true)]
pub agent_id: Option<String>,
#[arg(long, global = true, value_name = "PATH")]
pub db_passphrase_file: Option<PathBuf>,
}
#[derive(Subcommand)]
pub enum Command {
Serve(ServeArgs),
Mcp {
#[arg(long, default_value = "semantic")]
tier: String,
#[arg(long, env = "AI_MEMORY_PROFILE")]
profile: Option<String>,
},
Store(StoreArgs),
Update(UpdateArgs),
Recall(RecallArgs),
Search(SearchArgs),
Get(GetArgs),
List(ListArgs),
Delete(DeleteArgs),
Promote(PromoteArgs),
Forget(ForgetArgs),
Link(LinkArgs),
Consolidate(ConsolidateArgs),
Gc,
Stats,
Namespaces,
Export,
Import(ImportArgs),
Resolve(ResolveArgs),
Shell,
Sync(SyncArgs),
SyncDaemon(SyncDaemonArgs),
AutoConsolidate(AutoConsolidateArgs),
Completions(CompletionsArgs),
Man,
Mine(MineArgs),
Archive(ArchiveArgs),
Agents(AgentsArgs),
Pending(PendingArgs),
Backup(BackupArgs),
Restore(RestoreArgs),
Curator(CuratorArgs),
Bench(BenchArgs),
#[cfg(feature = "sal")]
Migrate(MigrateArgs),
Doctor(DoctorCliArgs),
Boot(BootArgs),
Install(InstallArgs),
Wrap(WrapArgs),
Logs(LogsArgs),
Audit(AuditArgs),
}
#[derive(Args)]
pub struct DoctorCliArgs {
#[arg(long, value_name = "URL")]
pub remote: Option<String>,
#[arg(long)]
pub json: bool,
#[arg(long)]
pub fail_on_warn: bool,
#[arg(long)]
pub tokens: bool,
#[arg(long, value_name = "PROFILE")]
pub profile: Option<String>,
#[arg(long)]
pub raw_table: bool,
}
#[derive(Args)]
pub struct BenchArgs {
#[arg(long, default_value_t = bench::DEFAULT_ITERATIONS)]
pub iterations: usize,
#[arg(long, default_value_t = bench::DEFAULT_WARMUP)]
pub warmup: usize,
#[arg(long)]
pub json: bool,
#[arg(long, value_name = "PATH")]
pub baseline: Option<String>,
#[arg(long, default_value_t = bench::DEFAULT_REGRESSION_THRESHOLD_PCT)]
pub regression_threshold: f64,
#[arg(long, value_name = "PATH")]
pub history: Option<PathBuf>,
}
#[cfg(feature = "sal")]
#[derive(Args)]
pub struct MigrateArgs {
#[arg(long)]
pub from: String,
#[arg(long)]
pub to: String,
#[arg(long, default_value_t = 1000)]
pub batch: usize,
#[arg(long)]
pub namespace: Option<String>,
#[arg(long)]
pub dry_run: bool,
#[arg(long)]
pub json: bool,
}
#[derive(Args)]
pub struct ServeArgs {
#[arg(long, default_value = "127.0.0.1")]
pub host: String,
#[arg(long, default_value_t = DEFAULT_PORT)]
pub port: u16,
#[arg(long, requires = "tls_key")]
pub tls_cert: Option<PathBuf>,
#[arg(long, requires = "tls_cert")]
pub tls_key: Option<PathBuf>,
#[arg(long, requires = "tls_cert")]
pub mtls_allowlist: Option<PathBuf>,
#[arg(long, default_value_t = 30)]
pub shutdown_grace_secs: u64,
#[arg(long, default_value_t = 0)]
pub quorum_writes: usize,
#[arg(long, value_delimiter = ',')]
pub quorum_peers: Vec<String>,
#[arg(long, default_value_t = 2000)]
pub quorum_timeout_ms: u64,
#[arg(long)]
pub quorum_client_cert: Option<PathBuf>,
#[arg(long)]
pub quorum_client_key: Option<PathBuf>,
#[arg(long)]
pub quorum_ca_cert: Option<PathBuf>,
#[arg(long, default_value_t = 30)]
pub catchup_interval_secs: u64,
}
#[derive(Args)]
pub struct CompletionsArgs {
pub shell: Shell,
}
#[allow(clippy::too_many_lines)]
pub async fn run(cli: Cli, app_config: &AppConfig) -> Result<()> {
if let Some(path) = &cli.db_passphrase_file {
let passphrase = passphrase_from_file(path)?;
unsafe { std::env::set_var("AI_MEMORY_DB_PASSPHRASE", passphrase) };
}
let db_path = app_config.effective_db(&cli.db);
let j = cli.json;
let cli_agent_id: Option<String> = cli.agent_id.clone();
let needs_checkpoint = is_write_command(&cli.command);
let db_path_for_checkpoint = if needs_checkpoint {
Some(db_path.clone())
} else {
None
};
let result = match cli.command {
Command::Serve(a) => serve(db_path, a, app_config).await,
Command::Mcp { tier, profile } => {
let feature_tier = app_config.effective_tier(Some(&tier));
let resolved_profile = match app_config.effective_profile(profile.as_deref()) {
Ok(p) => p,
Err(e) => {
eprintln!("ai-memory mcp: invalid profile: {e}");
std::process::exit(2);
}
};
mcp::run_mcp_server(&db_path, feature_tier, app_config, &resolved_profile)?;
Ok(())
}
Command::Store(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::store::run(
&db_path,
a,
j,
app_config,
cli_agent_id.as_deref(),
&mut out,
)
}
Command::Update(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::update::run(&db_path, &a, j, &mut out)
}
Command::Recall(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::recall::run(&db_path, &a, j, app_config, &mut out)
}
Command::Search(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::search::run(&db_path, &a, j, &mut out)
}
Command::Get(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::crud::cmd_get(&db_path, &a, j, &mut out)
}
Command::List(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::crud::cmd_list(&db_path, &a, j, app_config, &mut out)
}
Command::Delete(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::crud::cmd_delete(&db_path, &a, j, cli_agent_id.as_deref(), &mut out)
}
Command::Promote(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::promote::cmd_promote(&db_path, &a, j, cli_agent_id.as_deref(), &mut out)
}
Command::Forget(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::forget::cmd_forget(&db_path, &a, j, &mut out)
}
Command::Link(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::link::cmd_link(&db_path, &a, j, &mut out)
}
Command::Consolidate(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::consolidate::run(&db_path, a, j, cli_agent_id.as_deref(), &mut out)
}
Command::Resolve(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::link::cmd_resolve(&db_path, &a, j, &mut out)
}
Command::Shell => cli::shell::run(&db_path),
Command::Sync(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::sync::run(&db_path, &a, j, cli_agent_id.as_deref(), &mut out)
}
Command::SyncDaemon(a) => cli::sync::run_daemon(&db_path, a, cli_agent_id.as_deref()).await,
Command::AutoConsolidate(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::consolidate::run_auto(&db_path, &a, j, cli_agent_id.as_deref(), &mut out)
}
Command::Gc => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::gc::run_gc(&db_path, j, app_config, &mut out)
}
Command::Stats => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::gc::run_stats(&db_path, j, &mut out)
}
Command::Namespaces => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::gc::run_namespaces(&db_path, j, &mut out)
}
Command::Export => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::io::export(&db_path, &mut out)
}
Command::Import(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::io::import(&db_path, &a, j, cli_agent_id.as_deref(), &mut out)
}
Command::Completions(a) => {
generate(
a.shell,
&mut Cli::command(),
"ai-memory",
&mut std::io::stdout(),
);
Ok(())
}
Command::Man => {
let cmd = Cli::command();
let man = clap_mangen::Man::new(cmd);
man.render(&mut std::io::stdout())?;
Ok(())
}
Command::Mine(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::io::mine(
&db_path,
a,
j,
app_config,
cli_agent_id.as_deref(),
&mut out,
)
}
Command::Archive(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::archive::run(&db_path, a, j, &mut out)
}
Command::Agents(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::agents::run_agents(&db_path, a, j, &mut out)
}
Command::Pending(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::agents::run_pending(&db_path, a, j, cli_agent_id.as_deref(), &mut out)
}
Command::Backup(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::backup::run_backup(&db_path, &a, j, &mut out)
}
Command::Restore(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::backup::run_restore(&db_path, &a, j, &mut out)
}
Command::Curator(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::curator::run(&db_path, &a, app_config, &mut out).await
}
Command::Bench(a) => cmd_bench(&a),
#[cfg(feature = "sal")]
Command::Migrate(a) => cmd_migrate(&a).await,
Command::Doctor(a) => {
let db_path_doctor = db_path.clone();
if a.tokens || a.raw_table {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
let exit = cli::doctor::run_tokens(
cli::doctor::TokensArgs {
json: a.json,
raw_table: a.raw_table,
profile: a.profile,
},
&mut out,
)?;
std::process::exit(exit);
}
let args = cli::doctor::DoctorArgs {
remote: a.remote,
json: a.json,
fail_on_warn: a.fail_on_warn,
};
let join = tokio::task::spawn_blocking(move || {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::doctor::run(&db_path_doctor, &args, &mut out)
})
.await;
match join {
Ok(Ok(0)) => Ok(()),
Ok(Ok(code)) => std::process::exit(code),
Ok(Err(e)) => Err(e),
Err(e) => Err(anyhow::anyhow!("doctor task join failed: {e}")),
}
}
Command::Boot(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
crate::audit::emit(crate::audit::EventBuilder::new(
crate::audit::AuditAction::SessionBoot,
crate::audit::actor(
cli_agent_id.as_deref().unwrap_or("anonymous"),
"explicit_or_default",
None,
),
crate::audit::target_sweep(a.namespace.as_deref().unwrap_or("auto")),
));
cli::boot::run(&db_path, &a, app_config, &mut out)
}
Command::Install(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::install::run(&a, &mut out)
}
Command::Wrap(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
let code = cli::wrap::run(&db_path, &a, app_config, &mut out)?;
drop(out);
drop(so);
drop(se);
if code == 0 {
Ok(())
} else {
std::process::exit(code);
}
}
Command::Logs(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
cli::logs::run(a, app_config, &mut out)
}
Command::Audit(a) => {
let stdout = std::io::stdout();
let stderr = std::io::stderr();
let mut so = stdout.lock();
let mut se = stderr.lock();
let mut out = cli::CliOutput::from_std(&mut so, &mut se);
match cli::audit::run(a, app_config, &mut out)? {
0 => Ok(()),
code => std::process::exit(code),
}
}
};
if result.is_ok()
&& let Some(cp_path) = db_path_for_checkpoint
&& let Ok(conn) = db::open(&cp_path)
{
let _ = db::checkpoint(&conn);
}
result
}
#[must_use]
pub fn is_write_command(cmd: &Command) -> bool {
matches!(
cmd,
Command::Store(_)
| Command::Update(_)
| Command::Delete(_)
| Command::Promote(_)
| Command::Forget(_)
| Command::Link(_)
| Command::Consolidate(_)
| Command::Resolve(_)
| Command::Sync(_)
| Command::SyncDaemon(_)
| Command::Import(_)
| Command::AutoConsolidate(_)
| Command::Gc
)
}
pub fn passphrase_from_file(path: &Path) -> Result<String> {
let raw = std::fs::read_to_string(path)
.with_context(|| format!("reading passphrase file {}", path.display()))?;
let passphrase = raw.trim_end_matches(['\n', '\r']).to_string();
if passphrase.is_empty() {
anyhow::bail!("passphrase file {} is empty", path.display());
}
Ok(passphrase)
}
pub fn apply_anonymize_default(app_config: &AppConfig) {
if app_config.effective_anonymize_default() && std::env::var("AI_MEMORY_ANONYMIZE").is_err() {
unsafe { std::env::set_var("AI_MEMORY_ANONYMIZE", "1") };
}
}
pub async fn build_embedder(feature_tier: FeatureTier, app_config: &AppConfig) -> Option<Embedder> {
let tier_config = feature_tier.config();
let Some(emb_model) = tier_config.embedding_model else {
tracing::info!(
"embedder disabled — tier={} keyword-only (FTS5); semantic recall not wired",
feature_tier.as_str()
);
return None;
};
let embed_url = app_config.effective_embed_url().to_string();
let build = match tokio::task::spawn_blocking(move || {
let embed_client = llm::OllamaClient::new_with_url(&embed_url, "nomic-embed-text")
.ok()
.map(Arc::new);
embeddings::Embedder::for_model(emb_model, embed_client)
})
.await
{
Ok(b) => b,
Err(e) => {
tracing::error!("embedder spawn_blocking join failed: {e}");
return None;
}
};
match build {
Ok(emb) => {
tracing::info!(
"embedder loaded ({}) — tier={} semantic recall enabled",
emb.model_description(),
feature_tier.as_str()
);
Some(emb)
}
Err(e) => {
tracing::error!(
"EMBEDDER LOAD FAILED — tier={} requested semantic features, \
but embedder init errored: {e}. Daemon falls back to keyword-only. \
Semantic recall, sync_push embedding refresh (#322), and HNSW index \
will be NO-OPS. Check network egress to HuggingFace Hub + available \
memory for model weights. To force keyword-only explicitly (silences \
this error), set `tier = \"keyword\"` in config.toml.",
feature_tier.as_str()
);
None
}
}
}
#[must_use]
pub fn build_vector_index(conn: &Connection, embedder_present: bool) -> Option<VectorIndex> {
if !embedder_present {
return None;
}
match db::get_all_embeddings(conn) {
Ok(entries) if !entries.is_empty() => Some(hnsw::VectorIndex::build(entries)),
_ => Some(hnsw::VectorIndex::empty()),
}
}
#[must_use]
pub fn spawn_gc_loop(
state: Db,
archive_max_days: Option<i64>,
interval: Duration,
) -> JoinHandle<()> {
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let lock = state.lock().await;
match db::gc(&lock.0, lock.3) {
Ok(n) if n > 0 => tracing::info!("gc: expired {n} memories"),
_ => {}
}
match db::auto_purge_archive(&lock.0, archive_max_days) {
Ok(n) if n > 0 => tracing::info!("gc: purged {n} old archived memories"),
_ => {}
}
}
})
}
#[must_use]
pub fn spawn_wal_checkpoint_loop(state: Db, interval: Duration) -> JoinHandle<()> {
let half = interval / 2;
tokio::spawn(async move {
tokio::time::sleep(half).await;
loop {
{
let lock = state.lock().await;
match db::checkpoint(&lock.0) {
Ok(()) => tracing::debug!("wal checkpoint: ok"),
Err(e) => tracing::warn!("wal checkpoint failed: {e}"),
}
}
tokio::time::sleep(interval).await;
}
})
}
#[must_use]
pub fn build_router(app_state: AppState, api_key_state: ApiKeyState) -> Router {
crate::build_router(api_key_state, app_state)
}
pub struct ServeBootstrap {
pub app_state: AppState,
pub api_key_state: ApiKeyState,
pub db_state: Db,
pub archive_max_days: Option<i64>,
pub task_handles: Vec<JoinHandle<()>>,
}
pub async fn bootstrap_serve(
db_path: &Path,
args: &ServeArgs,
app_config: &AppConfig,
) -> Result<ServeBootstrap> {
let resolved_ttl = app_config.effective_ttl();
let archive_on_gc = app_config.effective_archive_on_gc();
let conn = db::open(db_path)?;
let feature_tier = app_config.effective_tier(None);
let tier_config = feature_tier.config();
let embedder = build_embedder(feature_tier, app_config).await;
let vector_index = build_vector_index(&conn, embedder.is_some());
let db_state: Db = Arc::new(Mutex::new((
conn,
db_path.to_path_buf(),
resolved_ttl,
archive_on_gc,
)));
let federation = federation::FederationConfig::build(
args.quorum_writes,
&args.quorum_peers,
std::time::Duration::from_millis(args.quorum_timeout_ms),
args.quorum_client_cert.as_deref(),
args.quorum_client_key.as_deref(),
args.quorum_ca_cert.as_deref(),
format!("host:{}", gethostname::gethostname().to_string_lossy()),
)
.context("federation config")?;
let mut task_handles: Vec<JoinHandle<()>> = Vec::new();
if let Some(ref fed) = federation {
tracing::info!(
"federation enabled: W={} over {} peer(s), timeout {}ms",
fed.policy.w,
fed.peer_count(),
args.quorum_timeout_ms,
);
if args.catchup_interval_secs > 0 {
let interval = std::time::Duration::from_secs(args.catchup_interval_secs);
tracing::info!(
"catchup loop enabled: polling {} peer(s) every {}s",
fed.peer_count(),
args.catchup_interval_secs,
);
federation::spawn_catchup_loop(fed.clone(), db_state.clone(), interval);
} else {
tracing::info!("catchup loop disabled (--catchup-interval-secs=0)");
}
}
let app_state = AppState {
db: db_state.clone(),
embedder: Arc::new(embedder),
vector_index: Arc::new(Mutex::new(vector_index)),
federation: Arc::new(federation),
tier_config: Arc::new(tier_config),
scoring: Arc::new(app_config.effective_scoring()),
};
task_handles.push(spawn_gc_loop(
db_state.clone(),
app_config.archive_max_days,
Duration::from_secs(GC_INTERVAL_SECS),
));
task_handles.push(spawn_wal_checkpoint_loop(
db_state.clone(),
Duration::from_secs(WAL_CHECKPOINT_INTERVAL_SECS),
));
let api_key_state = ApiKeyState {
key: app_config.api_key.clone(),
};
if api_key_state.key.is_some() {
tracing::info!("API key authentication enabled");
}
Ok(ServeBootstrap {
app_state,
api_key_state,
db_state,
archive_max_days: app_config.archive_max_days,
task_handles,
})
}
fn init_tracing() {
let _ = tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::from_default_env()
.add_directive("ai_memory=info".parse().unwrap())
.add_directive("tower_http=info".parse().unwrap()),
)
.try_init();
}
#[allow(clippy::too_many_lines)]
pub async fn serve(db_path: PathBuf, args: ServeArgs, app_config: &AppConfig) -> Result<()> {
init_tracing();
let bootstrap = bootstrap_serve(&db_path, &args, app_config).await?;
let addr = format!("{}:{}", args.host, args.port);
tracing::info!("database: {}", db_path.display());
let shutdown_state = bootstrap.db_state.clone();
let shutdown = async move {
let _ = tokio::signal::ctrl_c().await;
tracing::info!("shutting down — checkpointing WAL");
let lock = shutdown_state.lock().await;
let _ = db::checkpoint(&lock.0);
};
if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = if let Some(allowlist_path) = &args.mtls_allowlist {
tracing::info!(
"mTLS enabled — client certs required. Allowlist: {}",
allowlist_path.display()
);
tls::load_mtls_rustls_config(cert, key, allowlist_path).await?
} else {
tracing::warn!(
"TLS enabled but mTLS NOT configured — sync endpoints \
(/api/v1/sync/push, /api/v1/sync/since) accept any client. \
Set --mtls-allowlist for production peer-mesh deployments \
(red-team #231)."
);
tls::load_rustls_config(cert, key).await?
};
let app = build_router(bootstrap.app_state, bootstrap.api_key_state);
tracing::info!("ai-memory listening on https://{addr}");
let socket_addr: std::net::SocketAddr = addr.parse()?;
let grace = std::time::Duration::from_secs(args.shutdown_grace_secs);
let handle = axum_server::Handle::new();
let handle_clone = handle.clone();
tokio::spawn(async move {
shutdown.await;
handle_clone.graceful_shutdown(Some(grace));
});
axum_server::bind_rustls(socket_addr, tls_config)
.handle(handle)
.serve(app.into_make_service())
.await?;
} else {
tracing::warn!(
"TLS NOT enabled — sync endpoints (/api/v1/sync/push, \
/api/v1/sync/since) accept any caller over plain HTTP. \
Set --tls-cert + --tls-key + --mtls-allowlist for production \
peer-mesh deployments (red-team #231)."
);
tracing::info!("ai-memory listening on http://{addr}");
serve_http_with_shutdown_future(
&addr,
bootstrap.api_key_state,
bootstrap.app_state,
shutdown,
)
.await?;
}
Ok(())
}
fn cmd_bench(args: &BenchArgs) -> Result<()> {
let iterations = args.iterations.clamp(1, 100_000);
let warmup = args.warmup.min(10_000);
let regression_threshold = args.regression_threshold.clamp(0.0, 1000.0);
let conn = db::open(Path::new(":memory:"))?;
let config = bench::BenchConfig {
iterations,
warmup,
namespace: bench::BENCH_NAMESPACE.to_string(),
};
let results = bench::run(&conn, &config)?;
let regressions = if let Some(path) = &args.baseline {
let baseline = bench::load_baseline(Path::new(path))?;
Some(bench::compare_against_baseline(
&results,
&baseline,
regression_threshold,
))
} else {
None
};
if args.json {
println!(
"{}",
serde_json::to_string_pretty(&serde_json::json!({
"iterations": iterations,
"warmup": warmup,
"results": results,
"regressions": regressions,
}))?
);
} else {
print!("{}", bench::render_table(&results));
if let Some(rows) = ®ressions {
println!();
print!("{}", bench::render_regression_table(rows));
}
}
if let Some(history_path) = &args.history {
let captured_at = chrono::Utc::now().to_rfc3339();
bench::append_history(history_path, &captured_at, iterations, warmup, &results)?;
let mut stderr = std::io::stderr().lock();
let _ = writeln!(
stderr,
"bench: appended run to history file {}",
history_path.display()
);
}
let budget_failed = results
.iter()
.any(|r| matches!(r.status, bench::Status::Fail));
let regression_failed = regressions
.as_ref()
.is_some_and(|rows| rows.iter().any(|r| r.regressed));
if budget_failed && regression_failed {
anyhow::bail!(
"bench: at least one operation exceeded its p95 budget by >10% AND regressed >{regression_threshold:.1}% vs baseline"
);
}
if budget_failed {
anyhow::bail!("bench: at least one operation exceeded its p95 budget by >10%");
}
if regression_failed {
anyhow::bail!(
"bench: at least one operation regressed >{regression_threshold:.1}% vs baseline"
);
}
Ok(())
}
#[cfg(feature = "sal")]
async fn cmd_migrate(args: &MigrateArgs) -> Result<()> {
let src = migrate::open_store(&args.from)
.await
.context("open source store")?;
let dst = migrate::open_store(&args.to)
.await
.context("open destination store")?;
let report = migrate::migrate(
src.as_ref(),
dst.as_ref(),
args.batch,
args.namespace.clone(),
args.dry_run,
)
.await;
if args.json {
let value = serde_json::json!({
"from_url": args.from,
"to_url": args.to,
"memories_read": report.memories_read,
"memories_written": report.memories_written,
"batches": report.batches,
"errors": report.errors,
"dry_run": report.dry_run,
});
println!("{}", serde_json::to_string_pretty(&value)?);
} else {
println!("migration report");
println!(" from: {}", args.from);
println!(" to: {}", args.to);
println!(" memories_read: {}", report.memories_read);
println!(" memories_written: {}", report.memories_written);
println!(" batches: {}", report.batches);
println!(" dry_run: {}", report.dry_run);
println!(" errors: {}", report.errors.len());
for e in &report.errors {
println!(" - {e}");
}
}
if !report.errors.is_empty() {
anyhow::bail!("migration completed with {} error(s)", report.errors.len());
}
Ok(())
}
pub async fn serve_http_with_shutdown(
addr: &str,
api_key_state: ApiKeyState,
app_state: AppState,
shutdown: Arc<Notify>,
) -> Result<()> {
serve_http_with_shutdown_future(addr, api_key_state, app_state, async move {
shutdown.notified().await;
})
.await
}
pub async fn serve_http_with_shutdown_future<F>(
addr: &str,
api_key_state: ApiKeyState,
app_state: AppState,
shutdown: F,
) -> Result<()>
where
F: std::future::Future<Output = ()> + Send + 'static,
{
let app = crate::build_router(api_key_state, app_state);
let listener = tokio::net::TcpListener::bind(addr)
.await
.with_context(|| format!("bind {addr}"))?;
axum::serve(listener, app)
.with_graceful_shutdown(shutdown)
.await
.context("axum::serve")?;
Ok(())
}
pub async fn sync_cycle_once(
client: &reqwest::Client,
db_path: &Path,
local_agent_id: &str,
peer_url: &str,
api_key: Option<&str>,
batch_size: usize,
) -> Result<()> {
let peer_url = peer_url.trim_end_matches('/');
let since = {
let conn = db::open(db_path)?;
db::sync_state_load(&conn, local_agent_id)?
.entries
.get(peer_url)
.cloned()
};
let mut pull_url = format!(
"{peer_url}/api/v1/sync/since?limit={batch_size}&peer={}",
urlencoding_minimal(local_agent_id)
);
if let Some(ref s) = since {
pull_url.push_str("&since=");
pull_url.push_str(&urlencoding_minimal(s));
}
let mut req = client.get(&pull_url).header("x-agent-id", local_agent_id);
if let Some(key) = api_key {
req = req.header("x-api-key", key);
}
let resp = req.send().await?;
if !resp.status().is_success() {
anyhow::bail!("sync-daemon: pull status {}", resp.status());
}
let pulled: SyncSinceResponse = resp.json().await?;
let pull_count = pulled.memories.len();
let latest_pulled = pulled.memories.last().map(|m| m.updated_at.clone());
{
let conn = db::open(db_path)?;
for mem in &pulled.memories {
if crate::validate::validate_memory(mem).is_ok() {
let _ = db::insert_if_newer(&conn, mem);
}
}
if let Some(ref at) = latest_pulled {
db::sync_state_observe(&conn, local_agent_id, peer_url, at)?;
}
}
let last_pushed = {
let conn = db::open(db_path)?;
db::sync_state_last_pushed(&conn, local_agent_id, peer_url)
};
let outgoing = {
let conn = db::open(db_path)?;
db::memories_updated_since(&conn, last_pushed.as_deref(), batch_size)?
};
let push_count = outgoing.len();
let latest_pushed = outgoing.last().map(|m| m.updated_at.clone());
if !outgoing.is_empty() {
let body = serde_json::json!({
"sender_agent_id": local_agent_id,
"sender_clock": { "entries": {} },
"memories": outgoing,
"dry_run": false,
});
let mut req = client
.post(format!("{peer_url}/api/v1/sync/push"))
.header("x-agent-id", local_agent_id)
.header("content-type", "application/json")
.json(&body);
if let Some(key) = api_key {
req = req.header("x-api-key", key);
}
let resp = req.send().await?;
if !resp.status().is_success() {
anyhow::bail!("sync-daemon: push status {}", resp.status());
}
if let Some(at) = latest_pushed {
let conn = db::open(db_path)?;
db::sync_state_record_push(&conn, local_agent_id, peer_url, &at)?;
}
}
tracing::info!("sync-daemon: peer={peer_url} pulled={pull_count} pushed={push_count}");
Ok(())
}
pub async fn run_sync_daemon_with_shutdown(
db_path: PathBuf,
local_agent_id: String,
peers: Vec<String>,
api_key: Option<String>,
interval_secs: u64,
batch_size: usize,
shutdown: Arc<Notify>,
) -> Result<()> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
run_sync_daemon_with_shutdown_using_client(
client,
db_path,
local_agent_id,
peers,
api_key,
interval_secs,
batch_size,
shutdown,
)
.await
}
pub async fn run_sync_daemon_with_shutdown_using_client(
client: reqwest::Client,
db_path: PathBuf,
local_agent_id: String,
peers: Vec<String>,
api_key: Option<String>,
interval_secs: u64,
batch_size: usize,
shutdown: Arc<Notify>,
) -> Result<()> {
let interval = interval_secs.max(1);
let batch_size = batch_size.max(1);
let db_path_owned: Arc<Path> = Arc::from(db_path.as_path());
let local_agent_id_arc: Arc<str> = Arc::from(local_agent_id.as_str());
let api_key_arc: Option<Arc<str>> = api_key.as_deref().map(Arc::from);
let peers_arc: Vec<Arc<str>> = peers.iter().map(|s| Arc::from(s.as_str())).collect();
loop {
let mut set: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
for peer_url in &peers_arc {
let client = client.clone();
let db_path = db_path_owned.clone();
let local_agent_id = local_agent_id_arc.clone();
let peer_url = peer_url.clone();
let api_key = api_key_arc.clone();
set.spawn(async move {
if let Err(e) = sync_cycle_once(
&client,
&db_path,
&local_agent_id,
&peer_url,
api_key.as_deref(),
batch_size,
)
.await
{
tracing::warn!("sync-daemon: peer {peer_url} cycle failed: {e}");
}
});
}
while set.join_next().await.is_some() {}
tokio::select! {
() = tokio::time::sleep(Duration::from_secs(interval)) => {}
() = shutdown.notified() => {
tracing::info!("sync-daemon: shutdown signal received");
return Ok(());
}
}
}
}
pub async fn run_curator_daemon_with_shutdown(
db_path: PathBuf,
cfg: crate::curator::CuratorConfig,
shutdown: Arc<Notify>,
) -> Result<()> {
let shutdown_flag = Arc::new(AtomicBool::new(false));
let shutdown_flag_for_signal = shutdown_flag.clone();
tokio::spawn(async move {
shutdown.notified().await;
shutdown_flag_for_signal.store(true, Ordering::Relaxed);
});
let llm_arc: Option<Arc<crate::llm::OllamaClient>> = None;
let db_owned = db_path;
tokio::task::spawn_blocking(move || {
crate::curator::run_daemon(db_owned, llm_arc, cfg, shutdown_flag);
})
.await
.map_err(|e| anyhow::anyhow!("curator daemon join: {e}"))?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub async fn run_curator_daemon_with_primitives(
db_path: PathBuf,
interval_secs: u64,
max_ops_per_cycle: usize,
dry_run: bool,
include_namespaces: Vec<String>,
exclude_namespaces: Vec<String>,
ollama_model: Option<String>,
shutdown: Arc<Notify>,
) -> Result<()> {
let cfg = crate::curator::CuratorConfig {
interval_secs,
max_ops_per_cycle,
dry_run,
include_namespaces,
exclude_namespaces,
};
let llm: Option<Arc<crate::llm::OllamaClient>> =
ollama_model.and_then(|m| crate::llm::OllamaClient::new(&m).ok().map(Arc::new));
let shutdown_flag = Arc::new(AtomicBool::new(false));
let shutdown_flag_for_signal = shutdown_flag.clone();
tokio::spawn(async move {
shutdown.notified().await;
shutdown_flag_for_signal.store(true, Ordering::Relaxed);
});
tokio::task::spawn_blocking(move || {
crate::curator::run_daemon(db_path, llm, cfg, shutdown_flag);
})
.await
.map_err(|e| anyhow::anyhow!("curator daemon join: {e}"))?;
Ok(())
}
fn urlencoding_minimal(s: &str) -> String {
use std::fmt::Write as _;
let mut out = String::with_capacity(s.len());
for b in s.bytes() {
match b {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
out.push(b as char);
}
_ => {
let _ = write!(out, "%{b:02X}");
}
}
}
out
}
#[derive(serde::Deserialize)]
struct SyncSinceResponse {
#[allow(dead_code)]
count: usize,
#[allow(dead_code)]
limit: usize,
memories: Vec<crate::models::Memory>,
}
#[allow(dead_code)]
fn _imports_in_use(_: Instant, _: Duration) {}
#[cfg(test)]
mod tests {
use super::*;
use crate::cli::test_utils::TestEnv;
use crate::config::ResolvedTtl;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use tower::ServiceExt as _;
fn args_with_db(_db: &Path) -> ServeArgs {
ServeArgs {
host: "127.0.0.1".to_string(),
port: 0,
tls_cert: None,
tls_key: None,
mtls_allowlist: None,
shutdown_grace_secs: 30,
quorum_writes: 0,
quorum_peers: vec![],
quorum_timeout_ms: 2000,
quorum_client_cert: None,
quorum_client_key: None,
quorum_ca_cert: None,
catchup_interval_secs: 0,
}
}
fn keyword_app_state(db_path: &Path) -> AppState {
let conn = db::open(db_path).unwrap();
let db_state: Db = Arc::new(Mutex::new((
conn,
db_path.to_path_buf(),
ResolvedTtl::default(),
true,
)));
AppState {
db: db_state,
embedder: Arc::new(None),
vector_index: Arc::new(Mutex::new(None)),
federation: Arc::new(None),
tier_config: Arc::new(FeatureTier::Keyword.config()),
scoring: Arc::new(crate::config::ResolvedScoring::default()),
}
}
fn env_var_lock() -> std::sync::MutexGuard<'static, ()> {
use std::sync::OnceLock;
static LOCK: OnceLock<std::sync::Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| std::sync::Mutex::new(()))
.lock()
.unwrap_or_else(|e| e.into_inner())
}
#[test]
fn test_is_write_command_all_variants() {
let writes: &[&[&str]] = &[
&["ai-memory", "store", "title", "content"],
&["ai-memory", "update", "id123", "--title", "t"],
&["ai-memory", "delete", "id123"],
&["ai-memory", "promote", "id123"],
&["ai-memory", "forget", "pattern"],
&["ai-memory", "link", "a", "b"],
&["ai-memory", "consolidate", "ids"],
&["ai-memory", "resolve", "a", "b"],
&["ai-memory", "sync", "--peer", "/tmp/peer.db"],
&[
"ai-memory",
"sync-daemon",
"--peers",
"http://x",
"--interval-secs",
"60",
],
&["ai-memory", "import"],
&["ai-memory", "auto-consolidate"],
&["ai-memory", "gc"],
];
let mut writes_checked = 0;
for argv in writes {
if let Ok(cli) = Cli::try_parse_from(*argv) {
assert!(
is_write_command(&cli.command),
"expected write for {argv:?}"
);
writes_checked += 1;
}
}
assert!(
writes_checked >= 5,
"expected at least 5 write variants checked, got {writes_checked}"
);
let reads: &[&[&str]] = &[
&["ai-memory", "mcp"],
&["ai-memory", "recall", "context"],
&["ai-memory", "search", "query"],
&["ai-memory", "get", "id"],
&["ai-memory", "list"],
&["ai-memory", "stats"],
&["ai-memory", "namespaces"],
&["ai-memory", "export"],
&["ai-memory", "shell"],
&["ai-memory", "man"],
&["ai-memory", "completions", "bash"],
&["ai-memory", "archive", "list"],
&["ai-memory", "agents", "list"],
&["ai-memory", "pending", "list"],
&["ai-memory", "bench"],
&["ai-memory", "serve", "--host", "127.0.0.1", "--port", "0"],
];
let mut reads_checked = 0;
for argv in reads {
if let Ok(cli) = Cli::try_parse_from(*argv) {
assert!(
!is_write_command(&cli.command),
"expected read for {argv:?}"
);
reads_checked += 1;
}
}
assert!(
reads_checked >= 8,
"expected at least 8 read variants checked, got {reads_checked}"
);
assert!(is_write_command(&Command::Gc));
assert!(!is_write_command(&Command::Stats));
assert!(!is_write_command(&Command::Namespaces));
assert!(!is_write_command(&Command::Export));
assert!(!is_write_command(&Command::Shell));
assert!(!is_write_command(&Command::Man));
assert!(!is_write_command(&Command::Mcp {
tier: "keyword".to_string(),
profile: None,
}));
}
#[tokio::test]
async fn test_router_has_health_endpoint() {
let env = TestEnv::fresh();
let app_state = keyword_app_state(&env.db_path);
let api_key_state = ApiKeyState { key: None };
let router = build_router(app_state, api_key_state);
let resp = router
.oneshot(
Request::builder()
.method("GET")
.uri("/api/v1/health")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_router_has_metrics_at_both_paths() {
let env = TestEnv::fresh();
let app_state = keyword_app_state(&env.db_path);
let api_key_state = ApiKeyState { key: None };
let r1 = build_router(app_state.clone(), api_key_state.clone())
.oneshot(
Request::builder()
.method("GET")
.uri("/metrics")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(r1.status(), StatusCode::OK);
let r2 = build_router(app_state, api_key_state)
.oneshot(
Request::builder()
.method("GET")
.uri("/api/v1/metrics")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(r2.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_router_lists_all_v1_memory_routes() {
let env = TestEnv::fresh();
let app_state = keyword_app_state(&env.db_path);
let api_key_state = ApiKeyState { key: None };
let router = build_router(app_state, api_key_state);
let resp = router
.oneshot(
Request::builder()
.method("GET")
.uri("/api/v1/memories")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert!(resp.status().is_success(), "got {}", resp.status());
}
#[tokio::test]
async fn test_router_applies_api_key_middleware_when_key_set() {
let env = TestEnv::fresh();
let app_state = keyword_app_state(&env.db_path);
let api_key_state = ApiKeyState {
key: Some("s3cret".to_string()),
};
let router = build_router(app_state, api_key_state);
let resp = router
.oneshot(
Request::builder()
.method("GET")
.uri("/api/v1/memories")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_router_skips_api_key_middleware_when_key_none() {
let env = TestEnv::fresh();
let app_state = keyword_app_state(&env.db_path);
let api_key_state = ApiKeyState { key: None };
let router = build_router(app_state, api_key_state);
let resp = router
.oneshot(
Request::builder()
.method("GET")
.uri("/api/v1/memories")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_build_embedder_keyword_tier_returns_none() {
let cfg = AppConfig::default();
let emb = build_embedder(FeatureTier::Keyword, &cfg).await;
assert!(emb.is_none());
}
#[tokio::test]
async fn test_build_embedder_load_failure_returns_none() {
}
#[test]
fn test_build_vector_index_no_embedder_returns_none() {
let env = TestEnv::fresh();
let conn = db::open(&env.db_path).unwrap();
assert!(build_vector_index(&conn, false).is_none());
}
#[test]
fn test_build_vector_index_empty_db_returns_empty_index() {
let env = TestEnv::fresh();
let conn = db::open(&env.db_path).unwrap();
let idx = build_vector_index(&conn, true);
assert!(
idx.is_some(),
"empty DB with embedder must yield empty index"
);
assert_eq!(idx.unwrap().len(), 0);
}
#[tokio::test(start_paused = true)]
async fn test_spawn_gc_loop_runs_and_can_be_aborted() {
let env = TestEnv::fresh();
let conn = db::open(&env.db_path).unwrap();
let state: Db = Arc::new(Mutex::new((
conn,
env.db_path.clone(),
ResolvedTtl::default(),
true,
)));
let h = spawn_gc_loop(state, None, Duration::from_secs(60));
tokio::time::advance(Duration::from_secs(61)).await;
tokio::task::yield_now().await;
h.abort();
let err = h.await.unwrap_err();
assert!(err.is_cancelled());
}
#[tokio::test(start_paused = true)]
async fn test_spawn_wal_checkpoint_loop_runs_and_can_be_aborted() {
let env = TestEnv::fresh();
let conn = db::open(&env.db_path).unwrap();
let state: Db = Arc::new(Mutex::new((
conn,
env.db_path.clone(),
ResolvedTtl::default(),
true,
)));
let h = spawn_wal_checkpoint_loop(state, Duration::from_secs(60));
tokio::time::advance(Duration::from_secs(31)).await;
tokio::task::yield_now().await;
tokio::time::advance(Duration::from_secs(60)).await;
tokio::task::yield_now().await;
h.abort();
let err = h.await.unwrap_err();
assert!(err.is_cancelled());
}
#[test]
fn test_passphrase_strips_trailing_newline() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("pass");
std::fs::write(&p, "secret\n").unwrap();
assert_eq!(passphrase_from_file(&p).unwrap(), "secret");
}
#[test]
fn test_passphrase_strips_trailing_crlf() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("pass");
std::fs::write(&p, "secret\r\n").unwrap();
assert_eq!(passphrase_from_file(&p).unwrap(), "secret");
}
#[test]
fn test_passphrase_empty_file_errors() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("empty");
std::fs::write(&p, "").unwrap();
let err = passphrase_from_file(&p).unwrap_err();
assert!(
err.to_string().contains("empty"),
"expected 'empty' error, got: {err}"
);
}
#[test]
fn test_passphrase_empty_after_trim_errors() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("nl-only");
std::fs::write(&p, "\n").unwrap();
let err = passphrase_from_file(&p).unwrap_err();
assert!(err.to_string().contains("empty"));
}
#[test]
fn test_passphrase_nonexistent_file_errors() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("does-not-exist");
let err = passphrase_from_file(&p).unwrap_err();
assert!(
err.to_string().contains("reading passphrase file")
|| err.chain().any(|e| e.to_string().contains("No such file"))
|| err.chain().any(|e| e.to_string().contains("cannot find")),
"got: {err:#}"
);
}
#[test]
fn test_passphrase_preserves_internal_whitespace() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("pass");
std::fs::write(&p, "my pass phrase\n").unwrap();
assert_eq!(passphrase_from_file(&p).unwrap(), "my pass phrase");
}
#[test]
fn test_anonymize_set_when_config_true_and_env_unset() {
let _g = env_var_lock();
unsafe { std::env::remove_var("AI_MEMORY_ANONYMIZE") };
let mut cfg = AppConfig::default();
cfg.identity = Some(crate::config::IdentityConfig {
anonymize_default: true,
});
apply_anonymize_default(&cfg);
assert_eq!(std::env::var("AI_MEMORY_ANONYMIZE").unwrap(), "1");
unsafe { std::env::remove_var("AI_MEMORY_ANONYMIZE") };
}
#[test]
fn test_anonymize_unchanged_when_env_already_set() {
let _g = env_var_lock();
unsafe { std::env::set_var("AI_MEMORY_ANONYMIZE", "0") };
let mut cfg = AppConfig::default();
cfg.identity = Some(crate::config::IdentityConfig {
anonymize_default: true,
});
apply_anonymize_default(&cfg);
assert_eq!(std::env::var("AI_MEMORY_ANONYMIZE").unwrap(), "0");
unsafe { std::env::remove_var("AI_MEMORY_ANONYMIZE") };
}
#[test]
fn test_anonymize_unchanged_when_config_false() {
let _g = env_var_lock();
unsafe { std::env::remove_var("AI_MEMORY_ANONYMIZE") };
let cfg = AppConfig::default();
apply_anonymize_default(&cfg);
assert!(std::env::var("AI_MEMORY_ANONYMIZE").is_err());
}
#[tokio::test]
async fn test_bootstrap_serve_keyword_tier_no_embedder() {
let env = TestEnv::fresh();
let mut cfg = AppConfig::default();
cfg.tier = Some("keyword".to_string());
let args = args_with_db(&env.db_path);
let bs = bootstrap_serve(&env.db_path, &args, &cfg).await.unwrap();
assert!(bs.app_state.embedder.is_none());
let vi = bs.app_state.vector_index.lock().await;
assert!(vi.is_none());
assert_eq!(bs.task_handles.len(), 2);
for h in bs.task_handles {
h.abort();
}
}
#[tokio::test]
async fn test_bootstrap_serve_with_api_key_logs_enabled() {
let env = TestEnv::fresh();
let mut cfg = AppConfig::default();
cfg.tier = Some("keyword".to_string());
cfg.api_key = Some("test-key".to_string());
let args = args_with_db(&env.db_path);
let bs = bootstrap_serve(&env.db_path, &args, &cfg).await.unwrap();
assert_eq!(bs.api_key_state.key.as_deref(), Some("test-key"));
for h in bs.task_handles {
h.abort();
}
}
#[tokio::test]
async fn test_bootstrap_serve_federation_disabled_when_quorum_zero() {
let env = TestEnv::fresh();
let mut cfg = AppConfig::default();
cfg.tier = Some("keyword".to_string());
let args = args_with_db(&env.db_path);
let bs = bootstrap_serve(&env.db_path, &args, &cfg).await.unwrap();
assert!(bs.app_state.federation.is_none());
for h in bs.task_handles {
h.abort();
}
}
#[tokio::test]
async fn test_bootstrap_serve_federation_enabled_attaches_config() {
let env = TestEnv::fresh();
let mut cfg = AppConfig::default();
cfg.tier = Some("keyword".to_string());
let mut args = args_with_db(&env.db_path);
args.quorum_writes = 1;
args.quorum_peers = vec!["http://127.0.0.1:65530".to_string()];
args.quorum_timeout_ms = 100;
args.catchup_interval_secs = 0;
let bs = bootstrap_serve(&env.db_path, &args, &cfg).await.unwrap();
assert!(bs.app_state.federation.is_some());
for h in bs.task_handles {
h.abort();
}
}
#[tokio::test]
async fn test_bootstrap_serve_federation_enabled_with_catchup_loop() {
let env = TestEnv::fresh();
let mut cfg = AppConfig::default();
cfg.tier = Some("keyword".to_string());
let mut args = args_with_db(&env.db_path);
args.quorum_writes = 1;
args.quorum_peers = vec!["http://127.0.0.1:65531".to_string()];
args.quorum_timeout_ms = 100;
args.catchup_interval_secs = 3600; let bs = bootstrap_serve(&env.db_path, &args, &cfg).await.unwrap();
assert!(bs.app_state.federation.is_some());
for h in bs.task_handles {
h.abort();
}
}
#[tokio::test]
async fn test_bootstrap_serve_federation_invalid_peer_errors() {
let env = TestEnv::fresh();
let mut cfg = AppConfig::default();
cfg.tier = Some("keyword".to_string());
let mut args = args_with_db(&env.db_path);
args.quorum_writes = 1;
args.quorum_peers = vec![
"http://127.0.0.1:65532".to_string(),
"http://127.0.0.1:65532/".to_string(), ];
let res = bootstrap_serve(&env.db_path, &args, &cfg).await;
let err = match res {
Ok(_) => panic!("expected error from duplicate peer URLs"),
Err(e) => e,
};
let s = format!("{err:#}");
assert!(
s.contains("federation") || s.contains("duplicate"),
"got: {s}"
);
}
#[test]
fn test_build_vector_index_populated_db_returns_built_index() {
let env = TestEnv::fresh();
let conn = db::open(&env.db_path).unwrap();
let now = chrono::Utc::now().to_rfc3339();
let mem = crate::models::Memory {
id: uuid::Uuid::new_v4().to_string(),
tier: crate::models::Tier::Mid,
namespace: "ns".to_string(),
title: "t".to_string(),
content: "c".to_string(),
tags: vec![],
priority: 5,
confidence: 1.0,
source: "test".to_string(),
access_count: 0,
created_at: now.clone(),
updated_at: now,
last_accessed_at: None,
expires_at: None,
metadata: crate::models::default_metadata(),
};
let id = db::insert(&conn, &mem).unwrap();
db::set_embedding(&conn, &id, &[1.0, 0.0, 0.0]).unwrap();
let idx = build_vector_index(&conn, true).expect("populated index");
assert!(
idx.len() >= 1,
"expected non-empty index, got len={}",
idx.len()
);
}
#[tokio::test(start_paused = true)]
async fn test_spawn_gc_loop_purges_expired_memories() {
let env = TestEnv::fresh();
let conn = db::open(&env.db_path).unwrap();
let past = (chrono::Utc::now() - chrono::Duration::days(1)).to_rfc3339();
let now = chrono::Utc::now().to_rfc3339();
let mem = crate::models::Memory {
id: uuid::Uuid::new_v4().to_string(),
tier: crate::models::Tier::Short,
namespace: "ns-gc".to_string(),
title: "stale".to_string(),
content: "stale".to_string(),
tags: vec![],
priority: 1,
confidence: 1.0,
source: "test".to_string(),
access_count: 0,
created_at: now.clone(),
updated_at: now,
last_accessed_at: None,
expires_at: Some(past),
metadata: crate::models::default_metadata(),
};
db::insert(&conn, &mem).unwrap();
drop(conn);
let conn = db::open(&env.db_path).unwrap();
let state: Db = Arc::new(Mutex::new((
conn,
env.db_path.clone(),
ResolvedTtl::default(),
true,
)));
let h = spawn_gc_loop(state.clone(), Some(1), Duration::from_secs(60));
tokio::time::advance(Duration::from_secs(61)).await;
tokio::task::yield_now().await;
tokio::time::advance(Duration::from_secs(61)).await;
tokio::task::yield_now().await;
h.abort();
let _ = h.await;
}
#[tokio::test(start_paused = true)]
async fn test_spawn_wal_checkpoint_loop_runs_multiple_cycles() {
let env = TestEnv::fresh();
let conn = db::open(&env.db_path).unwrap();
let state: Db = Arc::new(Mutex::new((
conn,
env.db_path.clone(),
ResolvedTtl::default(),
true,
)));
let h = spawn_wal_checkpoint_loop(state, Duration::from_secs(2));
for _ in 0..4 {
tokio::time::advance(Duration::from_secs(2)).await;
tokio::task::yield_now().await;
}
h.abort();
let _ = h.await;
}
#[test]
fn test_urlencoding_minimal_round_trip() {
assert_eq!(urlencoding_minimal("abcXYZ-_.~"), "abcXYZ-_.~");
assert_eq!(urlencoding_minimal("0123456789"), "0123456789");
assert_eq!(urlencoding_minimal("a:b"), "a%3Ab");
assert_eq!(urlencoding_minimal("a/b"), "a%2Fb");
assert_eq!(urlencoding_minimal("a@b"), "a%40b");
assert_eq!(urlencoding_minimal("a+b"), "a%2Bb");
assert_eq!(urlencoding_minimal(" "), "%20");
assert_eq!(urlencoding_minimal(""), "");
assert_eq!(
urlencoding_minimal("2024-01-02T03:04:05+00:00"),
"2024-01-02T03%3A04%3A05%2B00%3A00"
);
}
fn no_config_env() -> std::sync::MutexGuard<'static, ()> {
env_var_lock()
}
#[tokio::test]
async fn test_run_dispatch_stats_command() {
let _g = no_config_env();
let env = TestEnv::fresh();
let cfg = AppConfig::default();
let cli =
Cli::try_parse_from(["ai-memory", "--db", env.db_path.to_str().unwrap(), "stats"])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[tokio::test]
async fn test_run_dispatch_namespaces_command() {
let _g = no_config_env();
let env = TestEnv::fresh();
let cfg = AppConfig::default();
let cli = Cli::try_parse_from([
"ai-memory",
"--db",
env.db_path.to_str().unwrap(),
"namespaces",
])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[tokio::test]
async fn test_run_dispatch_export_command() {
let _g = no_config_env();
let env = TestEnv::fresh();
let cfg = AppConfig::default();
let cli =
Cli::try_parse_from(["ai-memory", "--db", env.db_path.to_str().unwrap(), "export"])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[tokio::test]
async fn test_run_dispatch_list_command() {
let _g = no_config_env();
let env = TestEnv::fresh();
let cfg = AppConfig::default();
let cli = Cli::try_parse_from(["ai-memory", "--db", env.db_path.to_str().unwrap(), "list"])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[tokio::test]
async fn test_run_dispatch_search_command() {
let _g = no_config_env();
let env = TestEnv::fresh();
let cfg = AppConfig::default();
let cli = Cli::try_parse_from([
"ai-memory",
"--db",
env.db_path.to_str().unwrap(),
"search",
"anyq",
])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[tokio::test]
async fn test_run_dispatch_archive_list_command() {
let _g = no_config_env();
let env = TestEnv::fresh();
let cfg = AppConfig::default();
let cli = Cli::try_parse_from([
"ai-memory",
"--db",
env.db_path.to_str().unwrap(),
"archive",
"list",
])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[tokio::test]
async fn test_run_dispatch_agents_list_command() {
let _g = no_config_env();
let env = TestEnv::fresh();
let cfg = AppConfig::default();
let cli = Cli::try_parse_from([
"ai-memory",
"--db",
env.db_path.to_str().unwrap(),
"agents",
"list",
])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[tokio::test]
async fn test_run_dispatch_pending_list_command() {
let _g = no_config_env();
let env = TestEnv::fresh();
let cfg = AppConfig::default();
let cli = Cli::try_parse_from([
"ai-memory",
"--db",
env.db_path.to_str().unwrap(),
"pending",
"list",
])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[tokio::test]
async fn test_run_dispatch_completions_command() {
let _g = no_config_env();
let env = TestEnv::fresh();
let cfg = AppConfig::default();
let cli = Cli::try_parse_from([
"ai-memory",
"--db",
env.db_path.to_str().unwrap(),
"completions",
"bash",
])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[tokio::test]
async fn test_run_dispatch_man_command() {
let _g = no_config_env();
let env = TestEnv::fresh();
let cfg = AppConfig::default();
let cli = Cli::try_parse_from(["ai-memory", "--db", env.db_path.to_str().unwrap(), "man"])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[tokio::test]
async fn test_run_dispatch_gc_triggers_post_run_checkpoint() {
let _g = no_config_env();
let env = TestEnv::fresh();
let cfg = AppConfig::default();
let cli = Cli::try_parse_from(["ai-memory", "--db", env.db_path.to_str().unwrap(), "gc"])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[tokio::test]
async fn test_run_dispatch_resolve_command() {
let _g = no_config_env();
let env = TestEnv::fresh();
let id_a = crate::cli::test_utils::seed_memory(&env.db_path, "ns", "old", "old fact");
let id_b = crate::cli::test_utils::seed_memory(&env.db_path, "ns", "new", "new fact");
let cfg = AppConfig::default();
let cli = Cli::try_parse_from([
"ai-memory",
"--db",
env.db_path.to_str().unwrap(),
"resolve",
&id_a,
&id_b,
])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[tokio::test]
async fn test_run_dispatch_get_command() {
let _g = no_config_env();
let env = TestEnv::fresh();
let id = crate::cli::test_utils::seed_memory(&env.db_path, "ns", "t", "c");
let cfg = AppConfig::default();
let cli = Cli::try_parse_from([
"ai-memory",
"--db",
env.db_path.to_str().unwrap(),
"get",
&id,
])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[tokio::test]
async fn test_run_dispatch_promote_triggers_write_checkpoint() {
let _g = no_config_env();
let env = TestEnv::fresh();
let id = crate::cli::test_utils::seed_memory(&env.db_path, "ns", "t", "c");
let cfg = AppConfig::default();
let cli = Cli::try_parse_from([
"ai-memory",
"--db",
env.db_path.to_str().unwrap(),
"promote",
&id,
])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[tokio::test]
async fn test_run_dispatch_bench_smoke_runs_one_iteration() {
let _g = no_config_env();
let env = TestEnv::fresh();
let cfg = AppConfig::default();
let cli = Cli::try_parse_from([
"ai-memory",
"--db",
env.db_path.to_str().unwrap(),
"bench",
"--iterations",
"1",
"--warmup",
"0",
])
.unwrap();
let _ = run(cli, &cfg).await;
}
#[tokio::test]
async fn test_run_dispatch_bench_json_with_history() {
let _g = no_config_env();
let env = TestEnv::fresh();
let history = env.db_path.with_file_name("hist.jsonl");
let cfg = AppConfig::default();
let cli = Cli::try_parse_from([
"ai-memory",
"--db",
env.db_path.to_str().unwrap(),
"bench",
"--iterations",
"1",
"--warmup",
"0",
"--json",
"--history",
history.to_str().unwrap(),
])
.unwrap();
let _ = run(cli, &cfg).await;
if history.exists() {
let content = std::fs::read_to_string(&history).unwrap();
assert!(content.contains("captured_at") || !content.is_empty());
}
}
#[cfg(feature = "sal")]
#[tokio::test]
async fn test_run_dispatch_migrate_sqlite_to_sqlite_dry_run() {
let _g = no_config_env();
let src_env = TestEnv::fresh();
let dst_env = TestEnv::fresh();
crate::cli::test_utils::seed_memory(&src_env.db_path, "ns-mig", "t", "c");
let from = format!("sqlite://{}", src_env.db_path.display());
let to = format!("sqlite://{}", dst_env.db_path.display());
let cfg = AppConfig::default();
let cli = Cli::try_parse_from([
"ai-memory",
"--db",
src_env.db_path.to_str().unwrap(),
"migrate",
"--from",
&from,
"--to",
&to,
"--dry-run",
])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[cfg(feature = "sal")]
#[tokio::test]
async fn test_run_dispatch_migrate_json_output() {
let _g = no_config_env();
let src_env = TestEnv::fresh();
let dst_env = TestEnv::fresh();
crate::cli::test_utils::seed_memory(&src_env.db_path, "ns-mig", "t", "c");
let from = format!("sqlite://{}", src_env.db_path.display());
let to = format!("sqlite://{}", dst_env.db_path.display());
let cfg = AppConfig::default();
let cli = Cli::try_parse_from([
"ai-memory",
"--db",
src_env.db_path.to_str().unwrap(),
"migrate",
"--from",
&from,
"--to",
&to,
"--json",
])
.unwrap();
run(cli, &cfg).await.unwrap();
}
#[tokio::test]
async fn test_run_with_db_passphrase_file_exports_env() {
let _g = env_var_lock();
unsafe { std::env::remove_var("AI_MEMORY_DB_PASSPHRASE") };
let env = TestEnv::fresh();
let pass_path = env.db_path.with_file_name("pass");
std::fs::write(&pass_path, "test-passphrase\n").unwrap();
let cfg = AppConfig::default();
let cli = Cli::try_parse_from([
"ai-memory",
"--db",
env.db_path.to_str().unwrap(),
"--db-passphrase-file",
pass_path.to_str().unwrap(),
"stats",
])
.unwrap();
run(cli, &cfg).await.unwrap();
assert_eq!(
std::env::var("AI_MEMORY_DB_PASSPHRASE").unwrap(),
"test-passphrase"
);
unsafe { std::env::remove_var("AI_MEMORY_DB_PASSPHRASE") };
}
#[test]
fn test_init_tracing_is_idempotent() {
init_tracing();
init_tracing();
}
#[tokio::test]
async fn test_serve_http_with_shutdown_future_serves_then_stops() {
let env = TestEnv::fresh();
let app_state = keyword_app_state(&env.db_path);
let api_key_state = ApiKeyState { key: None };
let port = {
let l = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let p = l.local_addr().unwrap().port();
drop(l);
p
};
let addr = format!("127.0.0.1:{port}");
let shutdown = Arc::new(Notify::new());
let shutdown_clone = shutdown.clone();
let handle = tokio::spawn(async move {
serve_http_with_shutdown_future(&addr, api_key_state, app_state, async move {
shutdown_clone.notified().await;
})
.await
});
for _ in 0..40 {
if let Ok(client) = reqwest::Client::builder()
.timeout(Duration::from_millis(200))
.build()
&& client
.get(format!("http://127.0.0.1:{port}/api/v1/health"))
.send()
.await
.is_ok()
{
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
shutdown.notify_one();
let res = handle.await.unwrap();
assert!(res.is_ok(), "serve future returned: {res:?}");
}
#[tokio::test]
async fn test_serve_http_with_shutdown_future_bind_failure_errors() {
let env = TestEnv::fresh();
let app_state = keyword_app_state(&env.db_path);
let api_key_state = ApiKeyState { key: None };
let res = serve_http_with_shutdown_future(
"definitely-not-an-address:99999",
api_key_state,
app_state,
async {},
)
.await;
assert!(res.is_err(), "expected bind error, got: {res:?}");
}
}