use crate::config::{effective_mcp_auth_mode, effective_mcp_bearer_token};
use crate::mcp::{handle_mcp, handle_mcp_delete, handle_mcp_get, McpSession, McpState};
use crate::raw::RawOps;
use crate::wiki::WikiOps;
use anyhow::Result;
use axum::{
extract::State,
response::Response,
routing::{get, post},
Json, Router,
};
use serde_json::json;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
pub async fn run(config: crate::config::AppConfig) -> Result<()> {
let state = build_state(config.clone());
tokio::spawn(prune_sessions_task(
state.sessions.clone(),
Duration::from_secs(config.mcp.session_ttl_seconds.max(1)),
));
let app = build_app(state);
let addr: std::net::SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
tracing::info!("writestead daemon listening on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
let serve_result = axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await;
if let Err(err) = crate::daemon::cleanup_pid_file_if_current_process() {
tracing::warn!("failed to cleanup pid file on shutdown: {}", err);
}
serve_result?;
Ok(())
}
pub fn build_state(config: crate::config::AppConfig) -> McpState {
let sessions = Arc::new(RwLock::new(HashMap::<uuid::Uuid, McpSession>::new()));
let request_count = Arc::new(AtomicU64::new(0));
let tool_call_count = Arc::new(AtomicU64::new(0));
let tool_error_count = Arc::new(AtomicU64::new(0));
let tool_call_by_name = Arc::new(RwLock::new(HashMap::<String, u64>::new()));
let tool_error_by_name = Arc::new(RwLock::new(HashMap::<String, u64>::new()));
let raw_read_by_format = Arc::new(RwLock::new(HashMap::<String, u64>::new()));
let raw_read_failure_by_extractor = Arc::new(RwLock::new(HashMap::<String, u64>::new()));
let wiki = Arc::new(WikiOps::new(config.clone()));
let raw = Arc::new(RawOps::new(config.clone()));
McpState {
config,
wiki,
raw,
sessions,
server_version: env!("CARGO_PKG_VERSION").to_string(),
started_at: std::time::Instant::now(),
request_count,
tool_call_count,
tool_error_count,
tool_call_by_name,
tool_error_by_name,
raw_upload_count: Arc::new(AtomicU64::new(0)),
raw_upload_bytes_total: Arc::new(AtomicU64::new(0)),
raw_read_count: Arc::new(AtomicU64::new(0)),
raw_read_by_format,
raw_read_failure_by_extractor,
}
}
pub fn build_app(state: McpState) -> Router {
Router::new()
.route(
"/mcp",
post(handle_mcp)
.get(handle_mcp_get)
.delete(handle_mcp_delete),
)
.route("/health", get(health))
.route("/metrics", get(metrics))
.with_state(state)
.layer(tower_http::trace::TraceLayer::new_for_http())
}
async fn health(State(state): State<McpState>) -> Json<serde_json::Value> {
Json(json!({
"ok": true,
"name": state.config.name,
"vault_path": state.config.vault_path,
"sync_backend": state.config.sync.backend.to_string(),
"mcp_auth_mode": effective_mcp_auth_mode(&state.config).to_string(),
"mcp_has_bearer_token": effective_mcp_bearer_token(&state.config).is_some(),
"mcp_session_ttl_seconds": state.config.mcp.session_ttl_seconds,
"mcp_requests_total": state.request_count.load(Ordering::Relaxed),
"mcp_tool_calls_total": state.tool_call_count.load(Ordering::Relaxed),
"mcp_tool_errors_total": state.tool_error_count.load(Ordering::Relaxed),
"mcp_tool_calls_by_name": state.tool_call_by_name.read().await.clone(),
"mcp_tool_errors_by_name": state.tool_error_by_name.read().await.clone(),
"raw_uploads_total": state.raw_upload_count.load(Ordering::Relaxed),
"raw_upload_bytes_total": state.raw_upload_bytes_total.load(Ordering::Relaxed),
"raw_reads_total": state.raw_read_count.load(Ordering::Relaxed),
"raw_reads_by_format": state.raw_read_by_format.read().await.clone(),
"raw_read_failures_by_extractor": state.raw_read_failure_by_extractor.read().await.clone(),
"sync_metrics": crate::syncer::metrics_snapshot_json(),
"uptime_sec": state.started_at.elapsed().as_secs(),
"version": env!("CARGO_PKG_VERSION"),
}))
}
async fn metrics(State(state): State<McpState>) -> Response<String> {
let session_count = state.sessions.read().await.len();
let request_count = state.request_count.load(Ordering::Relaxed);
let tool_call_count = state.tool_call_count.load(Ordering::Relaxed);
let tool_error_count = state.tool_error_count.load(Ordering::Relaxed);
let by_tool = state.tool_call_by_name.read().await.clone();
let error_by_tool = state.tool_error_by_name.read().await.clone();
let raw_uploads_total = state.raw_upload_count.load(Ordering::Relaxed);
let raw_upload_bytes_total = state.raw_upload_bytes_total.load(Ordering::Relaxed);
let raw_reads_total = state.raw_read_count.load(Ordering::Relaxed);
let raw_reads_by_format = state.raw_read_by_format.read().await.clone();
let raw_read_failures_by_extractor = state.raw_read_failure_by_extractor.read().await.clone();
let sync_metrics = crate::syncer::metrics_snapshot();
let mut body = String::new();
body.push_str("# HELP writestead_uptime_seconds Daemon uptime in seconds\n");
body.push_str("# TYPE writestead_uptime_seconds gauge\n");
body.push_str(&format!(
"writestead_uptime_seconds {}\n",
state.started_at.elapsed().as_secs()
));
body.push_str("# HELP writestead_mcp_sessions_active Active MCP sessions\n");
body.push_str("# TYPE writestead_mcp_sessions_active gauge\n");
body.push_str(&format!(
"writestead_mcp_sessions_active {}\n",
session_count
));
body.push_str("# HELP writestead_mcp_requests_total Total MCP endpoint requests\n");
body.push_str("# TYPE writestead_mcp_requests_total counter\n");
body.push_str(&format!(
"writestead_mcp_requests_total {}\n",
request_count
));
body.push_str("# HELP writestead_mcp_tool_calls_total Total MCP tools/call requests\n");
body.push_str("# TYPE writestead_mcp_tool_calls_total counter\n");
body.push_str(&format!(
"writestead_mcp_tool_calls_total {}\n",
tool_call_count
));
body.push_str("# HELP writestead_mcp_tool_calls_by_tool_total MCP tool calls by tool name\n");
body.push_str("# TYPE writestead_mcp_tool_calls_by_tool_total counter\n");
let mut names: Vec<String> = by_tool.keys().cloned().collect();
names.sort();
for name in names {
if let Some(value) = by_tool.get(&name) {
body.push_str(&format!(
"writestead_mcp_tool_calls_by_tool_total{{tool=\"{}\"}} {}\n",
prometheus_escape_label(&name),
value
));
}
}
body.push_str("# HELP writestead_mcp_tool_errors_total Total MCP tool execution errors\n");
body.push_str("# TYPE writestead_mcp_tool_errors_total counter\n");
body.push_str(&format!(
"writestead_mcp_tool_errors_total {}\n",
tool_error_count
));
body.push_str(
"# HELP writestead_mcp_tool_errors_by_tool_total MCP tool execution errors by tool name\n",
);
body.push_str("# TYPE writestead_mcp_tool_errors_by_tool_total counter\n");
let mut error_names: Vec<String> = error_by_tool.keys().cloned().collect();
error_names.sort();
for name in error_names {
if let Some(value) = error_by_tool.get(&name) {
body.push_str(&format!(
"writestead_mcp_tool_errors_by_tool_total{{tool=\"{}\"}} {}\n",
prometheus_escape_label(&name),
value
));
}
}
body.push_str("# HELP writestead_raw_uploads_total Total raw uploads\n");
body.push_str("# TYPE writestead_raw_uploads_total counter\n");
body.push_str(&format!(
"writestead_raw_uploads_total {}\n",
raw_uploads_total
));
body.push_str("# HELP writestead_raw_upload_bytes_total Total raw upload bytes\n");
body.push_str("# TYPE writestead_raw_upload_bytes_total counter\n");
body.push_str(&format!(
"writestead_raw_upload_bytes_total {}\n",
raw_upload_bytes_total
));
body.push_str("# HELP writestead_raw_reads_total Total raw reads\n");
body.push_str("# TYPE writestead_raw_reads_total counter\n");
body.push_str(&format!("writestead_raw_reads_total {}\n", raw_reads_total));
body.push_str("# HELP writestead_raw_reads_by_format_total Raw reads by format\n");
body.push_str("# TYPE writestead_raw_reads_by_format_total counter\n");
let mut formats: Vec<String> = raw_reads_by_format.keys().cloned().collect();
formats.sort();
for format_name in formats {
if let Some(value) = raw_reads_by_format.get(&format_name) {
body.push_str(&format!(
"writestead_raw_reads_by_format_total{{format=\"{}\"}} {}\n",
prometheus_escape_label(&format_name),
value
));
}
}
body.push_str(
"# HELP writestead_raw_read_failures_by_extractor_total Raw read failures by extractor\n",
);
body.push_str("# TYPE writestead_raw_read_failures_by_extractor_total counter\n");
let mut extractor_names: Vec<String> = raw_read_failures_by_extractor.keys().cloned().collect();
extractor_names.sort();
for extractor in extractor_names {
if let Some(value) = raw_read_failures_by_extractor.get(&extractor) {
body.push_str(&format!(
"writestead_raw_read_failures_by_extractor_total{{extractor=\"{}\"}} {}\n",
prometheus_escape_label(&extractor),
value
));
}
}
body.push_str("# HELP writestead_sync_runs_total Total sync backend runs\n");
body.push_str("# TYPE writestead_sync_runs_total counter\n");
let mut sync_run_triggers: Vec<String> = sync_metrics.runs_by_trigger.keys().cloned().collect();
sync_run_triggers.sort();
for trigger in sync_run_triggers {
if let Some(value) = sync_metrics.runs_by_trigger.get(&trigger) {
body.push_str(&format!(
"writestead_sync_runs_total{{trigger=\"{}\"}} {}\n",
prometheus_escape_label(&trigger),
value
));
}
}
body.push_str("# HELP writestead_sync_errors_total Total sync backend errors\n");
body.push_str("# TYPE writestead_sync_errors_total counter\n");
let mut sync_error_triggers: Vec<String> =
sync_metrics.errors_by_trigger.keys().cloned().collect();
sync_error_triggers.sort();
for trigger in sync_error_triggers {
if let Some(value) = sync_metrics.errors_by_trigger.get(&trigger) {
body.push_str(&format!(
"writestead_sync_errors_total{{trigger=\"{}\"}} {}\n",
prometheus_escape_label(&trigger),
value
));
}
}
body.push_str("# HELP writestead_sync_duration_seconds Sync backend duration seconds\n");
body.push_str("# TYPE writestead_sync_duration_seconds summary\n");
body.push_str(&format!(
"writestead_sync_duration_seconds_sum {:.6}\n",
sync_metrics.duration_seconds_sum
));
body.push_str(&format!(
"writestead_sync_duration_seconds_count {}\n",
sync_metrics.duration_seconds_count
));
axum::response::Response::builder()
.header("content-type", "text/plain; version=0.0.4")
.body(body)
.unwrap()
}
async fn prune_sessions_task(
sessions: Arc<RwLock<HashMap<uuid::Uuid, McpSession>>>,
session_ttl: Duration,
) {
let prune_interval = Duration::from_secs((session_ttl.as_secs() / 12).clamp(1, 300));
loop {
tokio::time::sleep(prune_interval).await;
let now = std::time::Instant::now();
let mut guard = sessions.write().await;
guard.retain(|_, session| now.duration_since(session.created_at) < session_ttl);
}
}
fn prometheus_escape_label(input: &str) -> String {
input
.replace('\\', "\\\\")
.replace('"', "\\\"")
.replace('\n', "\\n")
}
async fn shutdown_signal() {
let ctrl_c = async {
let _ = tokio::signal::ctrl_c().await;
};
#[cfg(unix)]
let terminate = async {
if let Ok(mut signal) =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
{
let _ = signal.recv().await;
}
};
#[cfg(unix)]
{
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
}
#[cfg(not(unix))]
{
ctrl_c.await;
}
}