Skip to main content

do_memory_mcp/mcp/tools/embeddings/tool/
definitions.rs

1//! Embedding tools implementation.
2
3use crate::types::Tool;
4use do_memory_core::SelfLearningMemory;
5use serde_json::json;
6use std::sync::Arc;
7
8/// Embedding tools implementation
9pub struct EmbeddingTools {
10    pub memory: Arc<SelfLearningMemory>,
11}
12
13impl EmbeddingTools {
14    pub fn new(memory: Arc<SelfLearningMemory>) -> Self {
15        Self { memory }
16    }
17}
18
19/// Get the tool definition for generate_embedding
20pub fn generate_embedding_tool() -> Tool {
21    Tool::new(
22        "generate_embedding".to_string(),
23        "Generate an embedding vector for text using the configured embedding provider."
24            .to_string(),
25        json!({
26            "type": "object",
27            "properties": {
28                "text": {
29                    "type": "string",
30                    "description": "Text to generate embedding for"
31                },
32                "normalize": {
33                    "type": "boolean",
34                    "default": true,
35                    "description": "Whether to normalize the embedding vector to unit length"
36                }
37            },
38            "required": ["text"]
39        }),
40    )
41}
42
43/// Get the tool definition for search_by_embedding
44pub fn search_by_embedding_tool() -> Tool {
45    Tool::new(
46        "search_by_embedding".to_string(),
47        "Search episodes by embedding similarity using a pre-computed embedding vector."
48            .to_string(),
49        json!({
50            "type": "object",
51            "properties": {
52                "embedding": {
53                    "type": "array",
54                    "items": {"type": "number"},
55                    "description": "Embedding vector to search with"
56                },
57                "limit": {
58                    "type": "integer",
59                    "minimum": 1,
60                    "maximum": 100,
61                    "default": 10,
62                    "description": "Maximum number of results"
63                },
64                "similarity_threshold": {
65                    "type": "number",
66                    "minimum": 0.0,
67                    "maximum": 1.0,
68                    "default": 0.7,
69                    "description": "Minimum similarity score"
70                },
71                "domain": {"type": "string", "description": "Filter by domain"},
72                "task_type": {"type": "string", "description": "Filter by task type"}
73            },
74            "required": ["embedding"]
75        }),
76    )
77}
78
79/// Get the tool definition for embedding_provider_status
80pub fn embedding_provider_status_tool() -> Tool {
81    Tool::new(
82        "embedding_provider_status".to_string(),
83        "Get detailed status information about the configured embedding provider.".to_string(),
84        json!({
85            "type": "object",
86            "properties": {
87                "test_connectivity": {
88                    "type": "boolean",
89                    "default": false,
90                    "description": "Whether to perform a test embedding to verify connectivity"
91                }
92            },
93            "additionalProperties": false
94        }),
95    )
96}
97
98/// Get the tool definition for configure_embeddings
99pub fn configure_embeddings_tool() -> Tool {
100    Tool::new(
101        "configure_embeddings".to_string(),
102        "Configure semantic embedding provider for enhanced memory retrieval.".to_string(),
103        json!({
104            "type": "object",
105            "properties": {
106                "provider": {
107                    "type": "string",
108                    "enum": ["openai", "local", "mistral", "azure", "cohere"],
109                    "description": "Embedding provider to use"
110                },
111                "model": {"type": "string", "description": "Model name"},
112                "api_key_env": {"type": "string", "description": "API key env var"},
113                "similarity_threshold": {
114                    "type": "number", "minimum": 0.0, "maximum": 1.0, "default": 0.7,
115                    "description": "Min similarity score"
116                },
117                "batch_size": {
118                    "type": "integer", "minimum": 1, "maximum": 2048, "default": 32,
119                    "description": "Batch size"
120                },
121                "base_url": {"type": "string", "description": "Custom base URL"},
122                "api_version": {"type": "string", "description": "API version"},
123                "resource_name": {"type": "string", "description": "Azure resource"},
124                "deployment_name": {"type": "string", "description": "Azure deployment"}
125            },
126            "required": ["provider"]
127        }),
128    )
129}
130
131/// Get the tool definition for query_semantic_memory
132pub fn query_semantic_memory_tool() -> Tool {
133    Tool::new(
134        "query_semantic_memory".to_string(),
135        "Search episodic memory using semantic similarity with embeddings.".to_string(),
136        json!({
137            "type": "object",
138            "properties": {
139                "query": {"type": "string", "description": "Search query"},
140                "limit": {
141                    "type": "integer", "minimum": 1, "maximum": 100, "default": 10,
142                    "description": "Max results"
143                },
144                "similarity_threshold": {
145                    "type": "number", "minimum": 0.0, "maximum": 1.0, "default": 0.7,
146                    "description": "Min similarity"
147                },
148                "domain": {"type": "string", "description": "Filter by domain"},
149                "task_type": {"type": "string", "description": "Filter by task type"}
150            },
151            "required": ["query"]
152        }),
153    )
154}
155
156/// Get the tool definition for test_embeddings
157pub fn test_embeddings_tool() -> Tool {
158    Tool::new(
159        "test_embeddings".to_string(),
160        "Test embedding provider connectivity.".to_string(),
161        json!({"type": "object", "properties": {}, "additionalProperties": false}),
162    )
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_configure_embeddings_tool_definition() {
171        let tool = configure_embeddings_tool();
172        assert_eq!(tool.name, "configure_embeddings");
173    }
174
175    #[test]
176    fn test_query_semantic_memory_tool_definition() {
177        let tool = query_semantic_memory_tool();
178        assert_eq!(tool.name, "query_semantic_memory");
179    }
180
181    #[test]
182    fn test_test_embeddings_tool_definition() {
183        let tool = test_embeddings_tool();
184        assert_eq!(tool.name, "test_embeddings");
185    }
186}