1use std::collections::{HashMap, HashSet};
9
10use crate::types::{
11 GraphMatchResult, GraphPattern, HybridSearchHit, PatternTerm, QueryPlan, TriplePattern,
12 SearchRequest,
13};
14use crate::{FrameId, Memvid, Result};
15
16#[derive(Debug, Default)]
18pub struct QueryPlanner {
19 entity_patterns: Vec<EntityPattern>,
21}
22
23#[derive(Debug, Clone)]
25struct EntityPattern {
26 keywords: Vec<&'static str>,
28 slot: &'static str,
30 needs_value: bool,
32}
33
34impl QueryPlanner {
35 #[must_use]
37 pub fn new() -> Self {
38 let mut planner = Self::default();
39 planner.init_patterns();
40 planner
41 }
42
43 fn init_patterns(&mut self) {
44 self.entity_patterns.push(EntityPattern {
46 keywords: vec![
47 "who lives in",
48 "people in",
49 "users in",
50 "from",
51 "located in",
52 "based in",
53 ],
54 slot: "location",
55 needs_value: true,
56 });
57
58 self.entity_patterns.push(EntityPattern {
60 keywords: vec![
61 "who works at",
62 "employees of",
63 "people at",
64 "works for",
65 "employed by",
66 ],
67 slot: "workplace", needs_value: true,
69 });
70
71 self.entity_patterns.push(EntityPattern {
73 keywords: vec![
74 "who likes",
75 "who loves",
76 "fans of",
77 "people who like",
78 "people who love",
79 ],
80 slot: "preference",
81 needs_value: true,
82 });
83
84 self.entity_patterns.push(EntityPattern {
86 keywords: vec![
87 "what is",
88 "where does",
89 "who is",
90 "what does",
91 ],
92 slot: "",
93 needs_value: false,
94 });
95 }
96
97 #[must_use]
99 pub fn plan(&self, query: &str, top_k: usize) -> QueryPlan {
100 let query_lower = query.to_lowercase();
101
102 if let Some(pattern) = self.detect_pattern(&query_lower, query) {
104 if pattern.triples.is_empty() {
105 QueryPlan::vector_only(Some(query.to_string()), None, top_k)
107 } else {
108 QueryPlan::hybrid(pattern, Some(query.to_string()), None, top_k)
110 }
111 } else {
112 QueryPlan::vector_only(Some(query.to_string()), None, top_k)
114 }
115 }
116
117 fn detect_pattern(&self, query_lower: &str, _original: &str) -> Option<GraphPattern> {
118 let mut pattern = GraphPattern::new();
119
120 for ep in &self.entity_patterns {
121 for keyword in &ep.keywords {
122 if query_lower.contains(keyword) {
123 if let Some(pos) = query_lower.find(keyword) {
125 let after = &query_lower[pos + keyword.len()..];
126 let value = extract_value(after);
127
128 if !value.is_empty() && ep.needs_value {
129 pattern.add(TriplePattern::any_slot_value("entity", ep.slot, &value));
131 return Some(pattern);
132 }
133 }
134 }
135 }
136 }
137
138 if let Some((entity, slot)) = extract_possessive_query(query_lower) {
140 pattern.add(TriplePattern::entity_slot_any(&entity, &slot, "value"));
141 return Some(pattern);
142 }
143
144 Some(pattern)
145 }
146}
147
148fn extract_value(text: &str) -> String {
150 let trimmed = text.trim();
151 let stop_words = ["and", "or", "who", "what", "that", "?"];
153 let mut words = Vec::new();
154
155 for word in trimmed.split_whitespace() {
156 let clean = word.trim_matches(|c: char| !c.is_alphanumeric() && c != '-');
157 if stop_words.contains(&clean.to_lowercase().as_str()) {
158 break;
159 }
160 if !clean.is_empty() {
161 words.push(clean);
162 }
163 if words.len() >= 3 {
165 break;
166 }
167 }
168
169 words.join(" ")
170}
171
172fn extract_possessive_query(query: &str) -> Option<(String, String)> {
174 if let Some(pos) = query.find("'s ") {
176 let entity = query[..pos].split_whitespace().last()?;
177 let after = &query[pos + 3..];
178 let slot = after.split_whitespace().next()?;
179
180 let slot = match slot {
182 "job" | "work" | "employer" | "role" | "company" => "workplace",
183 "home" | "city" | "address" => "location",
184 "favorite" => "preference",
185 "wife" | "husband" | "spouse" | "partner" => "spouse",
186 other => other,
187 };
188
189 return Some((entity.to_string(), slot.to_string()));
190 }
191 None
192}
193
194pub struct GraphMatcher<'a> {
196 memvid: &'a Memvid,
197}
198
199impl<'a> GraphMatcher<'a> {
200 pub fn new(memvid: &'a Memvid) -> Self {
202 Self { memvid }
203 }
204
205 pub fn execute(&self, pattern: &GraphPattern) -> Vec<GraphMatchResult> {
207 let mut results = Vec::new();
208
209 for triple in &pattern.triples {
210 let matches = self.match_triple(triple);
211 results.extend(matches);
212 }
213
214 let mut seen = HashSet::new();
216 results.retain(|r| seen.insert(r.entity.clone()));
217
218 results
219 }
220
221 fn match_triple(&self, triple: &TriplePattern) -> Vec<GraphMatchResult> {
222 let mut results = Vec::new();
223
224 match (&triple.subject, &triple.predicate, &triple.object) {
225 (PatternTerm::Variable(var), PatternTerm::Literal(slot), PatternTerm::Literal(value)) => {
227 for entity in self.memvid.memory_entities() {
229 let cards = self.memvid.get_entity_memories(&entity);
230 for card in cards {
231 if card.slot.to_lowercase() == *slot
232 && card.value.to_lowercase().contains(&value.to_lowercase())
233 {
234 let mut result =
235 GraphMatchResult::new(entity.clone(), vec![card.source_frame_id], 1.0);
236 result.bind(var, entity.clone());
237 results.push(result);
238 break; }
240 }
241 }
242 }
243
244 (PatternTerm::Literal(entity), PatternTerm::Literal(slot), PatternTerm::Variable(var)) => {
246 if let Some(card) = self.memvid.get_current_memory(entity, slot) {
247 let mut result =
248 GraphMatchResult::new(entity.clone(), vec![card.source_frame_id], 1.0);
249 result.bind(var, card.value.clone());
250 results.push(result);
251 }
252 }
253
254 (PatternTerm::Literal(entity), PatternTerm::Literal(slot), PatternTerm::Literal(value)) => {
256 if let Some(card) = self.memvid.get_current_memory(entity, slot) {
257 if card.value.to_lowercase().contains(&value.to_lowercase()) {
258 let result =
259 GraphMatchResult::new(entity.clone(), vec![card.source_frame_id], 1.0);
260 results.push(result);
261 }
262 }
263 }
264
265 _ => {
266 }
268 }
269
270 results
271 }
272
273 #[must_use]
275 pub fn get_candidate_frames(&self, matches: &[GraphMatchResult]) -> Vec<FrameId> {
276 let mut frame_ids: Vec<FrameId> = matches
277 .iter()
278 .flat_map(|m| m.frame_ids.iter().copied())
279 .collect();
280 frame_ids.sort_unstable();
281 frame_ids.dedup();
282 frame_ids
283 }
284
285 #[must_use]
287 pub fn get_matched_entities(&self, matches: &[GraphMatchResult]) -> HashMap<FrameId, String> {
288 let mut map = HashMap::new();
289 for m in matches {
290 for &fid in &m.frame_ids {
291 map.insert(fid, m.entity.clone());
292 }
293 }
294 map
295 }
296}
297
298pub fn hybrid_search(
300 memvid: &mut Memvid,
301 plan: &QueryPlan,
302) -> Result<Vec<HybridSearchHit>> {
303 match plan {
304 QueryPlan::VectorOnly { query_text, top_k, .. } => {
305 let query = query_text.as_deref().unwrap_or("");
307 let request = SearchRequest {
308 query: query.to_string(),
309 top_k: *top_k,
310 snippet_chars: 200,
311 uri: None,
312 scope: None,
313 cursor: None,
314 #[cfg(feature = "temporal_track")]
315 temporal: None,
316 as_of_frame: None,
317 as_of_ts: None,
318 };
319 let response = memvid.search(request)?;
320 Ok(response.hits
321 .iter()
322 .map(|h| {
323 let score = h.score.unwrap_or(0.0);
324 HybridSearchHit {
325 frame_id: h.frame_id,
326 score,
327 graph_score: 0.0,
328 vector_score: score,
329 matched_entity: None,
330 preview: Some(h.text.clone()),
331 }
332 })
333 .collect())
334 }
335
336 QueryPlan::GraphOnly { pattern, limit } => {
337 let matcher = GraphMatcher::new(memvid);
338 let matches = matcher.execute(pattern);
339
340 Ok(matches
341 .into_iter()
342 .take(*limit)
343 .map(|m| HybridSearchHit {
344 frame_id: m.frame_ids.first().copied().unwrap_or(0),
345 score: m.confidence,
346 graph_score: m.confidence,
347 vector_score: 0.0,
348 matched_entity: Some(m.entity),
349 preview: None,
350 })
351 .collect())
352 }
353
354 QueryPlan::Hybrid {
355 graph_filter,
356 query_text,
357 top_k,
358 ..
359 } => {
360 let matcher = GraphMatcher::new(memvid);
362 let matches = matcher.execute(graph_filter);
363 let entity_map = matcher.get_matched_entities(&matches);
364 let candidate_frames = matcher.get_candidate_frames(&matches);
365
366 if candidate_frames.is_empty() {
367 let query = query_text.as_deref().unwrap_or("");
369 let request = SearchRequest {
370 query: query.to_string(),
371 top_k: *top_k,
372 snippet_chars: 200,
373 uri: None,
374 scope: None,
375 cursor: None,
376 #[cfg(feature = "temporal_track")]
377 temporal: None,
378 as_of_frame: None,
379 as_of_ts: None,
380 };
381 let response = memvid.search(request)?;
382 return Ok(response.hits
383 .iter()
384 .map(|h| {
385 let score = h.score.unwrap_or(0.0);
386 HybridSearchHit {
387 frame_id: h.frame_id,
388 score,
389 graph_score: 0.0,
390 vector_score: score,
391 matched_entity: None,
392 preview: Some(h.text.clone()),
393 }
394 })
395 .collect());
396 }
397
398 let mut hybrid_hits: Vec<HybridSearchHit> = Vec::new();
401
402 for &frame_id in &candidate_frames {
403 let matched_entity = entity_map.get(&frame_id).cloned();
404
405 let preview = memvid.frame_preview_by_id(frame_id).ok();
407
408 hybrid_hits.push(HybridSearchHit {
409 frame_id,
410 score: 1.0, graph_score: 1.0,
412 vector_score: 0.0,
413 matched_entity,
414 preview,
415 });
416 }
417
418 Ok(hybrid_hits.into_iter().take(*top_k).collect())
419 }
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
428 fn test_query_planner_detects_location() {
429 let planner = QueryPlanner::new();
430 let plan = planner.plan("who lives in San Francisco", 10);
431
432 match plan {
433 QueryPlan::Hybrid { graph_filter, .. } => {
434 assert!(!graph_filter.is_empty());
435 let triple = &graph_filter.triples[0];
436 assert!(matches!(&triple.predicate, PatternTerm::Literal(s) if s == "location"));
437 }
438 _ => panic!("Expected hybrid plan for location query"),
439 }
440 }
441
442 #[test]
443 fn test_query_planner_detects_workplace() {
444 let planner = QueryPlanner::new();
445 let plan = planner.plan("who works at Google", 10);
446
447 match plan {
448 QueryPlan::Hybrid { graph_filter, .. } => {
449 assert!(!graph_filter.is_empty());
450 let triple = &graph_filter.triples[0];
451 assert!(matches!(&triple.predicate, PatternTerm::Literal(s) if s == "workplace"));
452 }
453 _ => panic!("Expected hybrid plan for workplace query"),
454 }
455 }
456
457 #[test]
458 fn test_query_planner_possessive() {
459 let planner = QueryPlanner::new();
460 let plan = planner.plan("what is alice's employer", 10);
461
462 match plan {
463 QueryPlan::Hybrid { graph_filter, .. } => {
464 assert!(!graph_filter.is_empty());
465 let triple = &graph_filter.triples[0];
466 assert!(matches!(&triple.subject, PatternTerm::Literal(s) if s == "alice"));
467 }
468 _ => panic!("Expected hybrid plan for possessive query"),
469 }
470 }
471
472 #[test]
473 fn test_extract_value() {
474 assert_eq!(extract_value("San Francisco and"), "San Francisco");
475 assert_eq!(extract_value("Google who"), "Google");
476 assert_eq!(extract_value("New York City"), "New York City");
477 }
478
479 #[test]
480 fn test_extract_possessive() {
481 assert_eq!(
482 extract_possessive_query("what is alice's job"),
483 Some(("alice".to_string(), "workplace".to_string()))
484 );
485 assert_eq!(
486 extract_possessive_query("bob's location"),
487 Some(("bob".to_string(), "location".to_string()))
488 );
489 }
490}