use crate::client::{OriginClient, OriginError};
use crate::types::*;
use rmcp::{
handler::server::router::tool::ToolRouter,
handler::server::wrapper::Parameters,
model::{CallToolResult, Content, Implementation, InitializeResult, ServerCapabilities},
service::{NotificationContext, RoleServer},
tool, tool_handler, tool_router, ErrorData as McpError, ServerHandler,
};
use serde::{Deserialize, Deserializer};
/// Deserialize an `Option<usize>` that also accepts stringified numbers (e.g. `"10"`).
/// MCP clients like Claude Desktop sometimes send numeric params as strings.
fn deserialize_optional_usize_lenient<'de, D>(deserializer: D) -> Result<Option<usize>, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum StringOrNumber {
Number(usize),
Str(String),
}
match Option::<StringOrNumber>::deserialize(deserializer)? {
None => Ok(None),
Some(StringOrNumber::Number(n)) => Ok(Some(n)),
Some(StringOrNumber::Str(s)) => s
.parse::<usize>()
.map(Some)
.map_err(serde::de::Error::custom),
}
}
/// Controls which operations are allowed based on transport.
#[derive(Clone, Debug, PartialEq)]
pub enum TransportMode {
/// Local stdio — full access, all tools
Stdio,
/// Remote HTTP — block deletes, inject source_agent
Http,
}
#[derive(Clone)]
pub struct OriginMcpServer {
#[allow(dead_code)]
tool_router: ToolRouter<Self>,
client: OriginClient,
transport: TransportMode,
agent_name: String,
/// Client name from MCP initialize handshake (e.g., "Claude Code", "Claude Desktop")
client_name: std::sync::Arc<std::sync::Mutex<Option<String>>>,
user_id: Option<String>,
}
// ===== Parameter Structs =====
// --- Primary tool params ---
#[derive(Debug, Deserialize, schemars::JsonSchema)]
pub struct RememberParams {
#[schemars(
description = "The memory content. Write as a complete statement with context and reasoning, not shorthand. One idea per memory."
)]
pub content: String,
#[schemars(
description = "\"profile\" (about the user) or \"knowledge\" (about the world) — or precise: \"identity\", \"preference\", \"goal\", \"fact\", \"decision\" — auto-classified if omitted"
)]
pub memory_type: Option<String>,
#[schemars(
description = "Topic scope (e.g. 'rust', 'work', 'health', 'origin'). Auto-detected if omitted."
)]
pub domain: Option<String>,
#[schemars(
description = "Person, project, or tool name to anchor to (e.g. 'Alice', 'Origin', 'PostgreSQL'). Helps build the knowledge graph."
)]
pub entity: Option<String>,
#[schemars(
description = "0.0-1.0. Leave unset for auto-calculation based on type and trust level. Set low (0.3-0.5) for uncertain info, high (0.8-1.0) for user-stated facts."
)]
pub confidence: Option<f32>,
#[schemars(
description = "source_id of a memory this replaces. Use when correcting or updating an existing memory — get the ID from recall first."
)]
pub supersedes: Option<String>,
#[schemars(
description = "Pre-extracted structured fields as a JSON object. Auto-extracted by backend; only supply if you have high-quality structured data already."
)]
pub structured_fields: Option<serde_json::Map<String, serde_json::Value>>,
#[schemars(
description = "A question this memory answers, for search matching. Auto-generated by backend; only supply to override."
)]
pub retrieval_cue: Option<String>,
}
#[derive(Debug, Deserialize, schemars::JsonSchema)]
pub struct RecallParams {
#[schemars(
description = "Natural language search. Be specific: 'Alice database preference' finds more than 'database stuff'."
)]
pub query: String,
#[schemars(
description = "Max results, default 10. Use 3-5 for quick lookups, 10-20 for exploration."
)]
#[serde(default, deserialize_with = "deserialize_optional_usize_lenient")]
pub limit: Option<usize>,
#[schemars(
description = "Filter by type. Two-level filter: \"profile\" (user-facing) or \"knowledge\" (world-facing), or precise: identity, preference, goal, fact, decision."
)]
pub memory_type: Option<String>,
#[schemars(description = "Filter by topic scope.")]
pub domain: Option<String>,
}
#[derive(Debug, Deserialize, schemars::JsonSchema)]
pub struct ContextParams {
#[schemars(
description = "Topic or conversation summary to focus context retrieval. Omit at session start for general orientation; provide when shifting topics."
)]
pub topic: Option<String>,
#[schemars(
description = "Max context chunks, default 20. Increase for complex topics, decrease for quick check-ins."
)]
#[serde(default, deserialize_with = "deserialize_optional_usize_lenient")]
pub limit: Option<usize>,
#[schemars(
description = "Scope context to a domain/space (e.g. 'work', 'personal'). Auto-detected from conversation if omitted."
)]
pub domain: Option<String>,
}
#[derive(Debug, Deserialize, schemars::JsonSchema)]
pub struct ForgetParams {
#[schemars(
description = "The source_id of the memory to delete. Get this from recall results first."
)]
pub memory_id: String,
}
// ===== Internal Implementations =====
fn format_remember_success(resp: &StoreMemoryResponse) -> String {
let mut msg = format!("Stored {}", resp.source_id);
if !resp.warnings.is_empty() {
msg.push_str("\nWarnings:");
for warning in &resp.warnings {
msg.push_str(&format!("\n - {}", warning));
}
}
msg
}
/// Convert a backend error into a tool-level error result (isError: true)
/// with an actionable message. This keeps the MCP transport healthy
/// (no protocol-level McpError) while telling the caller what happened.
fn tool_error(e: OriginError, verb: &str) -> CallToolResult {
let msg = match &e {
OriginError::Unreachable(_) => format!(
"Origin daemon is not reachable (retried 3x over ~6s). \
The {verb} was NOT completed. Try again after the daemon is running."
),
OriginError::Api { status, body } => format!(
"Origin daemon returned HTTP {status}: {body}. The {verb} may not have completed."
),
OriginError::Deserialize(detail) => format!(
"Failed to parse daemon response: {detail}. \
This may indicate a version mismatch between origin-mcp and the daemon."
),
};
CallToolResult::error(vec![Content::text(msg)])
}
impl OriginMcpServer {
/// Resolve the source_agent for a write operation.
/// Priority: explicit param > MCP client name (from initialize) > configured agent_name.
fn resolve_source_agent(&self, param_agent: Option<String>) -> Option<String> {
// 1. Explicit param from tool call
if let Some(ref agent) = param_agent {
if !agent.is_empty() {
return param_agent;
}
}
// 2. Client name captured from MCP initialize handshake
if let Ok(guard) = self.client_name.lock() {
if let Some(ref name) = *guard {
return Some(name.clone());
}
}
// 3. Configured --agent-name flag
Some(self.agent_name.clone())
}
/// Resolve a local user_id for logging or future use.
/// This value is intentionally not sent on the wire (D4).
fn resolve_user_id(&self, param_user_id: Option<String>) -> Option<String> {
if self.transport == TransportMode::Http {
self.user_id.clone().or(param_user_id)
} else {
param_user_id
}
}
pub async fn remember_impl(&self, params: RememberParams) -> Result<CallToolResult, McpError> {
let source_agent = self.resolve_source_agent(None);
if let Some(uid) = self.resolve_user_id(None) {
tracing::debug!(user_id = %uid, "remember invoked");
}
let req = StoreMemoryRequest {
content: params.content,
memory_type: params.memory_type,
domain: params.domain,
source_agent,
title: None,
confidence: params.confidence,
supersedes: params.supersedes,
entity: params.entity,
entity_id: None,
structured_fields: params.structured_fields.map(serde_json::Value::Object),
retrieval_cue: params.retrieval_cue,
};
let resp: StoreMemoryResponse = match self.client.post("/api/memory/store", &req).await {
Ok(r) => r,
Err(e) => return Ok(tool_error(e, "memory store")),
};
Ok(CallToolResult::success(vec![Content::text(
format_remember_success(&resp),
)]))
}
pub async fn recall_impl(&self, params: RecallParams) -> Result<CallToolResult, McpError> {
let req = SearchMemoryRequest {
query: params.query,
limit: params.limit.unwrap_or(10),
memory_type: params.memory_type,
domain: params.domain,
source_agent: self.resolve_source_agent(None),
};
let resp: SearchMemoryResponse = match self.client.post("/api/memory/search", &req).await {
Ok(r) => r,
Err(e) => return Ok(tool_error(e, "search")),
};
let json = serde_json::to_string_pretty(&resp.results)
.map_err(|e| McpError::internal_error(e.to_string(), None))?;
Ok(CallToolResult::success(vec![Content::text(format!(
"{} results ({:.1}ms)\n{}",
resp.results.len(),
resp.took_ms,
json
))]))
}
pub async fn context_impl(&self, params: ContextParams) -> Result<CallToolResult, McpError> {
let req = ChatContextRequest {
query: None,
conversation_id: params.topic,
max_chunks: params.limit.unwrap_or(20),
relevance_threshold: None,
include_goals: true,
domain: params.domain,
};
// Extract only the `context` string field from the response.
//
// The full ChatContextResponse embeds Vec<SearchResult> which may
// contain fields added after the published origin-types version.
// Since context_impl only uses `resp.context`, we parse the raw
// JSON and pull that field directly — this makes the tool forward-
// compatible with any new fields the daemon might add.
let raw: serde_json::Value = match self.client.post("/api/chat-context", &req).await {
Ok(r) => r,
Err(e) => return Ok(tool_error(e, "context load")),
};
let context = raw
.get("context")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
if context.is_empty() {
Ok(CallToolResult::success(vec![Content::text(
"No relevant context found".to_string(),
)]))
} else {
Ok(CallToolResult::success(vec![Content::text(context)]))
}
}
pub async fn forget_impl(&self, memory_id: &str) -> Result<CallToolResult, McpError> {
if self.transport == TransportMode::Http {
return Ok(CallToolResult::error(vec![Content::text(
"Delete operations are not available over remote connections. \
Use the Origin desktop app to delete memories."
.to_string(),
)]));
}
let resp: DeleteResponse = match self
.client
.delete(&format!("/api/memory/delete/{}", memory_id))
.await
{
Ok(r) => r,
Err(e) => return Ok(tool_error(e, "delete")),
};
Ok(CallToolResult::success(vec![Content::text(
if resp.deleted {
"Memory deleted"
} else {
"Memory not found"
}
.to_string(),
)]))
}
}
// ===== Tool Registrations =====
#[tool_router]
impl OriginMcpServer {
pub fn new(
client: OriginClient,
transport: TransportMode,
agent_name: String,
user_id: Option<String>,
) -> Self {
Self {
tool_router: Self::tool_router(),
client,
transport,
agent_name,
client_name: std::sync::Arc::new(std::sync::Mutex::new(None)),
user_id,
}
}
// --- Primary Tools ---
#[tool(
description = "Store a memory. Call PROACTIVELY when you learn something durable about the user — preferences, decisions, corrections, or facts about people/projects/tools they care about. Don't wait for the user to say 'remember this' — that's a floor, not a trigger.\n\nWrite content as a complete, self-contained statement — someone reading it months later with no conversation context should understand it. Include the WHY, not just the WHAT. Name people, projects, and tools explicitly.\n\nThe backend auto-classifies type, extracts structured fields, detects entities, and links to the knowledge graph. You don't need to set memory_type or structured_fields unless you're confident — omitting them gets better results than guessing wrong.\n\nDo NOT store: system prompts, boot logs, heartbeat/health checks, transient task state ('currently working on...'), tool output/responses, architecture dumps, single-word acknowledgments, or content you have already stored. Focus on durable facts, preferences, decisions, goals, and identity information. Each call is one atomic idea — \"prefers TDD\" and \"uses pytest\" are two calls, not one.",
annotations(
title = "Remember",
read_only_hint = false,
destructive_hint = false,
idempotent_hint = false,
open_world_hint = false
)
)]
async fn remember(
&self,
Parameters(params): Parameters<RememberParams>,
) -> Result<CallToolResult, McpError> {
self.remember_impl(params).await
}
#[tool(
description = "Search memories by query. Use when the user asks 'do you remember', 'what do you know about', 'look up', or when you need a specific fact before acting.\n\nWrite queries as natural language — the search engine handles semantic matching. For precision, use filters (memory_type, domain) to narrow results. If you get too many results, add filters rather than making the query longer.\n\nThis is for targeted lookups. For broad session orientation, use context instead.",
annotations(title = "Recall", read_only_hint = true, open_world_hint = false)
)]
async fn recall(
&self,
Parameters(params): Parameters<RecallParams>,
) -> Result<CallToolResult, McpError> {
self.recall_impl(params).await
}
#[tool(
description = "Load session context — identity, preferences, goals, and topic-relevant memories. Call this FIRST at the start of every session before doing anything else. Also call on major topic shifts or when the user says 'catch me up' or 'what's the background on'.\n\nThis returns a curated blend of who the user is and what's relevant. For specific factual lookups, use recall instead. Use the result to model how the user thinks, not just to look things up — their preferences and corrections tell you how they want to be helped.",
annotations(title = "Context", read_only_hint = true, open_world_hint = false)
)]
async fn context(
&self,
Parameters(params): Parameters<ContextParams>,
) -> Result<CallToolResult, McpError> {
self.context_impl(params).await
}
#[tool(
description = "Delete a memory by ID. Use when the user says 'forget this', 'delete that', 'that's wrong and should be removed'. Requires the source_id — get it from recall first.\n\nThis is destructive and cannot be undone. For corrections, prefer storing a new memory with the supersedes param pointing to the old one — this preserves history.",
annotations(
title = "Forget",
read_only_hint = false,
destructive_hint = true,
idempotent_hint = true,
open_world_hint = false
)
)]
async fn forget(
&self,
Parameters(params): Parameters<ForgetParams>,
) -> Result<CallToolResult, McpError> {
self.forget_impl(¶ms.memory_id).await
}
}
// ===== ServerHandler =====
#[tool_handler]
impl ServerHandler for OriginMcpServer {
async fn on_initialized(&self, context: NotificationContext<RoleServer>) {
// Capture client name from MCP initialize handshake
if let Some(client_info) = context.peer.peer_info() {
let name = &client_info.client_info.name;
if !name.is_empty() {
if let Ok(mut guard) = self.client_name.lock() {
tracing::info!("MCP client identified: {}", name);
*guard = Some(name.clone());
}
}
}
}
fn get_info(&self) -> InitializeResult {
InitializeResult::new(
ServerCapabilities::builder()
.enable_tools()
.build(),
)
.with_server_info(
Implementation::new("origin-mcp", env!("CARGO_PKG_VERSION"))
)
.with_instructions(
"Origin is your personal memory layer — a local knowledge base that persists across sessions and tools.\n\
Think of yourself as a curator, not a logger. Store insights, not conversation artifacts.\n\n\
Origin is self-evolving — each memory you store contributes to a knowledge structure that grows over time. \
It's also shared across all the user's tools: what you write, other agents (Claude Desktop, Claude Code, \
ChatGPT, Cursor, etc.) will read later. Write for any future reader, not just this conversation.\n\n\
FIRST THING EVERY SESSION: Call context to load the user's identity, preferences, goals, and\n\
topic-relevant memories. This is how you know who you're talking to. Use the result to model how the \
user thinks — their preferences, corrections, and past decisions tell you how they want to be helped, \
not just what they already know.\n\n\
STORE PROACTIVELY — don't wait for the user to ask.\n\
- The user states a preference (\"I use X because...\", \"I prefer Y over Z\")\n\
- The user makes a decision (\"going with approach A\", \"switching to B\")\n\
- The user corrects you or prior info (\"actually, it's C, not D\") — store the correction so it sticks\n\
- The user shares a durable fact about themselves, their work, or people/projects/tools they care about — \
anchor it to the entity\n\n\
If the user asks explicitly (\"remember this\", \"save this\", \"don't forget\"), that's a floor — you \
should have already stored it.\n\n\
WHEN NOT TO STORE:\n\
- Conversation filler (\"ok\", \"thanks\", \"let's move on\")\n\
- Things the user can trivially re-derive (file paths, recent git history)\n\
- Anything already stored — recall first if unsure\n\
- Tool output or command results (file contents, git history, build logs) — these are derivable\n\
- General world facts or documentation that aren't personal to this user (e.g., \"Rust has a borrow \
checker\", \"PostgreSQL supports JSONB\") — those are not memory material.\n\
- Your own inferences about the user that they didn't express. Store what they said; infer from that \
when responding.\n\n\
CONTENT QUALITY — this is where you make the biggest difference:\n\
- Specific beats vague: \"prefers Rust for CLI tools because of compile-time safety\" > \"likes Rust\"\n\
- Include the WHY: the backend can classify \"dark mode\" as a preference, but only you know\n\
\"switched to dark mode because of migraines from bright screens\"\n\
- Name the entities: mention people, projects, tools by name — this powers the knowledge graph\n\
- Atomic: one idea per memory — \"prefers TDD\" and \"uses pytest\" should be two memories, not one\n\
- Declarative, not narrative: \"User prefers X because Y\" — not \"User said today they prefer X\". \
Memories outlive the conversation that produced them.\n\n\
MEMORY TYPES — omit and trust the backend.\n\n\
By default, do NOT set memory_type. The backend auto-classifies into identity / preference / goal / \
fact / decision with more context than you have. Agents that over-specify types tend to pick wrong.\n\n\
Opt-in specification:\n\
- \"profile\" — you're sure it's about the user (identity/preference/goal)\n\
- \"knowledge\" — you're sure it's about the world (fact/decision)\n\
- Precise type — only if you're confident and the distinction matters.\n\n\
EXCEPTION — decisions carry structured fields (alternatives considered, reversibility, domain) \
that power the Decision Log view. Set memory_type=\"decision\" explicitly ONLY when the user \
articulated alternatives weighed AND the reasoning for the choice. A bare \"I'm switching to Cursor\" \
is just a preference change — omit the type. \"Switching to Cursor over VSCode because of better \
Claude integration, and we can always go back\" — that's a decision.\n\n\
RECALL vs CONTEXT:\n\
- context: broad orientation, session start, topic shifts, \"catch me up\"\n\
- recall: specific lookup (\"what's Alice's role?\", \"database preferences\", \"our auth decision\")\n\n\
The backend handles classification, entity extraction, structured fields, quality scoring,\n\
and dedup — you don't need to replicate that logic. Focus on what only you know:\n\
the conversational context, why something matters, and what the user actually cares about."
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::OriginClient;
use crate::types::{
ChatContextRequest, ChatContextResponse, SearchMemoryRequest, SearchResult,
StoreMemoryRequest, StoreMemoryResponse,
};
fn make_server(
transport: TransportMode,
agent_name: &str,
user_id: Option<&str>,
) -> OriginMcpServer {
let client = OriginClient::new("http://127.0.0.1:19999".into());
OriginMcpServer::new(
client,
transport,
agent_name.into(),
user_id.map(String::from),
)
}
// ===== Transport resolution (existing) =====
#[test]
fn test_http_mode_prefers_param_over_agent_name() {
let server = make_server(TransportMode::Http, "claude.ai", None);
// Explicit param has highest priority
let result = server.resolve_source_agent(Some("user-provided".into()));
assert_eq!(result, Some("user-provided".into()));
}
#[test]
fn test_http_mode_sets_source_agent_when_none() {
let server = make_server(TransportMode::Http, "chatgpt", None);
let result = server.resolve_source_agent(None);
assert_eq!(result, Some("chatgpt".into()));
}
#[test]
fn test_stdio_mode_passes_through_source_agent() {
let server = make_server(TransportMode::Stdio, "ignored", None);
let result = server.resolve_source_agent(Some("user-provided".into()));
assert_eq!(result, Some("user-provided".into()));
}
#[test]
fn test_stdio_mode_falls_back_to_agent_name() {
let server = make_server(TransportMode::Stdio, "fallback", None);
// No param, no client_name → falls back to configured agent_name
let result = server.resolve_source_agent(None);
assert_eq!(result, Some("fallback".into()));
}
#[test]
fn test_http_mode_resolves_configured_user_id_for_local_use() {
let server = make_server(TransportMode::Http, "agent", Some("lucian"));
let result = server.resolve_user_id(None);
assert_eq!(result, Some("lucian".into()));
}
#[test]
fn test_transport_mode_equality() {
assert_eq!(TransportMode::Stdio, TransportMode::Stdio);
assert_eq!(TransportMode::Http, TransportMode::Http);
assert_ne!(TransportMode::Stdio, TransportMode::Http);
}
// ===== Param deserialization: RememberParams =====
#[test]
fn test_remember_params_minimal() {
let json = r#"{"content": "Lucian prefers dark mode"}"#;
let params: RememberParams = serde_json::from_str(json).unwrap();
assert_eq!(params.content, "Lucian prefers dark mode");
assert!(params.memory_type.is_none());
assert!(params.domain.is_none());
assert!(params.entity.is_none());
assert!(params.confidence.is_none());
assert!(params.supersedes.is_none());
}
#[test]
fn test_remember_params_full() {
let json = r#"{
"content": "We chose PostgreSQL over MongoDB",
"memory_type": "decision",
"domain": "origin",
"entity": "PostgreSQL",
"confidence": 0.95,
"supersedes": "mem_abc123"
}"#;
let params: RememberParams = serde_json::from_str(json).unwrap();
assert_eq!(params.content, "We chose PostgreSQL over MongoDB");
assert_eq!(params.memory_type.as_deref(), Some("decision"));
assert_eq!(params.domain.as_deref(), Some("origin"));
assert_eq!(params.entity.as_deref(), Some("PostgreSQL"));
assert_eq!(params.confidence, Some(0.95));
assert_eq!(params.supersedes.as_deref(), Some("mem_abc123"));
}
#[test]
fn test_remember_params_missing_content_fails() {
let json = r#"{"memory_type": "fact"}"#;
let result = serde_json::from_str::<RememberParams>(json);
assert!(result.is_err());
}
// ===== Param deserialization: RecallParams =====
#[test]
fn test_recall_params_minimal() {
let json = r#"{"query": "what does Alice work on?"}"#;
let params: RecallParams = serde_json::from_str(json).unwrap();
assert_eq!(params.query, "what does Alice work on?");
assert!(params.limit.is_none());
}
#[test]
fn test_recall_params_full() {
let json = r#"{
"query": "database preferences",
"limit": 5,
"memory_type": "decision",
"domain": "origin"
}"#;
let params: RecallParams = serde_json::from_str(json).unwrap();
assert_eq!(params.query, "database preferences");
assert_eq!(params.limit, Some(5));
assert_eq!(params.memory_type.as_deref(), Some("decision"));
assert_eq!(params.domain.as_deref(), Some("origin"));
}
#[test]
fn test_recall_params_limit_as_string() {
let json = r#"{"query": "test", "limit": "10"}"#;
let params: RecallParams = serde_json::from_str(json).unwrap();
assert_eq!(params.limit, Some(10));
}
#[test]
fn test_recall_params_missing_query_fails() {
let json = r#"{"limit": 5}"#;
let result = serde_json::from_str::<RecallParams>(json);
assert!(result.is_err());
}
// ===== Param deserialization: ContextParams =====
#[test]
fn test_context_params_empty() {
let json = r#"{}"#;
let params: ContextParams = serde_json::from_str(json).unwrap();
assert!(params.topic.is_none());
assert!(params.limit.is_none());
assert!(params.domain.is_none());
}
#[test]
fn test_context_params_full() {
let json = r#"{"topic": "project Origin architecture", "limit": 30, "domain": "work"}"#;
let params: ContextParams = serde_json::from_str(json).unwrap();
assert_eq!(params.topic.as_deref(), Some("project Origin architecture"));
assert_eq!(params.limit, Some(30));
assert_eq!(params.domain.as_deref(), Some("work"));
}
#[test]
fn test_context_params_limit_as_string() {
let json = r#"{"limit": "20"}"#;
let params: ContextParams = serde_json::from_str(json).unwrap();
assert_eq!(params.limit, Some(20));
}
#[test]
fn store_memory_request_serialization_excludes_user_id() {
let req = StoreMemoryRequest {
content: "test content".into(),
memory_type: None,
domain: None,
source_agent: Some("test-agent".into()),
title: None,
confidence: None,
supersedes: None,
entity: None,
entity_id: None,
structured_fields: None,
retrieval_cue: None,
};
let json = serde_json::to_value(&req).unwrap();
let obj = json.as_object().unwrap();
assert!(
!obj.contains_key("user_id"),
"user_id must not be on the wire; got: {:?}",
obj.keys().collect::<Vec<_>>()
);
}
#[test]
fn remember_success_message_is_terse() {
let resp = StoreMemoryResponse {
source_id: "mem_abc".into(),
chunks_created: 3,
memory_type: "fact".into(),
entity_id: Some("ent_xyz".into()),
quality: Some("high".into()),
warnings: vec![],
extraction_method: "llm".into(),
};
let msg = format_remember_success(&resp);
assert_eq!(msg, "Stored mem_abc");
assert!(!msg.contains("chunks"));
assert!(!msg.contains("quality"));
assert!(!msg.contains("entity"));
}
#[test]
fn remember_success_message_surfaces_warnings() {
let resp = StoreMemoryResponse {
source_id: "mem_abc".into(),
chunks_created: 1,
memory_type: "decision".into(),
entity_id: None,
quality: None,
warnings: vec!["decision memory missing required 'claim' field".into()],
extraction_method: "agent".into(),
};
let msg = format_remember_success(&resp);
assert!(msg.starts_with("Stored mem_abc"));
assert!(msg.contains("Warnings:"));
assert!(msg.contains("decision memory missing required 'claim' field"));
}
#[test]
fn search_memory_request_serialization_excludes_entity() {
let req = SearchMemoryRequest {
query: "test".into(),
limit: 10,
memory_type: None,
domain: None,
source_agent: None,
};
let json = serde_json::to_value(&req).unwrap();
let obj = json.as_object().unwrap();
assert!(
!obj.contains_key("entity"),
"entity must not be on the wire; got keys: {:?}",
obj.keys().collect::<Vec<_>>()
);
}
#[test]
fn chat_context_request_serialization_includes_domain() {
let req = ChatContextRequest {
query: None,
conversation_id: Some("topic".into()),
max_chunks: 20,
relevance_threshold: None,
include_goals: true,
domain: Some("work".into()),
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["domain"], serde_json::json!("work"));
assert_eq!(json["conversation_id"], serde_json::json!("topic"));
}
#[test]
fn chat_context_response_deserializes_with_profile_and_knowledge() {
let json = r#"{
"context": "user is Lucian, prefers Rust",
"profile": {
"narrative": "n",
"identity": ["rust"],
"preferences": [],
"goals": []
},
"knowledge": {
"concepts": [],
"decisions": [],
"relevant_memories": [],
"graph_context": []
},
"took_ms": 42.0,
"token_estimates": {
"tier1_identity": 10,
"tier2_project": 20,
"tier3_relevant": 30,
"total": 60
}
}"#;
let parsed: ChatContextResponse = serde_json::from_str(json).unwrap();
assert_eq!(parsed.context, "user is Lucian, prefers Rust");
assert_eq!(parsed.profile.identity, vec!["rust"]);
assert_eq!(parsed.token_estimates.total, 60);
}
#[test]
fn remember_params_structured_fields_schema_is_object() {
use schemars::schema_for;
let schema = schema_for!(RememberParams);
let json = serde_json::to_value(&schema).unwrap();
let sf_schema = json
.pointer("/properties/structured_fields")
.expect("structured_fields property in schema");
let type_val = sf_schema
.pointer("/type")
.unwrap_or(&serde_json::Value::Null);
let type_str = match type_val {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => arr
.iter()
.filter_map(|v| v.as_str())
.collect::<Vec<_>>()
.join(","),
other => panic!(
"structured_fields schema lacks type constraint; got: {:?}",
other
),
};
assert!(
type_str.contains("object"),
"expected object type, got: {}",
type_str
);
}
// ===== Param deserialization: ForgetParams =====
#[test]
fn test_forget_params() {
let json = r#"{"memory_id": "mem_abc123"}"#;
let params: ForgetParams = serde_json::from_str(json).unwrap();
assert_eq!(params.memory_id, "mem_abc123");
}
#[test]
fn test_forget_params_missing_id_fails() {
let json = r#"{}"#;
let result = serde_json::from_str::<ForgetParams>(json);
assert!(result.is_err());
}
// ===== Request serialization: StoreMemoryRequest =====
#[test]
fn test_store_request_includes_new_fields() {
let req = StoreMemoryRequest {
content: "test".into(),
memory_type: Some("decision".into()),
domain: None,
source_agent: Some("claude".into()),
title: None,
confidence: Some(0.9),
supersedes: Some("old_id".into()),
entity: Some("PostgreSQL".into()),
entity_id: None,
structured_fields: None,
retrieval_cue: None,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["entity"], "PostgreSQL");
assert_eq!(json["supersedes"], "old_id");
assert!(json["confidence"].as_f64().unwrap() > 0.89);
assert_eq!(json["source_agent"], "claude");
assert!(json.get("user_id").is_none());
}
#[test]
fn test_store_request_minimal() {
let req = StoreMemoryRequest {
content: "hello".into(),
memory_type: Some("fact".into()),
domain: None,
source_agent: None,
title: None,
confidence: None,
supersedes: None,
entity: None,
entity_id: None,
structured_fields: None,
retrieval_cue: None,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["content"], "hello");
assert_eq!(json["memory_type"], "fact");
assert!(json.get("user_id").is_none());
}
// ===== Response deserialization: StoreMemoryResponse =====
#[test]
fn test_store_response_with_new_fields() {
let json = r#"{
"source_id": "mem_xyz",
"chunks_created": 2,
"memory_type": "fact",
"entity_id": "ent_abc",
"quality": "high",
"warnings": ["decision memory missing claim"],
"extraction_method": "agent"
}"#;
let resp: StoreMemoryResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.source_id, "mem_xyz");
assert_eq!(resp.chunks_created, 2);
assert_eq!(resp.memory_type, "fact");
assert_eq!(resp.entity_id.as_deref(), Some("ent_abc"));
assert_eq!(resp.quality.as_deref(), Some("high"));
assert_eq!(resp.warnings, vec!["decision memory missing claim"]);
assert_eq!(resp.extraction_method, "agent");
}
#[test]
fn test_store_response_backward_compat_no_new_fields() {
// Old backend response without warnings/extraction_method
let json = r#"{
"source_id": "mem_old",
"chunks_created": 1,
"memory_type": "fact"
}"#;
let resp: StoreMemoryResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.source_id, "mem_old");
assert_eq!(resp.chunks_created, 1);
assert_eq!(resp.memory_type, "fact");
assert!(resp.entity_id.is_none());
assert!(resp.quality.is_none());
assert!(resp.warnings.is_empty());
assert_eq!(resp.extraction_method, "unknown");
}
#[test]
fn test_store_response_with_warnings_and_extraction_method() {
let json = r#"{
"source_id": "mem_xyz",
"chunks_created": 1,
"memory_type": "decision",
"warnings": ["decision memory missing required 'claim' field"],
"extraction_method": "llm"
}"#;
let resp: StoreMemoryResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.memory_type, "decision");
assert_eq!(
resp.warnings,
vec!["decision memory missing required 'claim' field"]
);
assert_eq!(resp.extraction_method, "llm");
}
// ===== Response deserialization: SearchResult =====
#[test]
fn test_search_result_with_new_fields() {
let json = r#"{
"id": "1",
"content": "We chose Postgres",
"source": "memory",
"source_id": "mem_1",
"title": "DB decision",
"url": null,
"chunk_index": 0,
"last_modified": 1711000000,
"score": 0.95,
"chunk_type": "memory",
"language": "en",
"semantic_unit": "sentence",
"memory_type": "decision",
"domain": "origin",
"source_agent": "claude",
"confidence": 0.9,
"confirmed": true,
"stability": "standard",
"supersedes": "mem_0",
"summary": "DB choice",
"entity_id": "ent_pg",
"entity_name": "PostgreSQL",
"quality": "high",
"is_archived": false,
"is_recap": false,
"source_text": "We chose Postgres",
"raw_score": 0.42
}"#;
let result: SearchResult = serde_json::from_str(json).unwrap();
assert_eq!(result.chunk_type.as_deref(), Some("memory"));
assert_eq!(result.language.as_deref(), Some("en"));
assert_eq!(result.semantic_unit.as_deref(), Some("sentence"));
assert_eq!(result.stability.as_deref(), Some("standard"));
assert_eq!(result.supersedes.as_deref(), Some("mem_0"));
assert_eq!(result.summary.as_deref(), Some("DB choice"));
assert_eq!(result.entity_id.as_deref(), Some("ent_pg"));
assert_eq!(result.entity_name.as_deref(), Some("PostgreSQL"));
assert_eq!(result.quality.as_deref(), Some("high"));
assert!(!result.is_archived);
assert!(!result.is_recap);
assert_eq!(result.source_text.as_deref(), Some("We chose Postgres"));
assert!((result.raw_score - 0.42).abs() < f32::EPSILON);
}
#[test]
fn test_search_result_backward_compat_no_new_fields() {
// Old backend response without entity/quality/archive/recap
let json = r#"{
"id": "1",
"content": "test",
"source": "memory",
"source_id": "mem_1",
"title": "test",
"url": null,
"chunk_index": 0,
"last_modified": 1711000000,
"score": 0.8,
"memory_type": "fact",
"domain": null,
"source_agent": null,
"confidence": null,
"confirmed": null
}"#;
let result: SearchResult = serde_json::from_str(json).unwrap();
assert!(result.entity_id.is_none());
assert!(result.entity_name.is_none());
assert!(result.quality.is_none());
assert!(!result.is_archived);
assert!(!result.is_recap);
assert!(result.structured_fields.is_none());
assert!(result.retrieval_cue.is_none());
assert_eq!(result.raw_score, 0.0);
}
#[test]
fn test_search_result_with_structured_fields_and_retrieval_cue() {
let json = r#"{
"id": "1",
"content": "Lucian prefers dark mode",
"source": "memory",
"source_id": "mem_1",
"title": "Dark mode preference",
"url": null,
"chunk_index": 0,
"last_modified": 1711000000,
"score": 0.92,
"memory_type": "preference",
"domain": null,
"source_agent": null,
"confidence": null,
"confirmed": null,
"structured_fields": "{\"theme\":\"dark\",\"applies_to\":\"all_apps\"}",
"retrieval_cue": "What UI theme does Lucian prefer?"
}"#;
let result: SearchResult = serde_json::from_str(json).unwrap();
assert_eq!(
result.structured_fields.as_deref(),
Some("{\"theme\":\"dark\",\"applies_to\":\"all_apps\"}")
);
assert_eq!(
result.retrieval_cue.as_deref(),
Some("What UI theme does Lucian prefer?")
);
assert!(!result.is_archived);
assert!(!result.is_recap);
assert_eq!(result.raw_score, 0.0);
}
#[test]
fn test_search_result_knowledge_graph_source() {
// Entity-boosted observation results from knowledge graph
let json = r#"{
"id": "obs_1",
"content": "Prefers Rust over Go",
"source": "knowledge_graph",
"source_id": "ent_lucian",
"title": "Lucian",
"url": null,
"chunk_index": 0,
"last_modified": 1711000000,
"score": 1.14,
"memory_type": null,
"domain": null,
"source_agent": null,
"confidence": null,
"confirmed": null,
"entity_id": "ent_lucian",
"entity_name": "Lucian"
}"#;
let result: SearchResult = serde_json::from_str(json).unwrap();
assert_eq!(result.source, "knowledge_graph");
assert_eq!(result.entity_id.as_deref(), Some("ent_lucian"));
assert_eq!(result.entity_name.as_deref(), Some("Lucian"));
assert!(!result.is_archived);
assert!(!result.is_recap);
assert_eq!(result.raw_score, 0.0);
}
// ===== Transport security: forget blocks on HTTP =====
#[tokio::test]
async fn test_forget_blocked_on_http_transport() {
let server = make_server(TransportMode::Http, "agent", None);
let result = server.forget_impl("mem_123").await.unwrap();
// Should return error content, not an Err
let content = &result.content[0];
match content.raw {
rmcp::model::RawContent::Text(ref tc) => {
assert!(tc.text.contains("not available over remote connections"));
}
_ => panic!("expected text content"),
}
}
#[tokio::test]
async fn test_forget_allowed_on_stdio_transport() {
// This will fail with connection error (no server), which proves
// the transport check passed and it tried to make the HTTP call.
// The error comes back as CallToolResult with is_error: true
// (tool-level failure), not McpError (protocol-level).
let server = make_server(TransportMode::Stdio, "agent", None);
let result = server.forget_impl("mem_123").await.unwrap();
assert!(
result.is_error.unwrap_or(false),
"should fail with connection error, not transport block"
);
}
// ===== Context default limit =====
#[test]
fn test_context_request_default_limit() {
let params = ContextParams {
topic: Some("test".into()),
limit: None,
domain: None,
};
let req = ChatContextRequest {
query: None,
conversation_id: params.topic,
max_chunks: params.limit.unwrap_or(20),
relevance_threshold: None,
include_goals: true,
domain: params.domain,
};
assert_eq!(req.max_chunks, 20);
}
#[test]
fn test_context_request_custom_limit() {
let params = ContextParams {
topic: None,
limit: Some(5),
domain: Some("work".into()),
};
let req = ChatContextRequest {
query: None,
conversation_id: params.topic,
max_chunks: params.limit.unwrap_or(20),
relevance_threshold: None,
include_goals: true,
domain: params.domain,
};
assert_eq!(req.max_chunks, 5);
assert_eq!(req.domain.as_deref(), Some("work"));
}
#[test]
fn test_context_maps_topic_to_conversation_id() {
let params = ContextParams {
topic: Some("project Origin".into()),
limit: None,
domain: None,
};
let req = ChatContextRequest {
query: None,
conversation_id: params.topic.clone(),
max_chunks: params.limit.unwrap_or(20),
relevance_threshold: None,
include_goals: true,
domain: params.domain,
};
assert_eq!(req.conversation_id.as_deref(), Some("project Origin"));
}
// ===== Remember request construction =====
#[test]
fn test_remember_constructs_store_request_with_entity() {
let server = make_server(TransportMode::Stdio, "claude", None);
let params = RememberParams {
content: "Alice manages the frontend team".into(),
memory_type: Some("fact".into()),
domain: Some("work".into()),
entity: Some("Alice".into()),
confidence: Some(0.9),
supersedes: None,
structured_fields: None,
retrieval_cue: None,
};
// Replicate remember_impl's request construction
let source_agent = server.resolve_source_agent(None);
let req = StoreMemoryRequest {
content: params.content,
memory_type: params.memory_type,
domain: params.domain,
source_agent,
title: None,
confidence: params.confidence,
supersedes: params.supersedes,
entity: params.entity,
entity_id: None,
structured_fields: params.structured_fields.map(serde_json::Value::Object),
retrieval_cue: params.retrieval_cue,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["content"], "Alice manages the frontend team");
assert_eq!(json["memory_type"], "fact");
assert_eq!(json["domain"], "work");
assert_eq!(json["entity"], "Alice");
assert!(json["confidence"].as_f64().unwrap() > 0.89);
// stdio mode: no param, no client_name → falls back to agent_name "claude"
assert_eq!(json["source_agent"], "claude");
}
#[test]
fn test_remember_http_mode_injects_agent() {
let server = make_server(TransportMode::Http, "claude.ai", Some("lucian"));
let source_agent = server.resolve_source_agent(None);
assert_eq!(source_agent, Some("claude.ai".into()));
}
// ===== Recall request construction =====
#[test]
fn test_recall_constructs_search_request() {
let params = RecallParams {
query: "database choices".into(),
limit: Some(5),
memory_type: Some("decision".into()),
domain: None,
};
let req = SearchMemoryRequest {
query: params.query,
limit: params.limit.unwrap_or(10),
memory_type: params.memory_type,
domain: params.domain,
source_agent: None,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["query"], "database choices");
assert_eq!(json["limit"], 5);
assert_eq!(json["memory_type"], "decision");
assert!(json.get("entity").is_none());
assert!(json["domain"].is_null());
assert!(json["source_agent"].is_null());
}
// ===== Memory type backward compat =====
#[test]
fn test_remember_passes_through_all_5_types() {
for t in &["identity", "preference", "fact", "decision", "goal"] {
let params = RememberParams {
content: "test".into(),
memory_type: Some(t.to_string()),
domain: None,
entity: None,
confidence: None,
supersedes: None,
structured_fields: None,
retrieval_cue: None,
};
assert_eq!(params.memory_type.as_deref(), Some(*t));
}
}
// ===== Structured fields in remember params =====
#[test]
fn test_remember_params_with_structured_fields_and_cue() {
let json = r#"{
"content": "Lucian prefers dark mode",
"structured_fields": {"theme":"dark"},
"retrieval_cue": "What theme does Lucian prefer?"
}"#;
let params: RememberParams = serde_json::from_str(json).unwrap();
let structured_fields = params.structured_fields.expect("structured_fields");
assert_eq!(
structured_fields.get("theme"),
Some(&serde_json::Value::String("dark".into()))
);
assert_eq!(
params.retrieval_cue.as_deref(),
Some("What theme does Lucian prefer?")
);
}
#[test]
fn test_store_request_with_structured_fields() {
let req = StoreMemoryRequest {
content: "test".into(),
memory_type: Some("fact".into()),
domain: None,
source_agent: None,
title: None,
confidence: None,
supersedes: None,
entity: None,
entity_id: None,
structured_fields: Some(serde_json::json!({"key":"val"})),
retrieval_cue: Some("What is the key?".into()),
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["structured_fields"], serde_json::json!({"key":"val"}));
assert_eq!(json["retrieval_cue"], "What is the key?");
}
// ===== ChatContextResponse deserialization =====
#[test]
fn test_chat_context_response() {
let json = r#"{
"context": "User prefers dark mode. Works on Origin project.",
"profile": {
"narrative": "narrative",
"identity": [],
"preferences": [],
"goals": []
},
"knowledge": {
"concepts": [],
"decisions": [],
"relevant_memories": [],
"graph_context": []
},
"took_ms": 12.5,
"token_estimates": {
"tier1_identity": 1,
"tier2_project": 2,
"tier3_relevant": 3,
"total": 6
}
}"#;
let resp: ChatContextResponse = serde_json::from_str(json).unwrap();
assert!(!resp.context.is_empty());
assert!(resp.profile.identity.is_empty());
assert_eq!(resp.took_ms, 12.5);
assert_eq!(resp.token_estimates.total, 6);
}
#[test]
fn test_chat_context_response_empty() {
let json = r#"{
"context": "",
"profile": {
"narrative": "",
"identity": [],
"preferences": [],
"goals": []
},
"knowledge": {
"concepts": [],
"decisions": [],
"relevant_memories": [],
"graph_context": []
},
"took_ms": 1.0,
"token_estimates": {
"tier1_identity": 0,
"tier2_project": 0,
"tier3_relevant": 0,
"total": 0
}
}"#;
let resp: ChatContextResponse = serde_json::from_str(json).unwrap();
assert!(resp.context.is_empty());
}
// ===== with_instructions content assertions =====
// These tests lock in the refined agent-facing guidance. If any
// assertion fails, either the rule was intentionally changed
// (update the test) or the refinement was accidentally dropped
// (restore the rule).
fn server_instructions() -> String {
let s = make_server(TransportMode::Stdio, "test", None);
s.get_info()
.instructions
.expect("server must ship with_instructions")
}
#[test]
fn instructions_mention_self_evolving_knowledge() {
assert!(
server_instructions().contains("self-evolving"),
"with_instructions must describe Origin as self-evolving"
);
}
#[test]
fn instructions_mention_shared_across_tools() {
assert!(
server_instructions().contains("shared across all"),
"with_instructions must tell agents the store is shared across tools"
);
}
#[test]
fn instructions_mention_how_user_thinks() {
assert!(
server_instructions().contains("how the user thinks"),
"with_instructions must frame context as modeling how the user thinks"
);
}
#[test]
fn instructions_use_proactive_framing() {
assert!(
server_instructions().contains("STORE PROACTIVELY"),
"with_instructions must use STORE PROACTIVELY framing (not passive WHEN TO STORE)"
);
}
#[test]
fn instructions_ban_tool_output_storage() {
assert!(
server_instructions().contains("Tool output or command results"),
"with_instructions must explicitly rule out tool output as storage material"
);
}
#[test]
fn instructions_ban_ghost_inferences() {
assert!(
server_instructions().contains("Your own inferences"),
"with_instructions must rule out storing agent's own inferences user didn't express"
);
}
#[test]
fn instructions_call_out_atomic_memory() {
assert!(
server_instructions().contains("Atomic: one idea per memory"),
"with_instructions must call out the atomic-memory rule explicitly by name"
);
}
#[test]
fn instructions_specify_declarative_writing() {
assert!(
server_instructions().contains("Declarative, not narrative"),
"with_instructions must require declarative (not narrative) writing style"
);
}
#[test]
fn instructions_default_to_omit_memory_type() {
let i = server_instructions();
assert!(
i.contains("omit and trust the backend"),
"with_instructions must default agents to omitting memory_type"
);
assert!(
i.contains("do NOT set memory_type"),
"with_instructions must explicitly say do NOT set memory_type by default"
);
}
#[test]
fn instructions_carve_out_decisions_for_decision_log() {
let i = server_instructions();
assert!(
i.contains("Decision Log"),
"with_instructions must name the Decision Log as the reason for explicit decision typing"
);
assert!(
i.contains("memory_type=\"decision\""),
"with_instructions must tell agents to set memory_type=\"decision\" explicitly for decisions"
);
}
// ===== tool-level and param-level description assertions =====
fn tool_descriptions() -> std::collections::HashMap<String, String> {
let server = make_server(TransportMode::Stdio, "test", None);
server
.tool_router
.list_all()
.into_iter()
.filter_map(|t| {
let desc = t.description.as_ref()?.to_string();
Some((t.name.to_string(), desc))
})
.collect()
}
#[test]
fn remember_description_calls_out_atomic() {
let descriptions = tool_descriptions();
let remember = descriptions.get("remember").expect("remember tool exists");
assert!(
remember.contains("Each call is one atomic idea"),
"remember description must call out atomic-per-call explicitly, got: {remember}"
);
}
#[test]
fn context_description_frames_modeling_user() {
let descriptions = tool_descriptions();
let ctx = descriptions.get("context").expect("context tool exists");
assert!(
ctx.contains("how the user thinks"),
"context description must frame the result as modeling how the user thinks, got: {ctx}"
);
}
#[test]
fn recall_memory_type_param_lists_two_level_filter() {
let params_schema = serde_json::to_string(&schemars::schema_for!(RecallParams))
.expect("RecallParams schema serializes");
assert!(
params_schema.contains("Two-level filter"),
"RecallParams.memory_type must advertise the two-level filter, got schema: {params_schema}"
);
assert!(
params_schema.contains("profile"),
"RecallParams.memory_type must mention profile alias"
);
assert!(
params_schema.contains("knowledge"),
"RecallParams.memory_type must mention knowledge alias"
);
}
}