1use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct MemoryChunk {
9 pub file: String,
11
12 pub line_start: i32,
14
15 pub line_end: i32,
17
18 pub content: String,
20
21 pub score: f64,
23
24 #[serde(default)]
26 pub updated_at: i64,
27}
28
29impl MemoryChunk {
30 pub fn new(file: String, line_start: i32, line_end: i32, content: String, score: f64) -> Self {
32 Self {
33 file,
34 line_start,
35 line_end,
36 content,
37 score,
38 updated_at: 0,
39 }
40 }
41
42 pub fn with_timestamp(mut self, updated_at: i64) -> Self {
44 self.updated_at = updated_at;
45 self
46 }
47
48 pub fn apply_temporal_decay(&mut self, lambda: f64, now_unix: i64) -> f64 {
52 if lambda <= 0.0 || self.updated_at <= 0 {
53 return self.score;
54 }
55
56 let age_secs = (now_unix - self.updated_at).max(0) as f64;
57 let age_days = age_secs / (24.0 * 60.0 * 60.0);
58 let decay_factor = (-lambda * age_days).exp();
59
60 self.score *= decay_factor;
61 self.score
62 }
63
64 pub fn preview(&self, max_len: usize) -> String {
66 if self.content.len() <= max_len {
67 self.content.clone()
68 } else {
69 format!(
70 "{}...",
71 &self.content[..self.content.floor_char_boundary(max_len)]
72 )
73 }
74 }
75
76 pub fn location(&self) -> String {
78 if self.line_start == self.line_end {
79 format!("{}:{}", self.file, self.line_start)
80 } else {
81 format!("{}:{}-{}", self.file, self.line_start, self.line_end)
82 }
83 }
84}
85
86#[allow(dead_code)]
93pub struct MmrReranker {
94 lambda: f64,
97}
98
99impl Default for MmrReranker {
100 fn default() -> Self {
101 Self { lambda: 0.7 }
102 }
103}
104
105impl MmrReranker {
106 pub fn new(lambda: f64) -> Self {
108 Self {
109 lambda: lambda.clamp(0.0, 1.0),
110 }
111 }
112
113 pub fn rerank(&self, chunks: &mut [MemoryChunk]) {
121 if chunks.len() <= 1 {
122 return;
123 }
124
125 let token_sets: Vec<HashSet<String>> =
127 chunks.iter().map(|c| tokenize(&c.content)).collect();
128
129 let original_scores: Vec<f64> = chunks.iter().map(|c| c.score).collect();
131
132 let mut selected: Vec<usize> = Vec::with_capacity(chunks.len());
134 let mut remaining: Vec<usize> = (0..chunks.len()).collect();
135
136 if let Some((best_pos, _best_idx)) =
138 remaining.iter().enumerate().max_by(|(_, a), (_, b)| {
139 original_scores[**a]
140 .partial_cmp(&original_scores[**b])
141 .unwrap_or(std::cmp::Ordering::Equal)
142 })
143 {
144 selected.push(remaining.remove(best_pos));
145 }
146
147 while !remaining.is_empty() {
149 let best = remaining
150 .iter()
151 .enumerate()
152 .max_by(|(_pos_a, idx_a), (_pos_b, idx_b)| {
153 let mmr_a =
154 self.compute_mmr(**idx_a, original_scores[**idx_a], &selected, &token_sets);
155 let mmr_b =
156 self.compute_mmr(**idx_b, original_scores[**idx_b], &selected, &token_sets);
157 mmr_a
158 .partial_cmp(&mmr_b)
159 .unwrap_or(std::cmp::Ordering::Equal)
160 });
161
162 if let Some((best_pos, best_idx)) = best {
163 let mmr_score = self.compute_mmr(
165 *best_idx,
166 original_scores[*best_idx],
167 &selected,
168 &token_sets,
169 );
170 chunks[*best_idx].score = mmr_score;
171 selected.push(remaining.remove(best_pos));
172 }
173 }
174
175 let mut reordered: Vec<MemoryChunk> =
177 selected.into_iter().map(|i| chunks[i].clone()).collect();
178 chunks.swap_with_slice(&mut reordered);
179 }
180
181 fn compute_mmr(
183 &self,
184 candidate_idx: usize,
185 relevance: f64,
186 selected: &[usize],
187 token_sets: &[HashSet<String>],
188 ) -> f64 {
189 let max_sim = if selected.is_empty() {
190 0.0
191 } else {
192 selected
193 .iter()
194 .map(|&sel_idx| {
195 jaccard_similarity(&token_sets[candidate_idx], &token_sets[sel_idx])
196 })
197 .fold(0.0_f64, f64::max)
198 };
199
200 self.lambda * relevance - (1.0 - self.lambda) * max_sim
201 }
202}
203
204#[allow(dead_code)]
206fn tokenize(text: &str) -> HashSet<String> {
207 text.to_lowercase()
208 .split_whitespace()
209 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
210 .filter(|s| !s.is_empty() && s.len() > 1) .map(|s| s.to_string())
212 .collect()
213}
214
215#[allow(dead_code)]
217fn jaccard_similarity(a: &HashSet<String>, b: &HashSet<String>) -> f64 {
218 if a.is_empty() || b.is_empty() {
219 return 0.0;
220 }
221
222 let intersection = a.intersection(b).count();
223 let union = a.union(b).count();
224
225 if union == 0 {
226 0.0
227 } else {
228 intersection as f64 / union as f64
229 }
230}
231
232#[allow(dead_code)]
236pub fn apply_mmr(chunks: &mut [MemoryChunk]) {
237 MmrReranker::default().rerank(chunks);
238}
239
240#[allow(dead_code)]
242pub fn apply_mmr_with_lambda(chunks: &mut [MemoryChunk], lambda: f64) {
243 MmrReranker::new(lambda).rerank(chunks);
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_memory_chunk_preview() {
252 let chunk = MemoryChunk::new(
253 "test.md".to_string(),
254 1,
255 5,
256 "This is a long content string that should be truncated".to_string(),
257 0.9,
258 );
259
260 assert_eq!(chunk.preview(20), "This is a long conte...");
261 assert_eq!(chunk.location(), "test.md:1-5");
262 }
263
264 #[test]
265 fn test_memory_chunk_single_line_location() {
266 let chunk = MemoryChunk::new(
267 "test.md".to_string(),
268 10,
269 10,
270 "Single line".to_string(),
271 0.5,
272 );
273
274 assert_eq!(chunk.location(), "test.md:10");
275 }
276
277 #[test]
278 fn test_memory_chunk_preview_multibyte() {
279 let chunk = MemoryChunk::new(
281 "test.md".to_string(),
282 1,
283 1,
284 "Hello 🌍🌎🌏 world".to_string(),
285 1.0,
286 );
287
288 let preview = chunk.preview(8);
290 assert!(preview.ends_with("..."));
291 assert_eq!(preview, "Hello ...");
293 }
294
295 #[test]
296 fn test_memory_chunk_preview_emdash() {
297 let chunk = MemoryChunk::new(
299 "test.md".to_string(),
300 1,
301 1,
302 "one—two—three—four—five".to_string(),
303 1.0,
304 );
305
306 let preview = chunk.preview(5);
308 assert!(preview.ends_with("..."));
309 assert_eq!(preview, "one...");
310 }
311
312 #[test]
313 fn test_temporal_decay_no_decay() {
314 let mut chunk = MemoryChunk::new("test.md".to_string(), 1, 1, "content".to_string(), 1.0);
316 chunk.updated_at = 1_700_000_000; let decayed = chunk.apply_temporal_decay(0.0, 1_710_000_000);
319 assert!((decayed - 1.0).abs() < 0.001);
320 }
321
322 #[test]
323 fn test_temporal_decay_seven_days() {
324 let mut chunk = MemoryChunk::new("test.md".to_string(), 1, 1, "content".to_string(), 1.0);
326 let now = 1_710_000_000i64;
327 chunk.updated_at = now - (7 * 24 * 60 * 60); let decayed = chunk.apply_temporal_decay(0.1, now);
330 assert!((decayed - 0.496).abs() < 0.01);
332 }
333
334 #[test]
335 fn test_temporal_decay_fresh() {
336 let mut chunk = MemoryChunk::new("test.md".to_string(), 1, 1, "content".to_string(), 1.0);
338 let now = 1_710_000_000i64;
339 chunk.updated_at = now;
340
341 let decayed = chunk.apply_temporal_decay(0.1, now);
342 assert!((decayed - 1.0).abs() < 0.001);
343 }
344
345 #[test]
346 fn test_jaccard_similarity() {
347 let a: HashSet<String> = ["apple", "banana", "cherry"]
348 .iter()
349 .map(|s| s.to_string())
350 .collect();
351 let b: HashSet<String> = ["banana", "cherry", "date"]
352 .iter()
353 .map(|s| s.to_string())
354 .collect();
355
356 let sim = jaccard_similarity(&a, &b);
359 assert!((sim - 0.5).abs() < 0.001);
360 }
361
362 #[test]
363 fn test_jaccard_similarity_empty() {
364 let a: HashSet<String> = ["apple"].iter().map(|s| s.to_string()).collect();
365 let b: HashSet<String> = HashSet::new();
366
367 assert_eq!(jaccard_similarity(&a, &b), 0.0);
368 assert_eq!(jaccard_similarity(&b, &a), 0.0);
369 }
370
371 #[test]
372 fn test_jaccard_similarity_identical() {
373 let a: HashSet<String> = ["apple", "banana"].iter().map(|s| s.to_string()).collect();
374 let b: HashSet<String> = ["apple", "banana"].iter().map(|s| s.to_string()).collect();
375
376 assert!((jaccard_similarity(&a, &b) - 1.0).abs() < 0.001);
377 }
378
379 #[test]
380 fn test_mmr_single_item() {
381 let mut chunks = vec![MemoryChunk::new(
382 "test.md".to_string(),
383 1,
384 1,
385 "content".to_string(),
386 0.9,
387 )];
388
389 apply_mmr(&mut chunks);
390 assert_eq!(chunks.len(), 1);
391 }
392
393 #[test]
394 fn test_mmr_diverse_results() {
395 let mut chunks = vec![
397 MemoryChunk::new(
398 "a.md".to_string(),
399 1,
400 1,
401 "apple banana cherry".to_string(),
402 0.9,
403 ),
404 MemoryChunk::new(
405 "b.md".to_string(),
406 1,
407 1,
408 "xray yacht zebra".to_string(),
409 0.9,
410 ),
411 ];
412
413 apply_mmr(&mut chunks);
414
415 assert_eq!(chunks.len(), 2);
417 let files: Vec<_> = chunks.iter().map(|c| c.file.clone()).collect();
419 assert!(files.contains(&"a.md".to_string()));
420 assert!(files.contains(&"b.md".to_string()));
421 }
422
423 #[test]
424 fn test_mmr_similar_penalized() {
425 let mut chunks = vec![
427 MemoryChunk::new(
428 "similar1.md".to_string(),
429 1,
430 1,
431 "apple banana".to_string(),
432 1.0,
433 ),
434 MemoryChunk::new(
435 "similar2.md".to_string(),
436 1,
437 1,
438 "apple banana cherry".to_string(),
439 0.95,
440 ),
441 MemoryChunk::new(
442 "diverse.md".to_string(),
443 1,
444 1,
445 "xray yacht zebra".to_string(),
446 0.8,
447 ),
448 ];
449
450 apply_mmr(&mut chunks);
451
452 assert_eq!(chunks[0].file, "similar1.md");
454
455 let diverse_pos = chunks.iter().position(|c| c.file == "diverse.md").unwrap();
457 let similar2_pos = chunks.iter().position(|c| c.file == "similar2.md").unwrap();
458
459 assert!(
461 diverse_pos < similar2_pos,
462 "Diverse result should rank higher than similar duplicate"
463 );
464 }
465
466 #[test]
467 fn test_mmr_lambda_extremes() {
468 let mut chunks = vec![
469 MemoryChunk::new("high.md".to_string(), 1, 1, "unique alpha".to_string(), 1.0),
470 MemoryChunk::new("low.md".to_string(), 1, 1, "unique alpha".to_string(), 0.5),
471 ];
472
473 apply_mmr_with_lambda(&mut chunks, 1.0);
475 assert_eq!(chunks[0].file, "high.md");
476
477 let mut chunks2 = vec![
479 MemoryChunk::new("high.md".to_string(), 1, 1, "unique alpha".to_string(), 1.0),
480 MemoryChunk::new(
481 "low.md".to_string(),
482 1,
483 1,
484 "different beta".to_string(),
485 0.5,
486 ),
487 ];
488
489 apply_mmr_with_lambda(&mut chunks2, 0.0);
492 assert_eq!(chunks2[0].file, "high.md"); }
494
495 #[test]
496 fn test_tokenize() {
497 let tokens = tokenize("Hello World! This is a test.");
498 assert!(tokens.contains("hello"));
499 assert!(tokens.contains("world"));
500 assert!(tokens.contains("this"));
501 assert!(tokens.contains("test"));
502 assert!(!tokens.contains("a"));
504 }
505}