use std::future::{Future, IntoFuture};
use std::ops::Deref;
use std::pin::Pin;
use chrono::{DateTime, FixedOffset, Utc};
use crate::embedding::EmbeddingModel;
use crate::memory::{KindSelector, Memory, Scope};
use crate::store::MemoryStore;
use crate::vector::{FilterCondition, MemoryFilter, NumericRange, VectorIndex};
use super::{Client, ClientError};
pub const DEFAULT_QUERY_LIMIT: usize = 10;
pub const DEFAULT_HYBRID_ALPHA: f32 = 0.7;
pub const DEFAULT_HYBRID_HALF_LIFE_DAYS: f32 = 7.0;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum DecayFn {
Exponential {
half_life: chrono::Duration,
},
Reciprocal {
scale: chrono::Duration,
},
Step {
thresholds: Vec<(chrono::Duration, f32)>,
},
}
impl DecayFn {
fn evaluate(&self, age: chrono::Duration) -> f32 {
let age_secs = age.num_seconds().max(0) as f32;
match self {
DecayFn::Exponential { half_life } => {
let hl = (half_life.num_seconds().max(1)) as f32;
(-std::f32::consts::LN_2 * age_secs / hl).exp()
}
DecayFn::Reciprocal { scale } => {
let s = (scale.num_seconds().max(1)) as f32;
1.0 / (1.0 + age_secs / s)
}
DecayFn::Step { thresholds } => {
for (boundary, value) in thresholds {
if age <= *boundary {
return *value;
}
}
thresholds.last().map(|(_, v)| *v).unwrap_or(0.0)
}
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct BlendWeights {
pub cosine: f32,
pub confidence: f32,
pub recency: f32,
pub category_bonus: f32,
pub preferred_categories: Vec<String>,
}
impl BlendWeights {
#[must_use]
pub fn relevance_first() -> Self {
Self {
cosine: 0.7,
confidence: 0.15,
recency: 0.15,
category_bonus: 0.05,
preferred_categories: Vec::new(),
}
}
#[must_use]
pub fn trust_first() -> Self {
Self {
cosine: 0.4,
confidence: 0.45,
recency: 0.15,
category_bonus: 0.05,
preferred_categories: Vec::new(),
}
}
#[must_use]
pub fn balanced() -> Self {
Self {
cosine: 0.4,
confidence: 0.3,
recency: 0.3,
category_bonus: 0.05,
preferred_categories: Vec::new(),
}
}
#[must_use]
pub fn prefer_categories(mut self, categories: impl IntoIterator<Item = String>) -> Self {
self.preferred_categories = categories.into_iter().collect();
self
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum RankingStrategy {
Hybrid {
alpha: f32,
decay: DecayFn,
},
Blended {
weights: BlendWeights,
decay: DecayFn,
},
}
impl RankingStrategy {
pub fn default_hybrid() -> Self {
Self::Hybrid {
alpha: DEFAULT_HYBRID_ALPHA,
decay: DecayFn::Exponential {
half_life: chrono::Duration::days(DEFAULT_HYBRID_HALF_LIFE_DAYS as i64),
},
}
}
#[must_use]
pub fn blended(weights: BlendWeights) -> Self {
Self::Blended {
weights,
decay: DecayFn::Exponential {
half_life: chrono::Duration::days(DEFAULT_HYBRID_HALF_LIFE_DAYS as i64),
},
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryContext {
memories: Vec<Memory>,
system_prompt: Option<String>,
strategy: RankingStrategy,
graph: crate::graph::GraphContext,
}
impl MemoryContext {
pub(super) fn new(
memories: Vec<Memory>,
system_prompt: Option<String>,
strategy: RankingStrategy,
) -> Self {
Self {
memories,
system_prompt,
strategy,
graph: crate::graph::GraphContext::default(),
}
}
#[cfg(feature = "knowledge-graph")]
#[must_use]
pub(super) fn with_graph_context(mut self, graph: crate::graph::GraphContext) -> Self {
self.graph = graph;
self
}
#[must_use]
pub fn memories(&self) -> &[Memory] {
&self.memories
}
#[must_use]
pub fn strategy_used(&self) -> &RankingStrategy {
&self.strategy
}
#[must_use]
pub fn system_prompt(&self) -> Option<&str> {
self.system_prompt.as_deref()
}
#[must_use]
pub fn graph(&self) -> &crate::graph::GraphContext {
&self.graph
}
}
impl Deref for MemoryContext {
type Target = [Memory];
fn deref(&self) -> &[Memory] {
&self.memories
}
}
impl std::fmt::Display for MemoryContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(prompt) = &self.system_prompt {
writeln!(f, "{prompt}")?;
}
let now = Utc::now().with_timezone(&chrono::FixedOffset::east_opt(0).unwrap());
for memory in &self.memories {
let anchor = memory.event_at.unwrap_or(memory.created_at);
let date = anchor.format("%Y-%m-%d");
let relative = relative_label(now - anchor);
writeln!(f, "- [{date}, {relative}] {}", memory.content)?;
}
Ok(())
}
}
fn relative_label(delta: chrono::Duration) -> String {
let secs = delta.num_seconds();
if secs < 0 {
return "in the future".to_string();
}
if secs < 60 {
return "just now".to_string();
}
let mins = delta.num_minutes();
if mins < 60 {
return format!("{mins} minute{} ago", if mins == 1 { "" } else { "s" });
}
let hours = delta.num_hours();
if hours < 24 {
return format!("{hours} hour{} ago", if hours == 1 { "" } else { "s" });
}
let days = delta.num_days();
if days < 30 {
return format!("{days} day{} ago", if days == 1 { "" } else { "s" });
}
let months = days / 30;
if months < 12 {
return format!("{months} month{} ago", if months == 1 { "" } else { "s" });
}
let years = days / 365;
format!("{years} year{} ago", if years == 1 { "" } else { "s" })
}
#[must_use = "query(..) returns a builder that must be awaited"]
pub struct QueryBuilder<'a> {
client: &'a Client,
query: String,
scope: Scope,
limit: usize,
episodic: bool,
semantic: bool,
metadata_filter: Option<MemoryFilter>,
min_similarity: Option<f32>,
created_at_range: NumericRange,
event_at_range: NumericRange,
ranking: Option<RankingStrategy>,
#[cfg(feature = "knowledge-graph")]
graph_depth: Option<usize>,
}
impl<'a> QueryBuilder<'a> {
pub(super) fn new(client: &'a Client, query: String, scope: Scope) -> Self {
Self {
client,
query,
scope,
limit: DEFAULT_QUERY_LIMIT,
episodic: false,
semantic: false,
metadata_filter: None,
min_similarity: None,
created_at_range: NumericRange::default(),
event_at_range: NumericRange::default(),
ranking: None,
#[cfg(feature = "knowledge-graph")]
graph_depth: None,
}
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
pub fn episodic(mut self) -> Self {
self.episodic = true;
self
}
pub fn semantic(mut self) -> Self {
self.semantic = true;
self
}
pub fn metadata_filter(mut self, filter: MemoryFilter) -> Self {
self.metadata_filter = Some(filter);
self
}
pub fn min_similarity(mut self, threshold: f32) -> Self {
self.min_similarity = Some(threshold);
self
}
pub fn created_after(mut self, at: impl Into<DateTime<FixedOffset>>) -> Self {
self.created_at_range.gte = Some(at.into().timestamp_millis() as f64);
self
}
pub fn created_before(mut self, at: impl Into<DateTime<FixedOffset>>) -> Self {
self.created_at_range.lt = Some(at.into().timestamp_millis() as f64);
self
}
pub fn event_at_after(mut self, at: impl Into<DateTime<FixedOffset>>) -> Self {
self.event_at_range.gte = Some(at.into().timestamp_millis() as f64);
self
}
pub fn event_at_before(mut self, at: impl Into<DateTime<FixedOffset>>) -> Self {
self.event_at_range.lt = Some(at.into().timestamp_millis() as f64);
self
}
pub fn ranking(mut self, strategy: RankingStrategy) -> Self {
self.ranking = Some(strategy);
self
}
#[cfg(feature = "knowledge-graph")]
pub fn with_graph(mut self) -> Self {
self.graph_depth = Some(crate::graph::DEFAULT_ENRICHMENT_DEPTH);
self
}
#[cfg(feature = "knowledge-graph")]
pub fn with_graph_depth(mut self, depth: usize) -> Self {
self.graph_depth = Some(depth.clamp(1, crate::graph::MAX_ENRICHMENT_DEPTH));
self
}
}
fn kind_selector(episodic: bool, semantic: bool) -> KindSelector {
match (episodic, semantic) {
(false, false) => KindSelector::default(),
(episodic, semantic) => KindSelector { episodic, semantic },
}
}
fn combine_filter(
metadata_filter: Option<MemoryFilter>,
created_at: NumericRange,
event_at: NumericRange,
) -> Option<MemoryFilter> {
if metadata_filter.is_none() && created_at.is_unbounded() && event_at.is_unbounded() {
return None;
}
let mut combined = metadata_filter.unwrap_or_default();
if !created_at.is_unbounded() {
combined.must.push(FilterCondition::Range {
field: "created_at".to_string(),
range: created_at,
});
}
if !event_at.is_unbounded() {
combined.must.push(FilterCondition::Range {
field: "event_at".to_string(),
range: event_at,
});
}
Some(combined)
}
fn rank_score(strategy: &RankingStrategy, cosine: f32, memory: &Memory, now: DateTime<FixedOffset>) -> f32 {
match strategy {
RankingStrategy::Hybrid { alpha, decay } => {
let anchor = memory.event_at.unwrap_or(memory.created_at);
let age = now - anchor;
let recency = decay.evaluate(age);
alpha * cosine + (1.0 - alpha) * recency
}
RankingStrategy::Blended { weights, decay } => {
let anchor = memory.event_at.unwrap_or(memory.created_at);
let recency = decay.evaluate(now - anchor);
let confidence = f32::from(memory.confidence.get()) / 100.0;
let category_bonus = match &memory.category {
Some(category) if weights.preferred_categories.iter().any(|c| c == category) => {
weights.category_bonus
}
_ => 0.0,
};
weights.cosine * cosine
+ weights.confidence * confidence
+ weights.recency * recency
+ category_bonus
}
}
}
impl<'a> IntoFuture for QueryBuilder<'a> {
type Output = Result<MemoryContext, ClientError>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(execute(self))
}
}
async fn execute(builder: QueryBuilder<'_>) -> Result<MemoryContext, ClientError> {
let kinds = kind_selector(builder.episodic, builder.semantic);
let strategy = builder.ranking.unwrap_or_else(RankingStrategy::default_hybrid);
#[cfg(feature = "knowledge-graph")]
let graph_depth = builder.graph_depth;
let QueryBuilder {
client,
query,
scope,
limit,
metadata_filter,
min_similarity,
created_at_range,
event_at_range,
..
} = builder;
#[cfg(feature = "knowledge-graph")]
let graph_scope = scope.clone();
let combined_filter = combine_filter(metadata_filter, created_at_range, event_at_range);
let candidate_limit = limit.saturating_mul(3).max(limit);
let inner = client.inner.clone();
let query_vector = inner.embedder.embed(&query).await?;
let hits = inner
.index
.search(scope, query_vector, candidate_limit, kinds, combined_filter, min_similarity)
.await?;
let pids: Vec<&str> = hits.iter().map(|(pid, _)| pid.as_str()).collect();
let mut rows = inner.store.find_by_pids(&pids).await?;
let cosine: std::collections::HashMap<&str, f32> = hits
.iter()
.map(|(pid, score)| (pid.as_str(), *score))
.collect();
let now: DateTime<FixedOffset> = Utc::now().into();
let mut scored: Vec<(f32, Memory)> = rows
.drain(..)
.filter_map(|m| {
let raw = *cosine.get(m.pid.as_str())?;
let score = rank_score(&strategy, raw, &m, now);
Some((score, m))
})
.collect();
scored.sort_by(|(a, _), (b, _)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit);
let memories: Vec<Memory> = scored
.into_iter()
.map(|(score, mut m)| {
m.score = Some(score);
m
})
.collect();
let context = MemoryContext::new(memories, inner.system_prompt.clone(), strategy);
#[cfg(feature = "knowledge-graph")]
if let Some(depth) = graph_depth {
if let Some(graph) = inner.graph.as_deref() {
use crate::graph::GraphStore;
let seed_pids: Vec<&str> = context.memories().iter().map(|m| m.pid.as_str()).collect();
let graph_context = graph.neighbors(&seed_pids, &graph_scope, depth).await?;
return Ok(context.with_graph_context(graph_context));
}
}
Ok(context)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_default_hybrid_use_documented_alpha_and_decay() {
let strategy = RankingStrategy::default_hybrid();
let RankingStrategy::Hybrid { alpha, decay } = strategy else {
panic!("default_hybrid must return the Hybrid variant; got {strategy:?}");
};
assert!((alpha - DEFAULT_HYBRID_ALPHA).abs() < f32::EPSILON);
assert_eq!(
decay,
DecayFn::Exponential {
half_life: chrono::Duration::days(DEFAULT_HYBRID_HALF_LIFE_DAYS as i64)
}
);
}
#[test]
fn should_exponential_decay_be_half_at_half_life() {
let decay = DecayFn::Exponential {
half_life: chrono::Duration::days(7),
};
let v = decay.evaluate(chrono::Duration::days(7));
assert!((v - 0.5).abs() < 1e-3, "exp decay at half-life should be ~0.5, got {v}");
}
#[test]
fn should_reciprocal_decay_be_half_at_scale() {
let decay = DecayFn::Reciprocal {
scale: chrono::Duration::days(7),
};
let v = decay.evaluate(chrono::Duration::days(7));
assert!((v - 0.5).abs() < 1e-3, "reciprocal decay at scale should be 0.5, got {v}");
}
#[test]
fn should_step_decay_apply_first_matching_bucket() {
let decay = DecayFn::Step {
thresholds: vec![
(chrono::Duration::hours(1), 1.0),
(chrono::Duration::days(1), 0.5),
(chrono::Duration::days(7), 0.1),
],
};
assert_eq!(decay.evaluate(chrono::Duration::minutes(30)), 1.0);
assert_eq!(decay.evaluate(chrono::Duration::hours(12)), 0.5);
assert_eq!(decay.evaluate(chrono::Duration::days(3)), 0.1);
assert_eq!(decay.evaluate(chrono::Duration::days(30)), 0.1);
}
#[test]
fn should_relative_label_render_minutes_and_days() {
assert_eq!(relative_label(chrono::Duration::seconds(30)), "just now");
assert_eq!(relative_label(chrono::Duration::minutes(5)), "5 minutes ago");
assert_eq!(relative_label(chrono::Duration::minutes(1)), "1 minute ago");
assert_eq!(relative_label(chrono::Duration::hours(3)), "3 hours ago");
assert_eq!(relative_label(chrono::Duration::days(2)), "2 days ago");
}
fn scored_fixture(now: DateTime<FixedOffset>, confidence: i8, category: Option<&str>) -> Memory {
Memory {
pid: "p".into(),
scope: Scope {
agent_id: "a".into(),
org_id: "o".into(),
user_id: "u".into(),
},
content: "c".into(),
metadata: serde_json::json!({}),
kind: crate::memory::MemoryKind::Semantic,
source_pid: None,
supersession: None,
created_at: now,
updated_at: now,
event_at: None,
score: None,
status: crate::store::IndexStatus::Indexed,
confidence: crate::memory::Confidence::new(confidence),
category: category.map(str::to_string),
retirement: None,
}
}
fn balanced_blend() -> RankingStrategy {
RankingStrategy::blended(BlendWeights::balanced())
}
#[test]
fn should_rank_high_confidence_above_low_at_equal_cosine() {
let now = Utc::now().into();
let strategy = balanced_blend();
let high = rank_score(&strategy, 0.8, &scored_fixture(now, 95, None), now);
let low = rank_score(&strategy, 0.8, &scored_fixture(now, 10, None), now);
assert!(high > low, "high confidence ({high}) must outrank low ({low}) at equal cosine");
}
#[test]
fn should_keep_recency_moving_ranking_at_equal_cosine_and_confidence() {
let now: DateTime<FixedOffset> = Utc::now().into();
let strategy = balanced_blend();
let mut old = scored_fixture(now, 80, None);
old.created_at = now - chrono::Duration::days(60);
let recent = scored_fixture(now, 80, None);
let recent_score = rank_score(&strategy, 0.8, &recent, now);
let old_score = rank_score(&strategy, 0.8, &old, now);
assert!(
recent_score > old_score,
"recent ({recent_score}) must outrank old ({old_score}) at equal cosine+confidence"
);
}
#[test]
fn should_apply_category_bonus_only_to_preferred_categories() {
let now: DateTime<FixedOffset> = Utc::now().into();
let strategy = RankingStrategy::blended(BlendWeights::balanced().prefer_categories(["preference".to_string()]));
let preferred = rank_score(&strategy, 0.8, &scored_fixture(now, 80, Some("preference")), now);
let other = rank_score(&strategy, 0.8, &scored_fixture(now, 80, Some("transient")), now);
let uncategorized = rank_score(&strategy, 0.8, &scored_fixture(now, 80, None), now);
assert!(preferred > other, "preferred category must earn the bonus");
assert!(
(other - uncategorized).abs() < f32::EPSILON,
"non-preferred and uncategorized rows must score identically (no bonus)"
);
}
#[test]
fn should_blend_be_inert_on_category_when_no_preference_set() {
let now: DateTime<FixedOffset> = Utc::now().into();
let strategy = balanced_blend();
let with_cat = rank_score(&strategy, 0.8, &scored_fixture(now, 80, Some("preference")), now);
let without = rank_score(&strategy, 0.8, &scored_fixture(now, 80, None), now);
assert!((with_cat - without).abs() < f32::EPSILON);
}
#[test]
fn should_preset_weights_differ_in_confidence_emphasis() {
assert!(BlendWeights::trust_first().confidence > BlendWeights::relevance_first().confidence);
assert!(BlendWeights::relevance_first().cosine > BlendWeights::trust_first().cosine);
}
}