Skip to main content

agentic_memory_mcp/tools/
memory_query.rs

1//! Tool: memory_query — Pattern query for matching nodes.
2
3use std::sync::Arc;
4use tokio::sync::Mutex;
5
6use serde::Deserialize;
7use serde_json::{json, Value};
8
9use agentic_memory::{EventType, PatternParams, PatternSort};
10
11use crate::session::SessionManager;
12use crate::types::{McpError, McpResult, ToolCallResult, ToolDefinition};
13
14#[derive(Debug, Deserialize)]
15struct QueryParams {
16    #[serde(default)]
17    event_types: Vec<String>,
18    min_confidence: Option<f32>,
19    max_confidence: Option<f32>,
20    #[serde(default)]
21    session_ids: Vec<u32>,
22    created_after: Option<u64>,
23    created_before: Option<u64>,
24    #[serde(default = "default_max_results")]
25    max_results: usize,
26    #[serde(default = "default_sort")]
27    sort_by: String,
28}
29
30fn default_max_results() -> usize {
31    20
32}
33
34fn default_sort() -> String {
35    "most_recent".to_string()
36}
37
38/// Return the tool definition for memory_query.
39pub fn definition() -> ToolDefinition {
40    ToolDefinition {
41        name: "memory_query".to_string(),
42        description: Some("Find memories matching conditions (pattern query)".to_string()),
43        input_schema: json!({
44            "type": "object",
45            "properties": {
46                "event_types": { "type": "array", "items": { "type": "string" } },
47                "min_confidence": { "type": "number" },
48                "max_confidence": { "type": "number" },
49                "session_ids": { "type": "array", "items": { "type": "integer" } },
50                "created_after": { "type": "integer" },
51                "created_before": { "type": "integer" },
52                "max_results": { "type": "integer", "default": 20 },
53                "sort_by": {
54                    "type": "string",
55                    "enum": ["most_recent", "highest_confidence", "most_accessed", "most_important"],
56                    "default": "most_recent"
57                }
58            }
59        }),
60    }
61}
62
63/// Execute the memory_query tool.
64pub async fn execute(
65    args: Value,
66    session: &Arc<Mutex<SessionManager>>,
67) -> McpResult<ToolCallResult> {
68    let params: QueryParams =
69        serde_json::from_value(args).map_err(|e| McpError::InvalidParams(e.to_string()))?;
70
71    let event_types: Vec<EventType> = params
72        .event_types
73        .iter()
74        .filter_map(|name| EventType::from_name(name))
75        .collect();
76
77    let sort_by = match params.sort_by.as_str() {
78        "highest_confidence" => PatternSort::HighestConfidence,
79        "most_accessed" => PatternSort::MostAccessed,
80        "most_important" => PatternSort::MostImportant,
81        _ => PatternSort::MostRecent,
82    };
83
84    let pattern = PatternParams {
85        event_types,
86        min_confidence: params.min_confidence,
87        max_confidence: params.max_confidence,
88        session_ids: params.session_ids,
89        created_after: params.created_after,
90        created_before: params.created_before,
91        min_decay_score: None,
92        max_results: params.max_results,
93        sort_by,
94    };
95
96    let session = session.lock().await;
97    let results = session
98        .query_engine()
99        .pattern(session.graph(), pattern)
100        .map_err(|e| McpError::AgenticMemory(format!("Pattern query failed: {e}")))?;
101
102    let nodes: Vec<Value> = results
103        .iter()
104        .map(|event| {
105            json!({
106                "id": event.id,
107                "event_type": event.event_type.name(),
108                "content": event.content,
109                "confidence": event.confidence,
110                "session_id": event.session_id,
111                "created_at": event.created_at,
112                "decay_score": event.decay_score,
113                "access_count": event.access_count,
114            })
115        })
116        .collect();
117
118    Ok(ToolCallResult::json(&json!({
119        "count": nodes.len(),
120        "nodes": nodes
121    })))
122}