Skip to main content

do_memory_mcp/mcp/tools/embeddings/tool/execute/
generate.rs

1//! Generate embedding and search by embedding tool implementations.
2
3use super::super::definitions::EmbeddingTools;
4use crate::mcp::tools::embeddings::types::{
5    EmbeddingSearchResult, GenerateEmbeddingInput, GenerateEmbeddingOutput, SearchByEmbeddingInput,
6    SearchByEmbeddingOutput,
7};
8use anyhow::{Result, anyhow};
9use do_memory_core::TaskOutcome;
10use tracing::{debug, info, instrument, warn};
11
12impl EmbeddingTools {
13    /// Execute the generate_embedding tool
14    #[instrument(skip(self, input), fields(text_len = input.text.len()))]
15    pub async fn execute_generate_embedding(
16        &self,
17        input: GenerateEmbeddingInput,
18    ) -> Result<GenerateEmbeddingOutput> {
19        let start_time = std::time::Instant::now();
20
21        info!("Generating embedding for text ({} chars)", input.text.len());
22
23        // Check if semantic_service is available
24        if let Some(semantic_service) = self.memory.semantic_service() {
25            // Generate the embedding
26            let mut embedding = semantic_service
27                .provider
28                .embed_text(&input.text)
29                .await
30                .map_err(|e| anyhow!("Failed to generate embedding: {}", e))?;
31
32            let config = semantic_service.config();
33            let model_name = config.provider.model_name();
34            let dimension = config.provider.effective_dimension();
35            let provider = format!("{:?}", config.provider);
36
37            // Normalize if requested
38            let normalized = input.normalize;
39            if normalized {
40                embedding = do_memory_core::embeddings::normalize_vector(embedding);
41            }
42
43            let generation_time_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
44
45            debug!(
46                "Generated {}-dimensional embedding in {}ms",
47                dimension, generation_time_ms
48            );
49
50            return Ok(GenerateEmbeddingOutput {
51                embedding,
52                dimension,
53                model: model_name,
54                provider,
55                generation_time_ms,
56                normalized,
57                token_count: None, // Would need tokenizer integration
58            });
59        }
60
61        // No semantic service configured
62        warn!("Semantic service not available, cannot generate embedding");
63        Err(anyhow!(
64            "Semantic embeddings not configured. Use configure_embeddings first."
65        ))
66    }
67
68    /// Execute the search_by_embedding tool
69    #[instrument(skip(self, input), fields(embedding_dim = input.embedding.len()))]
70    pub async fn execute_search_by_embedding(
71        &self,
72        input: SearchByEmbeddingInput,
73    ) -> Result<SearchByEmbeddingOutput> {
74        let start_time = std::time::Instant::now();
75
76        info!(
77            "Searching by embedding (dimension: {}, limit: {}, threshold: {})",
78            input.embedding.len(),
79            input.limit,
80            input.similarity_threshold
81        );
82
83        // Validate embedding dimension
84        let expected_dimension = if let Some(semantic_service) = self.memory.semantic_service() {
85            semantic_service.config().provider.effective_dimension()
86        } else {
87            384 // Default dimension
88        };
89
90        if input.embedding.len() != expected_dimension {
91            return Err(anyhow!(
92                "Embedding dimension mismatch: got {}, expected {}. Use the same model that generated your embeddings.",
93                input.embedding.len(),
94                expected_dimension
95            ));
96        }
97
98        // Check if semantic_service is available
99        if let Some(semantic_service) = self.memory.semantic_service() {
100            let config = semantic_service.config();
101            let provider = format!("{:?}", config.provider);
102
103            // Search for similar episodes using the embedding directly
104            let similar_episodes = semantic_service
105                .find_episodes_by_embedding(
106                    input.embedding.clone(),
107                    input.limit,
108                    input.similarity_threshold,
109                )
110                .await
111                .map_err(|e| anyhow!("Failed to search by embedding: {}", e))?;
112
113            // Convert to search results
114            let results: Vec<EmbeddingSearchResult> = similar_episodes
115                .into_iter()
116                .map(|result| {
117                    let episode = result.item;
118                    let outcome = episode.outcome.as_ref().map(|o| match o {
119                        TaskOutcome::Success { verdict, .. } => {
120                            format!("Success: {}", verdict)
121                        }
122                        TaskOutcome::PartialSuccess { verdict, .. } => {
123                            format!("Partial: {}", verdict)
124                        }
125                        TaskOutcome::Failure { reason, .. } => {
126                            format!("Failure: {}", reason)
127                        }
128                    });
129
130                    EmbeddingSearchResult {
131                        episode_id: episode.episode_id.to_string(),
132                        similarity_score: result.similarity,
133                        task_description: episode.task_description.clone(),
134                        domain: episode.context.domain.clone(),
135                        task_type: format!("{:?}", episode.task_type),
136                        outcome,
137                        timestamp: episode.start_time.timestamp(),
138                    }
139                })
140                .collect();
141
142            let search_time_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
143
144            debug!(
145                "Embedding search completed in {}ms, found {} results",
146                search_time_ms,
147                results.len()
148            );
149
150            return Ok(SearchByEmbeddingOutput {
151                results_found: results.len(),
152                results,
153                embedding_dimension: expected_dimension,
154                search_time_ms,
155                provider,
156            });
157        }
158
159        // No semantic service configured - fallback to standard retrieval with warning
160        warn!("Semantic service not available, cannot search by embedding");
161        Err(anyhow!(
162            "Semantic embeddings not configured. Use configure_embeddings first to enable embedding-based search."
163        ))
164    }
165}