1use serde_json::{json, Value};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::context::{Context, ContextDomain, ContextQuery, ScreeningStatus};
11use crate::protocol::{CallToolResult, InputSchema, PropertySchema, Tool};
12use crate::rag::{RagProcessor, RetrievalQuery};
13use crate::storage::ContextStore;
14use crate::temporal::TemporalQuery;
15
16pub struct ToolRegistry {
18 store: Arc<ContextStore>,
19 rag: Arc<RagProcessor>,
20}
21
22impl ToolRegistry {
23 pub fn new(store: Arc<ContextStore>, rag: Arc<RagProcessor>) -> Self {
25 Self { store, rag }
26 }
27
28 pub fn list_tools(&self) -> Vec<Tool> {
30 vec![
31 self.store_context_tool(),
32 self.get_context_tool(),
33 self.delete_context_tool(),
34 self.query_contexts_tool(),
35 self.retrieve_contexts_tool(),
36 self.update_screening_tool(),
37 self.get_temporal_stats_tool(),
38 self.get_storage_stats_tool(),
39 self.cleanup_expired_tool(),
40 ]
41 }
42
43 pub async fn execute(&self, name: &str, args: HashMap<String, Value>) -> CallToolResult {
45 match name {
46 "store_context" => self.store_context(args).await,
47 "get_context" => self.get_context(args).await,
48 "delete_context" => self.delete_context(args).await,
49 "query_contexts" => self.query_contexts(args).await,
50 "retrieve_contexts" => self.retrieve_contexts(args).await,
51 "update_screening" => self.update_screening(args).await,
52 "get_temporal_stats" => self.get_temporal_stats(args).await,
53 "get_storage_stats" => self.get_storage_stats(args).await,
54 "cleanup_expired" => self.cleanup_expired(args).await,
55 _ => CallToolResult::error(format!("Unknown tool: {}", name)),
56 }
57 }
58
59 fn store_context_tool(&self) -> Tool {
62 Tool {
63 name: "store_context".to_string(),
64 description: Some("Store a new context with metadata and optional TTL".to_string()),
65 input_schema: InputSchema::object()
66 .with_required("content", PropertySchema::string("The context content"))
67 .with_property(
68 "domain",
69 PropertySchema::string("Context domain").with_enum(vec![
70 "General",
71 "Code",
72 "Documentation",
73 "Conversation",
74 "Filesystem",
75 "WebSearch",
76 "Dataset",
77 "Research",
78 ]),
79 )
80 .with_property("source", PropertySchema::string("Source of the context"))
81 .with_property("tags", PropertySchema::array("Tags for categorization"))
82 .with_property(
83 "importance",
84 PropertySchema::number("Importance 0.0-1.0").with_default(json!(0.5)),
85 )
86 .with_property("ttl_hours", PropertySchema::number("Time to live in hours")),
87 }
88 }
89
90 fn get_context_tool(&self) -> Tool {
91 Tool {
92 name: "get_context".to_string(),
93 description: Some("Retrieve a context by ID".to_string()),
94 input_schema: InputSchema::object()
95 .with_required("id", PropertySchema::string("Context ID")),
96 }
97 }
98
99 fn delete_context_tool(&self) -> Tool {
100 Tool {
101 name: "delete_context".to_string(),
102 description: Some("Delete a context by ID".to_string()),
103 input_schema: InputSchema::object()
104 .with_required("id", PropertySchema::string("Context ID")),
105 }
106 }
107
108 fn query_contexts_tool(&self) -> Tool {
109 Tool {
110 name: "query_contexts".to_string(),
111 description: Some("Query contexts with filters".to_string()),
112 input_schema: InputSchema::object()
113 .with_property("domain", PropertySchema::string("Filter by domain"))
114 .with_property("tags", PropertySchema::array("Filter by tags"))
115 .with_property(
116 "min_importance",
117 PropertySchema::number("Minimum importance threshold"),
118 )
119 .with_property(
120 "max_age_hours",
121 PropertySchema::number("Maximum age in hours"),
122 )
123 .with_property(
124 "verified_only",
125 PropertySchema::boolean("Only return verified contexts"),
126 )
127 .with_property(
128 "limit",
129 PropertySchema::number("Maximum results").with_default(json!(10)),
130 ),
131 }
132 }
133
134 fn retrieve_contexts_tool(&self) -> Tool {
135 Tool {
136 name: "retrieve_contexts".to_string(),
137 description: Some("Retrieve contexts using RAG with scoring".to_string()),
138 input_schema: InputSchema::object()
139 .with_property("text", PropertySchema::string("Text query"))
140 .with_property("domain", PropertySchema::string("Domain filter"))
141 .with_property("tags", PropertySchema::array("Tag filters"))
142 .with_property(
143 "min_importance",
144 PropertySchema::number("Minimum importance"),
145 )
146 .with_property(
147 "max_age_hours",
148 PropertySchema::number("Maximum age for temporal filtering"),
149 )
150 .with_property(
151 "max_results",
152 PropertySchema::number("Maximum results").with_default(json!(10)),
153 ),
154 }
155 }
156
157 fn update_screening_tool(&self) -> Tool {
158 Tool {
159 name: "update_screening".to_string(),
160 description: Some("Update screening status of a context".to_string()),
161 input_schema: InputSchema::object()
162 .with_required("id", PropertySchema::string("Context ID"))
163 .with_required(
164 "status",
165 PropertySchema::string("New screening status")
166 .with_enum(vec!["Safe", "Flagged", "Blocked"]),
167 )
168 .with_property("reason", PropertySchema::string("Reason for status change")),
169 }
170 }
171
172 fn get_temporal_stats_tool(&self) -> Tool {
173 Tool {
174 name: "get_temporal_stats".to_string(),
175 description: Some("Get temporal statistics for stored contexts".to_string()),
176 input_schema: InputSchema::object()
177 .with_property("domain", PropertySchema::string("Filter by domain")),
178 }
179 }
180
181 fn get_storage_stats_tool(&self) -> Tool {
182 Tool {
183 name: "get_storage_stats".to_string(),
184 description: Some("Get storage statistics".to_string()),
185 input_schema: InputSchema::object(),
186 }
187 }
188
189 fn cleanup_expired_tool(&self) -> Tool {
190 Tool {
191 name: "cleanup_expired".to_string(),
192 description: Some("Remove expired contexts".to_string()),
193 input_schema: InputSchema::object(),
194 }
195 }
196
197 async fn store_context(&self, args: HashMap<String, Value>) -> CallToolResult {
200 let content = match args.get("content").and_then(|v| v.as_str()) {
201 Some(c) => c.to_string(),
202 None => return CallToolResult::error("Missing required parameter: content"),
203 };
204
205 let domain = args
206 .get("domain")
207 .and_then(|v| v.as_str())
208 .map(parse_domain)
209 .unwrap_or(ContextDomain::General);
210
211 let mut ctx = Context::new(content, domain);
212
213 if let Some(source) = args.get("source").and_then(|v| v.as_str()) {
215 ctx.metadata.source = source.to_string();
216 }
217
218 if let Some(tags) = args.get("tags").and_then(|v| v.as_array()) {
219 ctx.metadata.tags = tags
220 .iter()
221 .filter_map(|v| v.as_str().map(|s| s.to_string()))
222 .collect();
223 }
224
225 if let Some(importance) = args.get("importance").and_then(|v| v.as_f64()) {
226 ctx.metadata.importance = importance.clamp(0.0, 1.0) as f32;
227 }
228
229 if let Some(ttl) = args.get("ttl_hours").and_then(|v| v.as_i64()) {
230 ctx = ctx.with_ttl(std::time::Duration::from_secs(ttl as u64 * 3600));
231 }
232
233 let id = ctx.id.clone();
234 match self.store.store(ctx).await {
235 Ok(_stored_id) => CallToolResult::json(json!({
236 "success": true,
237 "id": id.to_string(),
238 "message": "Context stored successfully"
239 })),
240 Err(e) => CallToolResult::error(format!("Failed to store context: {}", e)),
241 }
242 }
243
244 async fn get_context(&self, args: HashMap<String, Value>) -> CallToolResult {
245 let id_str = match args.get("id").and_then(|v| v.as_str()) {
246 Some(id) => id,
247 None => return CallToolResult::error("Missing required parameter: id"),
248 };
249
250 let id = crate::context::ContextId::from_string(id_str.to_string());
251
252 match self.store.get(&id).await {
253 Ok(Some(ctx)) => CallToolResult::json(json!({
254 "id": ctx.id.to_string(),
255 "content": ctx.content,
256 "domain": format!("{:?}", ctx.domain),
257 "created_at": ctx.created_at.to_rfc3339(),
258 "accessed_at": ctx.accessed_at.to_rfc3339(),
259 "metadata": {
260 "source": ctx.metadata.source,
261 "tags": ctx.metadata.tags,
262 "importance": ctx.metadata.importance,
263 "verified": ctx.metadata.verified,
264 "screening_status": format!("{:?}", ctx.metadata.screening_status)
265 },
266 "age_hours": ctx.age_hours()
267 })),
268 Ok(None) => CallToolResult::error(format!("Context not found: {}", id_str)),
269 Err(e) => CallToolResult::error(format!("Error retrieving context: {}", e)),
270 }
271 }
272
273 async fn delete_context(&self, args: HashMap<String, Value>) -> CallToolResult {
274 let id_str = match args.get("id").and_then(|v| v.as_str()) {
275 Some(id) => id,
276 None => return CallToolResult::error("Missing required parameter: id"),
277 };
278
279 let id = crate::context::ContextId::from_string(id_str.to_string());
280
281 match self.store.delete(&id).await {
282 Ok(true) => CallToolResult::json(json!({
283 "success": true,
284 "message": "Context deleted"
285 })),
286 Ok(false) => CallToolResult::error(format!("Context not found: {}", id_str)),
287 Err(e) => CallToolResult::error(format!("Error deleting context: {}", e)),
288 }
289 }
290
291 async fn query_contexts(&self, args: HashMap<String, Value>) -> CallToolResult {
292 let mut query = ContextQuery::new();
293
294 if let Some(domain) = args.get("domain").and_then(|v| v.as_str()) {
295 query = query.with_domain(parse_domain(domain));
296 }
297
298 if let Some(tags) = args.get("tags").and_then(|v| v.as_array()) {
299 for tag in tags.iter().filter_map(|v| v.as_str()) {
300 query = query.with_tag(tag.to_string());
301 }
302 }
303
304 if let Some(min_importance) = args.get("min_importance").and_then(|v| v.as_f64()) {
305 query = query.with_min_importance(min_importance as f32);
306 }
307
308 if let Some(max_age) = args.get("max_age_hours").and_then(|v| v.as_i64()) {
309 query = query.with_max_age_hours(max_age);
310 }
311
312 if let Some(verified) = args.get("verified_only").and_then(|v| v.as_bool()) {
313 if verified {
314 query = query.verified_only();
315 }
316 }
317
318 if let Some(limit) = args.get("limit").and_then(|v| v.as_u64()) {
319 query = query.with_limit(limit as usize);
320 }
321
322 match self.store.query(&query).await {
323 Ok(contexts) => {
324 let results: Vec<Value> = contexts
325 .iter()
326 .map(|ctx| {
327 json!({
328 "id": ctx.id.to_string(),
329 "content_preview": ctx.content.chars().take(100).collect::<String>(),
330 "domain": format!("{:?}", ctx.domain),
331 "importance": ctx.metadata.importance,
332 "age_hours": ctx.age_hours(),
333 "tags": ctx.metadata.tags
334 })
335 })
336 .collect();
337
338 CallToolResult::json(json!({
339 "count": results.len(),
340 "contexts": results
341 }))
342 }
343 Err(e) => CallToolResult::error(format!("Query failed: {}", e)),
344 }
345 }
346
347 async fn retrieve_contexts(&self, args: HashMap<String, Value>) -> CallToolResult {
348 let mut query = RetrievalQuery::new();
349
350 if let Some(text) = args.get("text").and_then(|v| v.as_str()) {
351 query.text = Some(text.to_string());
352 }
353
354 if let Some(domain) = args.get("domain").and_then(|v| v.as_str()) {
355 query = query.with_domain(parse_domain(domain));
356 }
357
358 if let Some(tags) = args.get("tags").and_then(|v| v.as_array()) {
359 for tag in tags.iter().filter_map(|v| v.as_str()) {
360 query = query.with_tag(tag.to_string());
361 }
362 }
363
364 if let Some(min_importance) = args.get("min_importance").and_then(|v| v.as_f64()) {
365 query = query.with_min_importance(min_importance as f32);
366 }
367
368 if let Some(max_age) = args.get("max_age_hours").and_then(|v| v.as_i64()) {
369 query = query.with_temporal(TemporalQuery::recent(max_age));
370 }
371
372 match self.rag.retrieve(&query).await {
373 Ok(result) => {
374 let contexts: Vec<Value> = result
375 .contexts
376 .iter()
377 .map(|sc| {
378 json!({
379 "id": sc.context.id.to_string(),
380 "content": sc.context.content,
381 "domain": format!("{:?}", sc.context.domain),
382 "score": sc.score,
383 "score_breakdown": {
384 "temporal": sc.score_breakdown.temporal,
385 "importance": sc.score_breakdown.importance,
386 "domain_match": sc.score_breakdown.domain_match,
387 "tag_match": sc.score_breakdown.tag_match
388 },
389 "age_hours": sc.context.age_hours(),
390 "tags": sc.context.metadata.tags
391 })
392 })
393 .collect();
394
395 CallToolResult::json(json!({
396 "count": contexts.len(),
397 "candidates_considered": result.candidates_considered,
398 "processing_time_ms": result.processing_time_ms,
399 "temporal_stats": {
400 "count": result.temporal_stats.count,
401 "avg_age_hours": result.temporal_stats.avg_age_hours,
402 "distribution": result.temporal_stats.distribution
403 },
404 "contexts": contexts
405 }))
406 }
407 Err(e) => CallToolResult::error(format!("Retrieval failed: {}", e)),
408 }
409 }
410
411 async fn update_screening(&self, args: HashMap<String, Value>) -> CallToolResult {
412 let id_str = match args.get("id").and_then(|v| v.as_str()) {
413 Some(id) => id,
414 None => return CallToolResult::error("Missing required parameter: id"),
415 };
416
417 let status_str = match args.get("status").and_then(|v| v.as_str()) {
418 Some(s) => s,
419 None => return CallToolResult::error("Missing required parameter: status"),
420 };
421
422 let status = match status_str.to_lowercase().as_str() {
423 "safe" => ScreeningStatus::Safe,
424 "flagged" => ScreeningStatus::Flagged,
425 "blocked" => ScreeningStatus::Blocked,
426 _ => return CallToolResult::error(format!("Invalid status: {}", status_str)),
427 };
428
429 let id = crate::context::ContextId::from_string(id_str.to_string());
430
431 match self.store.get(&id).await {
432 Ok(Some(mut ctx)) => {
433 ctx.metadata.screening_status = status.clone();
434 match self.store.store(ctx).await {
435 Ok(_) => CallToolResult::json(json!({
436 "success": true,
437 "id": id_str,
438 "new_status": format!("{:?}", status)
439 })),
440 Err(e) => CallToolResult::error(format!("Failed to update: {}", e)),
441 }
442 }
443 Ok(None) => CallToolResult::error(format!("Context not found: {}", id_str)),
444 Err(e) => CallToolResult::error(format!("Error: {}", e)),
445 }
446 }
447
448 async fn get_temporal_stats(&self, args: HashMap<String, Value>) -> CallToolResult {
449 let mut query = ContextQuery::new();
450
451 if let Some(domain) = args.get("domain").and_then(|v| v.as_str()) {
452 query = query.with_domain(parse_domain(domain));
453 }
454
455 match self.store.query(&query).await {
456 Ok(contexts) => {
457 let stats = crate::temporal::TemporalStats::from_contexts(&contexts);
458 CallToolResult::json(json!({
459 "count": stats.count,
460 "oldest": stats.oldest.map(|t| t.to_rfc3339()),
461 "newest": stats.newest.map(|t| t.to_rfc3339()),
462 "avg_age_hours": stats.avg_age_hours,
463 "distribution": {
464 "last_hour": stats.distribution.last_hour,
465 "last_day": stats.distribution.last_day,
466 "last_week": stats.distribution.last_week,
467 "last_month": stats.distribution.last_month,
468 "older": stats.distribution.older
469 }
470 }))
471 }
472 Err(e) => CallToolResult::error(format!("Failed to get stats: {}", e)),
473 }
474 }
475
476 async fn get_storage_stats(&self, _args: HashMap<String, Value>) -> CallToolResult {
477 let stats = self.store.stats().await;
478 CallToolResult::json(json!({
479 "memory_count": stats.memory_count,
480 "disk_count": stats.disk_count,
481 "cache_capacity": stats.cache_capacity
482 }))
483 }
484
485 async fn cleanup_expired(&self, _args: HashMap<String, Value>) -> CallToolResult {
486 match self.store.cleanup_expired().await {
487 Ok(count) => CallToolResult::json(json!({
488 "success": true,
489 "removed_count": count
490 })),
491 Err(e) => CallToolResult::error(format!("Cleanup failed: {}", e)),
492 }
493 }
494}
495
496fn parse_domain(s: &str) -> ContextDomain {
498 match s.to_lowercase().as_str() {
499 "code" => ContextDomain::Code,
500 "documentation" | "docs" => ContextDomain::Documentation,
501 "conversation" | "chat" => ContextDomain::Conversation,
502 "filesystem" | "files" => ContextDomain::Filesystem,
503 "websearch" | "web" => ContextDomain::WebSearch,
504 "dataset" | "data" => ContextDomain::Dataset,
505 "research" => ContextDomain::Research,
506 _ => ContextDomain::General,
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 fn test_parse_domain() {
516 assert_eq!(parse_domain("Code"), ContextDomain::Code);
517 assert_eq!(parse_domain("docs"), ContextDomain::Documentation);
518 assert_eq!(parse_domain("unknown"), ContextDomain::General);
519 }
520}