Skip to main content

rust_memex/
tools.rs

1//! In-process helper functions for memory-oriented operations.
2//!
3//! These helpers are convenient for Rust callers embedding `rust-memex`
4//! directly. The authoritative MCP tool contract exposed over stdio and
5//! HTTP/SSE comes from `tool_definitions()`, which mirrors the shared runtime
6//! surface instead of maintaining a second, drifting schema list here.
7//!
8//! # Example
9//!
10//! ```rust,ignore
11//! use rust_memex::{MemexEngine, MemexConfig};
12//! use rust_memex::tools::{store_document, search_documents, tool_definitions, ToolResult};
13//! use serde_json::json;
14//!
15//! #[tokio::main]
16//! async fn main() -> anyhow::Result<()> {
17//!     let engine = MemexEngine::for_app("my-app", "documents").await?;
18//!
19//!     // Store a document using the local helper API
20//!     let result = store_document(
21//!         &engine,
22//!         "doc-1".to_string(),
23//!         "Important patient notes about feline diabetes".to_string(),
24//!         json!({"patient_id": "P-123", "doc_type": "notes"}),
25//!     ).await?;
26//!     assert!(result.success);
27//!
28//!     // Search for documents
29//!     let results = search_documents(&engine, "diabetes".to_string(), 5, None).await?;
30//!     println!("Found {} documents", results.len());
31//!
32//!     // Get the canonical MCP tool definitions exposed by rust-memex
33//!     let tools = tool_definitions();
34//!     println!("Available tools: {:?}", tools.iter().map(|t| &t.name).collect::<Vec<_>>());
35//!
36//!     Ok(())
37//! }
38//! ```
39
40use crate::engine::{BatchResult, MemexEngine, MetaFilter, StoreItem};
41use crate::rag::SearchResult;
42use anyhow::Result;
43use serde::{Deserialize, Serialize};
44use serde_json::{Value, json};
45
46// =============================================================================
47// TOOL RESULT
48// =============================================================================
49
50/// Result type for tool operations.
51///
52/// Provides a consistent response format for all tool operations,
53/// suitable for MCP tool call responses.
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ToolResult {
56    /// Whether the operation succeeded
57    pub success: bool,
58    /// Human-readable message describing the result
59    pub message: String,
60    /// Optional data payload (operation-specific)
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub data: Option<Value>,
63}
64
65impl ToolResult {
66    /// Create a success result with just a message
67    pub fn ok(message: impl Into<String>) -> Self {
68        Self {
69            success: true,
70            message: message.into(),
71            data: None,
72        }
73    }
74
75    /// Create a success result with data
76    pub fn ok_with_data(message: impl Into<String>, data: Value) -> Self {
77        Self {
78            success: true,
79            message: message.into(),
80            data: Some(data),
81        }
82    }
83
84    /// Create an error result
85    pub fn error(message: impl Into<String>) -> Self {
86        Self {
87            success: false,
88            message: message.into(),
89            data: None,
90        }
91    }
92}
93
94// =============================================================================
95// TOOL FUNCTIONS
96// =============================================================================
97
98/// Store text in memory with automatic embedding generation.
99///
100/// # Arguments
101/// * `engine` - The MemexEngine instance
102/// * `id` - Unique document identifier
103/// * `text` - Text content to embed and store
104/// * `metadata` - Additional metadata (JSON object)
105///
106/// # Returns
107/// `ToolResult` indicating success or failure
108///
109/// # Example
110///
111/// ```rust,ignore
112/// let result = store_document(
113///     &engine,
114///     "visit-123".to_string(),
115///     "Patient presented with lethargy...".to_string(),
116///     json!({"patient_id": "P-456", "visit_type": "checkup"}),
117/// ).await?;
118/// ```
119pub async fn store_document(
120    engine: &MemexEngine,
121    id: String,
122    text: String,
123    metadata: Value,
124) -> Result<ToolResult> {
125    match engine.store(&id, &text, metadata).await {
126        Ok(()) => Ok(ToolResult::ok(format!(
127            "Document '{}' stored successfully",
128            id
129        ))),
130        Err(e) => Ok(ToolResult::error(format!(
131            "Failed to store document '{}': {}",
132            id, e
133        ))),
134    }
135}
136
137/// Search memory semantically using vector similarity.
138///
139/// # Arguments
140/// * `engine` - The MemexEngine instance
141/// * `query` - Search query text
142/// * `limit` - Maximum number of results to return
143/// * `filter` - Optional metadata filter for narrowing results
144///
145/// # Returns
146/// Vector of `SearchResult` ordered by relevance (highest score first)
147///
148/// # Example
149///
150/// ```rust,ignore
151/// // Simple search
152/// let results = search_documents(&engine, "diabetes symptoms".to_string(), 10, None).await?;
153///
154/// // Filtered search
155/// let filter = MetaFilter::for_patient("P-456");
156/// let results = search_documents(&engine, "diabetes".to_string(), 10, Some(filter)).await?;
157/// ```
158pub async fn search_documents(
159    engine: &MemexEngine,
160    query: String,
161    limit: usize,
162    filter: Option<MetaFilter>,
163) -> Result<Vec<SearchResult>> {
164    match filter {
165        Some(f) => engine.search_filtered(&query, f, limit).await,
166        None => engine.search(&query, limit).await,
167    }
168}
169
170/// Get a document by ID.
171///
172/// # Arguments
173/// * `engine` - The MemexEngine instance
174/// * `id` - Document identifier to retrieve
175///
176/// # Returns
177/// `Option<SearchResult>` - The document if found, None otherwise
178///
179/// # Example
180///
181/// ```rust,ignore
182/// if let Some(doc) = get_document(&engine, "visit-123".to_string()).await? {
183///     println!("Found: {}", doc.text);
184/// }
185/// ```
186pub async fn get_document(engine: &MemexEngine, id: String) -> Result<Option<SearchResult>> {
187    engine.get(&id).await
188}
189
190/// Delete a document by ID.
191///
192/// # Arguments
193/// * `engine` - The MemexEngine instance
194/// * `id` - Document identifier to delete
195///
196/// # Returns
197/// `ToolResult` indicating success or failure, with deletion status
198///
199/// # Example
200///
201/// ```rust,ignore
202/// let result = delete_document(&engine, "visit-123".to_string()).await?;
203/// if result.success {
204///     println!("Document deleted");
205/// }
206/// ```
207pub async fn delete_document(engine: &MemexEngine, id: String) -> Result<ToolResult> {
208    match engine.delete(&id).await {
209        Ok(true) => Ok(ToolResult::ok(format!(
210            "Document '{}' deleted successfully",
211            id
212        ))),
213        Ok(false) => Ok(ToolResult::ok_with_data(
214            format!("Document '{}' not found", id),
215            json!({"deleted": false}),
216        )),
217        Err(e) => Ok(ToolResult::error(format!(
218            "Failed to delete document '{}': {}",
219            id, e
220        ))),
221    }
222}
223
224/// Batch store multiple documents efficiently.
225///
226/// More efficient than calling `store_document()` multiple times as embeddings
227/// are generated in batches.
228///
229/// # Arguments
230/// * `engine` - The MemexEngine instance
231/// * `items` - Vector of items to store
232///
233/// # Returns
234/// `BatchResult` with success/failure counts
235///
236/// # Example
237///
238/// ```rust,ignore
239/// let items = vec![
240///     StoreItem::new("doc-1", "First document").with_metadata(json!({"type": "note"})),
241///     StoreItem::new("doc-2", "Second document").with_metadata(json!({"type": "note"})),
242/// ];
243/// let result = store_documents_batch(&engine, items).await?;
244/// println!("Stored {} documents", result.success_count);
245/// ```
246pub async fn store_documents_batch(
247    engine: &MemexEngine,
248    items: Vec<StoreItem>,
249) -> Result<BatchResult> {
250    engine.store_batch(items).await
251}
252
253/// Delete all documents matching a metadata filter.
254///
255/// This is the primary method for GDPR-compliant data deletion.
256///
257/// # Arguments
258/// * `engine` - The MemexEngine instance
259/// * `filter` - Metadata filter specifying which documents to delete
260///
261/// # Returns
262/// `ToolResult` with count of deleted documents
263///
264/// # Example
265///
266/// ```rust,ignore
267/// // GDPR request - delete all patient data
268/// let filter = MetaFilter::for_patient("P-456");
269/// let result = delete_documents_by_filter(&engine, filter).await?;
270/// if let Some(data) = result.data {
271///     println!("Deleted {} documents", data["deleted_count"]);
272/// }
273/// ```
274pub async fn delete_documents_by_filter(
275    engine: &MemexEngine,
276    filter: MetaFilter,
277) -> Result<ToolResult> {
278    match engine.delete_by_filter(filter.clone()).await {
279        Ok(count) => Ok(ToolResult::ok_with_data(
280            format!("Deleted {} documents matching filter", count),
281            json!({
282                "deleted_count": count,
283                "filter": filter,
284            }),
285        )),
286        Err(e) => Ok(ToolResult::error(format!(
287            "Failed to delete by filter: {}",
288            e
289        ))),
290    }
291}
292
293// =============================================================================
294// TOOL DEFINITIONS FOR MCP
295// =============================================================================
296
297/// MCP tool definition schema.
298///
299/// This structure mirrors the canonical MCP tool metadata emitted by the shared
300/// rust-memex transport layer.
301#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct ToolDefinition {
303    /// Tool name (used for invocation)
304    pub name: String,
305    /// Human-readable description of what the tool does
306    pub description: String,
307    /// JSON Schema for the tool's input parameters
308    #[serde(rename = "inputSchema", alias = "input_schema")]
309    pub input_schema: Value,
310}
311
312impl ToolDefinition {
313    /// Create a new tool definition
314    pub fn new(
315        name: impl Into<String>,
316        description: impl Into<String>,
317        input_schema: Value,
318    ) -> Self {
319        Self {
320            name: name.into(),
321            description: description.into(),
322            input_schema,
323        }
324    }
325}
326
327/// Get the canonical MCP tool definitions exposed by rust-memex transports.
328///
329/// This is derived from the shared transport contract so stdio, HTTP/SSE, and
330/// library consumers all see the same tool metadata.
331///
332/// # Example
333///
334/// ```rust,ignore
335/// let tools = tool_definitions();
336/// for tool in &tools {
337///     println!("Tool: {} - {}", tool.name, tool.description);
338/// }
339/// ```
340pub fn tool_definitions() -> Vec<ToolDefinition> {
341    crate::mcp_protocol::shared_tools_list_result()["tools"]
342        .as_array()
343        .expect("shared_tools_list_result().tools must be an array")
344        .iter()
345        .map(|tool| {
346            ToolDefinition::new(
347                tool["name"]
348                    .as_str()
349                    .expect("shared MCP tool definition missing name"),
350                tool["description"]
351                    .as_str()
352                    .expect("shared MCP tool definition missing description"),
353                tool.get("inputSchema")
354                    .cloned()
355                    .expect("shared MCP tool definition missing inputSchema"),
356            )
357        })
358        .collect()
359}
360
361// =============================================================================
362// TESTS
363// =============================================================================
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn test_tool_result_ok() {
371        let result = ToolResult::ok("Success");
372        assert!(result.success);
373        assert_eq!(result.message, "Success");
374        assert!(result.data.is_none());
375    }
376
377    #[test]
378    fn test_tool_result_ok_with_data() {
379        let result = ToolResult::ok_with_data("Success", json!({"count": 42}));
380        assert!(result.success);
381        assert_eq!(result.message, "Success");
382        assert_eq!(result.data.unwrap()["count"], 42);
383    }
384
385    #[test]
386    fn test_tool_result_error() {
387        let result = ToolResult::error("Something went wrong");
388        assert!(!result.success);
389        assert_eq!(result.message, "Something went wrong");
390        assert!(result.data.is_none());
391    }
392
393    #[test]
394    fn test_tool_definitions_count() {
395        let tools = tool_definitions();
396        assert_eq!(tools.len(), 12);
397    }
398
399    #[test]
400    fn test_tool_definitions_names() {
401        let tools = tool_definitions();
402        let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
403
404        assert!(names.contains(&"health"));
405        assert!(names.contains(&"rag_index"));
406        assert!(names.contains(&"memory_upsert"));
407        assert!(names.contains(&"memory_search"));
408        assert!(names.contains(&"memory_get"));
409        assert!(names.contains(&"memory_delete"));
410        assert!(names.contains(&"memory_purge_namespace"));
411        assert!(names.contains(&"namespace_create_token"));
412        assert!(names.contains(&"namespace_revoke_token"));
413        assert!(names.contains(&"namespace_list_protected"));
414        assert!(names.contains(&"namespace_security_status"));
415        assert!(names.contains(&"dive"));
416    }
417
418    #[test]
419    fn test_tool_definitions_have_required_fields() {
420        let tools = tool_definitions();
421
422        for tool in tools {
423            assert!(!tool.name.is_empty(), "Tool name should not be empty");
424            assert!(
425                !tool.description.is_empty(),
426                "Tool description should not be empty"
427            );
428            assert!(
429                tool.input_schema.is_object(),
430                "Input schema should be an object"
431            );
432            assert!(
433                tool.input_schema.get("type").is_some(),
434                "Input schema should have a type field"
435            );
436            assert!(
437                tool.input_schema.get("properties").is_some(),
438                "Input schema should have properties"
439            );
440        }
441    }
442
443    #[test]
444    fn test_tool_definitions_match_shared_mcp_contract() {
445        let serialized = serde_json::to_value(tool_definitions()).unwrap();
446        assert_eq!(
447            serialized,
448            crate::mcp_protocol::shared_tools_list_result()["tools"]
449        );
450    }
451
452    #[test]
453    fn test_tool_result_serialization() {
454        let result = ToolResult::ok_with_data("Success", json!({"id": "doc-1"}));
455        let json_str = serde_json::to_string(&result).unwrap();
456
457        assert!(json_str.contains("\"success\":true"));
458        assert!(json_str.contains("\"message\":\"Success\""));
459        assert!(json_str.contains("\"data\""));
460    }
461
462    #[test]
463    fn test_tool_definition_creation() {
464        let tool = ToolDefinition::new(
465            "test_tool",
466            "A test tool",
467            json!({
468                "type": "object",
469                "properties": {
470                    "input": { "type": "string" }
471                }
472            }),
473        );
474
475        assert_eq!(tool.name, "test_tool");
476        assert_eq!(tool.description, "A test tool");
477        assert!(tool.input_schema["properties"]["input"].is_object());
478    }
479
480    #[test]
481    fn test_tool_definition_serializes_with_mcp_field_name() {
482        let tool = ToolDefinition::new(
483            "test_tool",
484            "A test tool",
485            json!({
486                "type": "object",
487                "properties": {}
488            }),
489        );
490
491        let value = serde_json::to_value(tool).unwrap();
492        assert!(value.get("inputSchema").is_some());
493        assert!(value.get("input_schema").is_none());
494    }
495}