Skip to main content

memvid_core/
graph_search.rs

1//! Graph-aware search combining MemoryCards/Logic-Mesh with vector search.
2//!
3//! This module provides hybrid retrieval that can:
4//! 1. Parse natural language queries for relational patterns
5//! 2. Match patterns against entity state (`MemoryCards`) or graph (Logic-Mesh)
6//! 3. Combine graph-filtered candidates with vector ranking
7
8use std::collections::{HashMap, HashSet};
9
10use crate::types::{
11    GraphMatchResult, GraphPattern, HybridSearchHit, PatternTerm, QueryPlan, SearchRequest,
12    TriplePattern,
13};
14use crate::{FrameId, Memvid, Result};
15
16/// Query planner that analyzes queries and creates execution plans.
17#[derive(Debug, Default)]
18pub struct QueryPlanner {
19    /// Patterns for detecting relational queries
20    entity_patterns: Vec<EntityPattern>,
21}
22
23/// Pattern for detecting entity-related queries.
24#[derive(Debug, Clone)]
25struct EntityPattern {
26    /// Keywords that trigger this pattern
27    keywords: Vec<&'static str>,
28    /// Slot to query
29    slot: &'static str,
30    /// Whether the pattern looks for a specific value
31    needs_value: bool,
32}
33
34impl QueryPlanner {
35    /// Create a new query planner.
36    #[must_use]
37    pub fn new() -> Self {
38        let mut planner = Self::default();
39        planner.init_patterns();
40        planner
41    }
42
43    fn init_patterns(&mut self) {
44        // Location patterns
45        self.entity_patterns.push(EntityPattern {
46            keywords: vec![
47                "who lives in",
48                "people in",
49                "users in",
50                "from",
51                "located in",
52                "based in",
53            ],
54            slot: "location",
55            needs_value: true,
56        });
57
58        // Employer/workplace patterns
59        self.entity_patterns.push(EntityPattern {
60            keywords: vec![
61                "who works at",
62                "employees of",
63                "people at",
64                "works for",
65                "employed by",
66            ],
67            slot: "workplace", // OpenAI enrichment uses "workplace" not "employer"
68            needs_value: true,
69        });
70
71        // Preference patterns
72        self.entity_patterns.push(EntityPattern {
73            keywords: vec![
74                "who likes",
75                "who loves",
76                "fans of",
77                "people who like",
78                "people who love",
79            ],
80            slot: "preference",
81            needs_value: true,
82        });
83
84        // Entity state patterns
85        self.entity_patterns.push(EntityPattern {
86            keywords: vec!["what is", "where does", "who is", "what does"],
87            slot: "",
88            needs_value: false,
89        });
90    }
91
92    /// Analyze a query and produce an execution plan.
93    #[must_use]
94    pub fn plan(&self, query: &str, top_k: usize) -> QueryPlan {
95        let query_lower = query.to_lowercase();
96
97        // Try to detect relational patterns
98        if let Some(pattern) = self.detect_pattern(&query_lower, query) {
99            if pattern.triples.is_empty() {
100                // No specific pattern found, use vector search
101                QueryPlan::vector_only(Some(query.to_string()), None, top_k)
102            } else {
103                // Found relational pattern - use hybrid search
104                QueryPlan::hybrid(pattern, Some(query.to_string()), None, top_k)
105            }
106        } else {
107            // Default to vector-only search
108            QueryPlan::vector_only(Some(query.to_string()), None, top_k)
109        }
110    }
111
112    fn detect_pattern(&self, query_lower: &str, _original: &str) -> Option<GraphPattern> {
113        let mut pattern = GraphPattern::new();
114
115        for ep in &self.entity_patterns {
116            for keyword in &ep.keywords {
117                if query_lower.contains(keyword) {
118                    // Extract the value after the keyword
119                    if let Some(pos) = query_lower.find(keyword) {
120                        let after = &query_lower[pos + keyword.len()..];
121                        let value = extract_value(after);
122
123                        if !value.is_empty() && ep.needs_value {
124                            // Create pattern: ?entity :slot "value"
125                            pattern.add(TriplePattern::any_slot_value("entity", ep.slot, &value));
126                            return Some(pattern);
127                        }
128                    }
129                }
130            }
131        }
132
133        // Check for entity-specific queries like "alice's employer" or "what is alice's job"
134        if let Some((entity, slot)) = extract_possessive_query(query_lower) {
135            pattern.add(TriplePattern::entity_slot_any(&entity, &slot, "value"));
136            return Some(pattern);
137        }
138
139        Some(pattern)
140    }
141}
142
143/// Extract a value from text after a keyword.
144fn extract_value(text: &str) -> String {
145    let trimmed = text.trim();
146    // Take words until we hit a common query continuation
147    let stop_words = ["and", "or", "who", "what", "that", "?"];
148    let mut words = Vec::new();
149
150    for word in trimmed.split_whitespace() {
151        let clean = word.trim_matches(|c: char| !c.is_alphanumeric() && c != '-');
152        if stop_words.contains(&clean.to_lowercase().as_str()) {
153            break;
154        }
155        if !clean.is_empty() {
156            words.push(clean);
157        }
158        // Stop after a few words
159        if words.len() >= 3 {
160            break;
161        }
162    }
163
164    words.join(" ")
165}
166
167/// Extract entity and slot from possessive queries like "alice's employer".
168fn extract_possessive_query(query: &str) -> Option<(String, String)> {
169    // Pattern: "X's Y" or "X's Y is"
170    if let Some(pos) = query.find("'s ") {
171        let entity = query[..pos].split_whitespace().last()?;
172        let after = &query[pos + 3..];
173        let slot = after.split_whitespace().next()?;
174
175        // Map common slot aliases
176        let slot = match slot {
177            "job" | "work" | "employer" | "role" | "company" => "workplace",
178            "home" | "city" | "address" => "location",
179            "favorite" => "preference",
180            "wife" | "husband" | "spouse" | "partner" => "spouse",
181            other => other,
182        };
183
184        return Some((entity.to_string(), slot.to_string()));
185    }
186    None
187}
188
189/// Graph matcher that executes patterns against `MemoryCards`.
190pub struct GraphMatcher<'a> {
191    memvid: &'a Memvid,
192}
193
194impl<'a> GraphMatcher<'a> {
195    /// Create a new graph matcher.
196    #[must_use]
197    pub fn new(memvid: &'a Memvid) -> Self {
198        Self { memvid }
199    }
200
201    /// Execute a graph pattern and return matching results.
202    #[must_use]
203    pub fn execute(&self, pattern: &GraphPattern) -> Vec<GraphMatchResult> {
204        let mut results = Vec::new();
205
206        for triple in &pattern.triples {
207            let matches = self.match_triple(triple);
208            results.extend(matches);
209        }
210
211        // Deduplicate by entity
212        let mut seen = HashSet::new();
213        results.retain(|r| seen.insert(r.entity.clone()));
214
215        results
216    }
217
218    fn match_triple(&self, triple: &TriplePattern) -> Vec<GraphMatchResult> {
219        let mut results = Vec::new();
220
221        match (&triple.subject, &triple.predicate, &triple.object) {
222            // Pattern: ?entity :slot "value" - find entities with this slot value
223            (
224                PatternTerm::Variable(var),
225                PatternTerm::Literal(slot),
226                PatternTerm::Literal(value),
227            ) => {
228                // Iterate all entities and check for matching slot value
229                for entity in self.memvid.memory_entities() {
230                    let cards = self.memvid.get_entity_memories(&entity);
231                    for card in cards {
232                        if card.slot.to_lowercase() == *slot
233                            && card.value.to_lowercase().contains(&value.to_lowercase())
234                        {
235                            let mut result = GraphMatchResult::new(
236                                entity.clone(),
237                                vec![card.source_frame_id],
238                                1.0,
239                            );
240                            result.bind(var, entity.clone());
241                            results.push(result);
242                            break; // One match per entity
243                        }
244                    }
245                }
246            }
247
248            // Pattern: "entity" :slot ?value - get entity's slot value
249            (
250                PatternTerm::Literal(entity),
251                PatternTerm::Literal(slot),
252                PatternTerm::Variable(var),
253            ) => {
254                if let Some(card) = self.memvid.get_current_memory(entity, slot) {
255                    let mut result =
256                        GraphMatchResult::new(entity.clone(), vec![card.source_frame_id], 1.0);
257                    result.bind(var, card.value.clone());
258                    results.push(result);
259                }
260            }
261
262            // Pattern: "entity" :slot "value" - check if entity has this exact value
263            (
264                PatternTerm::Literal(entity),
265                PatternTerm::Literal(slot),
266                PatternTerm::Literal(value),
267            ) => {
268                if let Some(card) = self.memvid.get_current_memory(entity, slot) {
269                    if card.value.to_lowercase().contains(&value.to_lowercase()) {
270                        let result =
271                            GraphMatchResult::new(entity.clone(), vec![card.source_frame_id], 1.0);
272                        results.push(result);
273                    }
274                }
275            }
276
277            _ => {
278                // Other patterns not yet implemented
279            }
280        }
281
282        results
283    }
284
285    /// Get frame IDs from graph matches for use in vector search filtering.
286    #[must_use]
287    pub fn get_candidate_frames(&self, matches: &[GraphMatchResult]) -> Vec<FrameId> {
288        let mut frame_ids: Vec<FrameId> = matches
289            .iter()
290            .flat_map(|m| m.frame_ids.iter().copied())
291            .collect();
292        frame_ids.sort_unstable();
293        frame_ids.dedup();
294        frame_ids
295    }
296
297    /// Get matched entities for context.
298    #[must_use]
299    pub fn get_matched_entities(&self, matches: &[GraphMatchResult]) -> HashMap<FrameId, String> {
300        let mut map = HashMap::new();
301        for m in matches {
302            for &fid in &m.frame_ids {
303                map.insert(fid, m.entity.clone());
304            }
305        }
306        map
307    }
308}
309
310/// Execute a hybrid search: graph filter + vector ranking.
311pub fn hybrid_search(memvid: &mut Memvid, plan: &QueryPlan) -> Result<Vec<HybridSearchHit>> {
312    match plan {
313        QueryPlan::VectorOnly {
314            query_text, top_k, ..
315        } => {
316            // Fall back to regular lexical search
317            let query = query_text.as_deref().unwrap_or("");
318            let request = SearchRequest {
319                query: query.to_string(),
320                top_k: *top_k,
321                snippet_chars: 200,
322                uri: None,
323                scope: None,
324                cursor: None,
325                #[cfg(feature = "temporal_track")]
326                temporal: None,
327                as_of_frame: None,
328                as_of_ts: None,
329                no_sketch: false,
330                acl_context: None,
331                acl_enforcement_mode: crate::types::AclEnforcementMode::Audit,
332            };
333            let response = memvid.search(request)?;
334            Ok(response
335                .hits
336                .iter()
337                .map(|h| {
338                    let score = h.score.unwrap_or(0.0);
339                    HybridSearchHit {
340                        frame_id: h.frame_id,
341                        score,
342                        graph_score: 0.0,
343                        vector_score: score,
344                        matched_entity: None,
345                        preview: Some(h.text.clone()),
346                    }
347                })
348                .collect())
349        }
350
351        QueryPlan::GraphOnly { pattern, limit } => {
352            let matcher = GraphMatcher::new(memvid);
353            let matches = matcher.execute(pattern);
354
355            Ok(matches
356                .into_iter()
357                .take(*limit)
358                .map(|m| HybridSearchHit {
359                    frame_id: m.frame_ids.first().copied().unwrap_or(0),
360                    score: m.confidence,
361                    graph_score: m.confidence,
362                    vector_score: 0.0,
363                    matched_entity: Some(m.entity),
364                    preview: None,
365                })
366                .collect())
367        }
368
369        QueryPlan::Hybrid {
370            graph_filter,
371            query_text,
372            top_k,
373            ..
374        } => {
375            // Step 1: Execute graph pattern to get candidate frames
376            let matcher = GraphMatcher::new(memvid);
377            let matches = matcher.execute(graph_filter);
378            let entity_map = matcher.get_matched_entities(&matches);
379            let candidate_frames = matcher.get_candidate_frames(&matches);
380
381            if candidate_frames.is_empty() {
382                // No graph matches - fall back to lexical search
383                let query = query_text.as_deref().unwrap_or("");
384                let request = SearchRequest {
385                    query: query.to_string(),
386                    top_k: *top_k,
387                    snippet_chars: 200,
388                    uri: None,
389                    scope: None,
390                    cursor: None,
391                    #[cfg(feature = "temporal_track")]
392                    temporal: None,
393                    as_of_frame: None,
394                    as_of_ts: None,
395                    no_sketch: false,
396                    acl_context: None,
397                    acl_enforcement_mode: crate::types::AclEnforcementMode::Audit,
398                };
399                let response = memvid.search(request)?;
400                return Ok(response
401                    .hits
402                    .iter()
403                    .map(|h| {
404                        let score = h.score.unwrap_or(0.0);
405                        HybridSearchHit {
406                            frame_id: h.frame_id,
407                            score,
408                            graph_score: 0.0,
409                            vector_score: score,
410                            matched_entity: None,
411                            preview: Some(h.text.clone()),
412                        }
413                    })
414                    .collect());
415            }
416
417            // Step 2: Return graph matches directly with frame previews
418            // Graph matching already found the relevant frames - return them
419            let mut hybrid_hits: Vec<HybridSearchHit> = Vec::new();
420
421            for &frame_id in &candidate_frames {
422                let matched_entity = entity_map.get(&frame_id).cloned();
423
424                // Get frame preview if possible
425                let preview = memvid.frame_preview_by_id(frame_id).ok();
426
427                hybrid_hits.push(HybridSearchHit {
428                    frame_id,
429                    score: 1.0, // Graph match score
430                    graph_score: 1.0,
431                    vector_score: 0.0,
432                    matched_entity,
433                    preview,
434                });
435            }
436
437            Ok(hybrid_hits.into_iter().take(*top_k).collect())
438        }
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn test_query_planner_detects_location() {
448        let planner = QueryPlanner::new();
449        let plan = planner.plan("who lives in San Francisco", 10);
450
451        match plan {
452            QueryPlan::Hybrid { graph_filter, .. } => {
453                assert!(!graph_filter.is_empty());
454                let triple = &graph_filter.triples[0];
455                assert!(matches!(&triple.predicate, PatternTerm::Literal(s) if s == "location"));
456            }
457            _ => panic!("Expected hybrid plan for location query"),
458        }
459    }
460
461    #[test]
462    fn test_query_planner_detects_workplace() {
463        let planner = QueryPlanner::new();
464        let plan = planner.plan("who works at Google", 10);
465
466        match plan {
467            QueryPlan::Hybrid { graph_filter, .. } => {
468                assert!(!graph_filter.is_empty());
469                let triple = &graph_filter.triples[0];
470                assert!(matches!(&triple.predicate, PatternTerm::Literal(s) if s == "workplace"));
471            }
472            _ => panic!("Expected hybrid plan for workplace query"),
473        }
474    }
475
476    #[test]
477    fn test_query_planner_possessive() {
478        let planner = QueryPlanner::new();
479        let plan = planner.plan("what is alice's employer", 10);
480
481        match plan {
482            QueryPlan::Hybrid { graph_filter, .. } => {
483                assert!(!graph_filter.is_empty());
484                let triple = &graph_filter.triples[0];
485                assert!(matches!(&triple.subject, PatternTerm::Literal(s) if s == "alice"));
486            }
487            _ => panic!("Expected hybrid plan for possessive query"),
488        }
489    }
490
491    #[test]
492    fn test_extract_value() {
493        assert_eq!(extract_value("San Francisco and"), "San Francisco");
494        assert_eq!(extract_value("Google who"), "Google");
495        assert_eq!(extract_value("New York City"), "New York City");
496    }
497
498    #[test]
499    fn test_extract_possessive() {
500        assert_eq!(
501            extract_possessive_query("what is alice's job"),
502            Some(("alice".to_string(), "workplace".to_string()))
503        );
504        assert_eq!(
505            extract_possessive_query("bob's location"),
506            Some(("bob".to_string(), "location".to_string()))
507        );
508    }
509}