1use crate::{
7 error::Result,
8 graph::GraphManager,
9 index::{TemporalIndex, VectorIndex},
10 ingest::{get_causal_extractor, SlmMetadata},
11 query::{
12 aggregator::MultiTurnAggregator,
13 fusion::{FusedResult, FusionEngine},
14 intent::{IntentClassification, IntentClassifier},
15 profile_search::ProfileSearch,
16 },
17 storage::StorageEngine,
18 trace_begin, trace_record,
19 types::{Memory, MemoryId, MetadataFilter, Timestamp},
20};
21use std::collections::HashMap;
22use std::sync::{Arc, RwLock};
23
24pub struct QueryPlanner {
26 pub(crate) storage: Arc<StorageEngine>,
27 pub(crate) vector_index: Arc<RwLock<VectorIndex>>,
28 pub(crate) bm25_index: Arc<crate::index::BM25Index>,
29 pub(crate) temporal_index: Arc<TemporalIndex>,
30 pub(crate) graph_manager: Arc<RwLock<GraphManager>>,
31 intent_classifier: IntentClassifier,
32 #[cfg(feature = "slm")]
33 slm_classifier: Option<Arc<std::sync::Mutex<crate::slm::SlmClassifier>>>,
34 slm_query_classification_enabled: bool,
37 fusion_engine: FusionEngine,
38 semantic_prefilter_threshold: f32,
39 profile_search: ProfileSearch,
41 adaptive_k_threshold: f32,
44}
45
46impl QueryPlanner {
47 pub fn new(
49 storage: Arc<StorageEngine>,
50 vector_index: Arc<RwLock<VectorIndex>>,
51 bm25_index: Arc<crate::index::BM25Index>,
52 temporal_index: Arc<TemporalIndex>,
53 graph_manager: Arc<RwLock<GraphManager>>,
54 fusion_semantic_threshold: f32,
55 semantic_prefilter_threshold: f32,
56 fusion_strategy: crate::query::FusionStrategy,
57 rrf_k: f32,
58 #[cfg(feature = "slm")] slm_config: Option<crate::slm::SlmConfig>,
59 slm_query_classification_enabled: bool,
60 adaptive_k_threshold: f32,
61 ) -> Result<Self> {
62 #[cfg(feature = "slm")]
64 let slm_classifier = if let Some(config) = slm_config {
65 tracing::info!(
66 "Initializing SLM classifier with model: {}",
67 config.model_id
68 );
69 match crate::slm::SlmClassifier::new(config) {
70 Ok(classifier) => {
71 tracing::info!("SLM classifier initialized successfully (lazy loading)");
72 Some(Arc::new(std::sync::Mutex::new(classifier)))
73 }
74 Err(e) => {
75 tracing::warn!(
76 "Failed to initialize SLM classifier: {}, falling back to patterns",
77 e
78 );
79 None
80 }
81 }
82 } else {
83 None
84 };
85
86 let profile_search = ProfileSearch::new(Arc::clone(&storage));
88
89 Ok(Self {
90 storage,
91 vector_index,
92 bm25_index,
93 temporal_index,
94 graph_manager,
95 intent_classifier: IntentClassifier::new(),
96 #[cfg(feature = "slm")]
97 slm_classifier,
98 slm_query_classification_enabled,
99 fusion_engine: FusionEngine::new()
100 .with_semantic_threshold(fusion_semantic_threshold)
101 .with_strategy(fusion_strategy)
102 .with_rrf_k(rrf_k),
103 semantic_prefilter_threshold,
104 profile_search,
105 adaptive_k_threshold,
106 })
107 }
108
109 fn classify_intent(&self, query_text: &str) -> Result<IntentClassification> {
123 #[cfg(feature = "slm")]
126 if self.slm_query_classification_enabled {
127 if let Some(slm_classifier) = &self.slm_classifier {
128 match slm_classifier.lock() {
130 Ok(mut classifier) => match classifier.classify_intent(query_text) {
131 Ok(classification) => {
132 tracing::debug!(
133 "SLM classified query as {:?} (confidence: {:.2})",
134 classification.intent,
135 classification.confidence
136 );
137 return Ok(classification);
138 }
139 Err(e) => {
140 tracing::warn!(
141 "SLM classification failed: {}, falling back to patterns",
142 e
143 );
144 }
145 },
146 Err(e) => {
147 tracing::warn!(
148 "Failed to acquire SLM classifier lock: {}, falling back to patterns",
149 e
150 );
151 }
152 }
153 }
154 }
155
156 #[cfg(not(feature = "slm"))]
157 {
158 let _ = self.slm_query_classification_enabled; }
160
161 Ok(self.intent_classifier.classify(query_text))
163 }
164
165 pub fn query(
187 &self,
188 query_text: &str,
189 query_embedding: &[f32],
190 limit: usize,
191 namespace: Option<&str>,
192 filters: Option<&[MetadataFilter]>,
193 user_entity: Option<&str>,
194 mut trace: Option<&mut crate::trace::TraceRecorder>,
195 ) -> Result<(
196 IntentClassification,
197 Vec<FusedResult>,
198 Vec<crate::query::profile_search::MatchedProfileFact>,
199 )> {
200 trace_begin!(trace, "intent_classification");
202 let intent = self.classify_intent(query_text)?;
203 trace_record!(trace, "intent", format!("{:?}", intent.intent));
204 trace_record!(trace, "confidence", intent.confidence);
205 trace_record!(trace, "entity_focus", intent.entity_focus.clone());
206
207 if let Some(entity_name) = intent.entity_focus.clone() {
212 if self.storage.find_entity_by_name(&entity_name)?.is_some() {
213 trace_begin!(trace, "entity_focused_path");
214 trace_record!(trace, "entity_name", entity_name.clone());
215 trace_record!(trace, "took_fast_path", true);
216 return self.retrieve_entity_focused(
217 &entity_name,
218 query_text,
219 query_embedding,
220 limit,
221 namespace,
222 filters,
223 intent,
224 );
225 }
226 }
228
229 let mut query_entities = self
239 .profile_search
240 .detect_entities(query_text)
241 .unwrap_or_default();
242
243 if let Some(user_name) = user_entity {
250 let user_lower = user_name.to_lowercase();
251 if !query_entities.contains(&user_lower) {
252 let query_lower = query_text.to_lowercase();
253 const FIRST_PERSON: &[&str] = &[
255 "i", "me", "my", "mine", "myself", "i'm", "i've", "i'd", "i'll",
256 ];
257 let has_first_person = FIRST_PERSON
258 .iter()
259 .any(|pronoun| Self::contains_whole_word_static(&query_lower, pronoun));
260 if has_first_person {
261 query_entities.push(user_lower);
262 }
263 }
264 }
265 let speaker_entity = if query_entities.len() == 1 {
266 Some(query_entities[0].clone())
267 } else {
268 None
269 };
270
271 trace_begin!(trace, "entity_detection");
272 trace_record!(trace, "query_entities", query_entities.clone());
273 trace_record!(trace, "speaker_entity", speaker_entity.clone());
274 trace_record!(
275 trace,
276 "user_entity_resolved",
277 user_entity.map(|s| s.to_string())
278 );
279
280 let needs_filtering =
282 namespace.is_some() || (filters.is_some() && !filters.unwrap().is_empty());
283 let fetch_multiplier = if needs_filtering || !query_entities.is_empty() {
286 7
287 } else {
288 5
289 };
290
291 trace_begin!(trace, "semantic_search");
292 let mut semantic_scores =
293 self.semantic_search(query_embedding, limit * fetch_multiplier)?;
294 trace_record!(trace, "candidate_count", semantic_scores.len());
295 trace_record!(trace, "fetch_limit", limit * fetch_multiplier);
296 {
297 let mut top_scores: Vec<f32> = semantic_scores.values().copied().collect();
298 top_scores.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
299 top_scores.truncate(5);
300 trace_record!(trace, "top_5_scores", top_scores);
301 }
302
303 trace_begin!(trace, "bm25_search");
304 let bm25_scores = self.bm25_search(query_text, limit * fetch_multiplier)?;
305 trace_record!(trace, "candidate_count", bm25_scores.len());
306
307 trace_begin!(trace, "temporal_search");
308 let temporal_scores = self.temporal_search(query_text, limit * fetch_multiplier)?;
309 trace_record!(trace, "candidate_count", temporal_scores.len());
310
311 trace_begin!(trace, "causal_search");
312 let causal_scores = self.causal_search(query_text, limit * fetch_multiplier)?;
313 trace_record!(trace, "candidate_count", causal_scores.len());
314 let entity_fetch = limit * 3;
324 let mut entity_scores: HashMap<MemoryId, f32> = {
325 let raw = self.entity_search(query_text, entity_fetch)?;
326 let mut norm: HashMap<MemoryId, f32> = HashMap::with_capacity(raw.len());
327 for (id, score) in raw {
328 norm.entry(MemoryId::from_u64(id.to_u64()))
329 .and_modify(|s: &mut f32| *s = (*s).max(score))
330 .or_insert(score);
331 }
332 norm
333 };
334
335 trace_begin!(trace, "entity_injection");
352 let entity_inject_cap = entity_fetch;
353 let entity_count_before = entity_scores.len();
354 for entity_name in &query_entities {
355 if let Ok(Some(profile)) = self.storage.get_entity_profile(entity_name) {
356 for (injected, source_id) in profile.source_memories.iter().enumerate() {
357 if injected >= entity_inject_cap {
358 break;
359 }
360 entity_scores
361 .entry(MemoryId::from_u64(source_id.to_u64()))
362 .and_modify(|s| *s = (*s).max(2.0))
363 .or_insert(2.0);
364 }
365 }
366 }
367 trace_record!(trace, "entities_injected", query_entities.len());
368 trace_record!(
369 trace,
370 "entity_scores_added",
371 entity_scores.len() - entity_count_before
372 );
373 trace_record!(trace, "inject_cap", entity_inject_cap);
374
375 const RELATED_ENTITY_SCORE: f32 = 1.0;
383 const MAX_RELATED_MEMORIES_PER_ENTITY: usize = 10;
384
385 if !query_entities.is_empty() {
386 let graph = self.graph_manager.read().unwrap();
387 let mut id_to_name: Option<HashMap<String, String>> = None;
389
390 for entity_name in &query_entities {
391 if let Ok(Some(profile)) = self.storage.get_entity_profile(entity_name) {
392 let related = graph.get_related_entities(&profile.entity_id);
393 if related.is_empty() {
394 continue;
395 }
396 let map = id_to_name.get_or_insert_with(|| {
398 self.storage
399 .list_entity_profiles()
400 .unwrap_or_default()
401 .iter()
402 .map(|p| (p.entity_id.to_string(), p.name.to_lowercase()))
403 .collect()
404 });
405 for (related_id, _rel_type) in related {
406 if let Some(related_name) = map.get(&related_id.to_string()) {
407 if query_entities.iter().any(|qe| qe == related_name) {
409 continue;
410 }
411 if let Ok(Some(related_profile)) =
412 self.storage.get_entity_profile(related_name)
413 {
414 for source_id in related_profile
415 .source_memories
416 .iter()
417 .take(MAX_RELATED_MEMORIES_PER_ENTITY)
418 {
419 entity_scores
420 .entry(MemoryId::from_u64(source_id.to_u64()))
421 .and_modify(|s| *s = (*s).max(RELATED_ENTITY_SCORE))
422 .or_insert(RELATED_ENTITY_SCORE);
423 }
424 }
425 }
426 }
427 }
428 }
429 }
430
431 trace_begin!(trace, "profile_search");
438 let profile_result =
439 self.profile_search
440 .search(query_text, query_embedding, limit * fetch_multiplier)?;
441 let matched_facts = profile_result.matched_facts;
442 trace_record!(trace, "matched_facts_count", matched_facts.len());
443 if let Some(ref mut t) = trace {
444 let fact_details: Vec<String> = matched_facts
445 .iter()
446 .take(10)
447 .map(|f| {
448 format!(
449 "{}:{}={} (score={:.2})",
450 f.entity_name, f.fact_type, f.value, f.score
451 )
452 })
453 .collect();
454 t.record("fact_details", fact_details);
455 }
456 for (memory_id, profile_score) in profile_result.source_scores {
457 let boosted = 2.0 + profile_score; entity_scores
460 .entry(MemoryId::from_u64(memory_id.to_u64()))
461 .and_modify(|s| {
462 *s = (*s).max(boosted);
463 })
464 .or_insert(boosted);
465 }
466
467 let mut bm25_scores: HashMap<MemoryId, f32> = {
472 let mut norm: HashMap<MemoryId, f32> = HashMap::with_capacity(bm25_scores.len());
473 for (id, score) in bm25_scores {
474 norm.entry(MemoryId::from_u64(id.to_u64()))
475 .and_modify(|s: &mut f32| *s = (*s).max(score))
476 .or_insert(score);
477 }
478 norm
479 };
480 let mut temporal_scores: HashMap<MemoryId, f32> = {
481 let mut norm: HashMap<MemoryId, f32> = HashMap::with_capacity(temporal_scores.len());
482 for (id, score) in temporal_scores {
483 norm.entry(MemoryId::from_u64(id.to_u64()))
484 .and_modify(|s: &mut f32| *s = (*s).max(score))
485 .or_insert(score);
486 }
487 norm
488 };
489 let mut causal_scores: HashMap<MemoryId, f32> = {
490 let mut norm: HashMap<MemoryId, f32> = HashMap::with_capacity(causal_scores.len());
491 for (id, score) in causal_scores {
492 norm.entry(MemoryId::from_u64(id.to_u64()))
493 .and_modify(|s: &mut f32| *s = (*s).max(score))
494 .or_insert(score);
495 }
496 norm
497 };
498
499 let effective_prefilter = if MultiTurnAggregator::is_aggregation(&query_text.to_lowercase())
508 {
509 self.semantic_prefilter_threshold * 0.5
510 } else {
511 self.semantic_prefilter_threshold
512 };
513
514 if let Some(ns) = namespace {
516 self.filter_by_namespace(&mut semantic_scores, ns)?;
517 self.filter_by_namespace(&mut bm25_scores, ns)?;
518 self.filter_by_namespace(&mut temporal_scores, ns)?;
519 self.filter_by_namespace(&mut causal_scores, ns)?;
520 self.filter_by_namespace(&mut entity_scores, ns)?;
521 }
522
523 if let Some(filter_list) = filters {
525 if !filter_list.is_empty() {
526 self.filter_by_metadata(&mut semantic_scores, filter_list)?;
527 self.filter_by_metadata(&mut bm25_scores, filter_list)?;
528 self.filter_by_metadata(&mut temporal_scores, filter_list)?;
529 self.filter_by_metadata(&mut causal_scores, filter_list)?;
530 self.filter_by_metadata(&mut entity_scores, filter_list)?;
531 }
532 }
533
534 let mut seed_results: HashMap<MemoryId, f32> = semantic_scores.clone();
538 for (id, score) in &bm25_scores {
539 seed_results
540 .entry(id.clone())
541 .and_modify(|s: &mut f32| *s = (*s).max(*score))
542 .or_insert(*score);
543 }
544
545 let avg_seed_quality = if !seed_results.is_empty() {
547 seed_results.values().sum::<f32>() / seed_results.len() as f32
548 } else {
549 0.0
550 };
551
552 const GRAPH_EXPANSION_THRESHOLD: f32 = 0.6;
554 if avg_seed_quality >= GRAPH_EXPANSION_THRESHOLD {
555 let graph_traversal = crate::query::graph_traversal::GraphTraversal::new(
556 self.graph_manager.clone(),
557 self.storage.clone(),
558 crate::query::graph_traversal::TraversalConfig::default(),
559 );
560
561 let expanded_results =
562 graph_traversal.expand(&seed_results, intent.intent, limit * 5)?;
563
564 for (expanded_id, expansion_score) in expanded_results {
567 let expanded_partial = MemoryId::from_u64(expanded_id.to_u64());
570 if seed_results.contains_key(&expanded_partial) {
572 continue;
573 }
574
575 match intent.intent {
577 crate::query::QueryIntent::Causal => {
578 causal_scores.insert(expanded_partial, expansion_score);
579 }
580 crate::query::QueryIntent::Entity | crate::query::QueryIntent::Factual => {
581 entity_scores.insert(expanded_partial, expansion_score);
582 }
583 crate::query::QueryIntent::Temporal => {
584 temporal_scores
586 .entry(expanded_partial.clone())
587 .and_modify(|s| *s = (*s).max(expansion_score * 0.7))
588 .or_insert(expansion_score * 0.7);
589 entity_scores.insert(expanded_partial, expansion_score * 0.5);
590 }
591 }
592 }
593 }
594
595 trace_begin!(trace, "prefusion_filter");
600 let sem_count_before = semantic_scores.len();
601 if effective_prefilter > 0.0 {
602 let entity_ids: std::collections::HashSet<&MemoryId> = entity_scores.keys().collect();
603 semantic_scores
604 .retain(|id, score| *score >= effective_prefilter || entity_ids.contains(id));
605 }
606 trace_record!(trace, "threshold", effective_prefilter);
607 trace_record!(trace, "before_count", sem_count_before);
608 trace_record!(trace, "after_count", semantic_scores.len());
609 trace_record!(trace, "entity_exempt_count", entity_scores.len());
610
611 trace_begin!(trace, "rrf_fusion");
613 let mut fused_results = self.fusion_engine.fuse(
614 intent.intent,
615 &semantic_scores,
616 &bm25_scores,
617 &temporal_scores,
618 &causal_scores,
619 &entity_scores,
620 );
621 trace_record!(trace, "candidate_count", fused_results.len());
622 if let Some(ref mut t) = trace {
623 let top_scores: Vec<f32> = fused_results
624 .iter()
625 .take(10)
626 .map(|r| r.fused_score)
627 .collect();
628 let top_ids: Vec<String> = fused_results
629 .iter()
630 .take(10)
631 .map(|r| r.id.to_string())
632 .collect();
633 t.record("top_10_fused_scores", top_scores);
634 t.record("top_10_ids", top_ids);
635 }
636
637 for result in &mut fused_results {
646 let mut dimension_count = 0;
648
649 if result.semantic_score > 0.0 {
650 dimension_count += 1;
651 }
652 if result.bm25_score > 0.0 {
653 dimension_count += 1;
654 }
655 if result.temporal_score > 0.0 {
656 dimension_count += 1;
657 }
658 if result.causal_score > 0.0 {
659 dimension_count += 1;
660 }
661 if result.entity_score > 0.0 {
662 dimension_count += 1;
663 }
664
665 let single_dim_confidence = match intent.intent {
669 crate::query::QueryIntent::Factual => 0.7,
670 crate::query::QueryIntent::Entity => 0.7,
671 _ => 0.4,
672 };
673
674 result.confidence = match dimension_count {
675 5 => 1.0, 4 => 0.9, 3 => 0.8, 2 => 0.6, 1 => single_dim_confidence, _ => 0.2, };
682
683 result.fused_score *= result.confidence;
685 }
686
687 fused_results.sort_by(|a, b| {
689 b.fused_score
690 .partial_cmp(&a.fused_score)
691 .unwrap_or(std::cmp::Ordering::Equal)
692 });
693
694 trace_begin!(trace, "confidence_scoring");
695 if let Some(ref mut t) = trace {
696 let scores: Vec<f32> = fused_results.iter().map(|r| r.fused_score).collect();
697 let mean = if scores.is_empty() {
698 0.0
699 } else {
700 scores.iter().sum::<f32>() / scores.len() as f32
701 };
702 let std_dev = if scores.len() > 1 {
703 let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>()
704 / (scores.len() - 1) as f32;
705 variance.sqrt()
706 } else {
707 0.0
708 };
709 t.record("mean_score", mean);
710 t.record("std_score", std_dev);
711 t.record("result_count", fused_results.len());
712 }
713
714 trace_begin!(trace, "speaker_reranking");
716 trace_record!(trace, "speaker_entity", speaker_entity.clone());
717 let mut speaker_reranked_count = 0usize;
726 if let Some(ref speaker_target) = speaker_entity {
727 let target_lower = speaker_target.to_lowercase();
728 let mut any_speaker_data = false;
729 for result in &mut fused_results {
730 if let Ok(Some(memory)) = self.storage.get_memory_by_u64(result.id.to_u64()) {
731 if let Some(speaker) = memory.get_metadata("speaker") {
732 any_speaker_data = true;
733 if speaker.to_lowercase() != target_lower {
734 result.fused_score *= 0.2;
740 speaker_reranked_count += 1;
741 }
742 }
743 }
744 }
745 if any_speaker_data {
746 fused_results.sort_by(|a, b| {
747 b.fused_score
748 .partial_cmp(&a.fused_score)
749 .unwrap_or(std::cmp::Ordering::Equal)
750 });
751 }
752 }
753
754 trace_record!(trace, "reranked_count", speaker_reranked_count);
755
756 trace_begin!(trace, "aggregation_expansion");
758 let is_aggregation = MultiTurnAggregator::is_aggregation(&query_text.to_lowercase());
762 trace_record!(trace, "is_aggregation", is_aggregation);
763 if is_aggregation {
764 if let Some(ref entity_name) = speaker_entity {
766 if let Ok(Some(entity)) = self.storage.find_entity_by_name(entity_name) {
767 let graph = self.graph_manager.read().unwrap();
768 let entity_memories = graph.get_entity_memories(&entity.id);
769 drop(graph);
770
771 let existing_ids: std::collections::HashSet<_> =
772 fused_results.iter().map(|r| r.id.clone()).collect();
773
774 let top_k_embeddings: Vec<Vec<f32>> = fused_results
776 .iter()
777 .take(limit.min(10)) .filter_map(|r| {
779 self.storage
780 .get_memory_by_u64(r.id.to_u64())
781 .ok()
782 .flatten()
783 .map(|m| m.embedding.clone())
784 })
785 .filter(|e| !e.is_empty() && e.len() == query_embedding.len())
786 .collect();
787
788 let candidates: Vec<_> = entity_memories
789 .memories
790 .into_iter()
791 .filter(|m| !existing_ids.contains(m))
792 .collect();
793
794 let mut expanded = Vec::new();
795 let max_expansions = 5;
796
797 for candidate_id in candidates {
798 if expanded.len() >= max_expansions {
799 break;
800 }
801 if let Ok(Some(memory)) =
802 self.storage.get_memory_by_u64(candidate_id.to_u64())
803 {
804 if memory.embedding.is_empty()
805 || query_embedding.len() != memory.embedding.len()
806 {
807 continue;
808 }
809
810 let max_sim = top_k_embeddings
812 .iter()
813 .map(|e| cosine_similarity(&memory.embedding, e))
814 .fold(0.0f32, f32::max);
815
816 if max_sim < 0.8 {
817 let sem_score =
818 cosine_similarity(query_embedding, &memory.embedding);
819 expanded.push(FusedResult {
820 id: candidate_id,
821 semantic_score: sem_score,
822 bm25_score: 0.0,
823 temporal_score: 0.0,
824 causal_score: 0.0,
825 entity_score: 0.5,
826 fused_score: sem_score * 0.7,
827 confidence: 0.7,
828 });
829 }
830 }
831 }
832
833 if !expanded.is_empty() {
834 fused_results.extend(expanded);
835 fused_results.sort_by(|a, b| {
836 b.fused_score
837 .partial_cmp(&a.fused_score)
838 .unwrap_or(std::cmp::Ordering::Equal)
839 });
840 }
841 }
842 }
843 }
844
845 trace_begin!(trace, "dialog_bridging");
847 {
861 let existing_ids: std::collections::HashSet<MemoryId> =
862 fused_results.iter().map(|r| r.id.clone()).collect();
863 let mut neighbors: Vec<FusedResult> = Vec::new();
864
865 for parent in fused_results.iter().take(3) {
866 let parent_score = parent.fused_score;
867 if parent_score < 0.05 {
869 continue;
870 }
871 if let Ok(Some(memory)) = self.storage.get_memory_by_u64(parent.id.to_u64()) {
872 if let Some(dialog_id) = memory.metadata.get("dialog_id") {
873 if let Some((conv, turn)) = Self::parse_dialog_id(dialog_id) {
874 for delta in [-2i32, -1, 1, 2] {
875 let neighbor_turn = turn + delta;
876 if neighbor_turn < 1 {
877 continue;
878 }
879 let neighbor_did = format!("{}:{}", conv, neighbor_turn);
880 if let Ok(Some(nbr)) =
881 self.storage.find_memory_by_dialog_id(&neighbor_did)
882 {
883 let partial_id = MemoryId::from_u64(nbr.id.to_u64());
884 if existing_ids.contains(&partial_id) {
885 continue;
886 }
887 if nbr.embedding.is_empty()
889 || nbr.embedding.len() != query_embedding.len()
890 {
891 continue;
892 }
893 let sem_score =
894 cosine_similarity(query_embedding, &nbr.embedding);
895 if sem_score < 0.2 {
900 continue;
901 }
902 neighbors.push(FusedResult {
903 id: partial_id,
904 semantic_score: sem_score,
905 bm25_score: 0.0,
906 temporal_score: 0.0,
907 causal_score: 0.0,
908 entity_score: parent.entity_score * 0.8,
909 fused_score: parent_score * 0.8,
910 confidence: 0.6,
911 });
912 }
913 }
914 }
915 }
916 }
917 }
918
919 trace_record!(trace, "injected_count", neighbors.len());
920 if !neighbors.is_empty() {
921 fused_results.extend(neighbors);
922 fused_results.sort_by(|a, b| {
923 b.fused_score
924 .partial_cmp(&a.fused_score)
925 .unwrap_or(std::cmp::Ordering::Equal)
926 });
927 }
928 }
929
930 trace_begin!(trace, "mmr_selection");
932 trace_record!(trace, "pool_size", fused_results.len());
933 {
945 const MMR_LAMBDA: f32 = 0.9;
946 trace_record!(trace, "lambda", MMR_LAMBDA);
947 if fused_results.len() > limit {
948 let pool_size = fused_results.len().min(limit * 5);
950 let pool = &fused_results[..pool_size];
951
952 let mut pool_embeddings: Vec<Option<Vec<f32>>> = Vec::with_capacity(pool_size);
954
955 for r in pool.iter() {
956 if let Ok(Some(memory)) = self.storage.get_memory_by_u64(r.id.to_u64()) {
957 let emb = if !memory.embedding.is_empty()
958 && memory.embedding.len() == query_embedding.len()
959 {
960 Some(memory.embedding)
961 } else {
962 None
963 };
964 pool_embeddings.push(emb);
965 } else {
966 pool_embeddings.push(None);
967 }
968 }
969
970 let mut selected: Vec<usize> = Vec::with_capacity(limit);
971 let mut remaining: Vec<usize> = (0..pool_size).collect();
972
973 while selected.len() < limit && !remaining.is_empty() {
974 let mut best_idx_in_remaining = 0;
975 let mut best_mmr_score = f32::NEG_INFINITY;
976
977 for (ri, &cand_idx) in remaining.iter().enumerate() {
978 let relevance = pool[cand_idx].fused_score;
979
980 let max_sim = if selected.is_empty() {
981 0.0
982 } else if let Some(ref cand_emb) = pool_embeddings[cand_idx] {
983 selected
984 .iter()
985 .filter_map(|&sel_idx| {
986 pool_embeddings[sel_idx]
987 .as_ref()
988 .map(|sel_emb| cosine_similarity(cand_emb, sel_emb))
989 })
990 .fold(0.0f32, f32::max)
991 } else {
992 0.0
993 };
994
995 let mmr_score = MMR_LAMBDA * relevance - (1.0 - MMR_LAMBDA) * max_sim;
996 if mmr_score > best_mmr_score {
997 best_mmr_score = mmr_score;
998 best_idx_in_remaining = ri;
999 }
1000 }
1001
1002 let chosen = remaining.swap_remove(best_idx_in_remaining);
1003 selected.push(chosen);
1004 }
1005
1006 fused_results = selected.into_iter().map(|i| pool[i].clone()).collect();
1007 }
1008 }
1009
1010 trace_record!(trace, "selected_count", fused_results.len());
1011
1012 trace_begin!(trace, "multi_turn_aggregation");
1014 let aggregator = crate::query::aggregator::MultiTurnAggregator::default();
1015 let query_type = aggregator.classify_query(query_text);
1016 let pre_agg_count = fused_results.len();
1017 let mut final_results =
1018 aggregator.aggregate(query_type, query_text, fused_results, &self.storage, limit)?;
1019 trace_record!(trace, "query_type", format!("{:?}", query_type));
1020 trace_record!(trace, "before_count", pre_agg_count);
1021 trace_record!(trace, "after_count", final_results.len());
1022
1023 trace_begin!(trace, "adaptive_k");
1025 let pre_adaptive_count = final_results.len();
1026 if self.adaptive_k_threshold > 0.0
1027 && self.adaptive_k_threshold < 1.0
1028 && final_results.len() > 1
1029 {
1030 let new_len = adaptive_k_select(&final_results, self.adaptive_k_threshold, limit);
1031 final_results.truncate(new_len);
1032 }
1033 trace_record!(trace, "threshold", self.adaptive_k_threshold);
1034 trace_record!(trace, "before_count", pre_adaptive_count);
1035 trace_record!(trace, "after_count", final_results.len());
1036
1037 trace_begin!(trace, "result_summary");
1039 trace_record!(trace, "total_results", final_results.len());
1040 if let Some(ref mut t) = trace {
1041 let top_ids: Vec<String> = final_results
1042 .iter()
1043 .take(5)
1044 .map(|r| r.id.to_string())
1045 .collect();
1046 let top_scores: Vec<f32> = final_results
1047 .iter()
1048 .take(5)
1049 .map(|r| r.fused_score)
1050 .collect();
1051 let top_sem: Vec<f32> = final_results
1052 .iter()
1053 .take(5)
1054 .map(|r| r.semantic_score)
1055 .collect();
1056 t.record("top_5_ids", top_ids);
1057 t.record("top_5_fused_scores", top_scores);
1058 t.record("top_5_semantic_scores", top_sem);
1059 }
1060
1061 Ok((intent, final_results, matched_facts))
1062 }
1063
1064 fn retrieve_entity_focused(
1074 &self,
1075 entity_name: &str,
1076 query_text: &str,
1077 query_embedding: &[f32],
1078 limit: usize,
1079 namespace: Option<&str>,
1080 filters: Option<&[MetadataFilter]>,
1081 intent: IntentClassification,
1082 ) -> Result<(
1083 IntentClassification,
1084 Vec<FusedResult>,
1085 Vec<crate::query::profile_search::MatchedProfileFact>,
1086 )> {
1087 let entity = self.storage.find_entity_by_name(entity_name)?;
1089
1090 if entity.is_none() {
1091 return Ok((intent, vec![], vec![]));
1093 }
1094
1095 let entity = entity.unwrap();
1096
1097 let graph = self.graph_manager.read().unwrap();
1099 let entity_result = graph.get_entity_memories(&entity.id);
1100 let mut memory_id_set: std::collections::HashSet<MemoryId> =
1101 entity_result.memories.into_iter().collect();
1102 drop(graph); if let Ok(Some(profile)) = self.storage.get_entity_profile(entity_name) {
1110 for source_id in &profile.source_memories {
1111 memory_id_set.insert(source_id.clone());
1112 }
1113 }
1114
1115 if memory_id_set.is_empty() {
1116 return Ok((intent, vec![], vec![]));
1117 }
1118
1119 let memory_ids: Vec<MemoryId> = memory_id_set.into_iter().collect();
1120
1121 let profile_boost_memories: HashMap<MemoryId, f32> = self
1125 .profile_search
1126 .compute_fact_boosts(entity_name, query_text, query_embedding)?;
1127
1128 let matched_facts = self
1130 .profile_search
1131 .search(query_text, query_embedding, limit)?
1132 .matched_facts;
1133
1134 let mut scored_results = Vec::new();
1136
1137 for memory_id in memory_ids {
1138 let memory = match self.storage.get_memory_by_u64(memory_id.to_u64())? {
1140 Some(m) => m,
1141 None => continue, };
1143
1144 if let Some(ns) = namespace {
1146 if memory.get_namespace() != ns {
1147 continue;
1148 }
1149 }
1150
1151 if let Some(filter_list) = filters {
1153 if !filter_list.is_empty() && !Self::memory_matches_filters(&memory, filter_list) {
1154 continue;
1155 }
1156 }
1157
1158 let semantic_score = if !memory.embedding.is_empty()
1160 && query_embedding.len() == memory.embedding.len()
1161 {
1162 cosine_similarity(query_embedding, &memory.embedding)
1163 } else {
1164 0.0
1165 };
1166
1167 let bm25_score = {
1169 let results = self.bm25_index.search(query_text, 1000)?;
1170 results
1171 .iter()
1172 .find(|r| r.memory_id == memory_id)
1173 .map(|r| r.score)
1174 .unwrap_or(0.0)
1175 };
1176
1177 let profile_boost = profile_boost_memories
1179 .get(&memory_id)
1180 .copied()
1181 .unwrap_or(0.0);
1182
1183 let speaker_factor = match memory.get_metadata("speaker") {
1188 Some(speaker) if speaker.to_lowercase() == entity_name.to_lowercase() => 1.0,
1189 Some(_) => 0.4, None => 0.7, };
1192
1193 let base_score = 0.7 * semantic_score + 0.3 * bm25_score;
1195 let combined_score = if profile_boost > 0.0 {
1196 (base_score + profile_boost * 0.5).min(1.0) * speaker_factor
1198 } else {
1199 base_score * speaker_factor
1200 };
1201
1202 let confidence = if profile_boost > 0.0 { 0.95 } else { 0.9 };
1204
1205 scored_results.push(FusedResult {
1206 id: memory_id.clone(),
1207 semantic_score,
1208 bm25_score,
1209 temporal_score: 0.0,
1210 causal_score: 0.0,
1211 entity_score: 1.0 + profile_boost, fused_score: combined_score,
1213 confidence,
1214 });
1215 }
1216
1217 scored_results.sort_by(|a, b| {
1219 b.fused_score
1220 .partial_cmp(&a.fused_score)
1221 .unwrap_or(std::cmp::Ordering::Equal)
1222 });
1223
1224 let final_results = scored_results.into_iter().take(limit).collect();
1225
1226 Ok((intent, final_results, matched_facts))
1227 }
1228
1229 #[allow(dead_code)]
1238 fn extract_query_entity(&self, query_text: &str) -> Option<String> {
1239 let entities = self.profile_search.detect_entities(query_text).ok()?;
1240 if entities.len() == 1 {
1241 Some(entities[0].clone())
1242 } else {
1243 None
1244 }
1245 }
1246
1247 fn semantic_search(
1249 &self,
1250 query_embedding: &[f32],
1251 limit: usize,
1252 ) -> Result<HashMap<MemoryId, f32>> {
1253 let index = self.vector_index.read().unwrap();
1254 let results = index.search(query_embedding, limit)?;
1255
1256 let mut scores = HashMap::new();
1257 for result in results {
1258 scores.insert(result.id, result.similarity);
1259 }
1260
1261 Ok(scores)
1262 }
1263
1264 fn bm25_search(&self, query_text: &str, limit: usize) -> Result<HashMap<MemoryId, f32>> {
1269 let results = self.bm25_index.search(query_text, limit)?;
1270
1271 let mut scores = HashMap::new();
1272 for result in results {
1273 scores.insert(result.memory_id, result.score);
1274 }
1275
1276 FusionEngine::normalize_scores(&mut scores);
1278
1279 Ok(scores)
1280 }
1281
1282 fn temporal_search(&self, query_text: &str, limit: usize) -> Result<HashMap<MemoryId, f32>> {
1288 let fetch_limit = limit * 20;
1291 let recent_memories = self.temporal_index.recent(fetch_limit)?;
1292
1293 let mut scores = HashMap::new();
1294
1295 for temporal_result in recent_memories {
1296 if let Some(memory) = self.storage.get_memory(&temporal_result.id)? {
1298 let mut temporal_score = 0.0;
1299
1300 if let Some(slm_metadata) = Self::get_slm_metadata(&memory) {
1302 temporal_score = Self::calculate_slm_temporal_score(query_text, &slm_metadata);
1303 }
1304
1305 if temporal_score == 0.0 {
1307 if let Some(content_score) =
1309 self.get_temporal_content_score(&memory, query_text)
1310 {
1311 temporal_score = content_score;
1312 }
1313 }
1314
1315 if temporal_score > 0.0 {
1316 scores.insert(temporal_result.id.clone(), temporal_score);
1317 }
1318 }
1319 }
1320
1321 if !scores.is_empty() {
1323 FusionEngine::normalize_scores(&mut scores);
1325
1326 let mut score_vec: Vec<_> = scores.into_iter().collect();
1328 score_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1329 score_vec.truncate(limit);
1330
1331 Ok(score_vec.into_iter().collect())
1332 } else {
1333 Ok(HashMap::new())
1345 }
1346 }
1347
1348 fn get_temporal_content_score(&self, memory: &Memory, query_text: &str) -> Option<f32> {
1352 let temporal_expressions_json = memory.get_metadata("temporal_expressions")?;
1354 let memory_expressions: Vec<String> =
1355 serde_json::from_str(temporal_expressions_json).ok()?;
1356
1357 if memory_expressions.is_empty() {
1358 return None;
1359 }
1360
1361 let query_lower = query_text.to_lowercase();
1363
1364 let temporal_keywords = [
1366 "yesterday",
1367 "today",
1368 "tomorrow",
1369 "last week",
1370 "next week",
1371 "last month",
1372 "next month",
1373 "last year",
1374 "next year",
1375 "monday",
1376 "tuesday",
1377 "wednesday",
1378 "thursday",
1379 "friday",
1380 "saturday",
1381 "sunday",
1382 "january",
1383 "february",
1384 "march",
1385 "april",
1386 "may",
1387 "june",
1388 "july",
1389 "august",
1390 "september",
1391 "october",
1392 "november",
1393 "december",
1394 "morning",
1395 "afternoon",
1396 "evening",
1397 "night",
1398 "recently",
1399 "earlier",
1400 "later",
1401 "before",
1402 "after",
1403 "first",
1404 "last",
1405 "initially",
1406 "finally",
1407 "eventually",
1408 ];
1409
1410 let mut match_score: f32 = 0.0;
1411
1412 for memory_expr in &memory_expressions {
1413 let expr_lower = memory_expr.to_lowercase();
1414
1415 if query_lower.contains(&expr_lower) {
1417 match_score += 1.0;
1418 } else {
1419 for keyword in &temporal_keywords {
1421 if query_lower.contains(keyword) && expr_lower.contains(keyword) {
1422 match_score += 0.7;
1423 break;
1424 }
1425 }
1426 }
1427 }
1428
1429 if match_score > 0.0 {
1430 Some((match_score / memory_expressions.len() as f32).min(1.0))
1431 } else {
1432 None
1433 }
1434 }
1435
1436 #[allow(dead_code)]
1451 fn temporal_search_recency_fallback(&self, limit: usize) -> Result<HashMap<MemoryId, f32>> {
1452 let results = self.temporal_index.recent(limit)?;
1453
1454 let count = results.len();
1455 let mut scores = HashMap::new();
1456 let now_micros = crate::types::Timestamp::now().as_micros();
1457
1458 const HALF_LIFE_DAYS: f64 = 365.0;
1459 const MICROS_PER_DAY: f64 = 86_400.0 * 1_000_000.0;
1460 const FLOOR: f64 = 0.05;
1461
1462 for (i, result) in results.into_iter().enumerate() {
1463 let score = if count <= 1 {
1464 0.3
1465 } else {
1466 let effective_micros = self
1472 .storage
1473 .get_memory(&result.id)
1474 .ok()
1475 .flatten()
1476 .and_then(|m| m.get_metadata("event_date").cloned())
1477 .and_then(|d| Timestamp::from_iso8601_date(&d))
1478 .map(|ts| ts.as_micros())
1479 .unwrap_or(result.timestamp.as_micros());
1480
1481 let age_micros = now_micros.saturating_sub(effective_micros);
1482 let age_days = age_micros as f64 / MICROS_PER_DAY;
1483 let half_life = 0.5_f64.powf(age_days / HALF_LIFE_DAYS);
1484
1485 let rank_signal = 1.0 - (i as f64 / (count - 1) as f64);
1487
1488 let blended = 0.96 * half_life + 0.04 * rank_signal;
1490
1491 (0.3 * (FLOOR + (1.0 - FLOOR) * blended)) as f32
1493 };
1494 if score > 0.01 {
1495 scores.insert(result.id, score);
1496 }
1497 }
1498
1499 Ok(scores)
1500 }
1501
1502 fn entity_search(&self, query_text: &str, limit: usize) -> Result<HashMap<MemoryId, f32>> {
1511 let query_entities = self.profile_search.detect_entities(query_text)?;
1513
1514 if query_entities.is_empty() {
1516 return Ok(HashMap::new());
1517 }
1518
1519 let query_entity_set: std::collections::HashSet<String> =
1521 query_entities.iter().map(|e| e.to_lowercase()).collect();
1522
1523 let mut scores = HashMap::new();
1524 let mut found_via_graph = false;
1525
1526 {
1530 let graph = self.graph_manager.read().unwrap();
1531 let mut candidate_ids = Vec::new();
1532 let mut relationship_candidates: std::collections::HashSet<MemoryId> =
1533 std::collections::HashSet::new();
1534
1535 for query_entity in &query_entities {
1536 if let Ok(Some(entity)) = self.storage.find_entity_by_name(query_entity) {
1537 found_via_graph = true;
1538 let entity_result = graph.get_entity_memories(&entity.id);
1539 for memory_id in entity_result.memories {
1540 candidate_ids.push(memory_id);
1541 }
1542
1543 let related = graph.get_related_entities(&entity.id);
1545 for (related_entity_id, _relation_type) in related {
1546 let related_memories = graph.get_entity_memories(&related_entity_id);
1547 for memory_id in related_memories.memories {
1548 relationship_candidates.insert(memory_id.clone());
1549 candidate_ids.push(memory_id);
1550 }
1551 }
1552 }
1553 }
1554
1555 drop(graph); for memory_id in candidate_ids {
1559 if let Ok(Some(memory)) = self.storage.get_memory(&memory_id) {
1560 let mut entity_score = 0.0;
1561
1562 if let Some(slm_metadata) = Self::get_slm_metadata(&memory) {
1564 entity_score =
1565 Self::calculate_slm_entity_score(&query_entity_set, &slm_metadata);
1566 }
1567
1568 if entity_score == 0.0 {
1570 if let Some(entities_json) = memory.get_metadata("entity_names") {
1571 if let Ok(memory_entities) =
1572 serde_json::from_str::<Vec<String>>(entities_json)
1573 {
1574 entity_score = Self::calculate_entity_overlap(
1575 &query_entity_set,
1576 &memory_entities,
1577 );
1578 }
1579 }
1580 }
1581
1582 if entity_score == 0.0 {
1584 entity_score = if relationship_candidates.contains(&memory_id) {
1585 0.4
1586 } else {
1587 0.5
1588 };
1589 }
1590
1591 scores
1592 .entry(memory_id)
1593 .and_modify(|s: &mut f32| *s = (*s).max(entity_score))
1594 .or_insert(entity_score);
1595 }
1596 }
1597 }
1598
1599 if !found_via_graph {
1602 for query_entity in &query_entities {
1603 let bm25_results = self.bm25_index.search(query_entity, limit * 5)?;
1604 for result in bm25_results {
1605 if result.score > 0.0 {
1606 let normalized = result.score / (result.score + 1.0);
1607 scores
1608 .entry(result.memory_id)
1609 .and_modify(|s: &mut f32| *s = (*s + normalized).min(1.0))
1610 .or_insert(normalized);
1611 }
1612 }
1613 }
1614 }
1615
1616 if !scores.is_empty() {
1618 FusionEngine::normalize_scores(&mut scores);
1619
1620 let mut score_vec: Vec<_> = scores.into_iter().collect();
1621 score_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1622 score_vec.truncate(limit);
1623
1624 Ok(score_vec.into_iter().collect())
1625 } else {
1626 Ok(HashMap::new())
1627 }
1628 }
1629
1630 fn calculate_entity_overlap(
1637 query_entity_set: &std::collections::HashSet<String>,
1638 memory_entities: &[String],
1639 ) -> f32 {
1640 if query_entity_set.is_empty() || memory_entities.is_empty() {
1641 return 0.0;
1642 }
1643
1644 let mut total_score = 0.0;
1645
1646 let memory_entity_set: std::collections::HashSet<String> =
1648 memory_entities.iter().map(|e| e.to_lowercase()).collect();
1649
1650 for query_entity in query_entity_set {
1651 if memory_entity_set.contains(query_entity) {
1653 total_score += 1.0;
1654 } else {
1655 for memory_entity in &memory_entity_set {
1657 if memory_entity.contains(query_entity) || query_entity.contains(memory_entity)
1658 {
1659 total_score += 0.5;
1660 break;
1661 }
1662 }
1663 }
1664 }
1665
1666 if !query_entity_set.is_empty() {
1668 total_score / query_entity_set.len() as f32
1669 } else {
1670 0.0
1671 }
1672 }
1673
1674 fn causal_search(&self, query_text: &str, limit: usize) -> Result<HashMap<MemoryId, f32>> {
1684 let causal_extractor = get_causal_extractor();
1685
1686 if !causal_extractor.has_causal_intent(query_text) {
1688 return Ok(HashMap::new());
1690 }
1691
1692 let fetch_limit = limit * 20;
1695 let recent_memories = self.temporal_index.recent(fetch_limit)?;
1696
1697 let mut scores = HashMap::new();
1698
1699 for temporal_result in recent_memories {
1700 if let Some(memory) = self.storage.get_memory(&temporal_result.id)? {
1702 let mut causal_score = 0.0;
1703
1704 if let Some(slm_metadata) = Self::get_slm_metadata(&memory) {
1706 causal_score = Self::calculate_slm_causal_score(query_text, &slm_metadata);
1707 }
1708
1709 if causal_score == 0.0 {
1711 if let Some(density_str) = memory.get_metadata("causal_density") {
1712 if let Ok(causal_density) = density_str.parse::<f32>() {
1713 if causal_density > 0.1 {
1715 let has_graph_links = {
1717 let graph = self.graph_manager.read().unwrap();
1718 graph
1719 .get_causes(&temporal_result.id, 1)
1720 .ok()
1721 .is_some_and(|r| !r.paths.is_empty())
1722 || graph
1723 .get_effects(&temporal_result.id, 1)
1724 .ok()
1725 .is_some_and(|r| !r.paths.is_empty())
1726 };
1727
1728 causal_score = causal_extractor
1730 .calculate_relevance_score(causal_density, has_graph_links);
1731 }
1732 }
1733 }
1734 }
1735
1736 if causal_score > 0.0 {
1737 scores.insert(temporal_result.id.clone(), causal_score);
1738 }
1739 }
1740 }
1741
1742 if !scores.is_empty() {
1744 FusionEngine::normalize_scores(&mut scores);
1746
1747 let mut score_vec: Vec<_> = scores.into_iter().collect();
1749 score_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1750 score_vec.truncate(limit);
1751
1752 Ok(score_vec.into_iter().collect())
1753 } else {
1754 Ok(HashMap::new())
1756 }
1757 }
1758
1759 pub fn temporal_range_query(
1770 &self,
1771 start: Timestamp,
1772 end: Timestamp,
1773 limit: usize,
1774 namespace: Option<&str>,
1775 ) -> Result<Vec<FusedResult>> {
1776 let fetch_limit = if namespace.is_some() {
1778 limit * 3
1779 } else {
1780 limit
1781 };
1782 let results = self.temporal_index.range_query(start, end, fetch_limit)?;
1783
1784 let mut fused: Vec<FusedResult> = Vec::new();
1786 let mut position = 0;
1787
1788 for result in results {
1789 if let Some(ns) = namespace {
1791 if let Ok(Some(memory)) = self.storage.get_memory(&result.id) {
1792 if memory.get_namespace() != ns {
1793 continue; }
1795 } else {
1796 continue; }
1798 }
1799
1800 let temporal_score = if fetch_limit > 1 {
1802 1.0 - (position as f32 / (fetch_limit - 1) as f32)
1803 } else {
1804 1.0
1805 };
1806
1807 fused.push(FusedResult {
1808 id: result.id,
1809 semantic_score: 0.0,
1810 bm25_score: 0.0,
1811 temporal_score,
1812 causal_score: 0.0,
1813 entity_score: 0.0,
1814 fused_score: temporal_score,
1815 confidence: 1.0,
1816 });
1817
1818 position += 1;
1819 if fused.len() >= limit {
1820 break;
1821 }
1822 }
1823
1824 Ok(fused)
1825 }
1826
1827 fn contains_whole_word_static(text: &str, word: &str) -> bool {
1830 for (idx, _) in text.match_indices(word) {
1831 let before_ok = idx == 0 || !text.as_bytes()[idx - 1].is_ascii_alphanumeric();
1832 let after_idx = idx + word.len();
1833 let after_ok =
1834 after_idx >= text.len() || !text.as_bytes()[after_idx].is_ascii_alphanumeric();
1835 if before_ok && after_ok {
1836 return true;
1837 }
1838 }
1839 false
1840 }
1841
1842 fn parse_dialog_id(dialog_id: &str) -> Option<(&str, i32)> {
1847 let colon = dialog_id.rfind(':')?;
1848 let conv = &dialog_id[..colon];
1849 let turn: i32 = dialog_id[colon + 1..].parse().ok()?;
1850 if conv.is_empty() {
1851 return None;
1852 }
1853 Some((conv, turn))
1854 }
1855
1856 fn filter_by_namespace(
1860 &self,
1861 scores: &mut HashMap<MemoryId, f32>,
1862 namespace: &str,
1863 ) -> Result<()> {
1864 let mut to_remove = Vec::new();
1866
1867 for memory_id in scores.keys() {
1868 if let Some(memory) = self.storage.get_memory_by_u64(memory_id.to_u64())? {
1870 if memory.get_namespace() != namespace {
1871 to_remove.push(memory_id.clone());
1872 }
1873 } else {
1874 to_remove.push(memory_id.clone());
1876 }
1877 }
1878
1879 for id in to_remove {
1881 scores.remove(&id);
1882 }
1883
1884 Ok(())
1885 }
1886
1887 fn filter_by_metadata(
1891 &self,
1892 scores: &mut HashMap<MemoryId, f32>,
1893 filters: &[MetadataFilter],
1894 ) -> Result<()> {
1895 if filters.is_empty() {
1896 return Ok(());
1897 }
1898
1899 let mut to_remove = Vec::new();
1901
1902 for memory_id in scores.keys() {
1903 if let Some(memory) = self.storage.get_memory_by_u64(memory_id.to_u64())? {
1905 if !Self::memory_matches_filters(&memory, filters) {
1907 to_remove.push(memory_id.clone());
1908 }
1909 } else {
1910 to_remove.push(memory_id.clone());
1912 }
1913 }
1914
1915 for id in to_remove {
1917 scores.remove(&id);
1918 }
1919
1920 Ok(())
1921 }
1922
1923 fn memory_matches_filters(memory: &Memory, filters: &[MetadataFilter]) -> bool {
1925 for filter in filters {
1926 let value = memory.metadata.get(&filter.field).map(|s| s.as_str());
1927 if !filter.matches(value) {
1928 return false; }
1930 }
1931 true }
1933
1934 fn get_slm_metadata(memory: &Memory) -> Option<SlmMetadata> {
1939 memory
1940 .get_metadata("slm_metadata")
1941 .and_then(|json| serde_json::from_str(json).ok())
1942 }
1943
1944 fn calculate_slm_entity_score(
1951 query_entities: &std::collections::HashSet<String>,
1952 slm_metadata: &SlmMetadata,
1953 ) -> f32 {
1954 if query_entities.is_empty() || slm_metadata.entities.is_empty() {
1955 return 0.0;
1956 }
1957
1958 let mut total_score = 0.0;
1959 let mut matches = 0;
1960
1961 for query_entity in query_entities {
1962 let query_lower = query_entity.to_lowercase();
1963
1964 for extracted in &slm_metadata.entities {
1965 if extracted.name.to_lowercase() == query_lower {
1967 total_score += 1.0;
1968 matches += 1;
1969 break;
1970 }
1971
1972 for mention in &extracted.mentions {
1974 if mention.to_lowercase() == query_lower {
1975 total_score += 0.9; matches += 1;
1977 break;
1978 }
1979 }
1980
1981 if extracted.name.to_lowercase().contains(&query_lower)
1983 || query_lower.contains(&extracted.name.to_lowercase())
1984 {
1985 total_score += 0.5;
1986 matches += 1;
1987 break;
1988 }
1989 }
1990 }
1991
1992 if matches > 0 {
1993 total_score / query_entities.len() as f32
1994 } else {
1995 0.0
1996 }
1997 }
1998
1999 fn calculate_slm_temporal_score(query_text: &str, slm_metadata: &SlmMetadata) -> f32 {
2007 let query_lower = query_text.to_lowercase();
2008 let temporal = &slm_metadata.temporal;
2009 let mut score: f32 = 0.0;
2010
2011 for marker in &temporal.markers {
2013 if query_lower.contains(&marker.to_lowercase()) {
2014 score += 0.8;
2015 break; }
2017 }
2018
2019 for date in &temporal.absolute_dates {
2021 if query_lower.contains(&date.to_lowercase()) {
2022 score += 0.9;
2023 break;
2024 }
2025 }
2026
2027 if let Some(ref sequence) = temporal.sequence {
2029 let seq_lower = sequence.to_lowercase();
2030 if (query_lower.contains("first")
2032 || query_lower.contains("initial")
2033 || query_lower.contains("begin"))
2034 && seq_lower == "early"
2035 {
2036 score += 0.7;
2037 }
2038 if (query_lower.contains("last")
2040 || query_lower.contains("final")
2041 || query_lower.contains("end"))
2042 && seq_lower == "late"
2043 {
2044 score += 0.7;
2045 }
2046 }
2047
2048 if let Some(ref relative) = temporal.relative_time {
2050 let rel_lower = relative.to_lowercase();
2051 if query_lower.contains("before") && rel_lower.contains("before") {
2052 score += 0.6;
2053 }
2054 if query_lower.contains("after") && rel_lower.contains("after") {
2055 score += 0.6;
2056 }
2057 if (query_lower.contains("recent") || query_lower.contains("latest"))
2058 && rel_lower.contains("current")
2059 {
2060 score += 0.5;
2061 }
2062 }
2063
2064 score.min(1.0) }
2066
2067 fn calculate_slm_causal_score(query_text: &str, slm_metadata: &SlmMetadata) -> f32 {
2074 let causal = &slm_metadata.causal;
2075 let query_lower = query_text.to_lowercase();
2076
2077 let mut score = causal.density * 0.5;
2079
2080 if !causal.explicit_markers.is_empty() {
2082 score += 0.2;
2083 }
2084
2085 if causal.has_implicit_causation {
2087 score += 0.3;
2088 }
2089
2090 for rel in &causal.relationships {
2092 let cause_lower = rel.cause.to_lowercase();
2093 let effect_lower = rel.effect.to_lowercase();
2094
2095 if query_lower.contains(&cause_lower) || cause_lower.contains(&query_lower) {
2097 score += 0.3 * rel.confidence;
2098 }
2099 if query_lower.contains(&effect_lower) || effect_lower.contains(&query_lower) {
2100 score += 0.3 * rel.confidence;
2101 }
2102 }
2103
2104 score.min(1.0) }
2106}
2107
2108fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
2112 if a.len() != b.len() || a.is_empty() {
2113 return 0.0;
2114 }
2115
2116 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
2117 let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
2118 let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
2119
2120 if magnitude_a == 0.0 || magnitude_b == 0.0 {
2121 return 0.0;
2122 }
2123
2124 (dot_product / (magnitude_a * magnitude_b)).clamp(0.0, 1.0)
2125}
2126
2127fn adaptive_k_select(results: &[FusedResult], threshold: f32, limit: usize) -> usize {
2138 if results.len() <= 1 {
2139 return results.len();
2140 }
2141
2142 let min_k = (limit / 3).max(5).min(results.len());
2143
2144 let max_score = results
2146 .iter()
2147 .map(|r| r.fused_score)
2148 .fold(f32::NEG_INFINITY, f32::max);
2149
2150 let exp_scores: Vec<f32> = results
2151 .iter()
2152 .map(|r| (r.fused_score - max_score).exp())
2153 .collect();
2154
2155 let sum_exp: f32 = exp_scores.iter().sum();
2156
2157 if sum_exp <= 0.0 {
2158 return results.len();
2159 }
2160
2161 let mut cumulative = 0.0f32;
2162 for (i, &exp_s) in exp_scores.iter().enumerate() {
2163 cumulative += exp_s / sum_exp;
2164 if cumulative >= threshold {
2165 return (i + 1).max(min_k);
2166 }
2167 }
2168
2169 results.len()
2170}
2171
2172#[cfg(test)]
2173mod tests {
2174 use super::*;
2175 use crate::{
2176 graph::GraphManager,
2177 index::{VectorIndex, VectorIndexConfig},
2178 storage::StorageEngine,
2179 types::Memory,
2180 };
2181 use tempfile::tempdir;
2182
2183 fn create_test_planner() -> (QueryPlanner, tempfile::TempDir) {
2184 let dir = tempdir().unwrap();
2185 let path = dir.path().join("test.mfdb");
2186
2187 let storage = Arc::new(StorageEngine::open(&path).unwrap());
2188
2189 let vector_config = VectorIndexConfig {
2190 dimension: 384,
2191 connectivity: 16,
2192 expansion_add: 128,
2193 expansion_search: 64,
2194 };
2195 let vector_index = Arc::new(RwLock::new(
2196 VectorIndex::new(vector_config, Arc::clone(&storage)).unwrap(),
2197 ));
2198
2199 let bm25_index = Arc::new(crate::index::BM25Index::new(
2200 Arc::clone(&storage),
2201 crate::index::BM25Config::default(),
2202 ));
2203
2204 let temporal_index = Arc::new(TemporalIndex::new(Arc::clone(&storage)));
2205 let graph_manager = Arc::new(RwLock::new(GraphManager::new()));
2206
2207 let planner = QueryPlanner::new(
2210 storage,
2211 vector_index,
2212 bm25_index,
2213 temporal_index,
2214 graph_manager,
2215 0.0, 0.0, crate::query::FusionStrategy::Weighted, 60.0, #[cfg(feature = "slm")]
2220 None, false, 0.0, )
2224 .expect("Failed to create QueryPlanner");
2225
2226 (planner, dir)
2227 }
2228
2229 #[test]
2230 fn test_query_planner_creation() {
2231 let (_planner, _dir) = create_test_planner();
2232 }
2234
2235 #[test]
2236 fn test_semantic_search() {
2237 let (planner, _dir) = create_test_planner();
2238
2239 let memory = Memory::new("Test content".to_string(), vec![0.1; 384]);
2241 let mem_id = memory.id.clone();
2242 planner.storage.store_memory(&memory).unwrap();
2243 {
2244 let mut index = planner.vector_index.write().unwrap();
2245 index.add(mem_id.clone(), &memory.embedding).unwrap();
2246 }
2247
2248 let scores = planner.semantic_search(&vec![0.1; 384], 10).unwrap();
2250 assert!(!scores.is_empty());
2251 assert!(scores.len() > 0);
2253 }
2254
2255 #[test]
2256 fn test_temporal_search() {
2257 let (planner, _dir) = create_test_planner();
2258
2259 let mem1 = Memory::new("Memory 1".to_string(), vec![0.1; 384]);
2261 let mem2 = Memory::new("Memory 2".to_string(), vec![0.2; 384]);
2262
2263 planner.storage.store_memory(&mem1).unwrap();
2264 planner.storage.store_memory(&mem2).unwrap();
2265
2266 planner
2268 .temporal_index
2269 .add(&mem1.id, mem1.created_at)
2270 .unwrap();
2271 planner
2272 .temporal_index
2273 .add(&mem2.id, mem2.created_at)
2274 .unwrap();
2275
2276 let scores = planner.temporal_search("test query", 10).unwrap();
2281 assert_eq!(
2282 scores.len(),
2283 0,
2284 "No temporal expression match → empty scores"
2285 );
2286 }
2287
2288 #[test]
2289 fn test_entity_search() {
2290 let (planner, _dir) = create_test_planner();
2291
2292 let scores = planner.entity_search("Show me Alice", 10).unwrap();
2294 assert_eq!(scores.len(), 0);
2296 }
2297
2298 #[test]
2299 fn test_full_query() {
2300 let (planner, _dir) = create_test_planner();
2301
2302 let memory = Memory::new("Test content".to_string(), vec![0.1; 384]);
2304 let mem_id = memory.id.clone();
2305 planner.storage.store_memory(&memory).unwrap();
2306 {
2307 let mut index = planner.vector_index.write().unwrap();
2308 index.add(mem_id.clone(), &memory.embedding).unwrap();
2309 }
2310 planner
2312 .temporal_index
2313 .add(&mem_id, memory.created_at)
2314 .unwrap();
2315
2316 let (intent, results, _matched_facts) = planner
2318 .query("test query", &vec![0.1; 384], 10, None, None, None, None)
2319 .unwrap();
2320
2321 assert_eq!(intent.intent, crate::query::intent::QueryIntent::Factual);
2323
2324 assert!(!results.is_empty());
2326 assert!(results.iter().any(|r| r.id.to_u64() == mem_id.to_u64()));
2329 }
2330
2331 #[test]
2332 fn test_temporal_range_query() {
2333 let (planner, _dir) = create_test_planner();
2334
2335 let now = Timestamp::now();
2336 let mem1 = Memory::new_with_timestamp(
2337 "Memory 1".to_string(),
2338 vec![0.1; 384],
2339 now.subtract_days(1),
2340 );
2341
2342 planner.storage.store_memory(&mem1).unwrap();
2343 planner
2345 .temporal_index
2346 .add(&mem1.id, mem1.created_at)
2347 .unwrap();
2348
2349 let results = planner
2350 .temporal_range_query(now.subtract_days(2), now, 10, None)
2351 .unwrap();
2352
2353 assert_eq!(results.len(), 1);
2354 assert_eq!(results[0].id, mem1.id);
2355 }
2356
2357 #[test]
2358 fn test_query_with_namespace_filtering() {
2359 let (planner, _dir) = create_test_planner();
2360
2361 let mut mem1 = Memory::new("NS1 memory".to_string(), vec![0.1; 384]);
2363 mem1.set_namespace("ns1");
2364 let mem1_id = mem1.id.clone();
2365
2366 let mut mem2 = Memory::new("NS2 memory".to_string(), vec![0.15; 384]);
2367 mem2.set_namespace("ns2");
2368 let mem2_id = mem2.id.clone();
2369
2370 let mem3 = Memory::new("Default memory".to_string(), vec![0.2; 384]);
2371 let mem3_id = mem3.id.clone();
2372
2373 planner.storage.store_memory(&mem1).unwrap();
2375 planner.storage.store_memory(&mem2).unwrap();
2376 planner.storage.store_memory(&mem3).unwrap();
2377
2378 {
2380 let mut index = planner.vector_index.write().unwrap();
2381 index.add(mem1_id.clone(), &mem1.embedding).unwrap();
2382 index.add(mem2_id.clone(), &mem2.embedding).unwrap();
2383 index.add(mem3_id.clone(), &mem3.embedding).unwrap();
2384 }
2385 planner
2386 .temporal_index
2387 .add(&mem1_id, mem1.created_at)
2388 .unwrap();
2389 planner
2390 .temporal_index
2391 .add(&mem2_id, mem2.created_at)
2392 .unwrap();
2393 planner
2394 .temporal_index
2395 .add(&mem3_id, mem3.created_at)
2396 .unwrap();
2397
2398 let (_, results, _) = planner
2400 .query("test", &vec![0.1; 384], 10, Some("ns1"), None, None, None)
2401 .unwrap();
2402
2403 assert!(!results.is_empty());
2405 assert!(results.iter().all(|r| r.id.to_u64() == mem1_id.to_u64()));
2406
2407 let (_, results, _) = planner
2409 .query("test", &vec![0.1; 384], 10, Some("ns2"), None, None, None)
2410 .unwrap();
2411
2412 assert!(!results.is_empty());
2414 assert!(results.iter().all(|r| r.id.to_u64() == mem2_id.to_u64()));
2415
2416 let (_, results, _) = planner
2418 .query("test", &vec![0.1; 384], 10, Some(""), None, None, None)
2419 .unwrap();
2420
2421 assert!(!results.is_empty());
2423 assert!(results.iter().all(|r| r.id.to_u64() == mem3_id.to_u64()));
2424 }
2425
2426 #[test]
2427 fn test_temporal_range_query_with_namespace() {
2428 let (planner, _dir) = create_test_planner();
2429
2430 let now = Timestamp::now();
2431 let ts1 = now.subtract_days(1);
2432 let ts2 = Timestamp::from_unix_secs(ts1.as_unix_secs() + 60.0); let mut mem1 = Memory::new_with_timestamp("NS1 memory".to_string(), vec![0.1; 384], ts1);
2437 mem1.set_namespace("ns1");
2438 let mem1_id = mem1.id.clone();
2439
2440 let mut mem2 = Memory::new_with_timestamp("NS2 memory".to_string(), vec![0.2; 384], ts2);
2441 mem2.set_namespace("ns2");
2442 let mem2_id = mem2.id.clone();
2443
2444 planner.storage.store_memory(&mem1).unwrap();
2446 planner.storage.store_memory(&mem2).unwrap();
2447 planner
2448 .temporal_index
2449 .add(&mem1_id, mem1.created_at)
2450 .unwrap();
2451 planner
2452 .temporal_index
2453 .add(&mem2_id, mem2.created_at)
2454 .unwrap();
2455
2456 let retrieved_mem1 = planner.storage.get_memory(&mem1_id).unwrap().unwrap();
2458 assert_eq!(retrieved_mem1.get_namespace(), "ns1");
2459
2460 let all_results = planner
2462 .temporal_range_query(now.subtract_days(2), now, 10, None)
2463 .unwrap();
2464 assert_eq!(
2465 all_results.len(),
2466 2,
2467 "Should find both memories without filter"
2468 );
2469
2470 let results = planner
2472 .temporal_range_query(now.subtract_days(2), now, 10, Some("ns1"))
2473 .unwrap();
2474
2475 assert_eq!(results.len(), 1, "Should find exactly one memory in ns1");
2477 assert_eq!(results[0].id, mem1_id);
2478 }
2479
2480 #[test]
2481 fn test_filter_by_namespace() {
2482 let (planner, _dir) = create_test_planner();
2483
2484 let mut mem1 = Memory::new("NS1 memory".to_string(), vec![0.1; 384]);
2486 mem1.set_namespace("ns1");
2487 let mem1_id = mem1.id.clone();
2488
2489 let mut mem2 = Memory::new("NS2 memory".to_string(), vec![0.2; 384]);
2490 mem2.set_namespace("ns2");
2491 let mem2_id = mem2.id.clone();
2492
2493 planner.storage.store_memory(&mem1).unwrap();
2495 planner.storage.store_memory(&mem2).unwrap();
2496
2497 let mut scores = HashMap::new();
2499 scores.insert(mem1_id.clone(), 0.9);
2500 scores.insert(mem2_id.clone(), 0.8);
2501
2502 planner.filter_by_namespace(&mut scores, "ns1").unwrap();
2504
2505 assert_eq!(scores.len(), 1);
2507 assert!(scores.contains_key(&mem1_id));
2508 assert!(!scores.contains_key(&mem2_id));
2509 }
2510
2511 #[test]
2512 fn test_filter_by_metadata_exact_match() {
2513 let (planner, _dir) = create_test_planner();
2514
2515 let mut mem1 = Memory::new("Event memory".to_string(), vec![0.1; 384]);
2517 mem1.metadata
2518 .insert("type".to_string(), "event".to_string());
2519 mem1.metadata
2520 .insert("priority".to_string(), "high".to_string());
2521 let mem1_id = mem1.id.clone();
2522
2523 let mut mem2 = Memory::new("Task memory".to_string(), vec![0.2; 384]);
2524 mem2.metadata.insert("type".to_string(), "task".to_string());
2525 mem2.metadata
2526 .insert("priority".to_string(), "low".to_string());
2527 let mem2_id = mem2.id.clone();
2528
2529 planner.storage.store_memory(&mem1).unwrap();
2531 planner.storage.store_memory(&mem2).unwrap();
2532
2533 let mut scores = HashMap::new();
2535 scores.insert(mem1_id.clone(), 0.9);
2536 scores.insert(mem2_id.clone(), 0.8);
2537
2538 let filters = vec![MetadataFilter::eq("type", "event")];
2540 planner.filter_by_metadata(&mut scores, &filters).unwrap();
2541
2542 assert_eq!(scores.len(), 1);
2544 assert!(scores.contains_key(&mem1_id));
2545 assert!(!scores.contains_key(&mem2_id));
2546 }
2547
2548 #[test]
2549 fn test_filter_by_metadata_multiple_filters() {
2550 let (planner, _dir) = create_test_planner();
2551
2552 let mut mem1 = Memory::new("High priority event".to_string(), vec![0.1; 384]);
2554 mem1.metadata
2555 .insert("type".to_string(), "event".to_string());
2556 mem1.metadata
2557 .insert("priority".to_string(), "high".to_string());
2558 let mem1_id = mem1.id.clone();
2559
2560 let mut mem2 = Memory::new("Low priority event".to_string(), vec![0.2; 384]);
2561 mem2.metadata
2562 .insert("type".to_string(), "event".to_string());
2563 mem2.metadata
2564 .insert("priority".to_string(), "low".to_string());
2565 let mem2_id = mem2.id.clone();
2566
2567 let mut mem3 = Memory::new("High priority task".to_string(), vec![0.3; 384]);
2568 mem3.metadata.insert("type".to_string(), "task".to_string());
2569 mem3.metadata
2570 .insert("priority".to_string(), "high".to_string());
2571 let mem3_id = mem3.id.clone();
2572
2573 planner.storage.store_memory(&mem1).unwrap();
2575 planner.storage.store_memory(&mem2).unwrap();
2576 planner.storage.store_memory(&mem3).unwrap();
2577
2578 let mut scores = HashMap::new();
2580 scores.insert(mem1_id.clone(), 0.9);
2581 scores.insert(mem2_id.clone(), 0.8);
2582 scores.insert(mem3_id.clone(), 0.7);
2583
2584 let filters = vec![
2586 MetadataFilter::eq("type", "event"),
2587 MetadataFilter::eq("priority", "high"),
2588 ];
2589 planner.filter_by_metadata(&mut scores, &filters).unwrap();
2590
2591 assert_eq!(scores.len(), 1);
2593 assert!(scores.contains_key(&mem1_id));
2594 }
2595
2596 #[test]
2597 fn test_filter_by_metadata_comparison_operators() {
2598 let (planner, _dir) = create_test_planner();
2599
2600 let mut mem1 = Memory::new("Priority 8".to_string(), vec![0.1; 384]);
2602 mem1.metadata
2603 .insert("priority".to_string(), "8".to_string());
2604 let mem1_id = mem1.id.clone();
2605
2606 let mut mem2 = Memory::new("Priority 5".to_string(), vec![0.2; 384]);
2607 mem2.metadata
2608 .insert("priority".to_string(), "5".to_string());
2609 let mem2_id = mem2.id.clone();
2610
2611 let mut mem3 = Memory::new("Priority 3".to_string(), vec![0.3; 384]);
2612 mem3.metadata
2613 .insert("priority".to_string(), "3".to_string());
2614 let mem3_id = mem3.id.clone();
2615
2616 planner.storage.store_memory(&mem1).unwrap();
2618 planner.storage.store_memory(&mem2).unwrap();
2619 planner.storage.store_memory(&mem3).unwrap();
2620
2621 let mut scores = HashMap::new();
2623 scores.insert(mem1_id.clone(), 0.9);
2624 scores.insert(mem2_id.clone(), 0.8);
2625 scores.insert(mem3_id.clone(), 0.7);
2626
2627 let filters = vec![MetadataFilter::gte("priority", "5")];
2629 planner.filter_by_metadata(&mut scores, &filters).unwrap();
2630
2631 assert_eq!(scores.len(), 2);
2633 assert!(scores.contains_key(&mem1_id));
2634 assert!(scores.contains_key(&mem2_id));
2635 assert!(!scores.contains_key(&mem3_id));
2636 }
2637
2638 #[test]
2639 fn test_filter_by_metadata_in_operator() {
2640 let (planner, _dir) = create_test_planner();
2641
2642 let mut mem1 = Memory::new("Food memory".to_string(), vec![0.1; 384]);
2644 mem1.metadata
2645 .insert("category".to_string(), "food".to_string());
2646 let mem1_id = mem1.id.clone();
2647
2648 let mut mem2 = Memory::new("Travel memory".to_string(), vec![0.2; 384]);
2649 mem2.metadata
2650 .insert("category".to_string(), "travel".to_string());
2651 let mem2_id = mem2.id.clone();
2652
2653 let mut mem3 = Memory::new("Work memory".to_string(), vec![0.3; 384]);
2654 mem3.metadata
2655 .insert("category".to_string(), "work".to_string());
2656 let mem3_id = mem3.id.clone();
2657
2658 planner.storage.store_memory(&mem1).unwrap();
2660 planner.storage.store_memory(&mem2).unwrap();
2661 planner.storage.store_memory(&mem3).unwrap();
2662
2663 let mut scores = HashMap::new();
2665 scores.insert(mem1_id.clone(), 0.9);
2666 scores.insert(mem2_id.clone(), 0.8);
2667 scores.insert(mem3_id.clone(), 0.7);
2668
2669 let filters = vec![MetadataFilter::in_list(
2671 "category",
2672 vec!["food".to_string(), "travel".to_string()],
2673 )];
2674 planner.filter_by_metadata(&mut scores, &filters).unwrap();
2675
2676 assert_eq!(scores.len(), 2);
2678 assert!(scores.contains_key(&mem1_id));
2679 assert!(scores.contains_key(&mem2_id));
2680 assert!(!scores.contains_key(&mem3_id));
2681 }
2682
2683 #[test]
2684 fn test_memory_matches_filters() {
2685 let mut memory = Memory::new("Test".to_string(), vec![0.1; 384]);
2687 memory
2688 .metadata
2689 .insert("type".to_string(), "event".to_string());
2690
2691 let filters = vec![MetadataFilter::eq("type", "event")];
2692 assert!(QueryPlanner::memory_matches_filters(&memory, &filters));
2693
2694 let filters = vec![MetadataFilter::eq("type", "task")];
2695 assert!(!QueryPlanner::memory_matches_filters(&memory, &filters));
2696
2697 let filters = vec![MetadataFilter::eq("priority", "high")];
2699 assert!(!QueryPlanner::memory_matches_filters(&memory, &filters));
2700
2701 memory
2703 .metadata
2704 .insert("priority".to_string(), "high".to_string());
2705 let filters = vec![
2706 MetadataFilter::eq("type", "event"),
2707 MetadataFilter::eq("priority", "high"),
2708 ];
2709 assert!(QueryPlanner::memory_matches_filters(&memory, &filters));
2710
2711 let filters = vec![
2712 MetadataFilter::eq("type", "event"),
2713 MetadataFilter::eq("priority", "low"),
2714 ];
2715 assert!(!QueryPlanner::memory_matches_filters(&memory, &filters));
2716 }
2717
2718 #[test]
2719 fn test_temporal_content_matching() {
2720 use crate::ingest::IngestionPipeline;
2721
2722 let (planner, _dir) = create_test_planner();
2723
2724 let pipeline = IngestionPipeline::new(
2726 Arc::clone(&planner.storage),
2727 Arc::clone(&planner.vector_index),
2728 Arc::clone(&planner.bm25_index),
2729 Arc::clone(&planner.temporal_index),
2730 Arc::clone(&planner.graph_manager),
2731 false, );
2733
2734 let mem1 = Memory::new(
2736 "We had a meeting yesterday about the project".to_string(),
2737 vec![0.1; 384],
2738 );
2739 let mem1_id = mem1.id.clone();
2740 pipeline.add(mem1).unwrap();
2741
2742 let mem2 = Memory::new(
2743 "The conference was on June 15th, 2023".to_string(),
2744 vec![0.2; 384],
2745 );
2746 let mem2_id = mem2.id.clone();
2747 pipeline.add(mem2).unwrap();
2748
2749 let mem3 = Memory::new(
2750 "Machine learning is a fascinating field".to_string(), vec![0.3; 384],
2752 );
2753 let mem3_id = mem3.id.clone();
2754 pipeline.add(mem3).unwrap();
2755
2756 let scores = planner
2758 .temporal_search("What happened yesterday?", 10)
2759 .unwrap();
2760
2761 assert!(
2763 scores.contains_key(&mem1_id),
2764 "Should find memory with 'yesterday'"
2765 );
2766
2767 let score1 = scores.get(&mem1_id).unwrap_or(&0.0);
2769 let score3 = scores.get(&mem3_id).unwrap_or(&0.0);
2770 assert!(
2771 score1 > score3,
2772 "Memory with matching temporal expression should score higher than non-temporal memory"
2773 );
2774 }
2775
2776 #[test]
2777 fn test_temporal_fallback_to_recency() {
2778 use crate::ingest::IngestionPipeline;
2784
2785 let (planner, _dir) = create_test_planner();
2786
2787 let pipeline = IngestionPipeline::new(
2788 Arc::clone(&planner.storage),
2789 Arc::clone(&planner.vector_index),
2790 Arc::clone(&planner.bm25_index),
2791 Arc::clone(&planner.temporal_index),
2792 Arc::clone(&planner.graph_manager),
2793 false,
2794 );
2795
2796 let mem1 = Memory::new("Machine learning techniques".to_string(), vec![0.1; 384]);
2798 let mem1_id = mem1.id.clone();
2799 pipeline.add(mem1).unwrap();
2800
2801 std::thread::sleep(std::time::Duration::from_millis(10));
2802
2803 let mem2 = Memory::new("Deep learning models".to_string(), vec![0.2; 384]);
2804 let mem2_id = mem2.id.clone();
2805 pipeline.add(mem2).unwrap();
2806
2807 let scores = planner.temporal_search("Tell me about AI", 10).unwrap();
2809 assert!(
2810 scores.is_empty(),
2811 "No temporal expressions → empty from temporal_search()"
2812 );
2813
2814 let fallback_scores = planner.temporal_search_recency_fallback(10).unwrap();
2816 assert!(fallback_scores.contains_key(&mem1_id));
2817 assert!(fallback_scores.contains_key(&mem2_id));
2818
2819 let score1 = fallback_scores.get(&mem1_id).unwrap();
2821 let score2 = fallback_scores.get(&mem2_id).unwrap();
2822 assert!(
2823 *score1 <= 0.31 && *score2 <= 0.31,
2824 "Fallback scores should be weak (≤ 0.3)"
2825 );
2826
2827 assert!(
2829 score2 >= score1,
2830 "More recent memory should have higher fallback score"
2831 );
2832 }
2833
2834 #[test]
2835 fn test_temporal_fallback_spread_timestamps() {
2836 use crate::ingest::IngestionPipeline;
2840 use crate::types::Timestamp;
2841
2842 let (planner, _dir) = create_test_planner();
2843
2844 let pipeline = IngestionPipeline::new(
2845 Arc::clone(&planner.storage),
2846 Arc::clone(&planner.vector_index),
2847 Arc::clone(&planner.bm25_index),
2848 Arc::clone(&planner.temporal_index),
2849 Arc::clone(&planner.graph_manager),
2850 false,
2851 );
2852
2853 let old_ts = Timestamp::now().subtract_days(730);
2855 let mem_old = Memory::new_with_timestamp(
2856 "Old conversation about cooking".to_string(),
2857 vec![0.1; 384],
2858 old_ts,
2859 );
2860 let old_id = mem_old.id.clone();
2861 pipeline.add(mem_old).unwrap();
2862
2863 let mid_ts = Timestamp::now().subtract_days(30);
2865 let mem_mid = Memory::new_with_timestamp(
2866 "Recent discussion about travel".to_string(),
2867 vec![0.2; 384],
2868 mid_ts,
2869 );
2870 let mid_id = mem_mid.id.clone();
2871 pipeline.add(mem_mid).unwrap();
2872
2873 let mem_new = Memory::new("Just talked about music".to_string(), vec![0.3; 384]);
2875 let new_id = mem_new.id.clone();
2876 pipeline.add(mem_new).unwrap();
2877
2878 let scores = planner.temporal_search_recency_fallback(10).unwrap();
2881
2882 let score_old = *scores.get(&old_id).unwrap();
2883 let score_mid = *scores.get(&mid_id).unwrap();
2884 let score_new = *scores.get(&new_id).unwrap();
2885
2886 assert!(
2888 score_new > score_mid,
2889 "Today's memory ({}) should score higher than 30-day old ({})",
2890 score_new,
2891 score_mid
2892 );
2893 assert!(
2894 score_mid > score_old,
2895 "30-day old memory ({}) should score higher than 2-year old ({})",
2896 score_mid,
2897 score_old
2898 );
2899
2900 assert!(
2904 score_new > score_old * 2.0,
2905 "New memory ({}) should be at least 2x the 2-year old ({})",
2906 score_new,
2907 score_old
2908 );
2909
2910 assert!(
2912 score_new <= 0.31,
2913 "Scores should stay in 0-0.3 range, got {}",
2914 score_new
2915 );
2916 }
2917
2918 #[test]
2919 fn test_temporal_fallback_same_timestamps_rank_tiebreaker() {
2920 use crate::ingest::IngestionPipeline;
2924 use crate::types::Timestamp;
2925
2926 let (planner, _dir) = create_test_planner();
2927
2928 let pipeline = IngestionPipeline::new(
2929 Arc::clone(&planner.storage),
2930 Arc::clone(&planner.vector_index),
2931 Arc::clone(&planner.bm25_index),
2932 Arc::clone(&planner.temporal_index),
2933 Arc::clone(&planner.graph_manager),
2934 false,
2935 );
2936
2937 let base_ts = Timestamp::now().subtract_days(100);
2939 for i in 0..5 {
2940 let ts = Timestamp::from_micros(base_ts.as_micros() + (i as u64 * 1000));
2942 let mem = Memory::new_with_timestamp(
2943 format!("Memory number {}", i),
2944 vec![0.1 * (i as f32 + 1.0); 384],
2945 ts,
2946 );
2947 pipeline.add(mem).unwrap();
2948 }
2949
2950 let scores = planner.temporal_search_recency_fallback(10).unwrap();
2953
2954 assert_eq!(scores.len(), 5, "All 5 memories should have scores");
2956
2957 let mut score_vals: Vec<f32> = scores.values().copied().collect();
2959 score_vals.sort_by(|a, b| b.partial_cmp(a).unwrap());
2960
2961 for i in 0..score_vals.len() - 1 {
2963 assert!(
2964 score_vals[i] > score_vals[i + 1],
2965 "Score {} ({}) should be > score {} ({}): rank tiebreaker should differentiate",
2966 i,
2967 score_vals[i],
2968 i + 1,
2969 score_vals[i + 1]
2970 );
2971 }
2972
2973 let spread = score_vals[0] - score_vals[score_vals.len() - 1];
2976 assert!(
2977 spread > 0.001,
2978 "Score spread should be > 0.001 from rank tiebreaker, got {}",
2979 spread
2980 );
2981 assert!(
2982 spread < 0.1,
2983 "Score spread should be modest (< 0.1), got {}",
2984 spread
2985 );
2986 }
2987
2988 #[test]
2989 fn test_temporal_content_matching_absolute_dates() {
2990 use crate::ingest::IngestionPipeline;
2991
2992 let (planner, _dir) = create_test_planner();
2993
2994 let pipeline = IngestionPipeline::new(
2995 Arc::clone(&planner.storage),
2996 Arc::clone(&planner.vector_index),
2997 Arc::clone(&planner.bm25_index),
2998 Arc::clone(&planner.temporal_index),
2999 Arc::clone(&planner.graph_manager),
3000 false,
3001 );
3002
3003 let mem1 = Memory::new(
3005 "The conference was scheduled for June 15th, 2023".to_string(),
3006 vec![0.1; 384],
3007 );
3008 let mem1_id = mem1.id.clone();
3009 pipeline.add(mem1).unwrap();
3010
3011 let mem2 = Memory::new(
3012 "We met in May 2024 to discuss plans".to_string(),
3013 vec![0.2; 384],
3014 );
3015 let mem2_id = mem2.id.clone();
3016 pipeline.add(mem2).unwrap();
3017
3018 let scores = planner
3020 .temporal_search("When was the June conference?", 10)
3021 .unwrap();
3022
3023 assert!(
3025 scores.contains_key(&mem1_id),
3026 "Should find memory with June date"
3027 );
3028
3029 let score1 = scores.get(&mem1_id).unwrap_or(&0.0);
3031 let score2 = scores.get(&mem2_id).unwrap_or(&0.0);
3032 assert!(
3033 score1 >= score2,
3034 "Memory with matching month should score higher or equal to non-matching month"
3035 );
3036 }
3037
3038 #[test]
3039 fn test_entity_content_matching() {
3040 use crate::ingest::IngestionPipeline;
3041
3042 let (planner, _dir) = create_test_planner();
3043
3044 planner
3046 .storage
3047 .store_entity_profile(&crate::types::EntityProfile::new(
3048 crate::types::EntityId::new(),
3049 "Alice".into(),
3050 "person".into(),
3051 ))
3052 .unwrap();
3053
3054 let pipeline = IngestionPipeline::new(
3056 Arc::clone(&planner.storage),
3057 Arc::clone(&planner.vector_index),
3058 Arc::clone(&planner.bm25_index),
3059 Arc::clone(&planner.temporal_index),
3060 Arc::clone(&planner.graph_manager),
3061 true, );
3063
3064 let mem1 = Memory::new(
3066 "Alice presented Project Alpha at the conference".to_string(),
3067 vec![0.1; 384],
3068 );
3069 let mem1_id = mem1.id.clone();
3070 pipeline.add(mem1).unwrap();
3071
3072 let mem2 = Memory::new(
3073 "Bob worked on Project Beta for Acme Corp".to_string(),
3074 vec![0.2; 384],
3075 );
3076 let mem2_id = mem2.id.clone();
3077 pipeline.add(mem2).unwrap();
3078
3079 let mem3 = Memory::new(
3080 "Machine learning is fascinating".to_string(), vec![0.3; 384],
3082 );
3083 let mem3_id = mem3.id.clone();
3084 pipeline.add(mem3).unwrap();
3085
3086 let scores = planner
3088 .entity_search("Tell me about Alice and Project Alpha", 10)
3089 .unwrap();
3090
3091 assert!(
3093 scores.contains_key(&mem1_id),
3094 "Should find memory with matching entities"
3095 );
3096
3097 let score1 = scores.get(&mem1_id).unwrap_or(&0.0);
3099 let score3 = scores.get(&mem3_id).unwrap_or(&0.0);
3100 assert!(
3101 score1 > score3,
3102 "Memory with matching entities should score higher than non-entity memory"
3103 );
3104 }
3105
3106 #[test]
3107 fn test_entity_search_filters_stop_words() {
3108 use crate::ingest::IngestionPipeline;
3109
3110 let (planner, _dir) = create_test_planner();
3111
3112 planner
3114 .storage
3115 .store_entity_profile(&crate::types::EntityProfile::new(
3116 crate::types::EntityId::new(),
3117 "Alice".into(),
3118 "person".into(),
3119 ))
3120 .unwrap();
3121
3122 let pipeline = IngestionPipeline::new(
3123 Arc::clone(&planner.storage),
3124 Arc::clone(&planner.vector_index),
3125 Arc::clone(&planner.bm25_index),
3126 Arc::clone(&planner.temporal_index),
3127 Arc::clone(&planner.graph_manager),
3128 true,
3129 );
3130
3131 let mem1 = Memory::new("Alice met Bob at Building C".to_string(), vec![0.1; 384]);
3133 let mem1_id = mem1.id.clone();
3134 pipeline.add(mem1).unwrap();
3135
3136 let scores = planner
3139 .entity_search("What about Alice and The meeting?", 10)
3140 .unwrap();
3141
3142 assert!(
3144 scores.contains_key(&mem1_id),
3145 "Should find memory with Alice, ignoring stop words"
3146 );
3147 }
3148
3149 #[test]
3150 fn test_entity_search_empty_when_no_entities() {
3151 let (planner, _dir) = create_test_planner();
3152
3153 let scores = planner
3155 .entity_search("tell me about machine learning", 10)
3156 .unwrap();
3157
3158 assert_eq!(
3160 scores.len(),
3161 0,
3162 "Should return empty when query has no entities"
3163 );
3164 }
3165
3166 #[test]
3167 fn test_entity_search_multi_word_entity() {
3168 use crate::ingest::IngestionPipeline;
3169 use crate::types::{EntityFact, EntityId, EntityProfile};
3170
3171 let (planner, _dir) = create_test_planner();
3172
3173 let pipeline = IngestionPipeline::new(
3174 Arc::clone(&planner.storage),
3175 Arc::clone(&planner.vector_index),
3176 Arc::clone(&planner.bm25_index),
3177 Arc::clone(&planner.temporal_index),
3178 Arc::clone(&planner.graph_manager),
3179 true,
3180 );
3181
3182 let mem1 = Memory::new(
3184 "Project Alpha is managed by the team".to_string(),
3185 vec![0.1; 384],
3186 );
3187 let mem1_id = mem1.id.clone();
3188 pipeline.add(mem1).unwrap();
3189
3190 let profile = EntityProfile {
3192 entity_id: EntityId::new(),
3193 name: "Project Alpha".to_string(),
3194 entity_type: "project".to_string(),
3195 facts: std::collections::HashMap::new(),
3196 source_memories: vec![mem1_id.clone()],
3197 updated_at: crate::types::Timestamp::now(),
3198 summary: None,
3199 };
3200 planner.storage.store_entity_profile(&profile).unwrap();
3201
3202 let scores = planner
3204 .entity_search("What is Project Alpha doing?", 10)
3205 .unwrap();
3206
3207 assert!(
3209 scores.contains_key(&mem1_id),
3210 "Should find memory with multi-word entity match"
3211 );
3212 }
3213
3214 #[test]
3215 fn test_entity_overlap_calculation() {
3216 let query_set: std::collections::HashSet<String> =
3218 vec!["alice".to_string(), "bob".to_string()]
3219 .into_iter()
3220 .collect();
3221
3222 let memory_entities = vec!["Alice".to_string(), "Bob".to_string()];
3223
3224 let score = QueryPlanner::calculate_entity_overlap(&query_set, &memory_entities);
3225
3226 assert!((score - 1.0).abs() < 0.01, "Exact match should score 1.0");
3228
3229 let query_set2: std::collections::HashSet<String> =
3231 vec!["alice".to_string()].into_iter().collect();
3232
3233 let memory_entities2 = vec!["Bob".to_string()];
3234
3235 let score2 = QueryPlanner::calculate_entity_overlap(&query_set2, &memory_entities2);
3236
3237 assert_eq!(score2, 0.0, "No match should score 0.0");
3239
3240 let query_set3: std::collections::HashSet<String> =
3242 vec!["project".to_string()].into_iter().collect();
3243
3244 let memory_entities3 = vec!["Project Alpha".to_string()];
3245
3246 let score3 = QueryPlanner::calculate_entity_overlap(&query_set3, &memory_entities3);
3247
3248 assert!(
3250 (score3 - 0.5).abs() < 0.01,
3251 "Substring match should score 0.5"
3252 );
3253 }
3254
3255 #[test]
3256 fn test_causal_content_matching() {
3257 use crate::ingest::IngestionPipeline;
3258
3259 let (planner, _dir) = create_test_planner();
3260
3261 let pipeline = IngestionPipeline::new(
3262 Arc::clone(&planner.storage),
3263 Arc::clone(&planner.vector_index),
3264 Arc::clone(&planner.bm25_index),
3265 Arc::clone(&planner.temporal_index),
3266 Arc::clone(&planner.graph_manager),
3267 false,
3268 );
3269
3270 let mem1 = Memory::new(
3272 "The meeting was cancelled because Alice was sick".to_string(),
3273 vec![0.1; 384],
3274 );
3275 let mem1_id = mem1.id.clone();
3276 pipeline.add(mem1).unwrap();
3277
3278 let mem2 = Memory::new(
3280 "The bug was caused by a race condition which led to crashes".to_string(),
3281 vec![0.2; 384],
3282 );
3283 let mem2_id = mem2.id.clone();
3284 pipeline.add(mem2).unwrap();
3285
3286 let mem3 = Memory::new("We had a nice lunch today".to_string(), vec![0.3; 384]);
3288 let mem3_id = mem3.id.clone();
3289 pipeline.add(mem3).unwrap();
3290
3291 let scores = planner
3293 .causal_search("Why was the meeting cancelled?", 10)
3294 .unwrap();
3295
3296 assert!(
3298 scores.contains_key(&mem1_id) || scores.contains_key(&mem2_id),
3299 "Should find memories with causal language"
3300 );
3301
3302 assert!(
3304 !scores.contains_key(&mem3_id) || scores.get(&mem3_id).unwrap_or(&1.0) < &0.01,
3305 "Should not score non-causal memory highly"
3306 );
3307 }
3308
3309 #[test]
3310 fn test_causal_search_empty_when_no_intent() {
3311 let (planner, _dir) = create_test_planner();
3312
3313 let scores = planner
3315 .causal_search("Tell me about machine learning", 10)
3316 .unwrap();
3317
3318 assert_eq!(
3320 scores.len(),
3321 0,
3322 "Should return empty when query has no causal intent"
3323 );
3324 }
3325
3326 #[test]
3327 fn test_causal_search_with_causal_intent() {
3328 use crate::ingest::IngestionPipeline;
3329
3330 let (planner, _dir) = create_test_planner();
3331
3332 let pipeline = IngestionPipeline::new(
3333 Arc::clone(&planner.storage),
3334 Arc::clone(&planner.vector_index),
3335 Arc::clone(&planner.bm25_index),
3336 Arc::clone(&planner.temporal_index),
3337 Arc::clone(&planner.graph_manager),
3338 false,
3339 );
3340
3341 let mem1 = Memory::new(
3343 "The server crashed because of a memory leak. This was caused by unclosed connections."
3344 .to_string(),
3345 vec![0.1; 384],
3346 );
3347 let mem1_id = mem1.id.clone();
3348 pipeline.add(mem1).unwrap();
3349
3350 let scores = planner
3352 .causal_search("Why did the server crash?", 10)
3353 .unwrap();
3354
3355 assert!(
3357 scores.contains_key(&mem1_id),
3358 "Should find memory with causal explanation for 'why' question"
3359 );
3360
3361 let score = scores.get(&mem1_id).unwrap();
3363 assert!(*score > 0.0, "Causal memory should have positive score");
3364 }
3365
3366 #[test]
3367 fn test_causal_search_density_threshold() {
3368 use crate::ingest::IngestionPipeline;
3369
3370 let (planner, _dir) = create_test_planner();
3371
3372 let pipeline = IngestionPipeline::new(
3373 Arc::clone(&planner.storage),
3374 Arc::clone(&planner.vector_index),
3375 Arc::clone(&planner.bm25_index),
3376 Arc::clone(&planner.temporal_index),
3377 Arc::clone(&planner.graph_manager),
3378 false,
3379 );
3380
3381 let mem1 = Memory::new(
3383 "This is a very long memory with lots of words that do not have causal markers except for this one because word.".to_string(),
3384 vec![0.1; 384],
3385 );
3386 let mem1_id = mem1.id.clone();
3387 pipeline.add(mem1).unwrap();
3388
3389 let scores = planner.causal_search("Why did this happen?", 10).unwrap();
3391
3392 assert!(
3394 !scores.contains_key(&mem1_id),
3395 "Should filter out memories with very low causal density"
3396 );
3397 }
3398
3399 #[test]
3401 fn test_entity_focused_retrieval() {
3402 use crate::ingest::IngestionPipeline;
3403 use crate::Entity;
3404
3405 let (planner, _dir) = create_test_planner();
3406
3407 let pipeline = IngestionPipeline::new(
3408 Arc::clone(&planner.storage),
3409 Arc::clone(&planner.vector_index),
3410 Arc::clone(&planner.bm25_index),
3411 Arc::clone(&planner.temporal_index),
3412 Arc::clone(&planner.graph_manager),
3413 true, );
3415
3416 let entity = Entity::new("Alice");
3418 planner.storage.store_entity(&entity).unwrap();
3419
3420 let mem1 = Memory::new(
3422 "Alice likes playing tennis on weekends".to_string(),
3423 vec![0.1; 384],
3424 );
3425 let mem1_id = mem1.id.clone();
3426 pipeline.add(mem1).unwrap();
3427
3428 let mem2 = Memory::new(
3429 "Alice enjoys reading science fiction books".to_string(),
3430 vec![0.2; 384],
3431 );
3432 let mem2_id = mem2.id.clone();
3433 pipeline.add(mem2).unwrap();
3434
3435 let mem3 = Memory::new("Alice prefers coffee over tea".to_string(), vec![0.3; 384]);
3436 let mem3_id = mem3.id.clone();
3437 pipeline.add(mem3).unwrap();
3438
3439 let mem4 = Memory::new(
3441 "Bob works on machine learning projects".to_string(),
3442 vec![0.4; 384],
3443 );
3444 pipeline.add(mem4).unwrap();
3445
3446 {
3448 let mut graph = planner.graph_manager.write().unwrap();
3449 graph.link_memory_to_entity(&mem1_id, &entity.id);
3450 graph.link_memory_to_entity(&mem2_id, &entity.id);
3451 graph.link_memory_to_entity(&mem3_id, &entity.id);
3452 }
3453
3454 let query_embedding = vec![0.15; 384];
3456 let (intent, results, _matched_facts) = planner
3457 .query(
3458 "What does Alice like?",
3459 &query_embedding,
3460 10,
3461 None,
3462 None,
3463 None,
3464 None,
3465 )
3466 .unwrap();
3467
3468 assert_eq!(intent.intent, crate::query::QueryIntent::Entity);
3470
3471 assert_eq!(intent.entity_focus, Some("Alice".to_string()));
3473
3474 assert!(
3476 results.len() >= 3,
3477 "Should retrieve all memories mentioning Alice"
3478 );
3479
3480 let result_ids: Vec<_> = results.iter().map(|r| &r.id).collect();
3482 assert!(
3483 result_ids.contains(&&mem1_id),
3484 "Should include mem1 (tennis)"
3485 );
3486 assert!(
3487 result_ids.contains(&&mem2_id),
3488 "Should include mem2 (books)"
3489 );
3490 assert!(
3491 result_ids.contains(&&mem3_id),
3492 "Should include mem3 (coffee)"
3493 );
3494
3495 for result in &results {
3497 if result_ids.contains(&&result.id) {
3498 assert_eq!(
3499 result.entity_score, 1.0,
3500 "All entity-focused results should have entity_score=1.0"
3501 );
3502 }
3503 }
3504 }
3505
3506 #[test]
3507 fn test_entity_focused_retrieval_not_found() {
3508 let (planner, _dir) = create_test_planner();
3509
3510 let query_embedding = vec![0.1; 384];
3512 let (_intent, results, _matched_facts) = planner
3513 .query(
3514 "What does NonExistentEntity like?",
3515 &query_embedding,
3516 10,
3517 None,
3518 None,
3519 None,
3520 None,
3521 )
3522 .unwrap();
3523
3524 assert_eq!(
3526 results.len(),
3527 0,
3528 "Should return empty for non-existent entity"
3529 );
3530 }
3531
3532 #[test]
3535 fn test_extract_query_entity_single_name() {
3536 let (planner, _dir) = create_test_planner();
3537
3538 planner
3540 .storage
3541 .store_entity_profile(&crate::types::EntityProfile::new(
3542 crate::types::EntityId::new(),
3543 "Melanie".into(),
3544 "person".into(),
3545 ))
3546 .unwrap();
3547 planner
3548 .storage
3549 .store_entity_profile(&crate::types::EntityProfile::new(
3550 crate::types::EntityId::new(),
3551 "Caroline".into(),
3552 "person".into(),
3553 ))
3554 .unwrap();
3555
3556 assert_eq!(
3557 planner.extract_query_entity("What books has Melanie read?"),
3558 Some("melanie".to_string())
3559 );
3560 assert_eq!(
3561 planner.extract_query_entity("What did Caroline research?"),
3562 Some("caroline".to_string())
3563 );
3564 }
3565
3566 #[test]
3567 fn test_extract_query_entity_possessive() {
3568 let (planner, _dir) = create_test_planner();
3569 planner
3570 .storage
3571 .store_entity_profile(&crate::types::EntityProfile::new(
3572 crate::types::EntityId::new(),
3573 "Caroline".into(),
3574 "person".into(),
3575 ))
3576 .unwrap();
3577
3578 assert_eq!(
3579 planner.extract_query_entity("What is Caroline's relationship status?"),
3580 Some("caroline".to_string())
3581 );
3582 }
3583
3584 #[test]
3585 fn test_extract_query_entity_two_names_returns_none() {
3586 let (planner, _dir) = create_test_planner();
3587 planner
3588 .storage
3589 .store_entity_profile(&crate::types::EntityProfile::new(
3590 crate::types::EntityId::new(),
3591 "Caroline".into(),
3592 "person".into(),
3593 ))
3594 .unwrap();
3595 planner
3596 .storage
3597 .store_entity_profile(&crate::types::EntityProfile::new(
3598 crate::types::EntityId::new(),
3599 "Melanie".into(),
3600 "person".into(),
3601 ))
3602 .unwrap();
3603
3604 assert_eq!(
3606 planner.extract_query_entity(
3607 "When did Caroline and Melanie go to a pride festival together?"
3608 ),
3609 None
3610 );
3611 }
3612
3613 #[test]
3614 fn test_extract_query_entity_no_entity() {
3615 let (planner, _dir) = create_test_planner();
3616 assert_eq!(
3618 planner.extract_query_entity("What happened yesterday?"),
3619 None
3620 );
3621 }
3622
3623 #[test]
3624 fn test_extract_query_entity_lowercase_query() {
3625 let (planner, _dir) = create_test_planner();
3626 planner
3627 .storage
3628 .store_entity_profile(&crate::types::EntityProfile::new(
3629 crate::types::EntityId::new(),
3630 "Caroline".into(),
3631 "person".into(),
3632 ))
3633 .unwrap();
3634
3635 assert_eq!(
3637 planner.extract_query_entity("what does caroline like?"),
3638 Some("caroline".to_string())
3639 );
3640 }
3641
3642 fn make_fused_result(score: f32) -> FusedResult {
3645 FusedResult {
3646 id: crate::types::MemoryId::new(),
3647 semantic_score: score,
3648 bm25_score: 0.0,
3649 temporal_score: 0.0,
3650 causal_score: 0.0,
3651 entity_score: 0.0,
3652 fused_score: score,
3653 confidence: 1.0,
3654 }
3655 }
3656
3657 #[test]
3658 fn test_adaptive_k_concentrated_scores() {
3659 let mut results = vec![make_fused_result(5.0)];
3661 for _ in 0..24 {
3662 results.push(make_fused_result(0.1));
3663 }
3664 let k = adaptive_k_select(&results, 0.7, 25);
3665 assert!(k <= 10, "Expected few results but got {}", k);
3668 assert!(k >= 5, "Expected at least min_k=5 but got {}", k); }
3670
3671 #[test]
3672 fn test_adaptive_k_uniform_scores() {
3673 let results: Vec<FusedResult> = (0..25).map(|_| make_fused_result(1.0)).collect();
3675 let k = adaptive_k_select(&results, 0.7, 25);
3676 assert!(k >= 17, "Expected ~18 results for uniform but got {}", k);
3678 assert!(k <= 20, "Expected ~18 results for uniform but got {}", k);
3679 }
3680
3681 #[test]
3682 fn test_adaptive_k_respects_min_k() {
3683 let mut results = vec![make_fused_result(100.0)];
3685 for _ in 0..24 {
3686 results.push(make_fused_result(0.001));
3687 }
3688 let k = adaptive_k_select(&results, 0.7, 25);
3689 assert!(k >= 8, "Expected at least min_k=8 but got {}", k);
3691 }
3692
3693 #[test]
3694 fn test_adaptive_k_single_result() {
3695 let results = vec![make_fused_result(1.0)];
3696 let k = adaptive_k_select(&results, 0.7, 25);
3697 assert_eq!(k, 1);
3698 }
3699
3700 #[test]
3701 fn test_adaptive_k_empty_results() {
3702 let results: Vec<FusedResult> = vec![];
3703 let k = adaptive_k_select(&results, 0.7, 25);
3704 assert_eq!(k, 0);
3705 }
3706
3707 #[test]
3710 fn test_parse_dialog_id_basic() {
3711 assert_eq!(QueryPlanner::parse_dialog_id("D7:5"), Some(("D7", 5)));
3712 assert_eq!(QueryPlanner::parse_dialog_id("D1:1"), Some(("D1", 1)));
3713 assert_eq!(QueryPlanner::parse_dialog_id("D25:27"), Some(("D25", 27)));
3714 }
3715
3716 #[test]
3717 fn test_parse_dialog_id_invalid() {
3718 assert_eq!(QueryPlanner::parse_dialog_id("D7"), None); assert_eq!(QueryPlanner::parse_dialog_id(":5"), None); assert_eq!(QueryPlanner::parse_dialog_id("D7:abc"), None); }
3722
3723 #[test]
3724 fn test_parse_dialog_id_high_turn() {
3725 assert_eq!(QueryPlanner::parse_dialog_id("D10:100"), Some(("D10", 100)));
3726 }
3727
3728 #[test]
3729 fn test_contains_whole_word_static_pronouns() {
3730 assert!(QueryPlanner::contains_whole_word_static(
3732 "what do i need",
3733 "i"
3734 ));
3735 assert!(QueryPlanner::contains_whole_word_static(
3736 "give me the answer",
3737 "me"
3738 ));
3739 assert!(QueryPlanner::contains_whole_word_static(
3740 "my favorite color",
3741 "my"
3742 ));
3743 assert!(QueryPlanner::contains_whole_word_static(
3744 "that is mine",
3745 "mine"
3746 ));
3747 assert!(QueryPlanner::contains_whole_word_static(
3748 "i'm going home",
3749 "i'm"
3750 ));
3751 assert!(QueryPlanner::contains_whole_word_static(
3752 "i've been there",
3753 "i've"
3754 ));
3755 assert!(!QueryPlanner::contains_whole_word_static(
3757 "imagine this",
3758 "i"
3759 ));
3760 assert!(!QueryPlanner::contains_whole_word_static(
3761 "time flies",
3762 "me"
3763 ));
3764 assert!(!QueryPlanner::contains_whole_word_static(
3765 "myth or fact",
3766 "my"
3767 ));
3768 assert!(!QueryPlanner::contains_whole_word_static(
3769 "undermine the case",
3770 "mine"
3771 ));
3772 }
3773
3774 #[test]
3775 fn test_adaptive_k_gradual_dropoff() {
3776 let results: Vec<FusedResult> = (0..10)
3778 .map(|i| make_fused_result(1.0 - i as f32 * 0.1))
3779 .collect();
3780 let k = adaptive_k_select(&results, 0.7, 25);
3781 assert!(
3784 k >= 5,
3785 "Expected reasonable k for gradual dropoff but got {}",
3786 k
3787 );
3788 }
3789
3790 #[test]
3793 fn test_related_entity_expansion_basic() {
3794 use crate::types::{EntityId, EntityProfile};
3795
3796 let (planner, _dir) = create_test_planner();
3797
3798 let alice_id = EntityId::new();
3800 let bob_id = EntityId::new();
3801
3802 let mem_alice = Memory::new("Alice likes hiking".to_string(), vec![0.1; 384]);
3803 let mem_alice_id = mem_alice.id.clone();
3804 planner.storage.store_memory(&mem_alice).unwrap();
3805
3806 let mem_bob = Memory::new("Bob won a marathon".to_string(), vec![0.2; 384]);
3807 let mem_bob_id = mem_bob.id.clone();
3808 planner.storage.store_memory(&mem_bob).unwrap();
3809
3810 {
3812 let mut idx = planner.vector_index.write().unwrap();
3813 idx.add(mem_alice_id.clone(), &mem_alice.embedding).unwrap();
3814 idx.add(mem_bob_id.clone(), &mem_bob.embedding).unwrap();
3815 }
3816
3817 let alice_profile = EntityProfile {
3818 entity_id: alice_id.clone(),
3819 name: "Alice".to_string(),
3820 entity_type: "person".to_string(),
3821 facts: std::collections::HashMap::new(),
3822 source_memories: vec![mem_alice_id.clone()],
3823 updated_at: crate::types::Timestamp::now(),
3824 summary: None,
3825 };
3826 let bob_profile = EntityProfile {
3827 entity_id: bob_id.clone(),
3828 name: "Bob".to_string(),
3829 entity_type: "person".to_string(),
3830 facts: std::collections::HashMap::new(),
3831 source_memories: vec![mem_bob_id.clone()],
3832 updated_at: crate::types::Timestamp::now(),
3833 summary: None,
3834 };
3835 planner
3836 .storage
3837 .store_entity_profile(&alice_profile)
3838 .unwrap();
3839 planner.storage.store_entity_profile(&bob_profile).unwrap();
3840
3841 {
3843 let mut graph = planner.graph_manager.write().unwrap();
3844 graph.link_entity_to_entity(&alice_id, &bob_id, "friend");
3845 }
3846
3847 let (_intent, results, _facts) = planner
3849 .query(
3850 "What did Alice's friends achieve?",
3851 &vec![0.15; 384],
3852 10,
3853 None,
3854 None,
3855 None,
3856 None,
3857 )
3858 .unwrap();
3859
3860 let has_bob_memory = results.iter().any(|r| r.id.to_u64() == mem_bob_id.to_u64());
3862 assert!(
3863 has_bob_memory,
3864 "Bob's memory should be discovered via Alice→Bob KG relationship"
3865 );
3866 }
3867
3868 #[test]
3869 fn test_related_entity_expansion_no_downgrade() {
3870 use crate::types::{EntityId, EntityProfile};
3871
3872 let (planner, _dir) = create_test_planner();
3873
3874 let alice_id = EntityId::new();
3876 let bob_id = EntityId::new();
3877
3878 let mem_bob = Memory::new("Bob likes cooking".to_string(), vec![0.2; 384]);
3879 let mem_bob_id = mem_bob.id.clone();
3880 planner.storage.store_memory(&mem_bob).unwrap();
3881 {
3882 let mut idx = planner.vector_index.write().unwrap();
3883 idx.add(mem_bob_id.clone(), &mem_bob.embedding).unwrap();
3884 }
3885
3886 let alice_profile = EntityProfile {
3887 entity_id: alice_id.clone(),
3888 name: "Alice".to_string(),
3889 entity_type: "person".to_string(),
3890 facts: std::collections::HashMap::new(),
3891 source_memories: vec![],
3892 updated_at: crate::types::Timestamp::now(),
3893 summary: None,
3894 };
3895 let bob_profile = EntityProfile {
3896 entity_id: bob_id.clone(),
3897 name: "Bob".to_string(),
3898 entity_type: "person".to_string(),
3899 facts: std::collections::HashMap::new(),
3900 source_memories: vec![mem_bob_id.clone()],
3901 updated_at: crate::types::Timestamp::now(),
3902 summary: None,
3903 };
3904 planner
3905 .storage
3906 .store_entity_profile(&alice_profile)
3907 .unwrap();
3908 planner.storage.store_entity_profile(&bob_profile).unwrap();
3909
3910 {
3912 let mut graph = planner.graph_manager.write().unwrap();
3913 graph.link_entity_to_entity(&alice_id, &bob_id, "friend");
3914 }
3915
3916 let (_intent, results, _facts) = planner
3919 .query(
3920 "What do Alice and Bob enjoy?",
3921 &vec![0.15; 384],
3922 10,
3923 None,
3924 None,
3925 None,
3926 None,
3927 )
3928 .unwrap();
3929
3930 let bob_result = results
3932 .iter()
3933 .find(|r| r.id.to_u64() == mem_bob_id.to_u64());
3934 assert!(bob_result.is_some(), "Bob's memory should be in results");
3935
3936 let bob_entity_score = bob_result.unwrap().entity_score;
3938 assert!(
3939 bob_entity_score >= 0.9, "Bob's entity score should reflect Step 2.1 (2.0), not be downgraded by 2.1c. Got {}",
3941 bob_entity_score
3942 );
3943 }
3944
3945 #[test]
3946 fn test_related_entity_expansion_cap() {
3947 use crate::types::{EntityId, EntityProfile};
3948
3949 let (planner, _dir) = create_test_planner();
3950
3951 let alice_id = EntityId::new();
3952 let bob_id = EntityId::new();
3953
3954 let mut bob_source_memories = Vec::new();
3956 for i in 0..15 {
3957 let mem = Memory::new(
3958 format!("Bob memory {}", i),
3959 vec![0.1 + i as f32 * 0.01; 384],
3960 );
3961 let mid = mem.id.clone();
3962 planner.storage.store_memory(&mem).unwrap();
3963 {
3964 let mut idx = planner.vector_index.write().unwrap();
3965 idx.add(mid.clone(), &mem.embedding).unwrap();
3966 }
3967 bob_source_memories.push(mid);
3968 }
3969
3970 let alice_profile = EntityProfile {
3971 entity_id: alice_id.clone(),
3972 name: "Alice".to_string(),
3973 entity_type: "person".to_string(),
3974 facts: std::collections::HashMap::new(),
3975 source_memories: vec![],
3976 updated_at: crate::types::Timestamp::now(),
3977 summary: None,
3978 };
3979 let bob_profile = EntityProfile {
3980 entity_id: bob_id.clone(),
3981 name: "Bob".to_string(),
3982 entity_type: "person".to_string(),
3983 facts: std::collections::HashMap::new(),
3984 source_memories: bob_source_memories.clone(),
3985 updated_at: crate::types::Timestamp::now(),
3986 summary: None,
3987 };
3988 planner
3989 .storage
3990 .store_entity_profile(&alice_profile)
3991 .unwrap();
3992 planner.storage.store_entity_profile(&bob_profile).unwrap();
3993
3994 {
3996 let mut graph = planner.graph_manager.write().unwrap();
3997 graph.link_entity_to_entity(&alice_id, &bob_id, "friend");
3998 }
3999
4000 let (_intent, results, _facts) = planner
4002 .query(
4003 "Tell me about Alice's friends",
4004 &vec![0.15; 384],
4005 10,
4006 None,
4007 None,
4008 None,
4009 None,
4010 )
4011 .unwrap();
4012
4013 let bob_memories_in_results: Vec<_> = results
4016 .iter()
4017 .filter(|r| {
4018 bob_source_memories
4019 .iter()
4020 .any(|bm| bm.to_u64() == r.id.to_u64())
4021 })
4022 .collect();
4023
4024 assert!(
4028 bob_memories_in_results.len() <= 10,
4029 "Should cap at 10 related entity memories, got {}",
4030 bob_memories_in_results.len()
4031 );
4032 }
4033}