1use std::collections::{HashMap, HashSet};
9
10use crate::types::{
11 GraphMatchResult, GraphPattern, HybridSearchHit, PatternTerm, QueryPlan, SearchRequest,
12 TriplePattern,
13};
14use crate::{FrameId, Memvid, Result};
15
16#[derive(Debug, Default)]
18pub struct QueryPlanner {
19 entity_patterns: Vec<EntityPattern>,
21}
22
23#[derive(Debug, Clone)]
25struct EntityPattern {
26 keywords: Vec<&'static str>,
28 slot: &'static str,
30 needs_value: bool,
32}
33
34impl QueryPlanner {
35 #[must_use]
37 pub fn new() -> Self {
38 let mut planner = Self::default();
39 planner.init_patterns();
40 planner
41 }
42
43 fn init_patterns(&mut self) {
44 self.entity_patterns.push(EntityPattern {
46 keywords: vec![
47 "who lives in",
48 "people in",
49 "users in",
50 "from",
51 "located in",
52 "based in",
53 ],
54 slot: "location",
55 needs_value: true,
56 });
57
58 self.entity_patterns.push(EntityPattern {
60 keywords: vec![
61 "who works at",
62 "employees of",
63 "people at",
64 "works for",
65 "employed by",
66 ],
67 slot: "workplace", needs_value: true,
69 });
70
71 self.entity_patterns.push(EntityPattern {
73 keywords: vec![
74 "who likes",
75 "who loves",
76 "fans of",
77 "people who like",
78 "people who love",
79 ],
80 slot: "preference",
81 needs_value: true,
82 });
83
84 self.entity_patterns.push(EntityPattern {
86 keywords: vec!["what is", "where does", "who is", "what does"],
87 slot: "",
88 needs_value: false,
89 });
90 }
91
92 #[must_use]
94 pub fn plan(&self, query: &str, top_k: usize) -> QueryPlan {
95 let query_lower = query.to_lowercase();
96
97 if let Some(pattern) = self.detect_pattern(&query_lower, query) {
99 if pattern.triples.is_empty() {
100 QueryPlan::vector_only(Some(query.to_string()), None, top_k)
102 } else {
103 QueryPlan::hybrid(pattern, Some(query.to_string()), None, top_k)
105 }
106 } else {
107 QueryPlan::vector_only(Some(query.to_string()), None, top_k)
109 }
110 }
111
112 fn detect_pattern(&self, query_lower: &str, _original: &str) -> Option<GraphPattern> {
113 let mut pattern = GraphPattern::new();
114
115 for ep in &self.entity_patterns {
116 for keyword in &ep.keywords {
117 if query_lower.contains(keyword) {
118 if let Some(pos) = query_lower.find(keyword) {
120 let after = &query_lower[pos + keyword.len()..];
121 let value = extract_value(after);
122
123 if !value.is_empty() && ep.needs_value {
124 pattern.add(TriplePattern::any_slot_value("entity", ep.slot, &value));
126 return Some(pattern);
127 }
128 }
129 }
130 }
131 }
132
133 if let Some((entity, slot)) = extract_possessive_query(query_lower) {
135 pattern.add(TriplePattern::entity_slot_any(&entity, &slot, "value"));
136 return Some(pattern);
137 }
138
139 Some(pattern)
140 }
141}
142
143fn extract_value(text: &str) -> String {
145 let trimmed = text.trim();
146 let stop_words = ["and", "or", "who", "what", "that", "?"];
148 let mut words = Vec::new();
149
150 for word in trimmed.split_whitespace() {
151 let clean = word.trim_matches(|c: char| !c.is_alphanumeric() && c != '-');
152 if stop_words.contains(&clean.to_lowercase().as_str()) {
153 break;
154 }
155 if !clean.is_empty() {
156 words.push(clean);
157 }
158 if words.len() >= 3 {
160 break;
161 }
162 }
163
164 words.join(" ")
165}
166
167fn extract_possessive_query(query: &str) -> Option<(String, String)> {
169 if let Some(pos) = query.find("'s ") {
171 let entity = query[..pos].split_whitespace().last()?;
172 let after = &query[pos + 3..];
173 let slot = after.split_whitespace().next()?;
174
175 let slot = match slot {
177 "job" | "work" | "employer" | "role" | "company" => "workplace",
178 "home" | "city" | "address" => "location",
179 "favorite" => "preference",
180 "wife" | "husband" | "spouse" | "partner" => "spouse",
181 other => other,
182 };
183
184 return Some((entity.to_string(), slot.to_string()));
185 }
186 None
187}
188
189pub struct GraphMatcher<'a> {
191 memvid: &'a Memvid,
192}
193
194impl<'a> GraphMatcher<'a> {
195 #[must_use]
197 pub fn new(memvid: &'a Memvid) -> Self {
198 Self { memvid }
199 }
200
201 #[must_use]
203 pub fn execute(&self, pattern: &GraphPattern) -> Vec<GraphMatchResult> {
204 let mut results = Vec::new();
205
206 for triple in &pattern.triples {
207 let matches = self.match_triple(triple);
208 results.extend(matches);
209 }
210
211 let mut seen = HashSet::new();
213 results.retain(|r| seen.insert(r.entity.clone()));
214
215 results
216 }
217
218 fn match_triple(&self, triple: &TriplePattern) -> Vec<GraphMatchResult> {
219 let mut results = Vec::new();
220
221 match (&triple.subject, &triple.predicate, &triple.object) {
222 (
224 PatternTerm::Variable(var),
225 PatternTerm::Literal(slot),
226 PatternTerm::Literal(value),
227 ) => {
228 for entity in self.memvid.memory_entities() {
230 let cards = self.memvid.get_entity_memories(&entity);
231 for card in cards {
232 if card.slot.to_lowercase() == *slot
233 && card.value.to_lowercase().contains(&value.to_lowercase())
234 {
235 let mut result = GraphMatchResult::new(
236 entity.clone(),
237 vec![card.source_frame_id],
238 1.0,
239 );
240 result.bind(var, entity.clone());
241 results.push(result);
242 break; }
244 }
245 }
246 }
247
248 (
250 PatternTerm::Literal(entity),
251 PatternTerm::Literal(slot),
252 PatternTerm::Variable(var),
253 ) => {
254 if let Some(card) = self.memvid.get_current_memory(entity, slot) {
255 let mut result =
256 GraphMatchResult::new(entity.clone(), vec![card.source_frame_id], 1.0);
257 result.bind(var, card.value.clone());
258 results.push(result);
259 }
260 }
261
262 (
264 PatternTerm::Literal(entity),
265 PatternTerm::Literal(slot),
266 PatternTerm::Literal(value),
267 ) => {
268 if let Some(card) = self.memvid.get_current_memory(entity, slot) {
269 if card.value.to_lowercase().contains(&value.to_lowercase()) {
270 let result =
271 GraphMatchResult::new(entity.clone(), vec![card.source_frame_id], 1.0);
272 results.push(result);
273 }
274 }
275 }
276
277 _ => {
278 }
280 }
281
282 results
283 }
284
285 #[must_use]
287 pub fn get_candidate_frames(&self, matches: &[GraphMatchResult]) -> Vec<FrameId> {
288 let mut frame_ids: Vec<FrameId> = matches
289 .iter()
290 .flat_map(|m| m.frame_ids.iter().copied())
291 .collect();
292 frame_ids.sort_unstable();
293 frame_ids.dedup();
294 frame_ids
295 }
296
297 #[must_use]
299 pub fn get_matched_entities(&self, matches: &[GraphMatchResult]) -> HashMap<FrameId, String> {
300 let mut map = HashMap::new();
301 for m in matches {
302 for &fid in &m.frame_ids {
303 map.insert(fid, m.entity.clone());
304 }
305 }
306 map
307 }
308}
309
310pub fn hybrid_search(memvid: &mut Memvid, plan: &QueryPlan) -> Result<Vec<HybridSearchHit>> {
312 match plan {
313 QueryPlan::VectorOnly {
314 query_text, top_k, ..
315 } => {
316 let query = query_text.as_deref().unwrap_or("");
318 let request = SearchRequest {
319 query: query.to_string(),
320 top_k: *top_k,
321 snippet_chars: 200,
322 uri: None,
323 scope: None,
324 cursor: None,
325 #[cfg(feature = "temporal_track")]
326 temporal: None,
327 as_of_frame: None,
328 as_of_ts: None,
329 no_sketch: false,
330 };
331 let response = memvid.search(request)?;
332 Ok(response
333 .hits
334 .iter()
335 .map(|h| {
336 let score = h.score.unwrap_or(0.0);
337 HybridSearchHit {
338 frame_id: h.frame_id,
339 score,
340 graph_score: 0.0,
341 vector_score: score,
342 matched_entity: None,
343 preview: Some(h.text.clone()),
344 }
345 })
346 .collect())
347 }
348
349 QueryPlan::GraphOnly { pattern, limit } => {
350 let matcher = GraphMatcher::new(memvid);
351 let matches = matcher.execute(pattern);
352
353 Ok(matches
354 .into_iter()
355 .take(*limit)
356 .map(|m| HybridSearchHit {
357 frame_id: m.frame_ids.first().copied().unwrap_or(0),
358 score: m.confidence,
359 graph_score: m.confidence,
360 vector_score: 0.0,
361 matched_entity: Some(m.entity),
362 preview: None,
363 })
364 .collect())
365 }
366
367 QueryPlan::Hybrid {
368 graph_filter,
369 query_text,
370 top_k,
371 ..
372 } => {
373 let matcher = GraphMatcher::new(memvid);
375 let matches = matcher.execute(graph_filter);
376 let entity_map = matcher.get_matched_entities(&matches);
377 let candidate_frames = matcher.get_candidate_frames(&matches);
378
379 if candidate_frames.is_empty() {
380 let query = query_text.as_deref().unwrap_or("");
382 let request = SearchRequest {
383 query: query.to_string(),
384 top_k: *top_k,
385 snippet_chars: 200,
386 uri: None,
387 scope: None,
388 cursor: None,
389 #[cfg(feature = "temporal_track")]
390 temporal: None,
391 as_of_frame: None,
392 as_of_ts: None,
393 no_sketch: false,
394 };
395 let response = memvid.search(request)?;
396 return Ok(response
397 .hits
398 .iter()
399 .map(|h| {
400 let score = h.score.unwrap_or(0.0);
401 HybridSearchHit {
402 frame_id: h.frame_id,
403 score,
404 graph_score: 0.0,
405 vector_score: score,
406 matched_entity: None,
407 preview: Some(h.text.clone()),
408 }
409 })
410 .collect());
411 }
412
413 let mut hybrid_hits: Vec<HybridSearchHit> = Vec::new();
416
417 for &frame_id in &candidate_frames {
418 let matched_entity = entity_map.get(&frame_id).cloned();
419
420 let preview = memvid.frame_preview_by_id(frame_id).ok();
422
423 hybrid_hits.push(HybridSearchHit {
424 frame_id,
425 score: 1.0, graph_score: 1.0,
427 vector_score: 0.0,
428 matched_entity,
429 preview,
430 });
431 }
432
433 Ok(hybrid_hits.into_iter().take(*top_k).collect())
434 }
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_query_planner_detects_location() {
444 let planner = QueryPlanner::new();
445 let plan = planner.plan("who lives in San Francisco", 10);
446
447 match plan {
448 QueryPlan::Hybrid { graph_filter, .. } => {
449 assert!(!graph_filter.is_empty());
450 let triple = &graph_filter.triples[0];
451 assert!(matches!(&triple.predicate, PatternTerm::Literal(s) if s == "location"));
452 }
453 _ => panic!("Expected hybrid plan for location query"),
454 }
455 }
456
457 #[test]
458 fn test_query_planner_detects_workplace() {
459 let planner = QueryPlanner::new();
460 let plan = planner.plan("who works at Google", 10);
461
462 match plan {
463 QueryPlan::Hybrid { graph_filter, .. } => {
464 assert!(!graph_filter.is_empty());
465 let triple = &graph_filter.triples[0];
466 assert!(matches!(&triple.predicate, PatternTerm::Literal(s) if s == "workplace"));
467 }
468 _ => panic!("Expected hybrid plan for workplace query"),
469 }
470 }
471
472 #[test]
473 fn test_query_planner_possessive() {
474 let planner = QueryPlanner::new();
475 let plan = planner.plan("what is alice's employer", 10);
476
477 match plan {
478 QueryPlan::Hybrid { graph_filter, .. } => {
479 assert!(!graph_filter.is_empty());
480 let triple = &graph_filter.triples[0];
481 assert!(matches!(&triple.subject, PatternTerm::Literal(s) if s == "alice"));
482 }
483 _ => panic!("Expected hybrid plan for possessive query"),
484 }
485 }
486
487 #[test]
488 fn test_extract_value() {
489 assert_eq!(extract_value("San Francisco and"), "San Francisco");
490 assert_eq!(extract_value("Google who"), "Google");
491 assert_eq!(extract_value("New York City"), "New York City");
492 }
493
494 #[test]
495 fn test_extract_possessive() {
496 assert_eq!(
497 extract_possessive_query("what is alice's job"),
498 Some(("alice".to_string(), "workplace".to_string()))
499 );
500 assert_eq!(
501 extract_possessive_query("bob's location"),
502 Some(("bob".to_string(), "location".to_string()))
503 );
504 }
505}