use crate::mcp::registry::McpTool;
use crate::models::field_names;
use crate::{db, validate};
use schemars::JsonSchema;
use serde::Deserialize;
use serde_json::{Value, json};
#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
#[allow(dead_code)]
pub struct KgQueryRequest {
pub source_id: String,
#[serde(default)]
pub max_depth: Option<i64>,
#[serde(default)]
pub valid_at: Option<String>,
#[serde(default)]
pub allowed_agents: Option<Vec<String>>,
#[serde(default)]
pub limit: Option<i64>,
#[serde(default)]
pub include_invalidated: Option<bool>,
#[schemars(description = "#889 traverse by source_uri.")]
#[serde(default)]
pub by_source_uri: Option<String>,
#[serde(default)]
pub namespace: Option<String>,
}
#[allow(dead_code)]
pub struct KgQueryTool;
impl McpTool for KgQueryTool {
fn name() -> &'static str {
crate::mcp::registry::tool_names::MEMORY_KG_QUERY
}
fn description() -> &'static str {
"Outbound KG traversal from a source memory (<=5 hops)."
}
fn docs() -> &'static str {
"Pillar 2 / Stream C: BFS/CTE traversal with cycle detection. Each row carries valid_from/valid_until/observed_by + target title/namespace. Filters chain across every hop. max_depth ceiling 5."
}
fn input_schema() -> Value {
crate::mcp::registry::input_schema_for::<KgQueryRequest>()
}
fn family() -> &'static str {
crate::profile::Family::Graph.name()
}
}
pub fn handle_kg_query(conn: &rusqlite::Connection, params: &Value) -> Result<Value, String> {
let by_source_uri = params[field_names::BY_SOURCE_URI]
.as_str()
.map(str::trim)
.filter(|s| !s.is_empty());
if let Some(uri) = by_source_uri {
validate::validate_source_uri(uri).map_err(|e| e.to_string())?;
let namespace = params["namespace"].as_str();
let limit = params["limit"]
.as_u64()
.and_then(|n| usize::try_from(n).ok());
let as_agent = params["as_agent"].as_str();
if let Some(a) = as_agent {
validate::validate_namespace(a).map_err(|e| e.to_string())?;
}
let roots = db::list_by_source_uri(conn, uri, namespace, limit, as_agent)
.map_err(|e| e.to_string())?;
let memories_json: Vec<Value> = roots
.iter()
.map(|m| {
json!({
"target_id": m.id,
"title": m.title,
(field_names::TARGET_NAMESPACE): m.namespace,
"depth": 0,
(field_names::SOURCE_URI): m.source_uri,
})
})
.collect();
return Ok(json!({
(field_names::BY_SOURCE_URI): uri,
"memories": memories_json,
"count": roots.len(),
}));
}
let source_id = params["source_id"]
.as_str()
.ok_or(crate::errors::msg::SOURCE_ID_REQUIRED)?;
validate::validate_id(source_id).map_err(|e| e.to_string())?;
let max_depth = params["max_depth"]
.as_u64()
.and_then(|n| usize::try_from(n).ok())
.unwrap_or(1);
let valid_at = params["valid_at"]
.as_str()
.map(str::trim)
.filter(|s| !s.is_empty());
if let Some(t) = valid_at {
validate::validate_expires_at_format(t).map_err(|e| e.to_string())?;
}
let allowed_agents: Option<Vec<String>> = params["allowed_agents"].as_array().map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(str::trim).filter(|s| !s.is_empty()))
.map(str::to_string)
.collect()
});
if let Some(agents) = allowed_agents.as_ref() {
for a in agents {
validate::validate_agent_id(a).map_err(|e| e.to_string())?;
}
}
let limit = params["limit"]
.as_u64()
.and_then(|n| usize::try_from(n).ok());
let include_invalidated = params[field_names::INCLUDE_INVALIDATED]
.as_bool()
.unwrap_or(false);
let nodes = db::kg_query(
conn,
source_id,
max_depth,
valid_at,
allowed_agents.as_deref(),
limit,
include_invalidated,
)
.map_err(|e| e.to_string())?;
let memories_json: Vec<Value> = nodes
.iter()
.map(|n| {
json!({
"target_id": n.target_id,
"relation": n.relation,
(field_names::VALID_FROM): n.valid_from,
(field_names::VALID_UNTIL): n.valid_until,
(field_names::OBSERVED_BY): n.observed_by,
"title": n.title,
(field_names::TARGET_NAMESPACE): n.target_namespace,
"depth": n.depth,
"path": n.path,
})
})
.collect();
let paths_json: Vec<&str> = nodes.iter().map(|n| n.path.as_str()).collect();
Ok(json!({
"source_id": source_id,
"max_depth": max_depth,
"memories": memories_json,
"paths": paths_json,
"count": nodes.len(),
}))
}
#[cfg(test)]
mod d1_4_985_tests {
use super::*;
use crate::mcp::d1_4_985_helpers::{
assert_descriptions_match, assert_property_set_parity, derived_props_for,
};
#[test]
fn memory_kg_query_parity_985() {
let derived = derived_props_for::<KgQueryRequest>();
assert_property_set_parity("memory_kg_query", &derived);
assert_descriptions_match("memory_kg_query", &derived);
}
#[test]
fn memory_kg_query_tool_metadata_985() {
assert_eq!(KgQueryTool::name(), "memory_kg_query");
assert_eq!(KgQueryTool::family(), "graph");
}
}