Skip to main content

agent_io/memory/
ranker.rs

1//! Memory ranking and importance decay
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5
6use super::entry::MemoryEntry;
7
8/// Ranking weights for memory relevance scoring
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct RankingWeights {
11    /// Weight for embedding similarity (0.0 - 1.0)
12    pub similarity: f32,
13    /// Weight for importance score (0.0 - 1.0)
14    pub importance: f32,
15    /// Weight for recency (0.0 - 1.0)
16    pub recency: f32,
17    /// Weight for access frequency (0.0 - 1.0)
18    pub frequency: f32,
19}
20
21impl Default for RankingWeights {
22    fn default() -> Self {
23        Self {
24            similarity: 0.4,
25            importance: 0.25,
26            recency: 0.2,
27            frequency: 0.15,
28        }
29    }
30}
31
32/// Memory ranker for scoring and sorting memories
33pub struct MemoryRanker {
34    weights: RankingWeights,
35    recency_half_life_hours: f32,
36}
37
38impl MemoryRanker {
39    /// Create a new memory ranker with default weights
40    pub fn new() -> Self {
41        Self {
42            weights: RankingWeights::default(),
43            recency_half_life_hours: 24.0 * 7.0, // 1 week
44        }
45    }
46
47    /// Create a ranker with custom weights
48    pub fn with_weights(weights: RankingWeights) -> Self {
49        Self {
50            weights,
51            recency_half_life_hours: 24.0 * 7.0,
52        }
53    }
54
55    /// Set recency half-life in hours
56    pub fn with_recency_half_life(mut self, hours: f32) -> Self {
57        self.recency_half_life_hours = hours;
58        self
59    }
60
61    /// Calculate recency score (exponential decay)
62    fn recency_score(&self, created_at: DateTime<Utc>) -> f32 {
63        let age_hours = (Utc::now() - created_at).num_hours() as f32;
64        (-age_hours / self.recency_half_life_hours).exp()
65    }
66
67    /// Calculate frequency score (logarithmic)
68    fn frequency_score(&self, access_count: u32) -> f32 {
69        if access_count == 0 {
70            0.0
71        } else {
72            (1.0 + access_count as f32).ln() / 10.0 // Normalize to roughly 0-1
73        }
74    }
75
76    /// Calculate composite relevance score
77    pub fn score(&self, entry: &MemoryEntry, query_embedding: &[f32]) -> f32 {
78        // Similarity score
79        let similarity = if let Some(ref embedding) = entry.embedding {
80            cosine_similarity(query_embedding, embedding)
81        } else {
82            0.0
83        };
84
85        // Recency score
86        let recency = self.recency_score(entry.created_at);
87
88        // Frequency score
89        let frequency = self.frequency_score(entry.access_count);
90
91        // Weighted combination
92        self.weights.similarity * similarity
93            + self.weights.importance * entry.importance
94            + self.weights.recency * recency
95            + self.weights.frequency * frequency
96    }
97
98    /// Rank memories by relevance to query
99    pub fn rank(&self, query_embedding: &[f32], memories: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
100        let mut scored: Vec<(f32, MemoryEntry)> = memories
101            .into_iter()
102            .map(|m| (self.score(&m, query_embedding), m))
103            .collect();
104
105        // Sort by score descending
106        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
107
108        scored.into_iter().map(|(_, m)| m).collect()
109    }
110
111    /// Rank memories and return with scores
112    pub fn rank_with_scores(
113        &self,
114        query_embedding: &[f32],
115        memories: Vec<MemoryEntry>,
116    ) -> Vec<(f32, MemoryEntry)> {
117        let mut scored: Vec<(f32, MemoryEntry)> = memories
118            .into_iter()
119            .map(|m| (self.score(&m, query_embedding), m))
120            .collect();
121
122        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
123        scored
124    }
125}
126
127impl Default for MemoryRanker {
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133/// Importance decay configuration
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct DecayConfig {
136    /// Decay rate per day (0.0 - 1.0)
137    pub daily_rate: f32,
138    /// Minimum importance threshold (memories below this are candidates for removal)
139    pub min_threshold: f32,
140    /// Age in days before decay starts
141    pub grace_period_days: u32,
142}
143
144impl Default for DecayConfig {
145    fn default() -> Self {
146        Self {
147            daily_rate: 0.01,
148            min_threshold: 0.1,
149            grace_period_days: 7,
150        }
151    }
152}
153
154impl DecayConfig {
155    /// Create new decay config
156    pub fn new() -> Self {
157        Self::default()
158    }
159
160    /// Set daily decay rate
161    pub fn with_rate(mut self, rate: f32) -> Self {
162        self.daily_rate = rate.clamp(0.0, 1.0);
163        self
164    }
165
166    /// Set minimum threshold
167    pub fn with_min_threshold(mut self, threshold: f32) -> Self {
168        self.min_threshold = threshold.clamp(0.0, 1.0);
169        self
170    }
171
172    /// Set grace period
173    pub fn with_grace_period(mut self, days: u32) -> Self {
174        self.grace_period_days = days;
175        self
176    }
177
178    /// Apply decay to a memory entry
179    pub fn apply(&self, entry: &mut MemoryEntry) -> bool {
180        let age_days = (Utc::now() - entry.created_at).num_days() as u32;
181
182        // Skip if in grace period
183        if age_days < self.grace_period_days {
184            return false;
185        }
186
187        // Apply exponential decay
188        let days_since_grace = age_days - self.grace_period_days;
189        let decay_factor = (1.0 - self.daily_rate).powi(days_since_grace as i32);
190        entry.importance *= decay_factor;
191
192        // Check if below threshold
193        entry.importance < self.min_threshold
194    }
195}
196
197/// Calculate cosine similarity between two vectors
198fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
199    if a.len() != b.len() || a.is_empty() {
200        return 0.0;
201    }
202
203    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
204    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
205    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
206
207    if norm_a == 0.0 || norm_b == 0.0 {
208        0.0
209    } else {
210        (dot / (norm_a * norm_b)).clamp(0.0, 1.0)
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn test_ranking_weights() {
220        let weights = RankingWeights::default();
221        assert!(
222            (weights.similarity + weights.importance + weights.recency + weights.frequency - 1.0)
223                .abs()
224                < 0.01
225        );
226    }
227
228    #[test]
229    fn test_recency_score() {
230        let ranker = MemoryRanker::new();
231
232        // Recent memory should have high score
233        let recent = Utc::now() - chrono::Duration::hours(1);
234        assert!(ranker.recency_score(recent) > 0.9);
235
236        // Old memory should have low score
237        let old = Utc::now() - chrono::Duration::hours(24 * 30);
238        assert!(ranker.recency_score(old) < 0.5);
239    }
240
241    #[test]
242    fn test_frequency_score() {
243        let ranker = MemoryRanker::new();
244
245        assert_eq!(ranker.frequency_score(0), 0.0);
246        assert!(ranker.frequency_score(10) > ranker.frequency_score(1));
247    }
248
249    #[test]
250    fn test_rank_memories() {
251        let ranker = MemoryRanker::new();
252
253        let mut entry1 = MemoryEntry::new("First");
254        entry1.embedding = Some(vec![1.0, 0.0, 0.0]);
255        entry1.importance = 0.9;
256
257        let mut entry2 = MemoryEntry::new("Second");
258        entry2.embedding = Some(vec![0.0, 1.0, 0.0]);
259        entry2.importance = 0.5;
260
261        let ranked = ranker.rank(&[0.9, 0.1, 0.0], vec![entry1.clone(), entry2.clone()]);
262        assert_eq!(ranked[0].content, "First");
263    }
264
265    #[test]
266    fn test_decay_config() {
267        let config = DecayConfig::new()
268            .with_rate(0.1)
269            .with_min_threshold(0.2)
270            .with_grace_period(3);
271
272        assert_eq!(config.daily_rate, 0.1);
273        assert_eq!(config.min_threshold, 0.2);
274        assert_eq!(config.grace_period_days, 3);
275    }
276
277    #[test]
278    fn test_decay_apply() {
279        let config = DecayConfig::default();
280        let mut entry = MemoryEntry::new("Test");
281        entry.importance = 0.5;
282        entry.created_at = Utc::now() - chrono::Duration::days(10);
283
284        let _below_threshold = config.apply(&mut entry);
285        assert!(entry.importance < 0.5);
286    }
287}