context_mcp/
tools.rs

1//! MCP tool implementations for context management
2//!
3//! Provides tools for storing, retrieving, and querying contexts
4//! with temporal reasoning and RAG support.
5
6use serde_json::{json, Value};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::context::{Context, ContextDomain, ContextQuery, ScreeningStatus};
11use crate::protocol::{CallToolResult, InputSchema, PropertySchema, Tool};
12use crate::rag::{RagProcessor, RetrievalQuery};
13use crate::storage::ContextStore;
14use crate::temporal::TemporalQuery;
15
16/// Tool registry managing all available tools
17pub struct ToolRegistry {
18    store: Arc<ContextStore>,
19    rag: Arc<RagProcessor>,
20}
21
22impl ToolRegistry {
23    /// Create a new tool registry
24    pub fn new(store: Arc<ContextStore>, rag: Arc<RagProcessor>) -> Self {
25        Self { store, rag }
26    }
27
28    /// Get all available tools
29    pub fn list_tools(&self) -> Vec<Tool> {
30        vec![
31            self.store_context_tool(),
32            self.get_context_tool(),
33            self.delete_context_tool(),
34            self.query_contexts_tool(),
35            self.retrieve_contexts_tool(),
36            self.update_screening_tool(),
37            self.get_temporal_stats_tool(),
38            self.get_storage_stats_tool(),
39            self.cleanup_expired_tool(),
40        ]
41    }
42
43    /// Execute a tool by name
44    pub async fn execute(&self, name: &str, args: HashMap<String, Value>) -> CallToolResult {
45        match name {
46            "store_context" => self.store_context(args).await,
47            "get_context" => self.get_context(args).await,
48            "delete_context" => self.delete_context(args).await,
49            "query_contexts" => self.query_contexts(args).await,
50            "retrieve_contexts" => self.retrieve_contexts(args).await,
51            "update_screening" => self.update_screening(args).await,
52            "get_temporal_stats" => self.get_temporal_stats(args).await,
53            "get_storage_stats" => self.get_storage_stats(args).await,
54            "cleanup_expired" => self.cleanup_expired(args).await,
55            _ => CallToolResult::error(format!("Unknown tool: {}", name)),
56        }
57    }
58
59    // Tool definitions
60
61    fn store_context_tool(&self) -> Tool {
62        Tool {
63            name: "store_context".to_string(),
64            description: Some("Store a new context with metadata and optional TTL".to_string()),
65            input_schema: InputSchema::object()
66                .with_required("content", PropertySchema::string("The context content"))
67                .with_property(
68                    "domain",
69                    PropertySchema::string("Context domain").with_enum(vec![
70                        "General",
71                        "Code",
72                        "Documentation",
73                        "Conversation",
74                        "Filesystem",
75                        "WebSearch",
76                        "Dataset",
77                        "Research",
78                    ]),
79                )
80                .with_property("source", PropertySchema::string("Source of the context"))
81                .with_property("tags", PropertySchema::array("Tags for categorization"))
82                .with_property(
83                    "importance",
84                    PropertySchema::number("Importance 0.0-1.0").with_default(json!(0.5)),
85                )
86                .with_property("ttl_hours", PropertySchema::number("Time to live in hours")),
87        }
88    }
89
90    fn get_context_tool(&self) -> Tool {
91        Tool {
92            name: "get_context".to_string(),
93            description: Some("Retrieve a context by ID".to_string()),
94            input_schema: InputSchema::object()
95                .with_required("id", PropertySchema::string("Context ID")),
96        }
97    }
98
99    fn delete_context_tool(&self) -> Tool {
100        Tool {
101            name: "delete_context".to_string(),
102            description: Some("Delete a context by ID".to_string()),
103            input_schema: InputSchema::object()
104                .with_required("id", PropertySchema::string("Context ID")),
105        }
106    }
107
108    fn query_contexts_tool(&self) -> Tool {
109        Tool {
110            name: "query_contexts".to_string(),
111            description: Some("Query contexts with filters".to_string()),
112            input_schema: InputSchema::object()
113                .with_property("domain", PropertySchema::string("Filter by domain"))
114                .with_property("tags", PropertySchema::array("Filter by tags"))
115                .with_property(
116                    "min_importance",
117                    PropertySchema::number("Minimum importance threshold"),
118                )
119                .with_property(
120                    "max_age_hours",
121                    PropertySchema::number("Maximum age in hours"),
122                )
123                .with_property(
124                    "verified_only",
125                    PropertySchema::boolean("Only return verified contexts"),
126                )
127                .with_property(
128                    "limit",
129                    PropertySchema::number("Maximum results").with_default(json!(10)),
130                ),
131        }
132    }
133
134    fn retrieve_contexts_tool(&self) -> Tool {
135        Tool {
136            name: "retrieve_contexts".to_string(),
137            description: Some("Retrieve contexts using RAG with scoring".to_string()),
138            input_schema: InputSchema::object()
139                .with_property("text", PropertySchema::string("Text query"))
140                .with_property("domain", PropertySchema::string("Domain filter"))
141                .with_property("tags", PropertySchema::array("Tag filters"))
142                .with_property(
143                    "min_importance",
144                    PropertySchema::number("Minimum importance"),
145                )
146                .with_property(
147                    "max_age_hours",
148                    PropertySchema::number("Maximum age for temporal filtering"),
149                )
150                .with_property(
151                    "max_results",
152                    PropertySchema::number("Maximum results").with_default(json!(10)),
153                ),
154        }
155    }
156
157    fn update_screening_tool(&self) -> Tool {
158        Tool {
159            name: "update_screening".to_string(),
160            description: Some("Update screening status of a context".to_string()),
161            input_schema: InputSchema::object()
162                .with_required("id", PropertySchema::string("Context ID"))
163                .with_required(
164                    "status",
165                    PropertySchema::string("New screening status")
166                        .with_enum(vec!["Safe", "Flagged", "Blocked"]),
167                )
168                .with_property("reason", PropertySchema::string("Reason for status change")),
169        }
170    }
171
172    fn get_temporal_stats_tool(&self) -> Tool {
173        Tool {
174            name: "get_temporal_stats".to_string(),
175            description: Some("Get temporal statistics for stored contexts".to_string()),
176            input_schema: InputSchema::object()
177                .with_property("domain", PropertySchema::string("Filter by domain")),
178        }
179    }
180
181    fn get_storage_stats_tool(&self) -> Tool {
182        Tool {
183            name: "get_storage_stats".to_string(),
184            description: Some("Get storage statistics".to_string()),
185            input_schema: InputSchema::object(),
186        }
187    }
188
189    fn cleanup_expired_tool(&self) -> Tool {
190        Tool {
191            name: "cleanup_expired".to_string(),
192            description: Some("Remove expired contexts".to_string()),
193            input_schema: InputSchema::object(),
194        }
195    }
196
197    // Tool implementations
198
199    async fn store_context(&self, args: HashMap<String, Value>) -> CallToolResult {
200        let content = match args.get("content").and_then(|v| v.as_str()) {
201            Some(c) => c.to_string(),
202            None => return CallToolResult::error("Missing required parameter: content"),
203        };
204
205        let domain = args
206            .get("domain")
207            .and_then(|v| v.as_str())
208            .map(parse_domain)
209            .unwrap_or(ContextDomain::General);
210
211        let mut ctx = Context::new(content, domain);
212
213        // Set metadata
214        if let Some(source) = args.get("source").and_then(|v| v.as_str()) {
215            ctx.metadata.source = source.to_string();
216        }
217
218        if let Some(tags) = args.get("tags").and_then(|v| v.as_array()) {
219            ctx.metadata.tags = tags
220                .iter()
221                .filter_map(|v| v.as_str().map(|s| s.to_string()))
222                .collect();
223        }
224
225        if let Some(importance) = args.get("importance").and_then(|v| v.as_f64()) {
226            ctx.metadata.importance = importance.clamp(0.0, 1.0) as f32;
227        }
228
229        if let Some(ttl) = args.get("ttl_hours").and_then(|v| v.as_i64()) {
230            ctx = ctx.with_ttl(std::time::Duration::from_secs(ttl as u64 * 3600));
231        }
232
233        let id = ctx.id.clone();
234        match self.store.store(ctx).await {
235            Ok(_stored_id) => CallToolResult::json(json!({
236                "success": true,
237                "id": id.to_string(),
238                "message": "Context stored successfully"
239            })),
240            Err(e) => CallToolResult::error(format!("Failed to store context: {}", e)),
241        }
242    }
243
244    async fn get_context(&self, args: HashMap<String, Value>) -> CallToolResult {
245        let id_str = match args.get("id").and_then(|v| v.as_str()) {
246            Some(id) => id,
247            None => return CallToolResult::error("Missing required parameter: id"),
248        };
249
250        let id = crate::context::ContextId::from_string(id_str.to_string());
251
252        match self.store.get(&id).await {
253            Ok(Some(ctx)) => CallToolResult::json(json!({
254                "id": ctx.id.to_string(),
255                "content": ctx.content,
256                "domain": format!("{:?}", ctx.domain),
257                "created_at": ctx.created_at.to_rfc3339(),
258                "accessed_at": ctx.accessed_at.to_rfc3339(),
259                "metadata": {
260                    "source": ctx.metadata.source,
261                    "tags": ctx.metadata.tags,
262                    "importance": ctx.metadata.importance,
263                    "verified": ctx.metadata.verified,
264                    "screening_status": format!("{:?}", ctx.metadata.screening_status)
265                },
266                "age_hours": ctx.age_hours()
267            })),
268            Ok(None) => CallToolResult::error(format!("Context not found: {}", id_str)),
269            Err(e) => CallToolResult::error(format!("Error retrieving context: {}", e)),
270        }
271    }
272
273    async fn delete_context(&self, args: HashMap<String, Value>) -> CallToolResult {
274        let id_str = match args.get("id").and_then(|v| v.as_str()) {
275            Some(id) => id,
276            None => return CallToolResult::error("Missing required parameter: id"),
277        };
278
279        let id = crate::context::ContextId::from_string(id_str.to_string());
280
281        match self.store.delete(&id).await {
282            Ok(true) => CallToolResult::json(json!({
283                "success": true,
284                "message": "Context deleted"
285            })),
286            Ok(false) => CallToolResult::error(format!("Context not found: {}", id_str)),
287            Err(e) => CallToolResult::error(format!("Error deleting context: {}", e)),
288        }
289    }
290
291    async fn query_contexts(&self, args: HashMap<String, Value>) -> CallToolResult {
292        let mut query = ContextQuery::new();
293
294        if let Some(domain) = args.get("domain").and_then(|v| v.as_str()) {
295            query = query.with_domain(parse_domain(domain));
296        }
297
298        if let Some(tags) = args.get("tags").and_then(|v| v.as_array()) {
299            for tag in tags.iter().filter_map(|v| v.as_str()) {
300                query = query.with_tag(tag.to_string());
301            }
302        }
303
304        if let Some(min_importance) = args.get("min_importance").and_then(|v| v.as_f64()) {
305            query = query.with_min_importance(min_importance as f32);
306        }
307
308        if let Some(max_age) = args.get("max_age_hours").and_then(|v| v.as_i64()) {
309            query = query.with_max_age_hours(max_age);
310        }
311
312        if let Some(verified) = args.get("verified_only").and_then(|v| v.as_bool()) {
313            if verified {
314                query = query.verified_only();
315            }
316        }
317
318        if let Some(limit) = args.get("limit").and_then(|v| v.as_u64()) {
319            query = query.with_limit(limit as usize);
320        }
321
322        match self.store.query(&query).await {
323            Ok(contexts) => {
324                let results: Vec<Value> = contexts
325                    .iter()
326                    .map(|ctx| {
327                        json!({
328                            "id": ctx.id.to_string(),
329                            "content_preview": ctx.content.chars().take(100).collect::<String>(),
330                            "domain": format!("{:?}", ctx.domain),
331                            "importance": ctx.metadata.importance,
332                            "age_hours": ctx.age_hours(),
333                            "tags": ctx.metadata.tags
334                        })
335                    })
336                    .collect();
337
338                CallToolResult::json(json!({
339                    "count": results.len(),
340                    "contexts": results
341                }))
342            }
343            Err(e) => CallToolResult::error(format!("Query failed: {}", e)),
344        }
345    }
346
347    async fn retrieve_contexts(&self, args: HashMap<String, Value>) -> CallToolResult {
348        let mut query = RetrievalQuery::new();
349
350        if let Some(text) = args.get("text").and_then(|v| v.as_str()) {
351            query.text = Some(text.to_string());
352        }
353
354        if let Some(domain) = args.get("domain").and_then(|v| v.as_str()) {
355            query = query.with_domain(parse_domain(domain));
356        }
357
358        if let Some(tags) = args.get("tags").and_then(|v| v.as_array()) {
359            for tag in tags.iter().filter_map(|v| v.as_str()) {
360                query = query.with_tag(tag.to_string());
361            }
362        }
363
364        if let Some(min_importance) = args.get("min_importance").and_then(|v| v.as_f64()) {
365            query = query.with_min_importance(min_importance as f32);
366        }
367
368        if let Some(max_age) = args.get("max_age_hours").and_then(|v| v.as_i64()) {
369            query = query.with_temporal(TemporalQuery::recent(max_age));
370        }
371
372        match self.rag.retrieve(&query).await {
373            Ok(result) => {
374                let contexts: Vec<Value> = result
375                    .contexts
376                    .iter()
377                    .map(|sc| {
378                        json!({
379                            "id": sc.context.id.to_string(),
380                            "content": sc.context.content,
381                            "domain": format!("{:?}", sc.context.domain),
382                            "score": sc.score,
383                            "score_breakdown": {
384                                "temporal": sc.score_breakdown.temporal,
385                                "importance": sc.score_breakdown.importance,
386                                "domain_match": sc.score_breakdown.domain_match,
387                                "tag_match": sc.score_breakdown.tag_match
388                            },
389                            "age_hours": sc.context.age_hours(),
390                            "tags": sc.context.metadata.tags
391                        })
392                    })
393                    .collect();
394
395                CallToolResult::json(json!({
396                    "count": contexts.len(),
397                    "candidates_considered": result.candidates_considered,
398                    "processing_time_ms": result.processing_time_ms,
399                    "temporal_stats": {
400                        "count": result.temporal_stats.count,
401                        "avg_age_hours": result.temporal_stats.avg_age_hours,
402                        "distribution": result.temporal_stats.distribution
403                    },
404                    "contexts": contexts
405                }))
406            }
407            Err(e) => CallToolResult::error(format!("Retrieval failed: {}", e)),
408        }
409    }
410
411    async fn update_screening(&self, args: HashMap<String, Value>) -> CallToolResult {
412        let id_str = match args.get("id").and_then(|v| v.as_str()) {
413            Some(id) => id,
414            None => return CallToolResult::error("Missing required parameter: id"),
415        };
416
417        let status_str = match args.get("status").and_then(|v| v.as_str()) {
418            Some(s) => s,
419            None => return CallToolResult::error("Missing required parameter: status"),
420        };
421
422        let status = match status_str.to_lowercase().as_str() {
423            "safe" => ScreeningStatus::Safe,
424            "flagged" => ScreeningStatus::Flagged,
425            "blocked" => ScreeningStatus::Blocked,
426            _ => return CallToolResult::error(format!("Invalid status: {}", status_str)),
427        };
428
429        let id = crate::context::ContextId::from_string(id_str.to_string());
430
431        match self.store.get(&id).await {
432            Ok(Some(mut ctx)) => {
433                ctx.metadata.screening_status = status.clone();
434                match self.store.store(ctx).await {
435                    Ok(_) => CallToolResult::json(json!({
436                        "success": true,
437                        "id": id_str,
438                        "new_status": format!("{:?}", status)
439                    })),
440                    Err(e) => CallToolResult::error(format!("Failed to update: {}", e)),
441                }
442            }
443            Ok(None) => CallToolResult::error(format!("Context not found: {}", id_str)),
444            Err(e) => CallToolResult::error(format!("Error: {}", e)),
445        }
446    }
447
448    async fn get_temporal_stats(&self, args: HashMap<String, Value>) -> CallToolResult {
449        let mut query = ContextQuery::new();
450
451        if let Some(domain) = args.get("domain").and_then(|v| v.as_str()) {
452            query = query.with_domain(parse_domain(domain));
453        }
454
455        match self.store.query(&query).await {
456            Ok(contexts) => {
457                let stats = crate::temporal::TemporalStats::from_contexts(&contexts);
458                CallToolResult::json(json!({
459                    "count": stats.count,
460                    "oldest": stats.oldest.map(|t| t.to_rfc3339()),
461                    "newest": stats.newest.map(|t| t.to_rfc3339()),
462                    "avg_age_hours": stats.avg_age_hours,
463                    "distribution": {
464                        "last_hour": stats.distribution.last_hour,
465                        "last_day": stats.distribution.last_day,
466                        "last_week": stats.distribution.last_week,
467                        "last_month": stats.distribution.last_month,
468                        "older": stats.distribution.older
469                    }
470                }))
471            }
472            Err(e) => CallToolResult::error(format!("Failed to get stats: {}", e)),
473        }
474    }
475
476    async fn get_storage_stats(&self, _args: HashMap<String, Value>) -> CallToolResult {
477        let stats = self.store.stats().await;
478        CallToolResult::json(json!({
479            "memory_count": stats.memory_count,
480            "disk_count": stats.disk_count,
481            "cache_capacity": stats.cache_capacity
482        }))
483    }
484
485    async fn cleanup_expired(&self, _args: HashMap<String, Value>) -> CallToolResult {
486        match self.store.cleanup_expired().await {
487            Ok(count) => CallToolResult::json(json!({
488                "success": true,
489                "removed_count": count
490            })),
491            Err(e) => CallToolResult::error(format!("Cleanup failed: {}", e)),
492        }
493    }
494}
495
496/// Parse domain string to enum
497fn parse_domain(s: &str) -> ContextDomain {
498    match s.to_lowercase().as_str() {
499        "code" => ContextDomain::Code,
500        "documentation" | "docs" => ContextDomain::Documentation,
501        "conversation" | "chat" => ContextDomain::Conversation,
502        "filesystem" | "files" => ContextDomain::Filesystem,
503        "websearch" | "web" => ContextDomain::WebSearch,
504        "dataset" | "data" => ContextDomain::Dataset,
505        "research" => ContextDomain::Research,
506        _ => ContextDomain::General,
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    #[test]
515    fn test_parse_domain() {
516        assert_eq!(parse_domain("Code"), ContextDomain::Code);
517        assert_eq!(parse_domain("docs"), ContextDomain::Documentation);
518        assert_eq!(parse_domain("unknown"), ContextDomain::General);
519    }
520}