agent_io/memory/
ranker.rs1use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5
6use super::entry::MemoryEntry;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct RankingWeights {
11 pub similarity: f32,
13 pub importance: f32,
15 pub recency: f32,
17 pub frequency: f32,
19}
20
21impl Default for RankingWeights {
22 fn default() -> Self {
23 Self {
24 similarity: 0.4,
25 importance: 0.25,
26 recency: 0.2,
27 frequency: 0.15,
28 }
29 }
30}
31
32pub struct MemoryRanker {
34 weights: RankingWeights,
35 recency_half_life_hours: f32,
36}
37
38impl MemoryRanker {
39 pub fn new() -> Self {
41 Self {
42 weights: RankingWeights::default(),
43 recency_half_life_hours: 24.0 * 7.0, }
45 }
46
47 pub fn with_weights(weights: RankingWeights) -> Self {
49 Self {
50 weights,
51 recency_half_life_hours: 24.0 * 7.0,
52 }
53 }
54
55 pub fn with_recency_half_life(mut self, hours: f32) -> Self {
57 self.recency_half_life_hours = hours;
58 self
59 }
60
61 fn recency_score(&self, created_at: DateTime<Utc>) -> f32 {
63 let age_hours = (Utc::now() - created_at).num_hours() as f32;
64 (-age_hours / self.recency_half_life_hours).exp()
65 }
66
67 fn frequency_score(&self, access_count: u32) -> f32 {
69 if access_count == 0 {
70 0.0
71 } else {
72 (1.0 + access_count as f32).ln() / 10.0 }
74 }
75
76 pub fn score(&self, entry: &MemoryEntry, query_embedding: &[f32]) -> f32 {
78 let similarity = if let Some(ref embedding) = entry.embedding {
80 cosine_similarity(query_embedding, embedding)
81 } else {
82 0.0
83 };
84
85 let recency = self.recency_score(entry.created_at);
87
88 let frequency = self.frequency_score(entry.access_count);
90
91 self.weights.similarity * similarity
93 + self.weights.importance * entry.importance
94 + self.weights.recency * recency
95 + self.weights.frequency * frequency
96 }
97
98 pub fn rank(&self, query_embedding: &[f32], memories: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
100 let mut scored: Vec<(f32, MemoryEntry)> = memories
101 .into_iter()
102 .map(|m| (self.score(&m, query_embedding), m))
103 .collect();
104
105 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
107
108 scored.into_iter().map(|(_, m)| m).collect()
109 }
110
111 pub fn rank_with_scores(
113 &self,
114 query_embedding: &[f32],
115 memories: Vec<MemoryEntry>,
116 ) -> Vec<(f32, MemoryEntry)> {
117 let mut scored: Vec<(f32, MemoryEntry)> = memories
118 .into_iter()
119 .map(|m| (self.score(&m, query_embedding), m))
120 .collect();
121
122 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
123 scored
124 }
125}
126
127impl Default for MemoryRanker {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct DecayConfig {
136 pub daily_rate: f32,
138 pub min_threshold: f32,
140 pub grace_period_days: u32,
142}
143
144impl Default for DecayConfig {
145 fn default() -> Self {
146 Self {
147 daily_rate: 0.01,
148 min_threshold: 0.1,
149 grace_period_days: 7,
150 }
151 }
152}
153
154impl DecayConfig {
155 pub fn new() -> Self {
157 Self::default()
158 }
159
160 pub fn with_rate(mut self, rate: f32) -> Self {
162 self.daily_rate = rate.clamp(0.0, 1.0);
163 self
164 }
165
166 pub fn with_min_threshold(mut self, threshold: f32) -> Self {
168 self.min_threshold = threshold.clamp(0.0, 1.0);
169 self
170 }
171
172 pub fn with_grace_period(mut self, days: u32) -> Self {
174 self.grace_period_days = days;
175 self
176 }
177
178 pub fn apply(&self, entry: &mut MemoryEntry) -> bool {
180 let age_days = (Utc::now() - entry.created_at).num_days() as u32;
181
182 if age_days < self.grace_period_days {
184 return false;
185 }
186
187 let days_since_grace = age_days - self.grace_period_days;
189 let decay_factor = (1.0 - self.daily_rate).powi(days_since_grace as i32);
190 entry.importance *= decay_factor;
191
192 entry.importance < self.min_threshold
194 }
195}
196
197fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
199 if a.len() != b.len() || a.is_empty() {
200 return 0.0;
201 }
202
203 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
204 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
205 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
206
207 if norm_a == 0.0 || norm_b == 0.0 {
208 0.0
209 } else {
210 (dot / (norm_a * norm_b)).clamp(0.0, 1.0)
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn test_ranking_weights() {
220 let weights = RankingWeights::default();
221 assert!(
222 (weights.similarity + weights.importance + weights.recency + weights.frequency - 1.0)
223 .abs()
224 < 0.01
225 );
226 }
227
228 #[test]
229 fn test_recency_score() {
230 let ranker = MemoryRanker::new();
231
232 let recent = Utc::now() - chrono::Duration::hours(1);
234 assert!(ranker.recency_score(recent) > 0.9);
235
236 let old = Utc::now() - chrono::Duration::hours(24 * 30);
238 assert!(ranker.recency_score(old) < 0.5);
239 }
240
241 #[test]
242 fn test_frequency_score() {
243 let ranker = MemoryRanker::new();
244
245 assert_eq!(ranker.frequency_score(0), 0.0);
246 assert!(ranker.frequency_score(10) > ranker.frequency_score(1));
247 }
248
249 #[test]
250 fn test_rank_memories() {
251 let ranker = MemoryRanker::new();
252
253 let mut entry1 = MemoryEntry::new("First");
254 entry1.embedding = Some(vec![1.0, 0.0, 0.0]);
255 entry1.importance = 0.9;
256
257 let mut entry2 = MemoryEntry::new("Second");
258 entry2.embedding = Some(vec![0.0, 1.0, 0.0]);
259 entry2.importance = 0.5;
260
261 let ranked = ranker.rank(&[0.9, 0.1, 0.0], vec![entry1.clone(), entry2.clone()]);
262 assert_eq!(ranked[0].content, "First");
263 }
264
265 #[test]
266 fn test_decay_config() {
267 let config = DecayConfig::new()
268 .with_rate(0.1)
269 .with_min_threshold(0.2)
270 .with_grace_period(3);
271
272 assert_eq!(config.daily_rate, 0.1);
273 assert_eq!(config.min_threshold, 0.2);
274 assert_eq!(config.grace_period_days, 3);
275 }
276
277 #[test]
278 fn test_decay_apply() {
279 let config = DecayConfig::default();
280 let mut entry = MemoryEntry::new("Test");
281 entry.importance = 0.5;
282 entry.created_at = Utc::now() - chrono::Duration::days(10);
283
284 let _below_threshold = config.apply(&mut entry);
285 assert!(entry.importance < 0.5);
286 }
287}