use crate::embeddings::{Embed, Embedder};
use crate::mcp::param_names;
use crate::mcp::registry::McpTool;
use crate::models::Memory;
use crate::{db, validate};
use schemars::JsonSchema;
use serde::Deserialize;
use serde_json::{Value, json};
#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
#[allow(dead_code)]
pub struct LoadFamilyRequest {
pub family: String,
#[serde(default)]
pub namespace: Option<String>,
#[serde(default)]
pub k: Option<i64>,
}
#[allow(dead_code)]
pub struct LoadFamilyTool;
impl McpTool for LoadFamilyTool {
fn name() -> &'static str {
crate::mcp::registry::tool_names::MEMORY_LOAD_FAMILY
}
fn description() -> &'static str {
"Load top-k recent + high-priority memories from a Family."
}
fn docs() -> &'static str {
"B1: top-k by metadata.family. Always-on; alternative to memory_recall when family is known. \
Issue #864 — `family` here is the MCP tool family (8 groups: \
core/lifecycle/graph/governance/power/meta/archive/other), \
NOT the memory_kind taxonomy (Observation/Reflection/Decision/Event/etc)."
}
fn input_schema() -> Value {
crate::mcp::registry::input_schema_for::<LoadFamilyRequest>()
}
fn family() -> &'static str {
crate::profile::Family::Core.name()
}
}
#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
#[allow(dead_code)]
pub struct SmartLoadRequest {
pub intent: String,
#[serde(default)]
pub namespace: Option<String>,
#[serde(default)]
pub k: Option<i64>,
}
#[allow(dead_code)]
pub struct SmartLoadTool;
impl McpTool for SmartLoadTool {
fn name() -> &'static str {
crate::mcp::registry::tool_names::MEMORY_SMART_LOAD
}
fn description() -> &'static str {
"Intent-routed loader: free-text intent picks the best Family."
}
fn docs() -> &'static str {
"B2: pick best Family from free-text intent, then forward to memory_load_family. \
Issue #864 — `Family` here is the MCP tool family (8 groups: \
core/lifecycle/graph/governance/power/meta/archive/other), \
NOT the memory_kind taxonomy (Observation/Reflection/Decision/Event/etc)."
}
fn input_schema() -> Value {
crate::mcp::registry::input_schema_for::<SmartLoadRequest>()
}
fn family() -> &'static str {
crate::profile::Family::Core.name()
}
}
#[cfg(test)]
mod d1_3_984_tests {
use super::*;
fn legacy_props(tool_name: &str) -> serde_json::Map<String, Value> {
let defs = crate::mcp::registry::tool_definitions();
let tools = defs
.get("tools")
.and_then(Value::as_array)
.expect("tool_definitions emits `tools` array");
let entry = tools
.iter()
.find(|t| t.get("name").and_then(Value::as_str) == Some(tool_name))
.unwrap_or_else(|| panic!("{tool_name} must be in legacy catalog"));
entry
.pointer("/inputSchema/properties")
.and_then(Value::as_object)
.unwrap_or_else(|| panic!("{tool_name}.inputSchema.properties must be object"))
.clone()
}
fn derived_props_for<T: schemars::JsonSchema>() -> serde_json::Map<String, Value> {
let schema = schemars::schema_for!(T);
let v = serde_json::to_value(schema).expect("schema → value");
v.get("properties")
.and_then(Value::as_object)
.or_else(|| {
v.pointer(&format!(
"/definitions/{}/properties",
std::any::type_name::<T>().rsplit("::").next().unwrap_or("")
))
.and_then(Value::as_object)
})
.cloned()
.expect("schemars schema must have properties at a known path")
}
fn assert_property_set_parity(tool_name: &str, derived: &serde_json::Map<String, Value>) {
let legacy = legacy_props(tool_name);
let legacy_keys: std::collections::BTreeSet<&str> =
legacy.keys().map(String::as_str).collect();
let derived_keys: std::collections::BTreeSet<&str> =
derived.keys().map(String::as_str).collect();
assert_eq!(
legacy_keys,
derived_keys,
"{tool_name}: property set drift; diff = {:?}",
legacy_keys
.symmetric_difference(&derived_keys)
.collect::<Vec<_>>()
);
}
fn assert_descriptions_match(tool_name: &str, derived: &serde_json::Map<String, Value>) {
let legacy = legacy_props(tool_name);
for (name, legacy_prop) in &legacy {
if let Some(want) = legacy_prop.get("description").and_then(Value::as_str) {
let got = derived
.get(name)
.and_then(|p| p.get("description"))
.and_then(Value::as_str);
assert_eq!(
got,
Some(want),
"{tool_name}.{name}: description must match legacy byte-for-byte"
);
}
}
}
#[test]
fn load_family_parity_984() {
let derived = derived_props_for::<LoadFamilyRequest>();
assert_property_set_parity("memory_load_family", &derived);
assert_descriptions_match("memory_load_family", &derived);
}
#[test]
fn smart_load_parity_984() {
let derived = derived_props_for::<SmartLoadRequest>();
assert_property_set_parity("memory_smart_load", &derived);
assert_descriptions_match("memory_smart_load", &derived);
}
#[test]
fn load_family_tool_metadata_984() {
assert_eq!(LoadFamilyTool::name(), "memory_load_family");
assert_eq!(LoadFamilyTool::family(), "core");
assert_eq!(SmartLoadTool::name(), "memory_smart_load");
assert_eq!(SmartLoadTool::family(), "core");
}
}
#[cfg(test)]
mod issue_1589_tests {
use super::*;
use crate::models::MemoryKind;
const TAXONOMY_PARENTHETICAL_OPEN: &str = "memory_kind taxonomy (";
fn assert_taxonomy_examples_valid(docs: &str) {
let start = docs
.find(TAXONOMY_PARENTHETICAL_OPEN)
.expect("loader docs must carry the memory_kind taxonomy disambiguation");
let after = &docs[start + TAXONOMY_PARENTHETICAL_OPEN.len()..];
let inner = &after[..after.find(')').expect("parenthetical must close")];
let examples: Vec<&str> = inner
.split('/')
.map(str::trim)
.filter(|t| !t.is_empty() && *t != "etc")
.collect();
assert!(
!examples.is_empty(),
"taxonomy parenthetical must cite at least one kind example"
);
for ex in examples {
assert!(
MemoryKind::from_str(&ex.to_ascii_lowercase()).is_some(),
"docstring cites {ex:?}, which is not a valid MemoryKind; \
valid kinds = {:?}",
MemoryKind::all()
.iter()
.map(MemoryKind::as_str)
.collect::<Vec<_>>()
);
}
}
#[test]
fn load_family_docs_taxonomy_examples_are_valid_kinds_1589() {
assert_taxonomy_examples_valid(LoadFamilyTool::docs());
}
#[test]
fn smart_load_docs_taxonomy_examples_are_valid_kinds_1589() {
assert_taxonomy_examples_valid(SmartLoadTool::docs());
}
}
pub fn handle_load_family(
conn: &rusqlite::Connection,
params: &Value,
caller: Option<&str>,
) -> Result<Value, String> {
use crate::profile::Family;
use std::str::FromStr;
let family_raw = params["family"].as_str().ok_or("family is required")?;
let family = Family::from_str(family_raw).map_err(|e| e.to_string())?;
let family_name = family.name();
let namespace = params.get(param_names::NAMESPACE).and_then(Value::as_str);
if let Some(ns) = namespace {
validate::validate_namespace(ns).map_err(|e| e.to_string())?;
}
let k_raw = params
.get(param_names::K)
.and_then(Value::as_u64)
.unwrap_or(20);
let k = usize::try_from(k_raw).unwrap_or(usize::MAX).clamp(1, 100);
let now = chrono::Utc::now().to_rfc3339();
let mut stmt = conn
.prepare(
"SELECT id, tier, namespace, title, content, tags, priority, confidence, source, \
access_count, created_at, updated_at, last_accessed_at, expires_at, metadata \
FROM memories \
WHERE (?1 IS NULL OR namespace = ?1) \
AND json_extract(metadata, '$.family') = ?2 \
AND (expires_at IS NULL OR expires_at > ?3) \
ORDER BY priority DESC, updated_at DESC \
LIMIT ?4",
)
.map_err(|e| format!("prepare memory_load_family failed: {e}"))?;
let rows = stmt
.query_map(
rusqlite::params![namespace, family_name, now, k],
db::row_to_memory,
)
.map_err(|e| format!("query memory_load_family failed: {e}"))?;
let memories: Vec<Memory> = rows
.collect::<rusqlite::Result<Vec<_>>>()
.map_err(|e| format!("collect memory_load_family rows failed: {e}"))?;
let memories: Vec<Memory> = match caller {
Some(c) => memories
.into_iter()
.filter(|m| crate::visibility::is_visible_to_caller(m, c))
.collect(),
None => memories,
};
Ok(json!({
"family": family_name,
"namespace": namespace,
"k": k,
"count": memories.len(),
"memories": memories,
}))
}
pub fn handle_smart_load(
conn: &rusqlite::Connection,
params: &Value,
embedder: Option<&dyn Embed>,
caller: Option<&str>,
) -> Result<Value, String> {
let intent_raw = params["intent"].as_str().ok_or("intent is required")?;
let intent = intent_raw.trim();
if intent.is_empty() {
let resp = forward_to_load_family(
conn,
crate::profile::Family::Core,
0.0,
"fallback",
intent,
params,
caller,
)?;
return Ok(resp);
}
let kw_pick = fallback_via_keywords(intent);
let (family, score, source) = match embedder {
Some(emb) => match best_family_via_embedder(emb, intent) {
Some((emb_family, emb_score)) => {
if kw_pick.2 == "keyword" && kw_pick.0 != emb_family {
kw_pick
} else {
(emb_family, emb_score, "embedder")
}
}
None => kw_pick,
},
None => kw_pick,
};
forward_to_load_family(conn, family, score, source, intent, params, caller)
}
fn forward_to_load_family(
conn: &rusqlite::Connection,
family: crate::profile::Family,
score: f32,
source: &str,
intent: &str,
params: &Value,
caller: Option<&str>,
) -> Result<Value, String> {
let family_name = family.name();
tracing::info!(
target: "memory_smart_load",
chosen_family = family_name,
score = score,
source = source,
intent_len = intent.len(),
"smart_load routed intent to family"
);
let mut forward = json!({"family": family_name});
if let Some(ns) = params.get(param_names::NAMESPACE).and_then(Value::as_str) {
forward["namespace"] = json!(ns);
}
if let Some(k) = params.get(param_names::K).and_then(Value::as_u64) {
forward["k"] = json!(k);
}
let inner = handle_load_family(conn, &forward, caller)?;
let memories = inner.get("memories").cloned().unwrap_or_else(|| json!([]));
let count = inner.get("count").cloned().unwrap_or_else(|| json!(0));
let k = inner
.get(param_names::K)
.cloned()
.unwrap_or_else(|| json!(20));
let namespace = inner
.get(param_names::NAMESPACE)
.cloned()
.unwrap_or(Value::Null);
let score_rounded = (f64::from(score) * crate::SCORE_DISPLAY_ROUND_FACTOR).round()
/ crate::SCORE_DISPLAY_ROUND_FACTOR;
Ok(json!({
"chosen_family": family_name,
"score": score_rounded,
"chosen_family_source": source,
"intent": intent,
"namespace": namespace,
"k": k,
"count": count,
"memories": memories,
}))
}
fn best_family_via_embedder(
emb: &dyn Embed,
intent: &str,
) -> Option<(crate::profile::Family, f32)> {
use crate::profile::Family;
let intent_vec = emb.embed_query(intent).ok()?;
let mut best: Option<(Family, f32)> = None;
for family in Family::all() {
let descriptor = family_descriptor(*family);
let Ok(desc_vec) = emb.embed(descriptor) else {
continue;
};
let score = Embedder::cosine_similarity(&intent_vec, &desc_vec);
if best.is_none_or(|(_, s)| score > s) {
best = Some((*family, score));
}
}
best
}
fn fallback_via_keywords(intent: &str) -> (crate::profile::Family, f32, &'static str) {
use crate::profile::Family;
let intent_tokens: Vec<String> = intent
.split(|c: char| !c.is_ascii_alphanumeric())
.filter(|s| !s.is_empty())
.map(str::to_ascii_lowercase)
.collect();
if intent_tokens.is_empty() {
return (Family::Core, 0.0, "fallback");
}
let mut best: Option<(Family, f32)> = None;
for family in Family::all() {
let descriptor = family_descriptor(*family).to_ascii_lowercase();
let desc_tokens: Vec<String> = descriptor
.split(|c: char| !c.is_ascii_alphanumeric())
.filter(|s| !s.is_empty())
.map(str::to_string)
.collect();
let token_matches = |a: &str, b: &str| -> bool {
if a == b {
return true;
}
if a.len() >= 5 && b.len() >= 5 && a[..5] == b[..5] {
return true;
}
false
};
let mut tool_distinct_sum: usize = 0;
let mut tool_distinct_max: usize = 0;
let mut full_id_hits: usize = 0;
for tool_name in family.tool_names() {
let lower = tool_name.to_ascii_lowercase();
let segments: Vec<&str> = lower
.split('_')
.filter(|s| !s.is_empty() && *s != "memory")
.collect();
let distinct = intent_tokens
.iter()
.filter(|t| segments.iter().any(|seg| token_matches(seg, t.as_str())))
.count();
tool_distinct_sum += distinct;
if distinct > tool_distinct_max {
tool_distinct_max = distinct;
}
if intent_tokens.iter().any(|t| t.as_str() == lower) {
full_id_hits += 1;
}
}
let desc_overlap = intent_tokens
.iter()
.filter(|t| desc_tokens.iter().any(|d| d == *t))
.count();
if desc_overlap == 0 && tool_distinct_sum == 0 && full_id_hits == 0 {
continue;
}
#[allow(clippy::cast_precision_loss)]
let score = (2.0 * desc_overlap as f32
+ tool_distinct_sum as f32
+ 2.0 * tool_distinct_max as f32
+ 4.0 * full_id_hits as f32)
/ (intent_tokens.len() as f32);
if best.is_none_or(|(_, s)| score > s) {
best = Some((*family, score));
}
}
best.map_or((Family::Core, 0.0, "fallback"), |(f, s)| (f, s, "keyword"))
}
fn family_descriptor(family: crate::profile::Family) -> &'static str {
use crate::profile::Family;
match family {
Family::Core => {
"store remember save record memory note write recall fetch get \
search find list browse read load family core baseline"
}
Family::Lifecycle => {
"update edit modify change delete remove forget purge garbage \
collect promote upgrade downgrade migrate refresh rotate"
}
Family::Graph => {
"graph link relation entity knowledge kg query timeline replay \
verify path traverse find_paths connect taxonomy alias debug \
flaky test investigate trace ancestry"
}
Family::Governance => {
"approve reject pending policy permission rule namespace \
standard subscribe unsubscribe governance review audit"
}
Family::Power => {
"consolidate merge contradiction duplicate auto tag expand \
query inbox subscription replay dlq dead letter retry power \
llm augment"
}
Family::Meta => {
"capabilities agent register session start stats meta info \
discovery introspection bootstrap"
}
Family::Archive => {
"archive backup restore purge old historical retention cold \
storage"
}
Family::Other => {
"subscription notify subscribe webhook event other miscellaneous \
notification message send dm direct another recipient inbox"
}
}
}