1use std::future::{Future, IntoFuture};
5use std::ops::Deref;
6use std::pin::Pin;
7
8use chrono::{DateTime, FixedOffset, Utc};
9
10use crate::embedding::EmbeddingModel;
11use crate::memory::{KindSelector, Memory, Scope};
12use crate::store::MemoryStore;
13use crate::vector::{FilterCondition, MemoryFilter, NumericRange, VectorIndex};
14
15use super::{Client, ClientError};
16
17pub const DEFAULT_QUERY_LIMIT: usize = 10;
19
20pub const DEFAULT_HYBRID_ALPHA: f32 = 0.7;
22
23pub const DEFAULT_HYBRID_HALF_LIFE_DAYS: f32 = 7.0;
25
26#[non_exhaustive]
33#[derive(Debug, Clone, PartialEq)]
34pub enum DecayFn {
35 Exponential {
37 half_life: chrono::Duration,
39 },
40
41 Reciprocal {
43 scale: chrono::Duration,
45 },
46
47 Step {
51 thresholds: Vec<(chrono::Duration, f32)>,
53 },
54}
55
56impl DecayFn {
57 fn evaluate(&self, age: chrono::Duration) -> f32 {
58 let age_secs = age.num_seconds().max(0) as f32;
59 match self {
60 DecayFn::Exponential { half_life } => {
61 let hl = (half_life.num_seconds().max(1)) as f32;
62 (-std::f32::consts::LN_2 * age_secs / hl).exp()
63 }
64 DecayFn::Reciprocal { scale } => {
65 let s = (scale.num_seconds().max(1)) as f32;
66 1.0 / (1.0 + age_secs / s)
67 }
68 DecayFn::Step { thresholds } => {
69 for (boundary, value) in thresholds {
70 if age <= *boundary {
71 return *value;
72 }
73 }
74 thresholds.last().map(|(_, v)| *v).unwrap_or(0.0)
75 }
76 }
77 }
78}
79
80#[derive(Debug, Clone, PartialEq)]
95pub struct BlendWeights {
96 pub cosine: f32,
98 pub confidence: f32,
100 pub recency: f32,
102 pub category_bonus: f32,
104 pub preferred_categories: Vec<String>,
106}
107
108impl BlendWeights {
109 #[must_use]
115 pub fn relevance_first() -> Self {
116 Self {
117 cosine: 0.7,
118 confidence: 0.15,
119 recency: 0.15,
120 category_bonus: 0.05,
121 preferred_categories: Vec::new(),
122 }
123 }
124
125 #[must_use]
132 pub fn trust_first() -> Self {
133 Self {
134 cosine: 0.4,
135 confidence: 0.45,
136 recency: 0.15,
137 category_bonus: 0.05,
138 preferred_categories: Vec::new(),
139 }
140 }
141
142 #[must_use]
144 pub fn balanced() -> Self {
145 Self {
146 cosine: 0.4,
147 confidence: 0.3,
148 recency: 0.3,
149 category_bonus: 0.05,
150 preferred_categories: Vec::new(),
151 }
152 }
153
154 #[must_use]
159 pub fn prefer_categories(mut self, categories: impl IntoIterator<Item = String>) -> Self {
160 self.preferred_categories = categories.into_iter().collect();
161 self
162 }
163}
164
165#[non_exhaustive]
180#[derive(Debug, Clone, PartialEq)]
181pub enum RankingStrategy {
182 Hybrid {
189 alpha: f32,
191 decay: DecayFn,
193 },
194
195 Blended {
204 weights: BlendWeights,
206 decay: DecayFn,
208 },
209}
210
211impl RankingStrategy {
212 pub fn default_hybrid() -> Self {
215 Self::Hybrid {
216 alpha: DEFAULT_HYBRID_ALPHA,
217 decay: DecayFn::Exponential {
218 half_life: chrono::Duration::days(DEFAULT_HYBRID_HALF_LIFE_DAYS as i64),
219 },
220 }
221 }
222
223 #[must_use]
230 pub fn blended(weights: BlendWeights) -> Self {
231 Self::Blended {
232 weights,
233 decay: DecayFn::Exponential {
234 half_life: chrono::Duration::days(DEFAULT_HYBRID_HALF_LIFE_DAYS as i64),
235 },
236 }
237 }
238}
239
240#[derive(Debug, Clone)]
261pub struct MemoryContext {
262 memories: Vec<Memory>,
263 system_prompt: Option<String>,
264 strategy: RankingStrategy,
265 graph: crate::graph::GraphContext,
266}
267
268impl MemoryContext {
269 pub(super) fn new(
270 memories: Vec<Memory>,
271 system_prompt: Option<String>,
272 strategy: RankingStrategy,
273 ) -> Self {
274 Self {
275 memories,
276 system_prompt,
277 strategy,
278 graph: crate::graph::GraphContext::default(),
279 }
280 }
281
282 #[cfg(feature = "knowledge-graph")]
284 #[must_use]
285 pub(super) fn with_graph_context(mut self, graph: crate::graph::GraphContext) -> Self {
286 self.graph = graph;
287 self
288 }
289
290 #[must_use]
292 pub fn memories(&self) -> &[Memory] {
293 &self.memories
294 }
295
296 #[must_use]
298 pub fn strategy_used(&self) -> &RankingStrategy {
299 &self.strategy
300 }
301
302 #[must_use]
304 pub fn system_prompt(&self) -> Option<&str> {
305 self.system_prompt.as_deref()
306 }
307
308 #[must_use]
315 pub fn graph(&self) -> &crate::graph::GraphContext {
316 &self.graph
317 }
318}
319
320impl Deref for MemoryContext {
321 type Target = [Memory];
322
323 fn deref(&self) -> &[Memory] {
324 &self.memories
325 }
326}
327
328impl std::fmt::Display for MemoryContext {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 if let Some(prompt) = &self.system_prompt {
331 writeln!(f, "{prompt}")?;
332 }
333 let now = Utc::now().with_timezone(&chrono::FixedOffset::east_opt(0).unwrap());
334 for memory in &self.memories {
335 let anchor = memory.event_at.unwrap_or(memory.created_at);
336 let date = anchor.format("%Y-%m-%d");
337 let relative = relative_label(now - anchor);
338 writeln!(f, "- [{date}, {relative}] {}", memory.content)?;
339 }
340 Ok(())
341 }
342}
343
344fn relative_label(delta: chrono::Duration) -> String {
345 let secs = delta.num_seconds();
346 if secs < 0 {
347 return "in the future".to_string();
348 }
349 if secs < 60 {
350 return "just now".to_string();
351 }
352 let mins = delta.num_minutes();
353 if mins < 60 {
354 return format!("{mins} minute{} ago", if mins == 1 { "" } else { "s" });
355 }
356 let hours = delta.num_hours();
357 if hours < 24 {
358 return format!("{hours} hour{} ago", if hours == 1 { "" } else { "s" });
359 }
360 let days = delta.num_days();
361 if days < 30 {
362 return format!("{days} day{} ago", if days == 1 { "" } else { "s" });
363 }
364 let months = days / 30;
365 if months < 12 {
366 return format!("{months} month{} ago", if months == 1 { "" } else { "s" });
367 }
368 let years = days / 365;
369 format!("{years} year{} ago", if years == 1 { "" } else { "s" })
370}
371
372#[must_use = "query(..) returns a builder that must be awaited"]
415pub struct QueryBuilder<'a> {
416 client: &'a Client,
417 query: String,
418 scope: Scope,
419 limit: usize,
420 episodic: bool,
421 semantic: bool,
422 metadata_filter: Option<MemoryFilter>,
423 min_similarity: Option<f32>,
424 created_at_range: NumericRange,
425 event_at_range: NumericRange,
426 ranking: Option<RankingStrategy>,
427 #[cfg(feature = "knowledge-graph")]
428 graph_depth: Option<usize>,
429}
430
431impl<'a> QueryBuilder<'a> {
432 pub(super) fn new(client: &'a Client, query: String, scope: Scope) -> Self {
433 Self {
434 client,
435 query,
436 scope,
437 limit: DEFAULT_QUERY_LIMIT,
438 episodic: false,
439 semantic: false,
440 metadata_filter: None,
441 min_similarity: None,
442 created_at_range: NumericRange::default(),
443 event_at_range: NumericRange::default(),
444 ranking: None,
445 #[cfg(feature = "knowledge-graph")]
446 graph_depth: None,
447 }
448 }
449
450 pub fn limit(mut self, limit: usize) -> Self {
452 self.limit = limit;
453 self
454 }
455
456 pub fn episodic(mut self) -> Self {
458 self.episodic = true;
459 self
460 }
461
462 pub fn semantic(mut self) -> Self {
464 self.semantic = true;
465 self
466 }
467
468 pub fn metadata_filter(mut self, filter: MemoryFilter) -> Self {
470 self.metadata_filter = Some(filter);
471 self
472 }
473
474 pub fn min_similarity(mut self, threshold: f32) -> Self {
476 self.min_similarity = Some(threshold);
477 self
478 }
479
480 pub fn created_after(mut self, at: impl Into<DateTime<FixedOffset>>) -> Self {
482 self.created_at_range.gte = Some(at.into().timestamp_millis() as f64);
483 self
484 }
485
486 pub fn created_before(mut self, at: impl Into<DateTime<FixedOffset>>) -> Self {
488 self.created_at_range.lt = Some(at.into().timestamp_millis() as f64);
489 self
490 }
491
492 pub fn event_at_after(mut self, at: impl Into<DateTime<FixedOffset>>) -> Self {
494 self.event_at_range.gte = Some(at.into().timestamp_millis() as f64);
495 self
496 }
497
498 pub fn event_at_before(mut self, at: impl Into<DateTime<FixedOffset>>) -> Self {
500 self.event_at_range.lt = Some(at.into().timestamp_millis() as f64);
501 self
502 }
503
504 pub fn ranking(mut self, strategy: RankingStrategy) -> Self {
506 self.ranking = Some(strategy);
507 self
508 }
509
510 #[cfg(feature = "knowledge-graph")]
518 pub fn with_graph(mut self) -> Self {
519 self.graph_depth = Some(crate::graph::DEFAULT_ENRICHMENT_DEPTH);
520 self
521 }
522
523 #[cfg(feature = "knowledge-graph")]
529 pub fn with_graph_depth(mut self, depth: usize) -> Self {
530 self.graph_depth = Some(depth.clamp(1, crate::graph::MAX_ENRICHMENT_DEPTH));
531 self
532 }
533}
534
535fn kind_selector(episodic: bool, semantic: bool) -> KindSelector {
536 match (episodic, semantic) {
537 (false, false) => KindSelector::default(),
538 (episodic, semantic) => KindSelector { episodic, semantic },
539 }
540}
541
542fn combine_filter(
543 metadata_filter: Option<MemoryFilter>,
544 created_at: NumericRange,
545 event_at: NumericRange,
546) -> Option<MemoryFilter> {
547 if metadata_filter.is_none() && created_at.is_unbounded() && event_at.is_unbounded() {
548 return None;
549 }
550 let mut combined = metadata_filter.unwrap_or_default();
551 if !created_at.is_unbounded() {
552 combined.must.push(FilterCondition::Range {
553 field: "created_at".to_string(),
554 range: created_at,
555 });
556 }
557 if !event_at.is_unbounded() {
558 combined.must.push(FilterCondition::Range {
559 field: "event_at".to_string(),
560 range: event_at,
561 });
562 }
563 Some(combined)
564}
565
566fn rank_score(strategy: &RankingStrategy, cosine: f32, memory: &Memory, now: DateTime<FixedOffset>) -> f32 {
567 match strategy {
568 RankingStrategy::Hybrid { alpha, decay } => {
569 let anchor = memory.event_at.unwrap_or(memory.created_at);
570 let age = now - anchor;
571 let recency = decay.evaluate(age);
572 alpha * cosine + (1.0 - alpha) * recency
573 }
574 RankingStrategy::Blended { weights, decay } => {
575 let anchor = memory.event_at.unwrap_or(memory.created_at);
576 let recency = decay.evaluate(now - anchor);
577 let confidence = f32::from(memory.confidence.get()) / 100.0;
581 let category_bonus = match &memory.category {
582 Some(category) if weights.preferred_categories.iter().any(|c| c == category) => {
583 weights.category_bonus
584 }
585 _ => 0.0,
586 };
587 weights.cosine * cosine
588 + weights.confidence * confidence
589 + weights.recency * recency
590 + category_bonus
591 }
592 }
593}
594
595impl<'a> IntoFuture for QueryBuilder<'a> {
596 type Output = Result<MemoryContext, ClientError>;
597 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
598
599 fn into_future(self) -> Self::IntoFuture {
600 Box::pin(execute(self))
601 }
602}
603
604async fn execute(builder: QueryBuilder<'_>) -> Result<MemoryContext, ClientError> {
605 let kinds = kind_selector(builder.episodic, builder.semantic);
606 let strategy = builder.ranking.unwrap_or_else(RankingStrategy::default_hybrid);
607 #[cfg(feature = "knowledge-graph")]
608 let graph_depth = builder.graph_depth;
609 let QueryBuilder {
610 client,
611 query,
612 scope,
613 limit,
614 metadata_filter,
615 min_similarity,
616 created_at_range,
617 event_at_range,
618 ..
619 } = builder;
620
621 #[cfg(feature = "knowledge-graph")]
623 let graph_scope = scope.clone();
624
625 let combined_filter = combine_filter(metadata_filter, created_at_range, event_at_range);
626 let candidate_limit = limit.saturating_mul(3).max(limit);
627 let inner = client.inner.clone();
628
629 let query_vector = inner.embedder.embed(&query).await?;
630 let hits = inner
631 .index
632 .search(scope, query_vector, candidate_limit, kinds, combined_filter, min_similarity)
633 .await?;
634
635 let pids: Vec<&str> = hits.iter().map(|(pid, _)| pid.as_str()).collect();
636 let mut rows = inner.store.find_by_pids(&pids).await?;
637
638 let cosine: std::collections::HashMap<&str, f32> = hits
639 .iter()
640 .map(|(pid, score)| (pid.as_str(), *score))
641 .collect();
642
643 let now: DateTime<FixedOffset> = Utc::now().into();
644 let mut scored: Vec<(f32, Memory)> = rows
645 .drain(..)
646 .filter_map(|m| {
647 let raw = *cosine.get(m.pid.as_str())?;
648 let score = rank_score(&strategy, raw, &m, now);
649 Some((score, m))
650 })
651 .collect();
652
653 scored.sort_by(|(a, _), (b, _)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
654 scored.truncate(limit);
655
656 let memories: Vec<Memory> = scored
657 .into_iter()
658 .map(|(score, mut m)| {
659 m.score = Some(score);
660 m
661 })
662 .collect();
663
664 let context = MemoryContext::new(memories, inner.system_prompt.clone(), strategy);
665
666 #[cfg(feature = "knowledge-graph")]
669 if let Some(depth) = graph_depth {
670 if let Some(graph) = inner.graph.as_deref() {
671 use crate::graph::GraphStore;
672 let seed_pids: Vec<&str> = context.memories().iter().map(|m| m.pid.as_str()).collect();
673 let graph_context = graph.neighbors(&seed_pids, &graph_scope, depth).await?;
674 return Ok(context.with_graph_context(graph_context));
675 }
676 }
677
678 Ok(context)
679}
680
681#[cfg(test)]
682mod tests {
683 use super::*;
684
685 #[test]
686 fn should_default_hybrid_use_documented_alpha_and_decay() {
687 let strategy = RankingStrategy::default_hybrid();
688 let RankingStrategy::Hybrid { alpha, decay } = strategy else {
689 panic!("default_hybrid must return the Hybrid variant; got {strategy:?}");
690 };
691 assert!((alpha - DEFAULT_HYBRID_ALPHA).abs() < f32::EPSILON);
692 assert_eq!(
693 decay,
694 DecayFn::Exponential {
695 half_life: chrono::Duration::days(DEFAULT_HYBRID_HALF_LIFE_DAYS as i64)
696 }
697 );
698 }
699
700 #[test]
701 fn should_exponential_decay_be_half_at_half_life() {
702 let decay = DecayFn::Exponential {
703 half_life: chrono::Duration::days(7),
704 };
705 let v = decay.evaluate(chrono::Duration::days(7));
706 assert!((v - 0.5).abs() < 1e-3, "exp decay at half-life should be ~0.5, got {v}");
707 }
708
709 #[test]
710 fn should_reciprocal_decay_be_half_at_scale() {
711 let decay = DecayFn::Reciprocal {
712 scale: chrono::Duration::days(7),
713 };
714 let v = decay.evaluate(chrono::Duration::days(7));
715 assert!((v - 0.5).abs() < 1e-3, "reciprocal decay at scale should be 0.5, got {v}");
716 }
717
718 #[test]
719 fn should_step_decay_apply_first_matching_bucket() {
720 let decay = DecayFn::Step {
721 thresholds: vec![
722 (chrono::Duration::hours(1), 1.0),
723 (chrono::Duration::days(1), 0.5),
724 (chrono::Duration::days(7), 0.1),
725 ],
726 };
727 assert_eq!(decay.evaluate(chrono::Duration::minutes(30)), 1.0);
728 assert_eq!(decay.evaluate(chrono::Duration::hours(12)), 0.5);
729 assert_eq!(decay.evaluate(chrono::Duration::days(3)), 0.1);
730 assert_eq!(decay.evaluate(chrono::Duration::days(30)), 0.1);
731 }
732
733 #[test]
734 fn should_relative_label_render_minutes_and_days() {
735 assert_eq!(relative_label(chrono::Duration::seconds(30)), "just now");
736 assert_eq!(relative_label(chrono::Duration::minutes(5)), "5 minutes ago");
737 assert_eq!(relative_label(chrono::Duration::minutes(1)), "1 minute ago");
738 assert_eq!(relative_label(chrono::Duration::hours(3)), "3 hours ago");
739 assert_eq!(relative_label(chrono::Duration::days(2)), "2 days ago");
740 }
741
742 fn scored_fixture(now: DateTime<FixedOffset>, confidence: i8, category: Option<&str>) -> Memory {
744 Memory {
745 pid: "p".into(),
746 scope: Scope {
747 agent_id: "a".into(),
748 org_id: "o".into(),
749 user_id: "u".into(),
750 },
751 content: "c".into(),
752 metadata: serde_json::json!({}),
753 kind: crate::memory::MemoryKind::Semantic,
754 source_pid: None,
755 supersession: None,
756 created_at: now,
757 updated_at: now,
758 event_at: None,
759 score: None,
760 status: crate::store::IndexStatus::Indexed,
761 confidence: crate::memory::Confidence::new(confidence),
762 category: category.map(str::to_string),
763 retirement: None,
764 }
765 }
766
767 fn balanced_blend() -> RankingStrategy {
768 RankingStrategy::blended(BlendWeights::balanced())
769 }
770
771 #[test]
772 fn should_rank_high_confidence_above_low_at_equal_cosine() {
773 let now = Utc::now().into();
776 let strategy = balanced_blend();
777 let high = rank_score(&strategy, 0.8, &scored_fixture(now, 95, None), now);
778 let low = rank_score(&strategy, 0.8, &scored_fixture(now, 10, None), now);
779 assert!(high > low, "high confidence ({high}) must outrank low ({low}) at equal cosine");
780 }
781
782 #[test]
783 fn should_keep_recency_moving_ranking_at_equal_cosine_and_confidence() {
784 let now: DateTime<FixedOffset> = Utc::now().into();
786 let strategy = balanced_blend();
787 let mut old = scored_fixture(now, 80, None);
788 old.created_at = now - chrono::Duration::days(60);
789 let recent = scored_fixture(now, 80, None);
790 let recent_score = rank_score(&strategy, 0.8, &recent, now);
791 let old_score = rank_score(&strategy, 0.8, &old, now);
792 assert!(
793 recent_score > old_score,
794 "recent ({recent_score}) must outrank old ({old_score}) at equal cosine+confidence"
795 );
796 }
797
798 #[test]
799 fn should_apply_category_bonus_only_to_preferred_categories() {
800 let now: DateTime<FixedOffset> = Utc::now().into();
801 let strategy = RankingStrategy::blended(BlendWeights::balanced().prefer_categories(["preference".to_string()]));
802 let preferred = rank_score(&strategy, 0.8, &scored_fixture(now, 80, Some("preference")), now);
803 let other = rank_score(&strategy, 0.8, &scored_fixture(now, 80, Some("transient")), now);
804 let uncategorized = rank_score(&strategy, 0.8, &scored_fixture(now, 80, None), now);
805 assert!(preferred > other, "preferred category must earn the bonus");
806 assert!(
807 (other - uncategorized).abs() < f32::EPSILON,
808 "non-preferred and uncategorized rows must score identically (no bonus)"
809 );
810 }
811
812 #[test]
813 fn should_blend_be_inert_on_category_when_no_preference_set() {
814 let now: DateTime<FixedOffset> = Utc::now().into();
816 let strategy = balanced_blend();
817 let with_cat = rank_score(&strategy, 0.8, &scored_fixture(now, 80, Some("preference")), now);
818 let without = rank_score(&strategy, 0.8, &scored_fixture(now, 80, None), now);
819 assert!((with_cat - without).abs() < f32::EPSILON);
820 }
821
822 #[test]
823 fn should_preset_weights_differ_in_confidence_emphasis() {
824 assert!(BlendWeights::trust_first().confidence > BlendWeights::relevance_first().confidence);
826 assert!(BlendWeights::relevance_first().cosine > BlendWeights::trust_first().cosine);
827 }
828}