use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use super::entry::MemoryEntry;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RankingWeights {
pub similarity: f32,
pub importance: f32,
pub recency: f32,
pub frequency: f32,
}
impl Default for RankingWeights {
fn default() -> Self {
Self {
similarity: 0.4,
importance: 0.25,
recency: 0.2,
frequency: 0.15,
}
}
}
pub struct MemoryRanker {
weights: RankingWeights,
recency_half_life_hours: f32,
}
impl MemoryRanker {
pub fn new() -> Self {
Self {
weights: RankingWeights::default(),
recency_half_life_hours: 24.0 * 7.0, }
}
pub fn with_weights(weights: RankingWeights) -> Self {
Self {
weights,
recency_half_life_hours: 24.0 * 7.0,
}
}
pub fn with_recency_half_life(mut self, hours: f32) -> Self {
self.recency_half_life_hours = hours;
self
}
fn recency_score(&self, created_at: DateTime<Utc>) -> f32 {
let age_hours = (Utc::now() - created_at).num_hours() as f32;
(-age_hours / self.recency_half_life_hours).exp()
}
fn frequency_score(&self, access_count: u32) -> f32 {
if access_count == 0 {
0.0
} else {
(1.0 + access_count as f32).ln() / 10.0 }
}
pub fn score(&self, entry: &MemoryEntry, query_embedding: &[f32]) -> f32 {
let similarity = if let Some(ref embedding) = entry.embedding {
cosine_similarity(query_embedding, embedding)
} else {
0.0
};
let recency = self.recency_score(entry.created_at);
let frequency = self.frequency_score(entry.access_count);
self.weights.similarity * similarity
+ self.weights.importance * entry.importance
+ self.weights.recency * recency
+ self.weights.frequency * frequency
}
pub fn rank(&self, query_embedding: &[f32], memories: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
let mut scored: Vec<(f32, MemoryEntry)> = memories
.into_iter()
.map(|m| (self.score(&m, query_embedding), m))
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored.into_iter().map(|(_, m)| m).collect()
}
pub fn rank_with_scores(
&self,
query_embedding: &[f32],
memories: Vec<MemoryEntry>,
) -> Vec<(f32, MemoryEntry)> {
let mut scored: Vec<(f32, MemoryEntry)> = memories
.into_iter()
.map(|m| (self.score(&m, query_embedding), m))
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored
}
}
impl Default for MemoryRanker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DecayConfig {
pub daily_rate: f32,
pub min_threshold: f32,
pub grace_period_days: u32,
}
impl Default for DecayConfig {
fn default() -> Self {
Self {
daily_rate: 0.01,
min_threshold: 0.1,
grace_period_days: 7,
}
}
}
impl DecayConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_rate(mut self, rate: f32) -> Self {
self.daily_rate = rate.clamp(0.0, 1.0);
self
}
pub fn with_min_threshold(mut self, threshold: f32) -> Self {
self.min_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn with_grace_period(mut self, days: u32) -> Self {
self.grace_period_days = days;
self
}
pub fn apply(&self, entry: &mut MemoryEntry) -> bool {
let age_days = (Utc::now() - entry.created_at).num_days() as u32;
if age_days < self.grace_period_days {
return false;
}
let days_since_grace = age_days - self.grace_period_days;
let decay_factor = (1.0 - self.daily_rate).powi(days_since_grace as i32);
entry.importance *= decay_factor;
entry.importance < self.min_threshold
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
(dot / (norm_a * norm_b)).clamp(0.0, 1.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ranking_weights() {
let weights = RankingWeights::default();
assert!(
(weights.similarity + weights.importance + weights.recency + weights.frequency - 1.0)
.abs()
< 0.01
);
}
#[test]
fn test_recency_score() {
let ranker = MemoryRanker::new();
let recent = Utc::now() - chrono::Duration::hours(1);
assert!(ranker.recency_score(recent) > 0.9);
let old = Utc::now() - chrono::Duration::hours(24 * 30);
assert!(ranker.recency_score(old) < 0.5);
}
#[test]
fn test_frequency_score() {
let ranker = MemoryRanker::new();
assert_eq!(ranker.frequency_score(0), 0.0);
assert!(ranker.frequency_score(10) > ranker.frequency_score(1));
}
#[test]
fn test_rank_memories() {
let ranker = MemoryRanker::new();
let mut entry1 = MemoryEntry::new("First");
entry1.embedding = Some(vec![1.0, 0.0, 0.0]);
entry1.importance = 0.9;
let mut entry2 = MemoryEntry::new("Second");
entry2.embedding = Some(vec![0.0, 1.0, 0.0]);
entry2.importance = 0.5;
let ranked = ranker.rank(&[0.9, 0.1, 0.0], vec![entry1.clone(), entry2.clone()]);
assert_eq!(ranked[0].content, "First");
}
#[test]
fn test_decay_config() {
let config = DecayConfig::new()
.with_rate(0.1)
.with_min_threshold(0.2)
.with_grace_period(3);
assert_eq!(config.daily_rate, 0.1);
assert_eq!(config.min_threshold, 0.2);
assert_eq!(config.grace_period_days, 3);
}
#[test]
fn test_decay_apply() {
let config = DecayConfig::default();
let mut entry = MemoryEntry::new("Test");
entry.importance = 0.5;
entry.created_at = Utc::now() - chrono::Duration::days(10);
let _below_threshold = config.apply(&mut entry);
assert!(entry.importance < 0.5);
}
}