use std::{path::PathBuf, sync::Arc};
use anyhow::Context;
use clap::{Args, Parser, Subcommand};
use mcp_session::{BoundedSessionManager, SessionConfig};
use rmcp::transport::streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService};
use tokio_util::sync::CancellationToken;
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use memory_mcp::auth::{self, AuthProvider, StoreBackend};
use memory_mcp::embedding::{CandleEmbeddingEngine, EmbeddingBackend, MODEL_ID};
use memory_mcp::index::ScopedIndex;
use memory_mcp::repo::MemoryRepo;
use memory_mcp::server::MemoryServer;
use memory_mcp::types::{validate_branch_name, AppState};
#[derive(Parser)]
#[command(
name = "memory-mcp",
about = "Semantic memory MCP server for AI agents"
)]
struct Cli {
#[command(subcommand)]
command: Option<Command>,
}
#[derive(Subcommand)]
enum Command {
Serve(ServeArgs),
Auth(AuthCommand),
Warmup(WarmupArgs),
}
#[derive(Args)]
struct AuthCommand {
#[command(subcommand)]
action: AuthAction,
}
#[derive(Subcommand)]
enum AuthAction {
Login(LoginArgs),
Status,
}
#[derive(Args)]
struct LoginArgs {
#[arg(long, value_enum)]
store: Option<StoreBackend>,
#[cfg(feature = "k8s")]
#[arg(long, default_value = "memory-mcp")]
k8s_namespace: String,
#[cfg(feature = "k8s")]
#[arg(long, default_value = "memory-mcp-github-token")]
k8s_secret_name: String,
}
#[derive(Args)]
struct ServeArgs {
#[arg(long, default_value = "127.0.0.1:8080", env = "MEMORY_MCP_BIND")]
bind: String,
#[arg(long, default_value = "~/.memory-mcp", env = "MEMORY_MCP_REPO_PATH")]
repo_path: String,
#[arg(long, default_value = "/mcp", env = "MEMORY_MCP_PATH")]
mcp_path: String,
#[arg(long, env = "MEMORY_MCP_REMOTE_URL")]
remote_url: Option<String>,
#[arg(long, default_value = "main", env = "MEMORY_MCP_BRANCH")]
branch: String,
#[arg(
long,
default_value_t = 100,
env = "MEMORY_MCP_MAX_SESSIONS",
value_parser = parse_nonzero_usize
)]
max_sessions: usize,
#[arg(long, default_value_t = 10, env = "MEMORY_MCP_SESSION_RATE_LIMIT")]
session_rate_limit: usize,
#[arg(
long,
default_value_t = 60,
env = "MEMORY_MCP_SESSION_RATE_WINDOW_SECS"
)]
session_rate_window_secs: u64,
}
#[derive(Args)]
struct WarmupArgs {}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
#[cfg(unix)]
{
unsafe {
libc::umask(0o077);
}
}
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "info".to_string().into()),
)
.with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr))
.init();
let cli = Cli::parse();
match cli.command {
None => {
let cli = Cli::parse_from(["memory-mcp", "serve"]);
match cli.command {
Some(Command::Serve(args)) => run_serve(args).await?,
_ => unreachable!(),
}
}
Some(Command::Serve(args)) => run_serve(args).await?,
Some(Command::Warmup(args)) => run_warmup(args).await?,
Some(Command::Auth(auth_cmd)) => match auth_cmd.action {
AuthAction::Login(login_args) => {
#[cfg(feature = "k8s")]
let k8s_config = if matches!(login_args.store, Some(StoreBackend::K8sSecret)) {
Some(auth::K8sSecretConfig {
namespace: login_args.k8s_namespace.clone(),
secret_name: login_args.k8s_secret_name.clone(),
})
} else {
None
};
auth::device_flow_login(
login_args.store,
#[cfg(feature = "k8s")]
k8s_config,
)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
}
AuthAction::Status => {
let provider = AuthProvider::default();
auth::print_auth_status(&provider);
}
},
}
Ok(())
}
async fn run_serve(args: ServeArgs) -> anyhow::Result<()> {
validate_branch_name(&args.branch).context("invalid --branch value")?;
let repo_path = expand_path(&args.repo_path)?;
info!("repo path: {}", repo_path.display());
let remote_url = args.remote_url.filter(|u| !u.is_empty());
let repo = MemoryRepo::init_or_open(&repo_path, remote_url.as_deref())
.with_context(|| format!("failed to open/init repo at {}", repo_path.display()))?;
let embedding: Box<dyn EmbeddingBackend> =
Box::new(CandleEmbeddingEngine::new().context("failed to init embedding engine")?);
let dimensions = embedding.dimensions();
let index_dir = repo_path.join(".memory-mcp-index");
let old_index = index_dir.join("index.usearch");
if old_index.exists() {
if let Err(e) = std::fs::remove_file(&old_index) {
tracing::warn!(error = %e, "failed to remove legacy index file");
}
let keys_file = index_dir.join("index.usearch.keys.json");
if let Err(e) = std::fs::remove_file(&keys_file) {
if e.kind() != std::io::ErrorKind::NotFound {
tracing::warn!(error = %e, "failed to remove legacy index keys file");
}
}
info!("removed legacy single-index files");
}
let index = ScopedIndex::load(&index_dir, dimensions).unwrap_or_else(|e| {
tracing::warn!("could not load scoped index ({}), creating fresh", e);
ScopedIndex::new(dimensions).expect("failed to create scoped index")
});
let auth = AuthProvider::new();
let state = Arc::new(AppState::new(
Arc::new(repo),
args.branch.clone(),
embedding,
index,
auth,
));
let state_for_shutdown = Arc::clone(&state);
let ct = CancellationToken::new();
let ct_child = ct.child_token();
#[allow(clippy::field_reassign_with_default)]
let service = StreamableHttpService::new(
move || Ok(MemoryServer::new(Arc::clone(&state))),
Arc::new({
let mut session_config = SessionConfig::default();
session_config.keep_alive = Some(std::time::Duration::from_secs(4 * 60 * 60));
let mgr = BoundedSessionManager::new(session_config, args.max_sessions);
if args.session_rate_limit > 0 && args.session_rate_window_secs > 0 {
mgr.with_rate_limit(
args.session_rate_limit,
std::time::Duration::from_secs(args.session_rate_window_secs),
)
} else {
mgr
}
}),
{
let mut server_config = StreamableHttpServerConfig::default();
server_config.cancellation_token = ct_child;
server_config
},
);
let mcp_path = args.mcp_path.clone();
let router = axum::Router::new()
.route(
"/healthz",
axum::routing::get(|| async {
axum::response::Json(serde_json::json!({"status": "ok"}))
}),
)
.nest_service(&mcp_path, service);
let listener = tokio::net::TcpListener::bind(&args.bind)
.await
.with_context(|| format!("failed to bind to {}", args.bind))?;
info!("listening on {} (MCP at {})", args.bind, args.mcp_path);
axum::serve(listener, router)
.with_graceful_shutdown(async move {
tokio::signal::ctrl_c()
.await
.expect("failed to listen for ctrl-c");
info!("shutdown signal received");
ct.cancel();
})
.await
.context("server error")?;
std::fs::create_dir_all(&index_dir)
.with_context(|| format!("failed to create index dir {}", index_dir.display()))?;
if let Err(e) = state_for_shutdown.index.save(&index_dir) {
tracing::warn!("failed to persist vector index on shutdown: {}", e);
} else {
info!("vector index saved to {}", index_dir.display());
}
Ok(())
}
async fn run_warmup(_args: WarmupArgs) -> anyhow::Result<()> {
info!("warming up embedding model '{}'", MODEL_ID);
let engine = CandleEmbeddingEngine::new().context("failed to init embedding engine")?;
let _ = engine
.embed(&["warmup".to_string()])
.await
.context("warmup embed failed")?;
info!("warmup complete");
Ok(())
}
fn parse_nonzero_usize(s: &str) -> Result<usize, String> {
let n: usize = s
.parse()
.map_err(|_| format!("'{s}' is not a valid integer"))?;
if n == 0 {
return Err("value must be at least 1".to_owned());
}
Ok(n)
}
fn expand_path(path: &str) -> anyhow::Result<PathBuf> {
match path.strip_prefix('~') {
Some(rest) if rest.is_empty() || rest.starts_with('/') => {
let home = dirs::home_dir().ok_or_else(|| {
anyhow::anyhow!(
"could not expand '~': home directory could not be determined; \
please provide --repo-path explicitly or set HOME"
)
})?;
Ok(home.join(rest.strip_prefix('/').unwrap_or(rest)))
}
Some(_) => anyhow::bail!(
"~user path expansion is not supported; \
please use an absolute path or ~/..."
),
None => Ok(PathBuf::from(path)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use clap::Parser;
#[test]
fn test_cli_bare_has_no_command() {
let cli = Cli::try_parse_from(["memory-mcp"]).expect("bare invocation should parse");
assert!(cli.command.is_none());
}
#[test]
fn test_cli_serve_with_bind() {
let cli = Cli::try_parse_from(["memory-mcp", "serve", "--bind", "0.0.0.0:9090"])
.expect("serve --bind should parse");
match cli.command {
Some(Command::Serve(args)) => assert_eq!(args.bind, "0.0.0.0:9090"),
_ => panic!("expected Serve command"),
}
}
#[test]
fn test_cli_auth_login_store_keyring() {
let cli = Cli::try_parse_from(["memory-mcp", "auth", "login", "--store", "keyring"])
.expect("auth login --store keyring should parse");
match cli.command {
Some(Command::Auth(auth_cmd)) => match auth_cmd.action {
AuthAction::Login(login_args) => {
assert!(matches!(login_args.store, Some(StoreBackend::Keyring)));
}
_ => panic!("expected Login action"),
},
_ => panic!("expected Auth command"),
}
}
#[test]
fn test_cli_auth_status() {
let cli = Cli::try_parse_from(["memory-mcp", "auth", "status"])
.expect("auth status should parse");
match cli.command {
Some(Command::Auth(auth_cmd)) => {
assert!(matches!(auth_cmd.action, AuthAction::Status));
}
_ => panic!("expected Auth command"),
}
}
#[test]
fn test_bare_serve_reparsed_uses_env_var() {
let cli = Cli::parse_from(["memory-mcp", "serve"]);
assert!(matches!(cli.command, Some(Command::Serve(_))));
}
#[cfg(feature = "k8s")]
#[test]
fn test_cli_auth_login_store_k8s_secret() {
let cli = Cli::try_parse_from(["memory-mcp", "auth", "login", "--store", "k8s-secret"])
.expect("auth login --store k8s-secret should parse");
match cli.command {
Some(Command::Auth(auth_cmd)) => match auth_cmd.action {
AuthAction::Login(login_args) => {
assert!(matches!(login_args.store, Some(StoreBackend::K8sSecret)));
assert_eq!(login_args.k8s_namespace, "memory-mcp");
assert_eq!(login_args.k8s_secret_name, "memory-mcp-github-token");
}
_ => panic!("expected Login action"),
},
_ => panic!("expected Auth command"),
}
}
#[test]
fn test_parse_nonzero_usize_zero_is_err() {
assert!(parse_nonzero_usize("0").is_err());
}
#[test]
fn test_parse_nonzero_usize_non_numeric_is_err() {
assert!(parse_nonzero_usize("abc").is_err());
}
#[test]
fn test_parse_nonzero_usize_one_is_ok() {
assert_eq!(parse_nonzero_usize("1").unwrap(), 1);
}
#[test]
fn test_parse_nonzero_usize_hundred_is_ok() {
assert_eq!(parse_nonzero_usize("100").unwrap(), 100);
}
#[test]
fn test_expand_path_tilde_alone() {
let result = expand_path("~").unwrap();
assert_eq!(result, dirs::home_dir().unwrap());
}
#[test]
fn test_expand_path_tilde_slash() {
let result = expand_path("~/foo/bar").unwrap();
assert_eq!(result, dirs::home_dir().unwrap().join("foo/bar"));
}
#[test]
fn test_expand_path_absolute() {
let result = expand_path("/tmp/repo").unwrap();
assert_eq!(result, PathBuf::from("/tmp/repo"));
}
#[test]
fn test_expand_path_relative() {
let result = expand_path("relative/path").unwrap();
assert_eq!(result, PathBuf::from("relative/path"));
}
#[test]
fn test_expand_path_tilde_user_rejected() {
let result = expand_path("~otheruser/path");
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("not supported"),
"error should mention unsupported: {msg}"
);
}
#[cfg(feature = "k8s")]
#[test]
fn test_cli_auth_login_k8s_namespace_override() {
let cli = Cli::try_parse_from([
"memory-mcp",
"auth",
"login",
"--store",
"k8s-secret",
"--k8s-namespace",
"custom-ns",
"--k8s-secret-name",
"custom-name",
])
.expect("auth login with k8s flags should parse");
match cli.command {
Some(Command::Auth(auth_cmd)) => match auth_cmd.action {
AuthAction::Login(login_args) => {
assert!(matches!(login_args.store, Some(StoreBackend::K8sSecret)));
assert_eq!(login_args.k8s_namespace, "custom-ns");
assert_eq!(login_args.k8s_secret_name, "custom-name");
}
_ => panic!("expected Login action"),
},
_ => panic!("expected Auth command"),
}
}
}