use anyhow::Result;
use serde_json::{json, Value};
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, OnceLock};
use tokio::sync::{broadcast, OnceCell, RwLock};
use trusty_common::bm25_client::Bm25Client;
use trusty_common::mcp::initialize_response;
use trusty_common::memory_core::embed::FastEmbedder;
use trusty_common::memory_core::store::ChatSessionStore;
use trusty_common::memory_core::PalaceRegistry;
use trusty_common::ChatProvider;
#[cfg(feature = "axum-server")]
use tracing::info;
pub mod activity;
pub mod attribution;
pub mod bm25_supervisor;
pub mod bootstrap;
#[cfg(feature = "axum-server")]
pub mod chat;
pub mod commands;
pub mod discovery;
pub mod hook_emit;
pub mod kg_extract;
pub mod mcp_service;
pub mod messaging;
pub mod openrpc;
pub mod prompt_facts;
pub mod prompt_log;
pub mod service;
pub mod tools;
pub mod transport;
#[cfg(feature = "axum-server")]
pub mod web;
pub use activity::{ActivityEntry, ActivityFilter, ActivityLog, ActivitySource};
pub use attribution::{CreatorInfo, CreatorSource};
pub const HOOK_PROMPT_EXCERPT_CHARS: usize = 80;
pub fn hook_prompt_excerpt(prompt: &str) -> String {
let normalised: String = prompt.split_whitespace().collect::<Vec<_>>().join(" ");
if normalised.chars().count() <= HOOK_PROMPT_EXCERPT_CHARS {
normalised
} else {
let kept: String = normalised
.chars()
.take(HOOK_PROMPT_EXCERPT_CHARS.saturating_sub(1))
.collect();
format!("{kept}…")
}
}
pub use mcp_service::MemoryMcpService;
pub use tools::MemoryMcpServer;
pub fn resolve_palace_registry_dir(data_dir: PathBuf) -> PathBuf {
let nested = data_dir.join("palaces");
if nested.is_dir() {
nested
} else {
data_dir
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum HookType {
UserPromptSubmit,
SessionStart,
}
impl HookType {
pub fn as_str(&self) -> &'static str {
match self {
Self::UserPromptSubmit => "UserPromptSubmit",
Self::SessionStart => "SessionStart",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum InjectionKind {
PromptContext,
InboxCheck,
}
impl InjectionKind {
pub fn as_str(&self) -> &'static str {
match self {
Self::PromptContext => "prompt-context",
Self::InboxCheck => "inbox-check",
}
}
}
#[derive(Clone, Debug, serde::Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DaemonEvent {
PalaceCreated {
id: String,
name: String,
source: ActivitySource,
},
DrawerAdded {
palace_id: String,
#[serde(default)]
palace_name: String,
drawer_count: usize,
timestamp: chrono::DateTime<chrono::Utc>,
#[serde(default)]
content_preview: String,
source: ActivitySource,
},
DrawerDeleted {
palace_id: String,
drawer_count: usize,
source: ActivitySource,
},
DreamCompleted {
palace_id: Option<String>,
merged: usize,
pruned: usize,
compacted: usize,
closets_updated: usize,
duration_ms: u64,
source: ActivitySource,
},
StatusChanged {
total_drawers: usize,
total_vectors: usize,
total_kg_triples: usize,
},
HookFired {
#[serde(default)]
palace_id: Option<String>,
#[serde(default)]
palace_name: Option<String>,
hook_type: HookType,
injection_kind: InjectionKind,
injection_length: u64,
#[serde(default)]
trigger_prompt_excerpt: String,
timestamp: chrono::DateTime<chrono::Utc>,
duration_ms: u64,
source: ActivitySource,
},
}
fn open_activity_log_with_fallback(data_root: &Path) -> Arc<ActivityLog> {
match ActivityLog::open(data_root) {
Ok(log) => Arc::new(log),
Err(primary_err) => {
tracing::warn!(
"could not open activity log at {}: {primary_err:#}; falling back to per-process tempdir",
data_root.display()
);
let fallback =
std::env::temp_dir().join(format!("trusty-memory-activity-{}", std::process::id()));
match ActivityLog::open(&fallback) {
Ok(log) => Arc::new(log),
Err(fallback_err) => {
tracing::warn!(
"activity log tempdir fallback at {} also failed: {fallback_err:#}; \
activity feed disabled for this process (no-op log)",
fallback.display()
);
Arc::new(ActivityLog::discard())
}
}
}
}
}
impl DaemonEvent {
pub fn type_str(&self) -> &'static str {
match self {
Self::PalaceCreated { .. } => "palace_created",
Self::DrawerAdded { .. } => "drawer_added",
Self::DrawerDeleted { .. } => "drawer_deleted",
Self::DreamCompleted { .. } => "dream_completed",
Self::StatusChanged { .. } => "status_changed",
Self::HookFired { .. } => "hook_fired",
}
}
pub fn palace_id(&self) -> Option<&str> {
match self {
Self::PalaceCreated { id, .. } => Some(id),
Self::DrawerAdded { palace_id, .. } | Self::DrawerDeleted { palace_id, .. } => {
Some(palace_id)
}
Self::DreamCompleted { palace_id, .. } => palace_id.as_deref(),
Self::HookFired { palace_id, .. } => palace_id.as_deref(),
Self::StatusChanged { .. } => None,
}
}
pub fn source(&self) -> Option<ActivitySource> {
match self {
Self::PalaceCreated { source, .. }
| Self::DrawerAdded { source, .. }
| Self::DrawerDeleted { source, .. }
| Self::DreamCompleted { source, .. }
| Self::HookFired { source, .. } => Some(*source),
Self::StatusChanged { .. } => None,
}
}
}
#[derive(Clone)]
pub struct AppState {
pub version: String,
pub registry: Arc<PalaceRegistry>,
pub data_root: PathBuf,
pub embedder: Arc<OnceCell<Arc<FastEmbedder>>>,
pub default_palace: Option<String>,
pub chat_provider: Arc<OnceCell<Option<Arc<dyn ChatProvider>>>>,
pub session_stores: Arc<dashmap::DashMap<String, Arc<ChatSessionStore>>>,
pub events: Arc<broadcast::Sender<DaemonEvent>>,
pub started_at: std::time::Instant,
pub log_buffer: trusty_common::log_buffer::LogBuffer,
pub disk_bytes: Arc<std::sync::atomic::AtomicU64>,
pub sys_metrics: Arc<tokio::sync::Mutex<trusty_common::sys_metrics::SysMetrics>>,
pub bound_addr: Arc<OnceLock<SocketAddr>>,
pub prompt_context_cache: Arc<RwLock<prompt_facts::PromptFactsCache>>,
pub activity_log: Arc<ActivityLog>,
pub bm25_client: Option<Arc<Bm25Client>>,
pub bm25_supervisor: Option<Arc<bm25_supervisor::Bm25Supervisor>>,
pub palace_write_locks: Arc<dashmap::DashMap<String, Arc<tokio::sync::Mutex<()>>>>,
pub pending_activity_writes: Arc<AtomicUsize>,
pub palace_names: Arc<dashmap::DashMap<String, String>>,
pub bm25_index_tx: tokio::sync::mpsc::Sender<tools::Bm25IndexRequest>,
}
impl AppState {
pub fn new(data_root: PathBuf) -> Self {
let (events_tx, _) = broadcast::channel::<DaemonEvent>(128);
let activity_log = open_activity_log_with_fallback(&data_root);
let (bm25_index_tx, bm25_index_rx) =
tokio::sync::mpsc::channel::<tools::Bm25IndexRequest>(tools::BM25_INDEX_QUEUE_CAPACITY);
tools::spawn_bm25_index_worker(bm25_index_rx, None, None);
Self {
version: env!("CARGO_PKG_VERSION").to_string(),
registry: Arc::new(PalaceRegistry::new()),
data_root,
embedder: Arc::new(OnceCell::new()),
default_palace: None,
chat_provider: Arc::new(OnceCell::new()),
session_stores: Arc::new(dashmap::DashMap::new()),
events: Arc::new(events_tx),
started_at: std::time::Instant::now(),
log_buffer: trusty_common::log_buffer::LogBuffer::new(
trusty_common::log_buffer::DEFAULT_LOG_CAPACITY,
),
disk_bytes: Arc::new(std::sync::atomic::AtomicU64::new(0)),
sys_metrics: Arc::new(tokio::sync::Mutex::new(
trusty_common::sys_metrics::SysMetrics::new(),
)),
bound_addr: Arc::new(OnceLock::new()),
prompt_context_cache: Arc::new(RwLock::new(prompt_facts::PromptFactsCache::default())),
activity_log,
bm25_client: None,
bm25_supervisor: None,
palace_write_locks: Arc::new(dashmap::DashMap::new()),
pending_activity_writes: Arc::new(AtomicUsize::new(0)),
palace_names: Arc::new(dashmap::DashMap::new()),
bm25_index_tx,
}
}
pub fn palace_write_lock(&self, palace_id: &str) -> Arc<tokio::sync::Mutex<()>> {
if let Some(existing) = self.palace_write_locks.get(palace_id) {
return existing.clone();
}
self.palace_write_locks
.entry(palace_id.to_string())
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.clone()
}
#[must_use]
pub fn with_bm25_client_from_env(mut self) -> Self {
if std::env::var("TRUSTY_BM25_DAEMON").as_deref() == Ok("1") {
let default_palace = self.default_palace.as_deref().unwrap_or("default");
self.bm25_client = Some(Arc::new(Bm25Client::for_palace(default_palace)));
self.bm25_supervisor = Some(Arc::new(bm25_supervisor::Bm25Supervisor::new()));
let (tx, rx) = tokio::sync::mpsc::channel::<tools::Bm25IndexRequest>(
tools::BM25_INDEX_QUEUE_CAPACITY,
);
tools::spawn_bm25_index_worker(
rx,
self.bm25_client.clone(),
self.bm25_supervisor.clone(),
);
self.bm25_index_tx = tx;
tracing::info!(
palace = default_palace,
"BM25 daemon client + spawn supervisor enabled (TRUSTY_BM25_DAEMON=1)"
);
}
self
}
pub async fn load_palaces_from_disk(&self) -> Result<usize> {
let registry_dir = self.data_root.clone();
let registry = self.registry.clone();
let palace_names = self.palace_names.clone();
let count = tokio::task::spawn_blocking(move || -> Result<usize> {
let palaces = PalaceRegistry::list_palaces(®istry_dir)?;
let total = palaces.len();
let mut loaded = 0usize;
let mut skipped = 0usize;
for palace in palaces {
match trusty_common::memory_core::PalaceHandle::open(&palace) {
Ok(handle) => {
tracing::debug!(
palace = %palace.id,
data_dir = %palace.data_dir.display(),
"loaded palace from disk"
);
palace_names.insert(palace.id.0.clone(), palace.name.clone());
registry.register_arc(handle);
loaded += 1;
}
Err(e) => {
tracing::warn!(
palace = %palace.id,
data_dir = %palace.data_dir.display(),
"skipping palace during startup hydration: {e:#}"
);
skipped += 1;
}
}
}
tracing::info!(
"palace hydration summary: loaded {loaded}/{total} ({skipped} skipped due to errors)"
);
Ok(loaded)
})
.await
.map_err(|e| anyhow::anyhow!("join load_palaces_from_disk: {e}"))??;
Ok(count)
}
#[must_use]
pub fn with_log_buffer(mut self, buffer: trusty_common::log_buffer::LogBuffer) -> Self {
self.log_buffer = buffer;
self
}
pub fn emit(&self, event: DaemonEvent) {
if let Some(source) = event.source() {
let event_type = event.type_str();
let palace_id = event.palace_id().map(|s| s.to_string());
let log = Arc::clone(&self.activity_log);
let event_for_log = event.clone();
let pending = Arc::clone(&self.pending_activity_writes);
let id = log.alloc_id();
pending.fetch_add(1, Ordering::SeqCst);
tokio::task::spawn_blocking(move || {
let result = log.append_with_id(id, source, palace_id, event_type, &event_for_log);
if let Err(e) = result {
tracing::warn!("activity_log.append failed for {event_type}: {e:#}");
}
pending.fetch_sub(1, Ordering::SeqCst);
});
}
let _ = self.events.send(event);
}
pub async fn flush_activity_writes(&self) {
while self.pending_activity_writes.load(Ordering::SeqCst) > 0 {
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
}
}
pub fn session_store(&self, palace_id: &str) -> Result<Arc<ChatSessionStore>> {
if let Some(entry) = self.session_stores.get(palace_id) {
return Ok(entry.clone());
}
let dir = self.data_root.join(palace_id);
std::fs::create_dir_all(&dir)
.map_err(|e| anyhow::anyhow!("create palace dir {}: {e}", dir.display()))?;
let store = Arc::new(ChatSessionStore::open(&dir.join("chat_sessions.db"))?);
self.session_stores
.insert(palace_id.to_string(), store.clone());
Ok(store)
}
pub fn with_default_palace(mut self, name: Option<String>) -> Self {
self.default_palace = name;
self
}
pub async fn chat_provider(&self) -> Option<Arc<dyn ChatProvider>> {
self.chat_provider
.get_or_init(|| async {
let cfg = crate::service::load_user_config().unwrap_or_default();
if cfg.local_model.enabled {
if let Some(mut p) =
trusty_common::auto_detect_local_provider(&cfg.local_model.base_url).await
{
p.model = cfg.local_model.model.clone();
return Some(Arc::new(p) as Arc<dyn ChatProvider>);
}
}
if !cfg.openrouter_api_key.is_empty() {
return Some(Arc::new(trusty_common::OpenRouterProvider::new(
cfg.openrouter_api_key,
cfg.openrouter_model,
)) as Arc<dyn ChatProvider>);
}
None
})
.await
.clone()
}
pub fn spawn_alias_discovery(&self, palace: String, project_root: PathBuf) {
let state = self.clone();
tokio::spawn(async move {
let args = serde_json::json!({
"palace": palace,
"project_root": project_root.to_string_lossy(),
});
match tools::dispatch_tool(&state, "discover_aliases", args).await {
Ok(result) => tracing::info!(
new = ?result.get("new"),
already_known = ?result.get("already_known"),
"alias discovery complete"
),
Err(e) => tracing::warn!("alias discovery failed: {e:#}"),
}
});
}
pub async fn embedder(&self) -> Result<Arc<FastEmbedder>> {
let cell = self.embedder.clone();
let embedder = cell
.get_or_try_init(|| async {
let e = FastEmbedder::new().await?;
Ok::<Arc<FastEmbedder>, anyhow::Error>(Arc::new(e))
})
.await?
.clone();
Ok(embedder)
}
}
impl std::fmt::Debug for AppState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AppState")
.field("version", &self.version)
.field("data_root", &self.data_root)
.field("registry_len", &self.registry.len())
.finish()
}
}
pub async fn handle_message(state: &AppState, msg: Value) -> Value {
let id = msg.get("id").cloned().unwrap_or(Value::Null);
let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
match method {
"initialize" => {
let extra = state
.default_palace
.as_ref()
.map(|dp| json!({ "default_palace": dp }));
let result = initialize_response("trusty-memory", &state.version, extra);
json!({
"jsonrpc": "2.0",
"id": id,
"result": result,
})
}
"notifications/initialized" | "notifications/cancelled" => Value::Null,
"tools/list" => json!({
"jsonrpc": "2.0",
"id": id,
"result": tools::tool_definitions_with(state.default_palace.is_some())
}),
"rpc.discover" => json!({
"jsonrpc": "2.0",
"id": id,
"result": openrpc::build_discover_response(
&state.version,
state.default_palace.is_some(),
),
}),
"tools/call" => {
let params = msg.get("params").cloned().unwrap_or_default();
let tool_name = params
.get("name")
.and_then(|n| n.as_str())
.unwrap_or("")
.to_string();
let args = params.get("arguments").cloned().unwrap_or_default();
match tools::dispatch_tool(state, &tool_name, args).await {
Ok(content) => {
let text = match &content {
Value::String(s) => s.clone(),
other => other.to_string(),
};
json!({
"jsonrpc": "2.0",
"id": id,
"result": {
"content": [{"type": "text", "text": text}]
}
})
}
Err(e) => json!({
"jsonrpc": "2.0",
"id": id,
"error": {"code": -32603, "message": format!("{e:#}")}
}),
}
}
"ping" => json!({"jsonrpc": "2.0", "id": id, "result": {}}),
_ => json!({
"jsonrpc": "2.0",
"id": id,
"error": {
"code": -32601,
"message": format!("Method not found: {method}")
}
}),
}
}
pub const DEFAULT_HTTP_PORT: u16 = 7070;
const DYNAMIC_PORT_RANGE: u16 = 10;
pub fn http_addr_path() -> Option<PathBuf> {
trusty_common::resolve_data_dir("trusty-memory")
.ok()
.map(|d| d.join("http_addr"))
}
pub async fn bind_dynamic_port() -> Result<tokio::net::TcpListener> {
let preferred: SocketAddr = SocketAddr::from(([127, 0, 0, 1], DEFAULT_HTTP_PORT));
if let Ok(listener) =
trusty_common::bind_with_auto_port(preferred, DYNAMIC_PORT_RANGE - 1).await
{
return Ok(listener);
}
tracing::warn!(
"all ports {DEFAULT_HTTP_PORT}..{} in use; requesting OS-assigned port",
DEFAULT_HTTP_PORT + DYNAMIC_PORT_RANGE - 1
);
let any: SocketAddr = SocketAddr::from(([127, 0, 0, 1], 0));
trusty_common::bind_with_auto_port(any, 0).await
}
#[cfg(feature = "axum-server")]
fn write_http_addr_file(path: &Path, addr: &SocketAddr) -> std::io::Result<()> {
use std::io::Write;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let tmp = path.with_extension("addr.tmp");
{
let mut f = std::fs::File::create(&tmp)?;
writeln!(f, "{addr}")?;
f.sync_all()?;
}
std::fs::rename(&tmp, path)?;
Ok(())
}
#[cfg(feature = "axum-server")]
pub async fn run_http_on(state: AppState, listener: tokio::net::TcpListener) -> Result<()> {
use axum::routing::get;
spawn_disk_size_ticker(state.clone());
spawn_status_event_ticker(state.clone());
let local = listener.local_addr().ok();
let written_path = if let Some(a) = local {
let _ = state.bound_addr.set(a);
info!("HTTP server listening on http://{a}");
eprintln!("HTTP server listening on http://{a}");
match http_addr_path() {
Some(p) => match write_http_addr_file(&p, &a) {
Ok(()) => {
info!("wrote daemon address to {}", p.display());
Some(p)
}
Err(e) => {
tracing::warn!("could not write {}: {e}", p.display());
None
}
},
None => {
tracing::warn!("no $HOME — skipping http_addr discovery file");
None
}
}
} else {
None
};
let uds_sock_path = spawn_uds_listener(state.clone()).await;
let bm25_supervisor = state.bm25_supervisor.clone();
let app = web::router()
.route("/sse", get(sse_handler))
.with_state(state);
let serve_result = axum::serve(listener, app).await;
if let Some(p) = written_path.as_ref() {
let _ = std::fs::remove_file(p);
}
if let Some(p) = uds_sock_path.as_ref() {
let _ = std::fs::remove_file(p);
}
if let Some(supervisor) = bm25_supervisor {
supervisor.shutdown().await;
}
serve_result?;
Ok(())
}
#[cfg(feature = "axum-server")]
async fn spawn_uds_listener(state: AppState) -> Option<PathBuf> {
let sock_path = transport::uds::socket_path_for(&state.data_root);
let listener = match transport::uds::bind_uds(&sock_path).await {
Ok(l) => l,
Err(e) => {
tracing::warn!(
"UDS bind at {} failed: {e:#}; continuing without UDS transport",
sock_path.display()
);
return None;
}
};
info!("UDS listener bound at {}", sock_path.display());
eprintln!("UDS listener bound at {}", sock_path.display());
if let Err(e) = transport::uds::write_uds_addr_file(&state.data_root, &sock_path) {
tracing::warn!(
"could not write {}/{}: {e:#}",
state.data_root.display(),
transport::uds::UDS_ADDR_FILE
);
}
let task_state = state.clone();
tokio::spawn(async move {
if let Err(e) = transport::uds::run_uds(task_state, listener).await {
tracing::error!("UDS accept loop exited: {e:#}");
}
});
Some(sock_path)
}
#[cfg(feature = "axum-server")]
pub async fn run_http(state: AppState, addr: std::net::SocketAddr) -> Result<()> {
let listener = tokio::net::TcpListener::bind(addr).await?;
run_http_on(state, listener).await
}
#[cfg(feature = "axum-server")]
pub async fn run_http_dynamic(state: AppState) -> Result<()> {
let listener = bind_dynamic_port().await?;
run_http_on(state, listener).await
}
#[cfg(feature = "axum-server")]
fn spawn_disk_size_ticker(state: AppState) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(10));
loop {
interval.tick().await;
let dir = state.data_root.clone();
let bytes = tokio::task::spawn_blocking(move || {
trusty_common::sys_metrics::dir_size_bytes(&dir)
})
.await
.unwrap_or(0);
state
.disk_bytes
.store(bytes, std::sync::atomic::Ordering::Relaxed);
}
});
}
const STATUS_EVENT_TICK_SECS: u64 = 30;
fn spawn_status_event_ticker(state: AppState) {
tokio::spawn(async move {
let mut interval =
tokio::time::interval(std::time::Duration::from_secs(STATUS_EVENT_TICK_SECS));
loop {
interval.tick().await;
let event = service::MemoryService::new(state.clone()).aggregate_status_event();
state.emit(event);
}
});
}
#[cfg(feature = "axum-server")]
pub(crate) async fn sse_handler(
axum::extract::State(state): axum::extract::State<AppState>,
) -> impl axum::response::IntoResponse {
use futures::StreamExt;
use tokio_stream::wrappers::BroadcastStream;
let rx = state.events.subscribe();
let initial = futures::stream::once(async {
Ok::<axum::body::Bytes, std::io::Error>(axum::body::Bytes::from(
"data: {\"type\":\"connected\"}\n\n",
))
});
let events = BroadcastStream::new(rx).map(|res| {
let frame = match res {
Ok(event) => match serde_json::to_string(&event) {
Ok(json) => format!("data: {json}\n\n"),
Err(e) => format!("data: {{\"type\":\"error\",\"message\":\"{e}\"}}\n\n"),
},
Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => {
format!("data: {{\"type\":\"lag\",\"skipped\":{n}}}\n\n")
}
};
Ok::<axum::body::Bytes, std::io::Error>(axum::body::Bytes::from(frame))
});
let stream = initial.chain(events);
axum::response::Response::builder()
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("X-Accel-Buffering", "no")
.body(axum::body::Body::from_stream(stream))
.expect("valid SSE response")
}
#[cfg(test)]
mod tests {
use super::*;
fn test_state() -> (AppState, tempfile::TempDir) {
let tmp = tempfile::tempdir().expect("tempdir");
let root = tmp.path().to_path_buf();
(AppState::new(root), tmp)
}
#[tokio::test]
async fn initialize_returns_protocol_version_and_capabilities() {
let (state, _tmp) = test_state();
let req = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test", "version": "0"}
}
});
let resp = handle_message(&state, req).await;
assert_eq!(resp["jsonrpc"], "2.0");
assert_eq!(resp["id"], 1);
assert_eq!(resp["result"]["protocolVersion"], "2024-11-05");
assert!(resp["result"]["capabilities"]["tools"].is_object());
assert_eq!(resp["result"]["serverInfo"]["name"], "trusty-memory");
}
#[tokio::test]
async fn initialized_notification_returns_null() {
let (state, _tmp) = test_state();
let req = json!({
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {}
});
let resp = handle_message(&state, req).await;
assert!(resp.is_null());
}
#[tokio::test]
async fn tools_list_returns_all_tools() {
let (state, _tmp) = test_state();
let req = json!({"jsonrpc": "2.0", "id": 2, "method": "tools/list"});
let resp = handle_message(&state, req).await;
let tools = resp["result"]["tools"].as_array().expect("tools array");
assert_eq!(tools.len(), 23);
}
#[tokio::test]
async fn unknown_method_returns_error() {
let (state, _tmp) = test_state();
let req = json!({"jsonrpc": "2.0", "id": 4, "method": "wat"});
let resp = handle_message(&state, req).await;
assert_eq!(resp["error"]["code"], -32601);
}
#[tokio::test]
async fn ping_returns_empty_result() {
let (state, _tmp) = test_state();
let req = json!({"jsonrpc": "2.0", "id": 5, "method": "ping"});
let resp = handle_message(&state, req).await;
assert!(resp["result"].is_object());
}
#[tokio::test]
async fn app_state_default_constructs() {
let (s, _tmp) = test_state();
assert!(!s.version.is_empty());
assert!(s.registry.is_empty());
assert!(s.default_palace.is_none());
}
#[test]
#[cfg(unix)]
fn open_activity_log_with_fallback_returns_discard_when_unwritable() {
if unsafe { libc::geteuid() } == 0 {
eprintln!(
"skipping open_activity_log_with_fallback_returns_discard_when_unwritable: running as root"
);
return;
}
use std::os::unix::fs::PermissionsExt;
let outer = tempfile::tempdir().expect("outer tempdir");
let primary = outer.path().join("primary");
let tmpdir = outer.path().join("fake-tmp");
std::fs::create_dir(&primary).expect("create primary");
std::fs::create_dir(&tmpdir).expect("create tmpdir");
std::fs::set_permissions(&primary, std::fs::Permissions::from_mode(0o000))
.expect("chmod primary");
std::fs::set_permissions(&tmpdir, std::fs::Permissions::from_mode(0o000))
.expect("chmod tmpdir");
let prev_tmpdir = std::env::var_os("TMPDIR");
std::env::set_var("TMPDIR", &tmpdir);
let log = open_activity_log_with_fallback(&primary);
match prev_tmpdir {
Some(v) => std::env::set_var("TMPDIR", v),
None => std::env::remove_var("TMPDIR"),
}
let _ = std::fs::set_permissions(&primary, std::fs::Permissions::from_mode(0o700));
let _ = std::fs::set_permissions(&tmpdir, std::fs::Permissions::from_mode(0o700));
assert!(
log.is_discard(),
"expected ActivityLog::Discard when both data root and tempdir are unwritable"
);
let id = log
.append(
ActivitySource::Http,
None,
"drawer_added",
json!({"smoke": true}),
)
.expect("discard append must succeed");
assert_eq!(id, 0);
assert_eq!(log.count().expect("discard count"), 0);
assert!(log
.list(&ActivityFilter::default(), 10, 0)
.expect("discard list")
.is_empty());
}
#[tokio::test]
async fn default_palace_used_when_arg_omitted() {
let tmp = tempfile::tempdir().expect("tempdir");
let root = tmp.path().to_path_buf();
let registry = trusty_common::memory_core::PalaceRegistry::new();
let palace = trusty_common::memory_core::Palace {
id: trusty_common::memory_core::PalaceId::new("default-pal"),
name: "default-pal".to_string(),
description: None,
created_at: chrono::Utc::now(),
data_dir: root.join("default-pal"),
};
registry
.create_palace(&root, palace)
.expect("create_palace");
let state = AppState::new(root).with_default_palace(Some("default-pal".to_string()));
let init = handle_message(
&state,
json!({"jsonrpc": "2.0", "id": 1, "method": "initialize"}),
)
.await;
assert_eq!(
init["result"]["serverInfo"]["default_palace"], "default-pal",
"initialize must echo default_palace in serverInfo"
);
let list = handle_message(
&state,
json!({"jsonrpc": "2.0", "id": 2, "method": "tools/list"}),
)
.await;
let tools = list["result"]["tools"].as_array().expect("tools array");
let remember = tools
.iter()
.find(|t| t["name"] == "memory_remember")
.expect("memory_remember tool");
let required: Vec<&str> = remember["inputSchema"]["required"]
.as_array()
.expect("required array")
.iter()
.filter_map(|v| v.as_str())
.collect();
assert!(
!required.contains(&"palace"),
"palace must not be required when default is configured; got {required:?}"
);
assert!(required.contains(&"text"));
let call = handle_message(
&state,
json!({
"jsonrpc": "2.0",
"id": 3,
"method": "tools/call",
"params": {
"name": "memory_remember",
"arguments": {"text": "default palace test memory content with several tokens"},
},
}),
)
.await;
let text = call["result"]["content"][0]["text"]
.as_str()
.unwrap_or_else(|| panic!("expected success result, got {call}"));
let parsed: Value = serde_json::from_str(text).expect("parse content json");
assert_eq!(parsed["palace"], "default-pal");
assert_eq!(parsed["status"], "stored");
assert!(parsed["drawer_id"].as_str().is_some());
}
#[tokio::test]
async fn missing_palace_without_default_errors() {
let (state, _tmp) = test_state();
let resp = handle_message(
&state,
json!({
"jsonrpc": "2.0",
"id": 7,
"method": "tools/call",
"params": {
"name": "memory_recall",
"arguments": {"query": "anything"},
},
}),
)
.await;
assert_eq!(resp["error"]["code"], -32603);
let msg = resp["error"]["message"].as_str().unwrap_or("");
assert!(
msg.contains("missing 'palace'"),
"expected helpful error, got: {msg}"
);
}
#[tokio::test]
async fn load_palaces_from_disk_rehydrates_registry() {
use trusty_common::memory_core::{Palace, PalaceId, PalaceRegistry};
let tmp = tempfile::tempdir().expect("tempdir");
let root = tmp.path().to_path_buf();
{
let writer = PalaceRegistry::new();
for id in ["alpha", "beta"] {
let palace = Palace {
id: PalaceId::new(id),
name: id.to_string(),
description: None,
created_at: chrono::Utc::now(),
data_dir: root.join(id),
};
writer
.create_palace(&root, palace)
.expect("persist palace to disk");
}
}
std::fs::create_dir_all(root.join("not-a-palace")).expect("mkdir");
let state = AppState::new(root);
assert!(
state.registry.is_empty(),
"AppState::new must start with an empty registry"
);
let count = state
.load_palaces_from_disk()
.await
.expect("load_palaces_from_disk");
assert_eq!(count, 2, "both persisted palaces should be loaded");
assert_eq!(state.registry.len(), 2, "registry should hold both palaces");
let ids: Vec<String> = state.registry.list().into_iter().map(|p| p.0).collect();
assert!(ids.contains(&"alpha".to_string()));
assert!(ids.contains(&"beta".to_string()));
}
#[test]
fn resolve_palace_registry_dir_prefers_palaces_subdir() {
let tmp = tempfile::tempdir().expect("tempdir");
let data_dir = tmp.path().to_path_buf();
std::fs::create_dir_all(data_dir.join("palaces")).expect("mkdir palaces");
let resolved = resolve_palace_registry_dir(data_dir.clone());
assert_eq!(resolved, data_dir.join("palaces"));
}
#[test]
fn resolve_palace_registry_dir_falls_back_to_data_dir() {
let tmp = tempfile::tempdir().expect("tempdir");
let data_dir = tmp.path().to_path_buf();
let resolved = resolve_palace_registry_dir(data_dir.clone());
assert_eq!(resolved, data_dir);
}
#[tokio::test]
async fn load_palaces_from_disk_handles_palaces_subdir() {
use trusty_common::memory_core::{Palace, PalaceId, PalaceRegistry};
let tmp = tempfile::tempdir().expect("tempdir");
let root = tmp.path().to_path_buf();
let nested = root.join("palaces");
{
let writer = PalaceRegistry::new();
for id in ["cto", "engineering"] {
let palace = Palace {
id: PalaceId::new(id),
name: id.to_string(),
description: None,
created_at: chrono::Utc::now(),
data_dir: nested.join(id),
};
writer
.create_palace(&nested, palace)
.expect("persist palace under palaces/ subdir");
}
}
let registry_dir = resolve_palace_registry_dir(root);
assert_eq!(registry_dir, nested, "must resolve into palaces/ subdir");
let state = AppState::new(registry_dir);
let count = state
.load_palaces_from_disk()
.await
.expect("load_palaces_from_disk");
assert_eq!(count, 2, "both nested palaces should be loaded");
assert_eq!(state.registry.len(), 2);
let ids: Vec<String> = state.registry.list().into_iter().map(|p| p.0).collect();
assert!(ids.contains(&"cto".to_string()));
assert!(ids.contains(&"engineering".to_string()));
}
#[tokio::test]
async fn load_palaces_from_disk_empty_root_returns_zero() {
let (state, _tmp) = test_state();
let count = state
.load_palaces_from_disk()
.await
.expect("load_palaces_from_disk on empty root");
assert_eq!(count, 0);
assert!(state.registry.is_empty());
}
#[tokio::test]
async fn palace_name_cache_populated_after_hydration() {
use trusty_common::memory_core::{Palace, PalaceId, PalaceRegistry};
let tmp = tempfile::tempdir().expect("tempdir");
let root = tmp.path().to_path_buf();
{
let writer = PalaceRegistry::new();
for (id, name) in [("alpha", "Alpha Project"), ("beta", "Beta Project")] {
let palace = Palace {
id: PalaceId::new(id),
name: name.to_string(),
description: None,
created_at: chrono::Utc::now(),
data_dir: root.join(id),
};
writer.create_palace(&root, palace).expect("persist palace");
}
}
let state = AppState::new(root);
assert!(
state.palace_names.is_empty(),
"fresh AppState must start with an empty name cache"
);
state
.load_palaces_from_disk()
.await
.expect("load_palaces_from_disk");
assert_eq!(state.palace_names.len(), 2, "cache must hold both palaces");
assert_eq!(
state.palace_names.get("alpha").map(|e| e.value().clone()),
Some("Alpha Project".to_string()),
);
assert_eq!(
state.palace_names.get("beta").map(|e| e.value().clone()),
Some("Beta Project".to_string()),
);
}
#[tokio::test]
async fn palace_name_cache_updates_on_create() {
use serde_json::json;
let (state, _tmp) = test_state();
let _ = tools::dispatch_tool(&state, "palace_create", json!({"name": "gamma"}))
.await
.expect("palace_create");
assert_eq!(
state.palace_names.get("gamma").map(|e| e.value().clone()),
Some("gamma".to_string()),
"palace_create must populate the in-memory name cache so writes \
can resolve the friendly name without a disk walk"
);
}
#[tokio::test]
async fn initialize_without_default_palace_omits_field() {
let (state, _tmp) = test_state();
let init = handle_message(
&state,
json!({"jsonrpc": "2.0", "id": 1, "method": "initialize"}),
)
.await;
assert!(init["result"]["serverInfo"]["default_palace"].is_null());
}
#[tokio::test]
async fn http_addr_path_uses_resolve_data_dir() {
let _guard = crate::commands::env_test_lock().lock().await;
let tmp = tempfile::tempdir().unwrap();
unsafe {
std::env::set_var(trusty_common::DATA_DIR_OVERRIDE_ENV, tmp.path());
}
let result = http_addr_path();
unsafe {
std::env::remove_var(trusty_common::DATA_DIR_OVERRIDE_ENV);
}
let p = result.expect("http_addr_path must return Some when data dir is resolvable");
assert!(
p.ends_with("trusty-memory/http_addr"),
"unexpected http_addr path: {}",
p.display()
);
}
#[cfg(feature = "axum-server")]
#[test]
fn http_addr_file_round_trip_via_helpers() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("http_addr");
let addr: SocketAddr = "127.0.0.1:7073".parse().unwrap();
write_http_addr_file(&path, &addr).unwrap();
let raw = std::fs::read_to_string(&path).unwrap();
assert_eq!(raw.trim(), "127.0.0.1:7073");
assert!(raw.ends_with('\n'));
}
#[tokio::test]
async fn bind_dynamic_port_returns_listener() {
let listener = bind_dynamic_port().await.expect("bind_dynamic_port");
let addr = listener.local_addr().expect("local_addr");
assert_eq!(addr.ip().to_string(), "127.0.0.1");
assert!(addr.port() > 0, "port must be non-zero after bind");
}
#[tokio::test]
async fn initialize_does_not_advertise_prompts_capability() {
let (state, _tmp) = test_state();
let init = handle_message(
&state,
json!({"jsonrpc": "2.0", "id": 1, "method": "initialize"}),
)
.await;
assert!(
init["result"]["capabilities"]["prompts"].is_null(),
"initialize must NOT advertise the prompts capability; got {init}"
);
for method in ["prompts/list", "prompts/get"] {
let resp =
handle_message(&state, json!({"jsonrpc": "2.0", "id": 2, "method": method})).await;
assert_eq!(
resp["error"]["code"], -32601,
"{method} should return method-not-found; got {resp}"
);
}
}
#[tokio::test]
async fn app_state_starts_with_empty_bound_addr() {
let (state, _tmp) = test_state();
assert!(state.bound_addr.get().is_none());
}
#[test]
fn daemon_event_type_str_matches_sse_tag() {
let cases = [
DaemonEvent::PalaceCreated {
id: "p".into(),
name: "p".into(),
source: ActivitySource::Http,
},
DaemonEvent::DrawerAdded {
palace_id: "p".into(),
palace_name: "p".into(),
drawer_count: 1,
timestamp: chrono::Utc::now(),
content_preview: String::new(),
source: ActivitySource::Mcp,
},
DaemonEvent::DrawerDeleted {
palace_id: "p".into(),
drawer_count: 0,
source: ActivitySource::Http,
},
DaemonEvent::DreamCompleted {
palace_id: None,
merged: 0,
pruned: 0,
compacted: 0,
closets_updated: 0,
duration_ms: 0,
source: ActivitySource::Http,
},
DaemonEvent::StatusChanged {
total_drawers: 0,
total_vectors: 0,
total_kg_triples: 0,
},
DaemonEvent::HookFired {
palace_id: Some("p".into()),
palace_name: Some("p".into()),
hook_type: HookType::UserPromptSubmit,
injection_kind: InjectionKind::PromptContext,
injection_length: 12,
trigger_prompt_excerpt: "hello".into(),
timestamp: chrono::Utc::now(),
duration_ms: 5,
source: ActivitySource::Hook,
},
];
for ev in &cases {
let json = serde_json::to_value(ev).unwrap();
assert_eq!(json["type"].as_str(), Some(ev.type_str()));
}
}
#[test]
fn hook_type_serde_round_trips() {
let cases = [
(HookType::UserPromptSubmit, "\"UserPromptSubmit\""),
(HookType::SessionStart, "\"SessionStart\""),
];
for (ht, expected) in cases {
let s = serde_json::to_string(&ht).unwrap();
assert_eq!(s, expected, "{ht:?} should serialise to {expected}");
let back: HookType = serde_json::from_str(&s).unwrap();
assert_eq!(back, ht);
assert_eq!(ht.as_str(), expected.trim_matches('"'));
}
}
#[test]
fn injection_kind_serde_round_trips() {
let cases = [
(InjectionKind::PromptContext, "\"prompt-context\""),
(InjectionKind::InboxCheck, "\"inbox-check\""),
];
for (ik, expected) in cases {
let s = serde_json::to_string(&ik).unwrap();
assert_eq!(s, expected);
let back: InjectionKind = serde_json::from_str(&s).unwrap();
assert_eq!(back, ik);
assert_eq!(ik.as_str(), expected.trim_matches('"'));
}
}
#[test]
fn hook_excerpt_truncates_long_prompts() {
let long = "x".repeat(200);
let excerpt = hook_prompt_excerpt(&long);
assert!(excerpt.chars().count() <= HOOK_PROMPT_EXCERPT_CHARS);
assert!(excerpt.ends_with('…'));
assert_eq!(hook_prompt_excerpt(""), "");
}
#[test]
fn hook_excerpt_collapses_whitespace() {
let input = "hello\n\nworld\t\tfoo";
let excerpt = hook_prompt_excerpt(input);
assert_eq!(excerpt, "hello world foo");
}
#[test]
fn daemon_event_palace_id_and_source_extraction() {
let ev = DaemonEvent::DrawerAdded {
palace_id: "alpha".into(),
palace_name: "alpha".into(),
drawer_count: 1,
timestamp: chrono::Utc::now(),
content_preview: String::new(),
source: ActivitySource::Mcp,
};
assert_eq!(ev.palace_id(), Some("alpha"));
assert_eq!(ev.source(), Some(ActivitySource::Mcp));
let status = DaemonEvent::StatusChanged {
total_drawers: 1,
total_vectors: 2,
total_kg_triples: 3,
};
assert_eq!(status.palace_id(), None);
assert_eq!(status.source(), None);
let dream = DaemonEvent::DreamCompleted {
palace_id: Some("p1".into()),
merged: 0,
pruned: 0,
compacted: 0,
closets_updated: 0,
duration_ms: 10,
source: ActivitySource::Http,
};
assert_eq!(dream.palace_id(), Some("p1"));
assert_eq!(dream.source(), Some(ActivitySource::Http));
}
#[tokio::test]
async fn emit_persists_mutations_but_skips_status_changed() {
let (state, _tmp) = test_state();
state.emit(DaemonEvent::PalaceCreated {
id: "p".into(),
name: "p".into(),
source: ActivitySource::Http,
});
state.emit(DaemonEvent::StatusChanged {
total_drawers: 1,
total_vectors: 0,
total_kg_triples: 0,
});
state.emit(DaemonEvent::DrawerAdded {
palace_id: "p".into(),
palace_name: "p".into(),
drawer_count: 1,
timestamp: chrono::Utc::now(),
content_preview: "x".into(),
source: ActivitySource::Mcp,
});
state.flush_activity_writes().await;
let count = state.activity_log.count().unwrap();
assert_eq!(count, 2, "only PalaceCreated + DrawerAdded must persist");
}
#[tokio::test]
async fn bm25_client_disabled_by_default() {
let _guard = crate::commands::env_test_lock().lock().await;
let prev = std::env::var("TRUSTY_BM25_DAEMON").ok();
unsafe {
std::env::remove_var("TRUSTY_BM25_DAEMON");
}
let (state, _tmp) = test_state();
let state = state.with_bm25_client_from_env();
assert!(
state.bm25_client.is_none(),
"bm25_client must be None when TRUSTY_BM25_DAEMON is unset"
);
assert!(
state.bm25_supervisor.is_none(),
"bm25_supervisor must be None when TRUSTY_BM25_DAEMON is unset"
);
if let Some(v) = prev {
unsafe {
std::env::set_var("TRUSTY_BM25_DAEMON", v);
}
}
}
#[tokio::test]
async fn bm25_client_enabled_when_env_set() {
let _guard = crate::commands::env_test_lock().lock().await;
let prev = std::env::var("TRUSTY_BM25_DAEMON").ok();
unsafe {
std::env::set_var("TRUSTY_BM25_DAEMON", "1");
}
let (state, _tmp) = test_state();
let state = state.with_bm25_client_from_env();
assert!(
state.bm25_client.is_some(),
"bm25_client must be Some when TRUSTY_BM25_DAEMON=1"
);
assert!(
state.bm25_supervisor.is_some(),
"bm25_supervisor must be Some when TRUSTY_BM25_DAEMON=1"
);
match prev {
Some(v) => unsafe { std::env::set_var("TRUSTY_BM25_DAEMON", v) },
None => unsafe { std::env::remove_var("TRUSTY_BM25_DAEMON") },
}
}
}