Skip to main content

do_memory_mcp/mcp/tools/embeddings/tool/execute/
query.rs

1//! Query semantic memory tool implementation.
2
3use super::super::definitions::EmbeddingTools;
4use crate::mcp::tools::embeddings::types::{
5    QuerySemanticMemoryInput, QuerySemanticMemoryOutput, SemanticResult,
6};
7use anyhow::Result;
8use do_memory_core::{ComplexityLevel, TaskContext, TaskOutcome};
9use tracing::{debug, info, instrument, warn};
10
11impl EmbeddingTools {
12    /// Execute the query_semantic_memory tool
13    #[instrument(skip(self, input), fields(query = %input.query))]
14    pub async fn execute_query_semantic_memory(
15        &self,
16        input: QuerySemanticMemoryInput,
17    ) -> Result<QuerySemanticMemoryOutput> {
18        let start_time = std::time::Instant::now();
19
20        info!("Executing semantic memory query: '{}'", input.query);
21
22        // Clone domain once to avoid ownership issues
23        let domain = input
24            .domain
25            .clone()
26            .unwrap_or_else(|| "general".to_string());
27
28        // Check if semantic_service is available
29        if let Some(semantic_service) = self.memory.semantic_service() {
30            let context = TaskContext {
31                domain: domain.clone(),
32                language: None,
33                framework: None,
34                complexity: ComplexityLevel::Moderate,
35                tags: input
36                    .task_type
37                    .as_ref()
38                    .map(|t| vec![t.clone()])
39                    .unwrap_or_default(),
40            };
41
42            let limit = input.limit.unwrap_or(10);
43
44            // Use the semantic service to find similar episodes
45            let similar_episodes = match semantic_service
46                .find_similar_episodes(&input.query, &context, limit)
47                .await
48            {
49                Ok(episodes) => episodes,
50                Err(e) => {
51                    warn!("Semantic search failed: {}, using fallback", e);
52                    // Fallback to standard retrieval
53                    let fallback_context = TaskContext {
54                        domain,
55                        language: None,
56                        framework: None,
57                        complexity: ComplexityLevel::Moderate,
58                        tags: input
59                            .task_type
60                            .as_ref()
61                            .map(|t| vec![t.clone()])
62                            .unwrap_or_default(),
63                    };
64                    self.memory
65                        .retrieve_relevant_context(input.query.clone(), fallback_context, limit)
66                        .await
67                        .into_iter()
68                        .map(|arc_ep| {
69                            // Dereference Arc<Episode> to Episode
70                            let episode = arc_ep.as_ref().clone();
71                            do_memory_core::embeddings::SimilaritySearchResult {
72                                item: episode,
73                                similarity: 0.5,
74                                metadata: do_memory_core::embeddings::SimilarityMetadata::default(),
75                            }
76                        })
77                        .collect()
78                }
79            };
80
81            // Convert to semantic results with actual similarity scores
82            let results: Vec<SemanticResult> = similar_episodes
83                .into_iter()
84                .map(|result| {
85                    let episode = result.item;
86                    let outcome = episode.outcome.as_ref().map(|o| match o {
87                        TaskOutcome::Success { verdict, .. } => {
88                            format!("Success: {}", verdict)
89                        }
90                        TaskOutcome::PartialSuccess { verdict, .. } => {
91                            format!("Partial: {}", verdict)
92                        }
93                        TaskOutcome::Failure { reason, .. } => {
94                            format!("Failure: {}", reason)
95                        }
96                    });
97
98                    SemanticResult {
99                        episode_id: episode.episode_id.to_string(),
100                        similarity_score: result.similarity,
101                        task_description: episode.task_description.clone(),
102                        domain: episode.context.domain.clone(),
103                        task_type: format!("{:?}", episode.task_type),
104                        outcome,
105                        timestamp: episode.start_time.timestamp(),
106                    }
107                })
108                .collect();
109
110            let query_time_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
111
112            debug!(
113                "Semantic query completed in {}ms, found {} results",
114                query_time_ms,
115                results.len()
116            );
117
118            let config = semantic_service.config();
119            return Ok(QuerySemanticMemoryOutput {
120                results_found: results.len(),
121                results,
122                embedding_dimension: config.provider.effective_dimension(),
123                query_time_ms,
124                provider: format!("{:?}", config.provider),
125            });
126        }
127
128        // Fallback if no semantic service configured
129        warn!("Semantic service not available, using standard retrieval as fallback");
130
131        let context = TaskContext {
132            domain: input.domain.unwrap_or_else(|| "general".to_string()),
133            language: None,
134            framework: None,
135            complexity: ComplexityLevel::Moderate,
136            tags: input
137                .task_type
138                .as_ref()
139                .map(|t| vec![t.clone()])
140                .unwrap_or_default(),
141        };
142
143        let limit = input.limit.unwrap_or(10);
144
145        let arc_episodes = self
146            .memory
147            .retrieve_relevant_context(input.query.clone(), context, limit)
148            .await;
149
150        // Convert Arc<Episode> episodes to semantic results with simulated scores
151        let results: Vec<SemanticResult> = arc_episodes
152            .into_iter()
153            .enumerate()
154            .map(|(idx, arc_ep)| {
155                // Dereference Arc<Episode> to access Episode
156                let episode = arc_ep.as_ref();
157                // Simulate similarity score (decreasing with rank)
158                let similarity_score = 0.95 - (idx as f32 * 0.05);
159
160                let outcome = episode.outcome.as_ref().map(|o| match o {
161                    TaskOutcome::Success { verdict, .. } => {
162                        format!("Success: {}", verdict)
163                    }
164                    TaskOutcome::PartialSuccess { verdict, .. } => {
165                        format!("Partial: {}", verdict)
166                    }
167                    TaskOutcome::Failure { reason, .. } => {
168                        format!("Failure: {}", reason)
169                    }
170                });
171
172                SemanticResult {
173                    episode_id: episode.episode_id.to_string(),
174                    similarity_score,
175                    task_description: episode.task_description.clone(),
176                    domain: episode.context.domain.clone(),
177                    task_type: format!("{:?}", episode.task_type),
178                    outcome,
179                    timestamp: episode.start_time.timestamp(),
180                }
181            })
182            .collect();
183
184        let query_time_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
185
186        debug!(
187            "Semantic query completed in {}ms, found {} results",
188            query_time_ms,
189            results.len()
190        );
191
192        Ok(QuerySemanticMemoryOutput {
193            results_found: results.len(),
194            results,
195            embedding_dimension: 384, // Default dimension
196            query_time_ms,
197            provider: "fallback-standard-retrieval".to_string(),
198        })
199    }
200}