Skip to main content

do_memory_mcp/server/tools/
core.rs

1// Core tool handlers for the MCP server
2//!
3//! This module contains core tool execution methods: list_tools, get_tool, query_memory, and analyze_patterns.
4
5use crate::types::Tool;
6use anyhow::Result;
7use do_memory_core::{Episode, Pattern, TaskOutcome};
8use serde_json::json;
9use std::collections::HashMap;
10use tracing::debug;
11
12/// Calculate a success score for an episode (higher = more successful)
13fn outcome_score(episode: &Episode) -> u8 {
14    match &episode.outcome {
15        Some(TaskOutcome::Success { .. }) => 3,
16        Some(TaskOutcome::PartialSuccess { .. }) => 2,
17        Some(TaskOutcome::Failure { .. }) => 1,
18        None => 0,
19    }
20}
21
22impl crate::server::MemoryMCPServer {
23    /// List all available tools
24    ///
25    /// Returns tools based on progressive disclosure - commonly used tools
26    /// are returned first, advanced tools are shown after usage patterns indicate need.
27    ///
28    /// With lazy loading, this initially returns only core tools to significantly
29    /// reduce input token usage. Extended tools are loaded on-demand.
30    pub async fn list_tools(&self) -> Vec<Tool> {
31        // Get currently loaded tools (core + session-loaded extended)
32        let loaded_tools = self.tool_registry.get_loaded_tools();
33
34        debug!(
35            "Listed {} tools (core + session-loaded, extended tools available on-demand)",
36            loaded_tools.len()
37        );
38        loaded_tools
39    }
40
41    /// Get a specific tool by name
42    ///
43    /// Loads the tool on-demand from the registry if not already loaded.
44    pub async fn get_tool(&self, name: &str) -> Option<Tool> {
45        self.tool_registry.load_tool(name).await
46    }
47
48    /// Execute the query_memory tool
49    ///
50    /// # Arguments
51    ///
52    /// * `query` - Search query
53    /// * `domain` - Task domain
54    /// * `task_type` - Optional task type filter
55    /// * `limit` - Maximum results to return
56    /// * `sort` - Sort order (relevance, newest, oldest, duration, success)
57    ///
58    /// # Returns
59    ///
60    /// Returns a JSON array of relevant episodes
61    ///
62    /// # Field Selection
63    ///
64    /// Clients can request specific fields using the `fields` parameter:
65    /// ```json
66    /// {
67    ///   "query": "test",
68    ///   "domain": "web-api",
69    ///   "fields": ["episodes.id", "episodes.task_description", "patterns.success_rate"]
70    /// }
71    /// ```
72    pub async fn query_memory(
73        &self,
74        query: String,
75        domain: String,
76        task_type: Option<String>,
77        limit: usize,
78        sort: String,
79        fields: Option<Vec<String>>,
80    ) -> Result<serde_json::Value> {
81        self.track_tool_usage("query_memory").await;
82
83        // Start monitoring request
84        let request_id = format!(
85            "query_memory_{}",
86            std::time::SystemTime::now()
87                .duration_since(std::time::UNIX_EPOCH)
88                .unwrap_or_default()
89                .as_nanos()
90        );
91        self.monitoring
92            .start_request(request_id.clone(), "query_memory".to_string())
93            .await;
94
95        debug!(
96            "Querying memory: query='{}', domain='{}', limit={}",
97            query, domain, limit
98        );
99
100        let start = std::time::Instant::now();
101
102        // Build task context from parameters
103        let context = do_memory_core::TaskContext {
104            domain,
105            language: None,
106            framework: None,
107            complexity: do_memory_core::ComplexityLevel::Moderate,
108            tags: task_type
109                .as_ref()
110                .map(|t| vec![t.clone()])
111                .unwrap_or_default(),
112        };
113
114        // Query actual memory for relevant episodes (returns Vec<Arc<Episode>>)
115        let arc_episodes = self
116            .memory
117            .retrieve_relevant_context(query.clone(), context.clone(), limit)
118            .await;
119
120        // Strict filtering: only return episodes that actually contain the query.
121        // Dereference Arc<Episode> to access Episode fields
122        let query_lc = query.to_lowercase();
123        let mut episodes: Vec<_> = arc_episodes
124            .into_iter()
125            .filter(|arc_ep| {
126                let ep = arc_ep.as_ref();
127                if ep.task_description.to_lowercase().contains(&query_lc) {
128                    return true;
129                }
130                for step in &ep.steps {
131                    if step.action.to_lowercase().contains(&query_lc) {
132                        return true;
133                    }
134                    if step
135                        .parameters
136                        .to_string()
137                        .to_lowercase()
138                        .contains(&query_lc)
139                    {
140                        return true;
141                    }
142                    if let Some(result) = &step.result {
143                        if serde_json::to_string(result)
144                            .unwrap_or_default()
145                            .to_lowercase()
146                            .contains(&query_lc)
147                        {
148                            return true;
149                        }
150                    }
151                }
152                false
153            })
154            .map(|arc_ep| arc_ep.as_ref().clone())
155            .collect();
156
157        // Apply sorting
158        match sort.as_str() {
159            "newest" => {
160                episodes.sort_by(|a, b| b.start_time.cmp(&a.start_time));
161            }
162            "oldest" => {
163                episodes.sort_by(|a, b| a.start_time.cmp(&b.start_time));
164            }
165            "duration" => {
166                episodes.sort_by(|a, b| {
167                    let dur_a = a.end_time.map(|e| e - a.start_time);
168                    let dur_b = b.end_time.map(|e| e - b.start_time);
169                    dur_b.cmp(&dur_a)
170                });
171            }
172            "success" => {
173                episodes.sort_by(|a, b| {
174                    let score_a = outcome_score(a);
175                    let score_b = outcome_score(b);
176                    score_b.cmp(&score_a)
177                });
178            }
179            _ => {} // "relevance" - keep default order
180        }
181
182        // Also get relevant patterns
183        let patterns = self
184            .memory
185            .retrieve_relevant_patterns(&context, limit)
186            .await;
187
188        // Calculate insights from retrieved data
189        let success_count = episodes
190            .iter()
191            .filter(|e| e.reward.as_ref().is_some_and(|r| r.total > 0.7))
192            .count();
193
194        let avg_success_rate = if !episodes.is_empty() {
195            success_count as f32 / episodes.len() as f32
196        } else {
197            0.0
198        };
199
200        let duration_ms = start.elapsed().as_millis() as u64;
201
202        // End monitoring request
203        self.monitoring.end_request(&request_id, true, None).await;
204
205        debug!("Memory query completed in {}ms", duration_ms);
206
207        // Build result
208        let result = json!({
209            "episodes": episodes,
210            "patterns": patterns,
211            "insights": {
212                "total_episodes": episodes.len(),
213                "relevant_patterns": patterns.len(),
214                "success_rate": avg_success_rate
215            }
216        });
217
218        // Apply field projection if requested
219        if let Some(field_list) = fields {
220            use crate::server::tools::field_projection::FieldSelector;
221            let selector = FieldSelector::new(field_list.into_iter().collect());
222            return selector.apply(&result);
223        }
224
225        Ok(result)
226    }
227
228    /// Execute the analyze_patterns tool
229    ///
230    /// # Arguments
231    ///
232    /// * `task_type` - Type of task to analyze
233    /// * `min_success_rate` - Minimum success rate filter
234    /// * `limit` - Maximum patterns to return
235    ///
236    /// # Returns
237    ///
238    /// Returns a JSON array of patterns with statistics
239    ///
240    /// # Field Selection
241    ///
242    /// Clients can request specific fields:
243    /// ```json
244    /// {
245    ///   "task_type": "code_generation",
246    ///   "fields": ["patterns.tool_sequence", "statistics.most_common_tools"]
247    /// }
248    /// ```
249    pub async fn analyze_patterns(
250        &self,
251        task_type: String,
252        min_success_rate: f32,
253        limit: usize,
254        fields: Option<Vec<String>>,
255    ) -> Result<serde_json::Value> {
256        self.track_tool_usage("analyze_patterns").await;
257
258        debug!(
259            "Analyzing patterns: task_type='{}', min_success_rate={}, limit={}",
260            task_type, min_success_rate, limit
261        );
262
263        // Build context for pattern retrieval
264        let context = do_memory_core::TaskContext {
265            domain: task_type.clone(),
266            language: None,
267            framework: None,
268            complexity: do_memory_core::ComplexityLevel::Moderate,
269            tags: vec![task_type],
270        };
271
272        // Retrieve patterns from memory
273        let all_patterns = self
274            .memory
275            .retrieve_relevant_patterns(&context, limit * 2)
276            .await;
277
278        // Filter by success rate and limit
279        let filtered_patterns: Vec<_> = all_patterns
280            .into_iter()
281            .filter(|p| p.success_rate() >= min_success_rate)
282            .take(limit)
283            .collect();
284
285        // Calculate statistics
286        let total_patterns = filtered_patterns.len();
287        let avg_success_rate = min_success_rate;
288
289        // Extract most common tools from patterns
290        let mut tool_counts: HashMap<String, usize> = HashMap::new();
291        for pattern in &filtered_patterns {
292            match pattern {
293                Pattern::ToolSequence { tools, .. } => {
294                    for tool in tools {
295                        *tool_counts.entry(tool.clone()).or_insert(0) += 1;
296                    }
297                }
298                Pattern::DecisionPoint { action, .. } => {
299                    *tool_counts.entry(action.clone()).or_insert(0) += 1;
300                }
301                Pattern::ErrorRecovery { recovery_steps, .. } => {
302                    for step in recovery_steps {
303                        *tool_counts.entry(step.clone()).or_insert(0) += 1;
304                    }
305                }
306                Pattern::ContextPattern {
307                    recommended_approach,
308                    ..
309                } => {
310                    *tool_counts.entry(recommended_approach.clone()).or_insert(0) += 1;
311                }
312            }
313        }
314
315        let mut most_common_tools: Vec<_> = tool_counts.into_iter().collect();
316        most_common_tools.sort_by(|a, b| b.1.cmp(&a.1));
317        let most_common_tools: Vec<String> = most_common_tools
318            .into_iter()
319            .take(5)
320            .map(|(tool, _)| tool)
321            .collect();
322
323        // Build result
324        let result = json!({
325            "patterns": filtered_patterns,
326            "statistics": {
327                "total_patterns": total_patterns,
328                "avg_success_rate": avg_success_rate,
329                "most_common_tools": most_common_tools
330            }
331        });
332
333        // Apply field projection if requested
334        if let Some(field_list) = fields {
335            use crate::server::tools::field_projection::FieldSelector;
336            let selector = FieldSelector::new(field_list.into_iter().collect());
337            return selector.apply(&result);
338        }
339
340        Ok(result)
341    }
342
343    /// Execute the bulk_episodes tool
344    ///
345    /// # Arguments
346    ///
347    /// * `episode_ids` - List of episode UUIDs to retrieve
348    ///
349    /// # Returns
350    ///
351    /// Returns a result with requested count, found count, and episodes
352    pub async fn get_episodes_by_ids(
353        &self,
354        episode_ids: &[uuid::Uuid],
355    ) -> Result<Vec<do_memory_core::Episode>> {
356        self.track_tool_usage("bulk_episodes").await;
357
358        debug!("Bulk retrieving {} episodes", episode_ids.len());
359
360        let episodes = self.memory.get_episodes_by_ids(episode_ids).await?;
361
362        debug!(
363            "Found {} of {} requested episodes",
364            episodes.len(),
365            episode_ids.len()
366        );
367
368        Ok(episodes)
369    }
370}