use anyhow::Result;
use serde_json::{json, Value};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::{broadcast, OnceCell};
use tracing::info;
use trusty_common::ChatProvider;
use trusty_mcp_core::{error_codes, initialize_response, Request, Response};
use trusty_memory_core::embed::FastEmbedder;
use trusty_memory_core::store::ChatSessionStore;
use trusty_memory_core::PalaceRegistry;
pub mod openrpc;
pub mod service;
pub mod tools;
pub mod web;
pub use service::MemoryMcpService;
pub use tools::MemoryMcpServer;
#[derive(Clone, Debug, serde::Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DaemonEvent {
PalaceCreated {
id: String,
name: String,
},
DrawerAdded {
palace_id: String,
drawer_count: usize,
},
DrawerDeleted {
palace_id: String,
drawer_count: usize,
},
DreamCompleted {
palace_id: Option<String>,
merged: usize,
pruned: usize,
compacted: usize,
closets_updated: usize,
duration_ms: u64,
},
StatusChanged {
total_drawers: usize,
total_vectors: usize,
total_kg_triples: usize,
},
}
#[derive(Clone)]
pub struct AppState {
pub version: String,
pub registry: Arc<PalaceRegistry>,
pub data_root: PathBuf,
pub embedder: Arc<OnceCell<Arc<FastEmbedder>>>,
pub default_palace: Option<String>,
pub chat_provider: Arc<OnceCell<Option<Arc<dyn ChatProvider>>>>,
pub session_stores: Arc<dashmap::DashMap<String, Arc<ChatSessionStore>>>,
pub events: Arc<broadcast::Sender<DaemonEvent>>,
}
impl AppState {
pub fn new(data_root: PathBuf) -> Self {
let (events_tx, _) = broadcast::channel::<DaemonEvent>(128);
Self {
version: env!("CARGO_PKG_VERSION").to_string(),
registry: Arc::new(PalaceRegistry::new()),
data_root,
embedder: Arc::new(OnceCell::new()),
default_palace: None,
chat_provider: Arc::new(OnceCell::new()),
session_stores: Arc::new(dashmap::DashMap::new()),
events: Arc::new(events_tx),
}
}
pub fn emit(&self, event: DaemonEvent) {
let _ = self.events.send(event);
}
pub fn session_store(&self, palace_id: &str) -> Result<Arc<ChatSessionStore>> {
if let Some(entry) = self.session_stores.get(palace_id) {
return Ok(entry.clone());
}
let dir = self.data_root.join(palace_id);
std::fs::create_dir_all(&dir)
.map_err(|e| anyhow::anyhow!("create palace dir {}: {e}", dir.display()))?;
let store = Arc::new(ChatSessionStore::open(&dir.join("chat_sessions.db"))?);
self.session_stores
.insert(palace_id.to_string(), store.clone());
Ok(store)
}
pub fn with_default_palace(mut self, name: Option<String>) -> Self {
self.default_palace = name;
self
}
pub async fn chat_provider(&self) -> Option<Arc<dyn ChatProvider>> {
self.chat_provider
.get_or_init(|| async {
let cfg = crate::web::load_user_config().unwrap_or_default();
if cfg.local_model.enabled {
if let Some(mut p) =
trusty_common::auto_detect_local_provider(&cfg.local_model.base_url).await
{
p.model = cfg.local_model.model.clone();
return Some(Arc::new(p) as Arc<dyn ChatProvider>);
}
}
if !cfg.openrouter_api_key.is_empty() {
return Some(Arc::new(trusty_common::OpenRouterProvider::new(
cfg.openrouter_api_key,
cfg.openrouter_model,
)) as Arc<dyn ChatProvider>);
}
None
})
.await
.clone()
}
pub async fn embedder(&self) -> Result<Arc<FastEmbedder>> {
let cell = self.embedder.clone();
let embedder = cell
.get_or_try_init(|| async {
let e = FastEmbedder::new().await?;
Ok::<Arc<FastEmbedder>, anyhow::Error>(Arc::new(e))
})
.await?
.clone();
Ok(embedder)
}
}
impl std::fmt::Debug for AppState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AppState")
.field("version", &self.version)
.field("data_root", &self.data_root)
.field("registry_len", &self.registry.len())
.finish()
}
}
pub async fn handle_message(state: &AppState, msg: Value) -> Value {
let id = msg.get("id").cloned().unwrap_or(Value::Null);
let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
match method {
"initialize" => {
let extra = state
.default_palace
.as_ref()
.map(|dp| json!({ "default_palace": dp }));
let result = initialize_response("trusty-memory", &state.version, extra);
json!({
"jsonrpc": "2.0",
"id": id,
"result": result,
})
}
"notifications/initialized" | "notifications/cancelled" => Value::Null,
"tools/list" => json!({
"jsonrpc": "2.0",
"id": id,
"result": tools::tool_definitions_with(state.default_palace.is_some())
}),
"rpc.discover" => json!({
"jsonrpc": "2.0",
"id": id,
"result": openrpc::build_discover_response(
&state.version,
state.default_palace.is_some(),
),
}),
"tools/call" => {
let params = msg.get("params").cloned().unwrap_or_default();
let tool_name = params
.get("name")
.and_then(|n| n.as_str())
.unwrap_or("")
.to_string();
let args = params.get("arguments").cloned().unwrap_or_default();
match tools::dispatch_tool(state, &tool_name, args).await {
Ok(content) => json!({
"jsonrpc": "2.0",
"id": id,
"result": {
"content": [{"type": "text", "text": content.to_string()}]
}
}),
Err(e) => json!({
"jsonrpc": "2.0",
"id": id,
"error": {"code": -32603, "message": e.to_string()}
}),
}
}
"ping" => json!({"jsonrpc": "2.0", "id": id, "result": {}}),
_ => json!({
"jsonrpc": "2.0",
"id": id,
"error": {
"code": -32601,
"message": format!("Method not found: {method}")
}
}),
}
}
pub async fn run_stdio(state: AppState) -> Result<()> {
info!("trusty-memory MCP stdio server starting");
let state = Arc::new(state);
trusty_mcp_core::run_stdio_loop(move |req: Request| {
let state = state.clone();
async move {
let msg = json!({
"jsonrpc": req.jsonrpc.unwrap_or_else(|| "2.0".to_string()),
"id": req.id.clone().unwrap_or(Value::Null),
"method": req.method,
"params": req.params.unwrap_or(Value::Null),
});
let resp_value = handle_message(&state, msg).await;
if resp_value.is_null() {
return Response::suppressed();
}
let id = resp_value.get("id").cloned();
if let Some(result) = resp_value.get("result").cloned() {
Response::ok(id, result)
} else if let Some(err) = resp_value.get("error") {
let code =
err.get("code")
.and_then(|c| c.as_i64())
.unwrap_or(error_codes::INTERNAL_ERROR as i64) as i32;
let message = err
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("internal error")
.to_string();
Response::err(id, code, message)
} else {
Response::err(
id,
error_codes::INTERNAL_ERROR,
"malformed handler response",
)
}
}
})
.await
}
pub async fn run_http_on(state: AppState, listener: tokio::net::TcpListener) -> Result<()> {
use axum::routing::get;
let app = web::router()
.route("/sse", get(sse_handler))
.with_state(state);
let local = listener.local_addr().ok();
if let Some(a) = local {
info!("HTTP server listening on http://{a}");
eprintln!("HTTP server listening on http://{a}");
}
axum::serve(listener, app).await?;
Ok(())
}
pub async fn run_http(state: AppState, addr: std::net::SocketAddr) -> Result<()> {
let listener = tokio::net::TcpListener::bind(addr).await?;
run_http_on(state, listener).await
}
pub(crate) async fn sse_handler(
axum::extract::State(state): axum::extract::State<AppState>,
) -> impl axum::response::IntoResponse {
use futures::StreamExt;
use tokio_stream::wrappers::BroadcastStream;
let rx = state.events.subscribe();
let initial = futures::stream::once(async {
Ok::<axum::body::Bytes, std::io::Error>(axum::body::Bytes::from(
"data: {\"type\":\"connected\"}\n\n",
))
});
let events = BroadcastStream::new(rx).map(|res| {
let frame = match res {
Ok(event) => match serde_json::to_string(&event) {
Ok(json) => format!("data: {json}\n\n"),
Err(e) => format!("data: {{\"type\":\"error\",\"message\":\"{e}\"}}\n\n"),
},
Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => {
format!("data: {{\"type\":\"lag\",\"skipped\":{n}}}\n\n")
}
};
Ok::<axum::body::Bytes, std::io::Error>(axum::body::Bytes::from(frame))
});
let stream = initial.chain(events);
axum::response::Response::builder()
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("X-Accel-Buffering", "no")
.body(axum::body::Body::from_stream(stream))
.expect("valid SSE response")
}
#[cfg(test)]
mod tests {
use super::*;
fn test_state() -> AppState {
let tmp = tempfile::tempdir().expect("tempdir");
let root = tmp.path().to_path_buf();
std::mem::forget(tmp);
AppState::new(root)
}
#[tokio::test]
async fn initialize_returns_protocol_version_and_capabilities() {
let state = test_state();
let req = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test", "version": "0"}
}
});
let resp = handle_message(&state, req).await;
assert_eq!(resp["jsonrpc"], "2.0");
assert_eq!(resp["id"], 1);
assert_eq!(resp["result"]["protocolVersion"], "2024-11-05");
assert!(resp["result"]["capabilities"]["tools"].is_object());
assert_eq!(resp["result"]["serverInfo"]["name"], "trusty-memory");
}
#[tokio::test]
async fn initialized_notification_returns_null() {
let state = test_state();
let req = json!({
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {}
});
let resp = handle_message(&state, req).await;
assert!(resp.is_null());
}
#[tokio::test]
async fn tools_list_returns_all_tools() {
let state = test_state();
let req = json!({"jsonrpc": "2.0", "id": 2, "method": "tools/list"});
let resp = handle_message(&state, req).await;
let tools = resp["result"]["tools"].as_array().expect("tools array");
assert_eq!(tools.len(), 12);
}
#[tokio::test]
async fn unknown_method_returns_error() {
let state = test_state();
let req = json!({"jsonrpc": "2.0", "id": 4, "method": "wat"});
let resp = handle_message(&state, req).await;
assert_eq!(resp["error"]["code"], -32601);
}
#[tokio::test]
async fn ping_returns_empty_result() {
let state = test_state();
let req = json!({"jsonrpc": "2.0", "id": 5, "method": "ping"});
let resp = handle_message(&state, req).await;
assert!(resp["result"].is_object());
}
#[tokio::test]
async fn app_state_default_constructs() {
let s = test_state();
assert!(!s.version.is_empty());
assert!(s.registry.is_empty());
assert!(s.default_palace.is_none());
}
#[tokio::test]
async fn default_palace_used_when_arg_omitted() {
let tmp = tempfile::tempdir().expect("tempdir");
let root = tmp.path().to_path_buf();
let registry = trusty_memory_core::PalaceRegistry::new();
let palace = trusty_memory_core::Palace {
id: trusty_memory_core::PalaceId::new("default-pal"),
name: "default-pal".to_string(),
description: None,
created_at: chrono::Utc::now(),
data_dir: root.join("default-pal"),
};
registry
.create_palace(&root, palace)
.expect("create_palace");
let state = AppState::new(root).with_default_palace(Some("default-pal".to_string()));
let init = handle_message(
&state,
json!({"jsonrpc": "2.0", "id": 1, "method": "initialize"}),
)
.await;
assert_eq!(
init["result"]["serverInfo"]["default_palace"], "default-pal",
"initialize must echo default_palace in serverInfo"
);
let list = handle_message(
&state,
json!({"jsonrpc": "2.0", "id": 2, "method": "tools/list"}),
)
.await;
let tools = list["result"]["tools"].as_array().expect("tools array");
let remember = tools
.iter()
.find(|t| t["name"] == "memory_remember")
.expect("memory_remember tool");
let required: Vec<&str> = remember["inputSchema"]["required"]
.as_array()
.expect("required array")
.iter()
.filter_map(|v| v.as_str())
.collect();
assert!(
!required.contains(&"palace"),
"palace must not be required when default is configured; got {required:?}"
);
assert!(required.contains(&"text"));
let call = handle_message(
&state,
json!({
"jsonrpc": "2.0",
"id": 3,
"method": "tools/call",
"params": {
"name": "memory_remember",
"arguments": {"text": "default-palace test memory"},
},
}),
)
.await;
let text = call["result"]["content"][0]["text"]
.as_str()
.unwrap_or_else(|| panic!("expected success result, got {call}"));
let parsed: Value = serde_json::from_str(text).expect("parse content json");
assert_eq!(parsed["palace"], "default-pal");
assert_eq!(parsed["status"], "stored");
assert!(parsed["drawer_id"].as_str().is_some());
}
#[tokio::test]
async fn missing_palace_without_default_errors() {
let state = test_state();
let resp = handle_message(
&state,
json!({
"jsonrpc": "2.0",
"id": 7,
"method": "tools/call",
"params": {
"name": "memory_recall",
"arguments": {"query": "anything"},
},
}),
)
.await;
assert_eq!(resp["error"]["code"], -32603);
let msg = resp["error"]["message"].as_str().unwrap_or("");
assert!(
msg.contains("missing 'palace'"),
"expected helpful error, got: {msg}"
);
}
#[tokio::test]
async fn initialize_without_default_palace_omits_field() {
let state = test_state();
let init = handle_message(
&state,
json!({"jsonrpc": "2.0", "id": 1, "method": "initialize"}),
)
.await;
assert!(init["result"]["serverInfo"]["default_palace"].is_null());
}
}