memvid_core/
graph_search.rs

1//! Graph-aware search combining MemoryCards/Logic-Mesh with vector search.
2//!
3//! This module provides hybrid retrieval that can:
4//! 1. Parse natural language queries for relational patterns
5//! 2. Match patterns against entity state (MemoryCards) or graph (Logic-Mesh)
6//! 3. Combine graph-filtered candidates with vector ranking
7
8use std::collections::{HashMap, HashSet};
9
10use crate::types::{
11    GraphMatchResult, GraphPattern, HybridSearchHit, PatternTerm, QueryPlan, TriplePattern,
12    SearchRequest,
13};
14use crate::{FrameId, Memvid, Result};
15
16/// Query planner that analyzes queries and creates execution plans.
17#[derive(Debug, Default)]
18pub struct QueryPlanner {
19    /// Patterns for detecting relational queries
20    entity_patterns: Vec<EntityPattern>,
21}
22
23/// Pattern for detecting entity-related queries.
24#[derive(Debug, Clone)]
25struct EntityPattern {
26    /// Keywords that trigger this pattern
27    keywords: Vec<&'static str>,
28    /// Slot to query
29    slot: &'static str,
30    /// Whether the pattern looks for a specific value
31    needs_value: bool,
32}
33
34impl QueryPlanner {
35    /// Create a new query planner.
36    #[must_use]
37    pub fn new() -> Self {
38        let mut planner = Self::default();
39        planner.init_patterns();
40        planner
41    }
42
43    fn init_patterns(&mut self) {
44        // Location patterns
45        self.entity_patterns.push(EntityPattern {
46            keywords: vec![
47                "who lives in",
48                "people in",
49                "users in",
50                "from",
51                "located in",
52                "based in",
53            ],
54            slot: "location",
55            needs_value: true,
56        });
57
58        // Employer/workplace patterns
59        self.entity_patterns.push(EntityPattern {
60            keywords: vec![
61                "who works at",
62                "employees of",
63                "people at",
64                "works for",
65                "employed by",
66            ],
67            slot: "workplace", // OpenAI enrichment uses "workplace" not "employer"
68            needs_value: true,
69        });
70
71        // Preference patterns
72        self.entity_patterns.push(EntityPattern {
73            keywords: vec![
74                "who likes",
75                "who loves",
76                "fans of",
77                "people who like",
78                "people who love",
79            ],
80            slot: "preference",
81            needs_value: true,
82        });
83
84        // Entity state patterns
85        self.entity_patterns.push(EntityPattern {
86            keywords: vec![
87                "what is",
88                "where does",
89                "who is",
90                "what does",
91            ],
92            slot: "",
93            needs_value: false,
94        });
95    }
96
97    /// Analyze a query and produce an execution plan.
98    #[must_use]
99    pub fn plan(&self, query: &str, top_k: usize) -> QueryPlan {
100        let query_lower = query.to_lowercase();
101
102        // Try to detect relational patterns
103        if let Some(pattern) = self.detect_pattern(&query_lower, query) {
104            if pattern.triples.is_empty() {
105                // No specific pattern found, use vector search
106                QueryPlan::vector_only(Some(query.to_string()), None, top_k)
107            } else {
108                // Found relational pattern - use hybrid search
109                QueryPlan::hybrid(pattern, Some(query.to_string()), None, top_k)
110            }
111        } else {
112            // Default to vector-only search
113            QueryPlan::vector_only(Some(query.to_string()), None, top_k)
114        }
115    }
116
117    fn detect_pattern(&self, query_lower: &str, _original: &str) -> Option<GraphPattern> {
118        let mut pattern = GraphPattern::new();
119
120        for ep in &self.entity_patterns {
121            for keyword in &ep.keywords {
122                if query_lower.contains(keyword) {
123                    // Extract the value after the keyword
124                    if let Some(pos) = query_lower.find(keyword) {
125                        let after = &query_lower[pos + keyword.len()..];
126                        let value = extract_value(after);
127
128                        if !value.is_empty() && ep.needs_value {
129                            // Create pattern: ?entity :slot "value"
130                            pattern.add(TriplePattern::any_slot_value("entity", ep.slot, &value));
131                            return Some(pattern);
132                        }
133                    }
134                }
135            }
136        }
137
138        // Check for entity-specific queries like "alice's employer" or "what is alice's job"
139        if let Some((entity, slot)) = extract_possessive_query(query_lower) {
140            pattern.add(TriplePattern::entity_slot_any(&entity, &slot, "value"));
141            return Some(pattern);
142        }
143
144        Some(pattern)
145    }
146}
147
148/// Extract a value from text after a keyword.
149fn extract_value(text: &str) -> String {
150    let trimmed = text.trim();
151    // Take words until we hit a common query continuation
152    let stop_words = ["and", "or", "who", "what", "that", "?"];
153    let mut words = Vec::new();
154
155    for word in trimmed.split_whitespace() {
156        let clean = word.trim_matches(|c: char| !c.is_alphanumeric() && c != '-');
157        if stop_words.contains(&clean.to_lowercase().as_str()) {
158            break;
159        }
160        if !clean.is_empty() {
161            words.push(clean);
162        }
163        // Stop after a few words
164        if words.len() >= 3 {
165            break;
166        }
167    }
168
169    words.join(" ")
170}
171
172/// Extract entity and slot from possessive queries like "alice's employer".
173fn extract_possessive_query(query: &str) -> Option<(String, String)> {
174    // Pattern: "X's Y" or "X's Y is"
175    if let Some(pos) = query.find("'s ") {
176        let entity = query[..pos].split_whitespace().last()?;
177        let after = &query[pos + 3..];
178        let slot = after.split_whitespace().next()?;
179
180        // Map common slot aliases
181        let slot = match slot {
182            "job" | "work" | "employer" | "role" | "company" => "workplace",
183            "home" | "city" | "address" => "location",
184            "favorite" => "preference",
185            "wife" | "husband" | "spouse" | "partner" => "spouse",
186            other => other,
187        };
188
189        return Some((entity.to_string(), slot.to_string()));
190    }
191    None
192}
193
194/// Graph matcher that executes patterns against MemoryCards.
195pub struct GraphMatcher<'a> {
196    memvid: &'a Memvid,
197}
198
199impl<'a> GraphMatcher<'a> {
200    /// Create a new graph matcher.
201    pub fn new(memvid: &'a Memvid) -> Self {
202        Self { memvid }
203    }
204
205    /// Execute a graph pattern and return matching results.
206    pub fn execute(&self, pattern: &GraphPattern) -> Vec<GraphMatchResult> {
207        let mut results = Vec::new();
208
209        for triple in &pattern.triples {
210            let matches = self.match_triple(triple);
211            results.extend(matches);
212        }
213
214        // Deduplicate by entity
215        let mut seen = HashSet::new();
216        results.retain(|r| seen.insert(r.entity.clone()));
217
218        results
219    }
220
221    fn match_triple(&self, triple: &TriplePattern) -> Vec<GraphMatchResult> {
222        let mut results = Vec::new();
223
224        match (&triple.subject, &triple.predicate, &triple.object) {
225            // Pattern: ?entity :slot "value" - find entities with this slot value
226            (PatternTerm::Variable(var), PatternTerm::Literal(slot), PatternTerm::Literal(value)) => {
227                // Iterate all entities and check for matching slot value
228                for entity in self.memvid.memory_entities() {
229                    let cards = self.memvid.get_entity_memories(&entity);
230                    for card in cards {
231                        if card.slot.to_lowercase() == *slot
232                            && card.value.to_lowercase().contains(&value.to_lowercase())
233                        {
234                            let mut result =
235                                GraphMatchResult::new(entity.clone(), vec![card.source_frame_id], 1.0);
236                            result.bind(var, entity.clone());
237                            results.push(result);
238                            break; // One match per entity
239                        }
240                    }
241                }
242            }
243
244            // Pattern: "entity" :slot ?value - get entity's slot value
245            (PatternTerm::Literal(entity), PatternTerm::Literal(slot), PatternTerm::Variable(var)) => {
246                if let Some(card) = self.memvid.get_current_memory(entity, slot) {
247                    let mut result =
248                        GraphMatchResult::new(entity.clone(), vec![card.source_frame_id], 1.0);
249                    result.bind(var, card.value.clone());
250                    results.push(result);
251                }
252            }
253
254            // Pattern: "entity" :slot "value" - check if entity has this exact value
255            (PatternTerm::Literal(entity), PatternTerm::Literal(slot), PatternTerm::Literal(value)) => {
256                if let Some(card) = self.memvid.get_current_memory(entity, slot) {
257                    if card.value.to_lowercase().contains(&value.to_lowercase()) {
258                        let result =
259                            GraphMatchResult::new(entity.clone(), vec![card.source_frame_id], 1.0);
260                        results.push(result);
261                    }
262                }
263            }
264
265            _ => {
266                // Other patterns not yet implemented
267            }
268        }
269
270        results
271    }
272
273    /// Get frame IDs from graph matches for use in vector search filtering.
274    #[must_use]
275    pub fn get_candidate_frames(&self, matches: &[GraphMatchResult]) -> Vec<FrameId> {
276        let mut frame_ids: Vec<FrameId> = matches
277            .iter()
278            .flat_map(|m| m.frame_ids.iter().copied())
279            .collect();
280        frame_ids.sort_unstable();
281        frame_ids.dedup();
282        frame_ids
283    }
284
285    /// Get matched entities for context.
286    #[must_use]
287    pub fn get_matched_entities(&self, matches: &[GraphMatchResult]) -> HashMap<FrameId, String> {
288        let mut map = HashMap::new();
289        for m in matches {
290            for &fid in &m.frame_ids {
291                map.insert(fid, m.entity.clone());
292            }
293        }
294        map
295    }
296}
297
298/// Execute a hybrid search: graph filter + vector ranking.
299pub fn hybrid_search(
300    memvid: &mut Memvid,
301    plan: &QueryPlan,
302) -> Result<Vec<HybridSearchHit>> {
303    match plan {
304        QueryPlan::VectorOnly { query_text, top_k, .. } => {
305            // Fall back to regular lexical search
306            let query = query_text.as_deref().unwrap_or("");
307            let request = SearchRequest {
308                query: query.to_string(),
309                top_k: *top_k,
310                snippet_chars: 200,
311                uri: None,
312                scope: None,
313                cursor: None,
314                #[cfg(feature = "temporal_track")]
315                temporal: None,
316                as_of_frame: None,
317                as_of_ts: None,
318            };
319            let response = memvid.search(request)?;
320            Ok(response.hits
321                .iter()
322                .map(|h| {
323                    let score = h.score.unwrap_or(0.0);
324                    HybridSearchHit {
325                        frame_id: h.frame_id,
326                        score,
327                        graph_score: 0.0,
328                        vector_score: score,
329                        matched_entity: None,
330                        preview: Some(h.text.clone()),
331                    }
332                })
333                .collect())
334        }
335
336        QueryPlan::GraphOnly { pattern, limit } => {
337            let matcher = GraphMatcher::new(memvid);
338            let matches = matcher.execute(pattern);
339
340            Ok(matches
341                .into_iter()
342                .take(*limit)
343                .map(|m| HybridSearchHit {
344                    frame_id: m.frame_ids.first().copied().unwrap_or(0),
345                    score: m.confidence,
346                    graph_score: m.confidence,
347                    vector_score: 0.0,
348                    matched_entity: Some(m.entity),
349                    preview: None,
350                })
351                .collect())
352        }
353
354        QueryPlan::Hybrid {
355            graph_filter,
356            query_text,
357            top_k,
358            ..
359        } => {
360            // Step 1: Execute graph pattern to get candidate frames
361            let matcher = GraphMatcher::new(memvid);
362            let matches = matcher.execute(graph_filter);
363            let entity_map = matcher.get_matched_entities(&matches);
364            let candidate_frames = matcher.get_candidate_frames(&matches);
365
366            if candidate_frames.is_empty() {
367                // No graph matches - fall back to lexical search
368                let query = query_text.as_deref().unwrap_or("");
369                let request = SearchRequest {
370                    query: query.to_string(),
371                    top_k: *top_k,
372                    snippet_chars: 200,
373                    uri: None,
374                    scope: None,
375                    cursor: None,
376                    #[cfg(feature = "temporal_track")]
377                    temporal: None,
378                    as_of_frame: None,
379                    as_of_ts: None,
380                };
381                let response = memvid.search(request)?;
382                return Ok(response.hits
383                    .iter()
384                    .map(|h| {
385                        let score = h.score.unwrap_or(0.0);
386                        HybridSearchHit {
387                            frame_id: h.frame_id,
388                            score,
389                            graph_score: 0.0,
390                            vector_score: score,
391                            matched_entity: None,
392                            preview: Some(h.text.clone()),
393                        }
394                    })
395                    .collect());
396            }
397
398            // Step 2: Return graph matches directly with frame previews
399            // Graph matching already found the relevant frames - return them
400            let mut hybrid_hits: Vec<HybridSearchHit> = Vec::new();
401
402            for &frame_id in &candidate_frames {
403                let matched_entity = entity_map.get(&frame_id).cloned();
404
405                // Get frame preview if possible
406                let preview = memvid.frame_preview_by_id(frame_id).ok();
407
408                hybrid_hits.push(HybridSearchHit {
409                    frame_id,
410                    score: 1.0, // Graph match score
411                    graph_score: 1.0,
412                    vector_score: 0.0,
413                    matched_entity,
414                    preview,
415                });
416            }
417
418            Ok(hybrid_hits.into_iter().take(*top_k).collect())
419        }
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[test]
428    fn test_query_planner_detects_location() {
429        let planner = QueryPlanner::new();
430        let plan = planner.plan("who lives in San Francisco", 10);
431
432        match plan {
433            QueryPlan::Hybrid { graph_filter, .. } => {
434                assert!(!graph_filter.is_empty());
435                let triple = &graph_filter.triples[0];
436                assert!(matches!(&triple.predicate, PatternTerm::Literal(s) if s == "location"));
437            }
438            _ => panic!("Expected hybrid plan for location query"),
439        }
440    }
441
442    #[test]
443    fn test_query_planner_detects_workplace() {
444        let planner = QueryPlanner::new();
445        let plan = planner.plan("who works at Google", 10);
446
447        match plan {
448            QueryPlan::Hybrid { graph_filter, .. } => {
449                assert!(!graph_filter.is_empty());
450                let triple = &graph_filter.triples[0];
451                assert!(matches!(&triple.predicate, PatternTerm::Literal(s) if s == "workplace"));
452            }
453            _ => panic!("Expected hybrid plan for workplace query"),
454        }
455    }
456
457    #[test]
458    fn test_query_planner_possessive() {
459        let planner = QueryPlanner::new();
460        let plan = planner.plan("what is alice's employer", 10);
461
462        match plan {
463            QueryPlan::Hybrid { graph_filter, .. } => {
464                assert!(!graph_filter.is_empty());
465                let triple = &graph_filter.triples[0];
466                assert!(matches!(&triple.subject, PatternTerm::Literal(s) if s == "alice"));
467            }
468            _ => panic!("Expected hybrid plan for possessive query"),
469        }
470    }
471
472    #[test]
473    fn test_extract_value() {
474        assert_eq!(extract_value("San Francisco and"), "San Francisco");
475        assert_eq!(extract_value("Google who"), "Google");
476        assert_eq!(extract_value("New York City"), "New York City");
477    }
478
479    #[test]
480    fn test_extract_possessive() {
481        assert_eq!(
482            extract_possessive_query("what is alice's job"),
483            Some(("alice".to_string(), "workplace".to_string()))
484        );
485        assert_eq!(
486            extract_possessive_query("bob's location"),
487            Some(("bob".to_string(), "location".to_string()))
488        );
489    }
490}