Skip to main content

localgpt_core/memory/
search.rs

1//! Memory search types and utilities
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6/// A chunk of memory content returned from search
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct MemoryChunk {
9    /// File path relative to workspace
10    pub file: String,
11
12    /// Starting line number (1-indexed)
13    pub line_start: i32,
14
15    /// Ending line number (1-indexed)
16    pub line_end: i32,
17
18    /// The actual content
19    pub content: String,
20
21    /// Relevance score (higher is better)
22    pub score: f64,
23
24    /// Unix timestamp when the chunk was last updated (for temporal decay)
25    #[serde(default)]
26    pub updated_at: i64,
27}
28
29impl MemoryChunk {
30    /// Create a new memory chunk
31    pub fn new(file: String, line_start: i32, line_end: i32, content: String, score: f64) -> Self {
32        Self {
33            file,
34            line_start,
35            line_end,
36            content,
37            score,
38            updated_at: 0,
39        }
40    }
41
42    /// Create a new memory chunk with timestamp
43    pub fn with_timestamp(mut self, updated_at: i64) -> Self {
44        self.updated_at = updated_at;
45        self
46    }
47
48    /// Apply temporal decay to the score based on age.
49    /// decay_factor = exp(-lambda * age_days)
50    /// Returns the decayed score.
51    pub fn apply_temporal_decay(&mut self, lambda: f64, now_unix: i64) -> f64 {
52        if lambda <= 0.0 || self.updated_at <= 0 {
53            return self.score;
54        }
55
56        let age_secs = (now_unix - self.updated_at).max(0) as f64;
57        let age_days = age_secs / (24.0 * 60.0 * 60.0);
58        let decay_factor = (-lambda * age_days).exp();
59
60        self.score *= decay_factor;
61        self.score
62    }
63
64    /// Get a preview of the content (first N characters)
65    pub fn preview(&self, max_len: usize) -> String {
66        if self.content.len() <= max_len {
67            self.content.clone()
68        } else {
69            format!(
70                "{}...",
71                &self.content[..self.content.floor_char_boundary(max_len)]
72            )
73        }
74    }
75
76    /// Get the location string (file:line)
77    pub fn location(&self) -> String {
78        if self.line_start == self.line_end {
79            format!("{}:{}", self.file, self.line_start)
80        } else {
81            format!("{}:{}-{}", self.file, self.line_start, self.line_end)
82        }
83    }
84}
85
86/// MMR (Maximal Marginal Relevance) re-ranking for search results.
87///
88/// MMR diversifies results by balancing relevance with novelty.
89/// Formula: MMR = λ * relevance - (1-λ) * max_similarity_to_selected
90///
91/// This helps avoid showing multiple very similar chunks in results.
92#[allow(dead_code)]
93pub struct MmrReranker {
94    /// Trade-off between relevance (1.0) and diversity (0.0)
95    /// Default: 0.7 (slightly favor relevance)
96    lambda: f64,
97}
98
99impl Default for MmrReranker {
100    fn default() -> Self {
101        Self { lambda: 0.7 }
102    }
103}
104
105impl MmrReranker {
106    /// Create a new MMR reranker with custom lambda
107    pub fn new(lambda: f64) -> Self {
108        Self {
109            lambda: lambda.clamp(0.0, 1.0),
110        }
111    }
112
113    /// Re-rank search results using MMR algorithm.
114    ///
115    /// # Arguments
116    /// * `chunks` - Search results to re-rank (will be modified in place)
117    ///
118    /// # Returns
119    /// The re-ranked chunks in MMR order
120    pub fn rerank(&self, chunks: &mut [MemoryChunk]) {
121        if chunks.len() <= 1 {
122            return;
123        }
124
125        // Tokenize all chunks once
126        let token_sets: Vec<HashSet<String>> =
127            chunks.iter().map(|c| tokenize(&c.content)).collect();
128
129        // Track original scores
130        let original_scores: Vec<f64> = chunks.iter().map(|c| c.score).collect();
131
132        // Track which indices have been selected
133        let mut selected: Vec<usize> = Vec::with_capacity(chunks.len());
134        let mut remaining: Vec<usize> = (0..chunks.len()).collect();
135
136        // Select first item (highest relevance)
137        if let Some((best_pos, _best_idx)) =
138            remaining.iter().enumerate().max_by(|(_, a), (_, b)| {
139                original_scores[**a]
140                    .partial_cmp(&original_scores[**b])
141                    .unwrap_or(std::cmp::Ordering::Equal)
142            })
143        {
144            selected.push(remaining.remove(best_pos));
145        }
146
147        // Greedily select remaining items using MMR
148        while !remaining.is_empty() {
149            let best = remaining
150                .iter()
151                .enumerate()
152                .max_by(|(_pos_a, idx_a), (_pos_b, idx_b)| {
153                    let mmr_a =
154                        self.compute_mmr(**idx_a, original_scores[**idx_a], &selected, &token_sets);
155                    let mmr_b =
156                        self.compute_mmr(**idx_b, original_scores[**idx_b], &selected, &token_sets);
157                    mmr_a
158                        .partial_cmp(&mmr_b)
159                        .unwrap_or(std::cmp::Ordering::Equal)
160                });
161
162            if let Some((best_pos, best_idx)) = best {
163                // Update the score to the MMR value for transparency
164                let mmr_score = self.compute_mmr(
165                    *best_idx,
166                    original_scores[*best_idx],
167                    &selected,
168                    &token_sets,
169                );
170                chunks[*best_idx].score = mmr_score;
171                selected.push(remaining.remove(best_pos));
172            }
173        }
174
175        // Reorder chunks by selection order
176        let mut reordered: Vec<MemoryChunk> =
177            selected.into_iter().map(|i| chunks[i].clone()).collect();
178        chunks.swap_with_slice(&mut reordered);
179    }
180
181    /// Compute MMR score for a candidate
182    fn compute_mmr(
183        &self,
184        candidate_idx: usize,
185        relevance: f64,
186        selected: &[usize],
187        token_sets: &[HashSet<String>],
188    ) -> f64 {
189        let max_sim = if selected.is_empty() {
190            0.0
191        } else {
192            selected
193                .iter()
194                .map(|&sel_idx| {
195                    jaccard_similarity(&token_sets[candidate_idx], &token_sets[sel_idx])
196                })
197                .fold(0.0_f64, f64::max)
198        };
199
200        self.lambda * relevance - (1.0 - self.lambda) * max_sim
201    }
202}
203
204/// Simple whitespace tokenizer with lowercase normalization
205#[allow(dead_code)]
206fn tokenize(text: &str) -> HashSet<String> {
207    text.to_lowercase()
208        .split_whitespace()
209        .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
210        .filter(|s| !s.is_empty() && s.len() > 1) // Skip single chars
211        .map(|s| s.to_string())
212        .collect()
213}
214
215/// Compute Jaccard similarity between two token sets
216#[allow(dead_code)]
217fn jaccard_similarity(a: &HashSet<String>, b: &HashSet<String>) -> f64 {
218    if a.is_empty() || b.is_empty() {
219        return 0.0;
220    }
221
222    let intersection = a.intersection(b).count();
223    let union = a.union(b).count();
224
225    if union == 0 {
226        0.0
227    } else {
228        intersection as f64 / union as f64
229    }
230}
231
232/// Apply MMR re-ranking to search results.
233///
234/// This is a convenience function that creates a reranker with default lambda (0.7).
235#[allow(dead_code)]
236pub fn apply_mmr(chunks: &mut [MemoryChunk]) {
237    MmrReranker::default().rerank(chunks);
238}
239
240/// Apply MMR re-ranking with custom lambda.
241#[allow(dead_code)]
242pub fn apply_mmr_with_lambda(chunks: &mut [MemoryChunk], lambda: f64) {
243    MmrReranker::new(lambda).rerank(chunks);
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_memory_chunk_preview() {
252        let chunk = MemoryChunk::new(
253            "test.md".to_string(),
254            1,
255            5,
256            "This is a long content string that should be truncated".to_string(),
257            0.9,
258        );
259
260        assert_eq!(chunk.preview(20), "This is a long conte...");
261        assert_eq!(chunk.location(), "test.md:1-5");
262    }
263
264    #[test]
265    fn test_memory_chunk_single_line_location() {
266        let chunk = MemoryChunk::new(
267            "test.md".to_string(),
268            10,
269            10,
270            "Single line".to_string(),
271            0.5,
272        );
273
274        assert_eq!(chunk.location(), "test.md:10");
275    }
276
277    #[test]
278    fn test_memory_chunk_preview_multibyte() {
279        // Emoji are 4 bytes each in UTF-8
280        let chunk = MemoryChunk::new(
281            "test.md".to_string(),
282            1,
283            1,
284            "Hello 🌍🌎🌏 world".to_string(),
285            1.0,
286        );
287
288        // max_len=8 lands inside the first emoji (bytes 6-9), should not panic
289        let preview = chunk.preview(8);
290        assert!(preview.ends_with("..."));
291        // Should truncate to "Hello " (6 bytes) since byte 8 is mid-emoji
292        assert_eq!(preview, "Hello ...");
293    }
294
295    #[test]
296    fn test_memory_chunk_preview_emdash() {
297        // Em-dash (—) is 3 bytes in UTF-8
298        let chunk = MemoryChunk::new(
299            "test.md".to_string(),
300            1,
301            1,
302            "one—two—three—four—five".to_string(),
303            1.0,
304        );
305
306        // "one—" is 3 + 3 = 6 bytes; max_len=5 lands mid-emdash
307        let preview = chunk.preview(5);
308        assert!(preview.ends_with("..."));
309        assert_eq!(preview, "one...");
310    }
311
312    #[test]
313    fn test_temporal_decay_no_decay() {
314        // Lambda = 0 means no decay
315        let mut chunk = MemoryChunk::new("test.md".to_string(), 1, 1, "content".to_string(), 1.0);
316        chunk.updated_at = 1_700_000_000; // Some old timestamp
317
318        let decayed = chunk.apply_temporal_decay(0.0, 1_710_000_000);
319        assert!((decayed - 1.0).abs() < 0.001);
320    }
321
322    #[test]
323    fn test_temporal_decay_seven_days() {
324        // Lambda = 0.1: 7-day old memory should get ~50% penalty
325        let mut chunk = MemoryChunk::new("test.md".to_string(), 1, 1, "content".to_string(), 1.0);
326        let now = 1_710_000_000i64;
327        chunk.updated_at = now - (7 * 24 * 60 * 60); // 7 days ago
328
329        let decayed = chunk.apply_temporal_decay(0.1, now);
330        // exp(-0.1 * 7) ≈ 0.496
331        assert!((decayed - 0.496).abs() < 0.01);
332    }
333
334    #[test]
335    fn test_temporal_decay_fresh() {
336        // Fresh memory (just updated) should have no penalty
337        let mut chunk = MemoryChunk::new("test.md".to_string(), 1, 1, "content".to_string(), 1.0);
338        let now = 1_710_000_000i64;
339        chunk.updated_at = now;
340
341        let decayed = chunk.apply_temporal_decay(0.1, now);
342        assert!((decayed - 1.0).abs() < 0.001);
343    }
344
345    #[test]
346    fn test_jaccard_similarity() {
347        let a: HashSet<String> = ["apple", "banana", "cherry"]
348            .iter()
349            .map(|s| s.to_string())
350            .collect();
351        let b: HashSet<String> = ["banana", "cherry", "date"]
352            .iter()
353            .map(|s| s.to_string())
354            .collect();
355
356        // Intersection: banana, cherry (2)
357        // Union: apple, banana, cherry, date (4)
358        let sim = jaccard_similarity(&a, &b);
359        assert!((sim - 0.5).abs() < 0.001);
360    }
361
362    #[test]
363    fn test_jaccard_similarity_empty() {
364        let a: HashSet<String> = ["apple"].iter().map(|s| s.to_string()).collect();
365        let b: HashSet<String> = HashSet::new();
366
367        assert_eq!(jaccard_similarity(&a, &b), 0.0);
368        assert_eq!(jaccard_similarity(&b, &a), 0.0);
369    }
370
371    #[test]
372    fn test_jaccard_similarity_identical() {
373        let a: HashSet<String> = ["apple", "banana"].iter().map(|s| s.to_string()).collect();
374        let b: HashSet<String> = ["apple", "banana"].iter().map(|s| s.to_string()).collect();
375
376        assert!((jaccard_similarity(&a, &b) - 1.0).abs() < 0.001);
377    }
378
379    #[test]
380    fn test_mmr_single_item() {
381        let mut chunks = vec![MemoryChunk::new(
382            "test.md".to_string(),
383            1,
384            1,
385            "content".to_string(),
386            0.9,
387        )];
388
389        apply_mmr(&mut chunks);
390        assert_eq!(chunks.len(), 1);
391    }
392
393    #[test]
394    fn test_mmr_diverse_results() {
395        // Two very different chunks with same relevance
396        let mut chunks = vec![
397            MemoryChunk::new(
398                "a.md".to_string(),
399                1,
400                1,
401                "apple banana cherry".to_string(),
402                0.9,
403            ),
404            MemoryChunk::new(
405                "b.md".to_string(),
406                1,
407                1,
408                "xray yacht zebra".to_string(),
409                0.9,
410            ),
411        ];
412
413        apply_mmr(&mut chunks);
414
415        // Both should be selected (they're diverse), order based on MMR
416        assert_eq!(chunks.len(), 2);
417        // Files should still be present
418        let files: Vec<_> = chunks.iter().map(|c| c.file.clone()).collect();
419        assert!(files.contains(&"a.md".to_string()));
420        assert!(files.contains(&"b.md".to_string()));
421    }
422
423    #[test]
424    fn test_mmr_similar_penalized() {
425        // High relevance similar vs lower relevance diverse
426        let mut chunks = vec![
427            MemoryChunk::new(
428                "similar1.md".to_string(),
429                1,
430                1,
431                "apple banana".to_string(),
432                1.0,
433            ),
434            MemoryChunk::new(
435                "similar2.md".to_string(),
436                1,
437                1,
438                "apple banana cherry".to_string(),
439                0.95,
440            ),
441            MemoryChunk::new(
442                "diverse.md".to_string(),
443                1,
444                1,
445                "xray yacht zebra".to_string(),
446                0.8,
447            ),
448        ];
449
450        apply_mmr(&mut chunks);
451
452        // First should be similar1 (highest relevance)
453        assert_eq!(chunks[0].file, "similar1.md");
454
455        // Diverse should rank higher than similar2 due to MMR
456        let diverse_pos = chunks.iter().position(|c| c.file == "diverse.md").unwrap();
457        let similar2_pos = chunks.iter().position(|c| c.file == "similar2.md").unwrap();
458
459        // Diverse should come before the similar duplicate
460        assert!(
461            diverse_pos < similar2_pos,
462            "Diverse result should rank higher than similar duplicate"
463        );
464    }
465
466    #[test]
467    fn test_mmr_lambda_extremes() {
468        let mut chunks = vec![
469            MemoryChunk::new("high.md".to_string(), 1, 1, "unique alpha".to_string(), 1.0),
470            MemoryChunk::new("low.md".to_string(), 1, 1, "unique alpha".to_string(), 0.5),
471        ];
472
473        // Lambda = 1.0: pure relevance, should prefer high.md
474        apply_mmr_with_lambda(&mut chunks, 1.0);
475        assert_eq!(chunks[0].file, "high.md");
476
477        // Reset and test lambda = 0.0: pure diversity (but identical content here)
478        let mut chunks2 = vec![
479            MemoryChunk::new("high.md".to_string(), 1, 1, "unique alpha".to_string(), 1.0),
480            MemoryChunk::new(
481                "low.md".to_string(),
482                1,
483                1,
484                "different beta".to_string(),
485                0.5,
486            ),
487        ];
488
489        // With lambda=0, it's purely about diversity
490        // First selection picks highest relevance, second gets penalized by similarity
491        apply_mmr_with_lambda(&mut chunks2, 0.0);
492        assert_eq!(chunks2[0].file, "high.md"); // First always highest relevance
493    }
494
495    #[test]
496    fn test_tokenize() {
497        let tokens = tokenize("Hello World! This is a test.");
498        assert!(tokens.contains("hello"));
499        assert!(tokens.contains("world"));
500        assert!(tokens.contains("this"));
501        assert!(tokens.contains("test"));
502        // Single char 'a' should be filtered
503        assert!(!tokens.contains("a"));
504    }
505}