Skip to main content

lean_ctx/core/
codebook.rs

1use std::collections::HashMap;
2
3/// Cross-file semantic deduplication via TF-IDF codebook.
4///
5/// Identifies patterns that appear frequently across files (high TF, low IDF)
6/// and creates short references for them. This avoids sending the same
7/// boilerplate to the LLM multiple times across different file reads.
8
9#[derive(Debug, Clone)]
10pub struct CodebookEntry {
11    pub id: String,
12    pub pattern: String,
13    pub frequency: usize,
14    pub idf: f64,
15}
16
17#[derive(Debug, Default)]
18pub struct Codebook {
19    entries: Vec<CodebookEntry>,
20    pattern_to_id: HashMap<String, String>,
21    next_id: usize,
22}
23
24impl Codebook {
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    /// Build codebook from multiple file contents (borrows, no cloning).
30    /// Identifies lines that appear in 3+ files and creates short references.
31    /// Skips codebook phase entirely if total line count exceeds 50,000
32    /// to prevent memory spikes on large projects.
33    pub fn build_from_files(&mut self, files: &[(&str, &str)]) {
34        let total_docs = files.len() as f64;
35        if total_docs < 2.0 {
36            return;
37        }
38
39        let total_lines: usize = files.iter().map(|(_, c)| c.lines().count()).sum();
40        if total_lines > 50_000 {
41            return;
42        }
43
44        let mut doc_freq: HashMap<String, usize> = HashMap::new();
45        let mut term_freq: HashMap<String, usize> = HashMap::new();
46
47        for (_, content) in files {
48            let mut seen_in_doc: std::collections::HashSet<String> =
49                std::collections::HashSet::new();
50            for line in content.lines() {
51                let normalized = normalize_line(line);
52                if normalized.len() < 10 {
53                    continue;
54                }
55
56                *term_freq.entry(normalized.clone()).or_insert(0) += 1;
57
58                if seen_in_doc.insert(normalized.clone()) {
59                    *doc_freq.entry(normalized).or_insert(0) += 1;
60                }
61            }
62        }
63
64        // Select patterns with high DF (appear in many files) — these are boilerplate
65        let mut candidates: Vec<(String, usize, f64)> = doc_freq
66            .into_iter()
67            .filter(|(_, df)| *df >= 3) // appears in 3+ files
68            .map(|(pattern, df)| {
69                let idf = (total_docs / df as f64).ln();
70                let tf = *term_freq.get(&pattern).unwrap_or(&0);
71                (pattern, tf, idf)
72            })
73            .collect();
74
75        // Sort by frequency descending (most common boilerplate first)
76        candidates.sort_by_key(|x| std::cmp::Reverse(x.1));
77
78        // Take top 50 patterns to keep codebook compact
79        for (pattern, freq, idf) in candidates.into_iter().take(50) {
80            let id = format!("§{}", self.next_id);
81            self.next_id += 1;
82            self.pattern_to_id.insert(pattern.clone(), id.clone());
83            self.entries.push(CodebookEntry {
84                id,
85                pattern,
86                frequency: freq,
87                idf,
88            });
89        }
90    }
91
92    /// Apply codebook to content: replace known patterns with short references.
93    /// Returns (compressed content, references used).
94    pub fn compress(&self, content: &str) -> (String, Vec<String>) {
95        if self.entries.is_empty() {
96            return (content.to_string(), vec![]);
97        }
98
99        let mut result = Vec::new();
100        let mut refs_used = Vec::new();
101
102        for line in content.lines() {
103            let normalized = normalize_line(line);
104            if let Some(id) = self.pattern_to_id.get(&normalized) {
105                if !refs_used.contains(id) {
106                    refs_used.push(id.clone());
107                }
108                result.push(format!("[{id}]"));
109            } else {
110                result.push(line.to_string());
111            }
112        }
113
114        (result.join("\n"), refs_used)
115    }
116
117    /// Format the codebook legend for lines that were referenced.
118    pub fn format_legend(&self, refs_used: &[String]) -> String {
119        if refs_used.is_empty() {
120            return String::new();
121        }
122
123        let mut lines = vec!["§CODEBOOK:".to_string()];
124        for entry in &self.entries {
125            if refs_used.contains(&entry.id) {
126                let short = if entry.pattern.len() > 60 {
127                    format!(
128                        "{}...",
129                        &entry.pattern[..entry.pattern.floor_char_boundary(57)]
130                    )
131                } else {
132                    entry.pattern.clone()
133                };
134                lines.push(format!("  {}={}", entry.id, short));
135            }
136        }
137        lines.join("\n")
138    }
139
140    pub fn len(&self) -> usize {
141        self.entries.len()
142    }
143
144    pub fn is_empty(&self) -> bool {
145        self.entries.is_empty()
146    }
147}
148
149/// Cosine similarity between two documents using TF-IDF vectors.
150/// IDF is computed over the two-document corpus to down-weight common terms
151/// like `fn`, `let`, `return` and up-weight domain-specific identifiers.
152pub fn tfidf_cosine_similarity(doc_a: &str, doc_b: &str) -> f64 {
153    tfidf_cosine_similarity_with_corpus(&[doc_a, doc_b], doc_a, doc_b)
154}
155
156/// TF-IDF cosine similarity with IDF computed over a larger corpus.
157pub fn tfidf_cosine_similarity_with_corpus(corpus: &[&str], doc_a: &str, doc_b: &str) -> f64 {
158    let idf = compute_idf(corpus);
159    let tfidf_a = tfidf_vector(doc_a, &idf);
160    let tfidf_b = tfidf_vector(doc_b, &idf);
161
162    let all_terms: std::collections::HashSet<&str> =
163        tfidf_a.keys().chain(tfidf_b.keys()).copied().collect();
164    if all_terms.is_empty() {
165        return 0.0;
166    }
167
168    let mut dot = 0.0;
169    let mut mag_a = 0.0;
170    let mut mag_b = 0.0;
171
172    for term in &all_terms {
173        let a = *tfidf_a.get(term).unwrap_or(&0.0);
174        let b = *tfidf_b.get(term).unwrap_or(&0.0);
175        dot += a * b;
176        mag_a += a * a;
177        mag_b += b * b;
178    }
179
180    let magnitude = (mag_a * mag_b).sqrt();
181    if magnitude < f64::EPSILON {
182        return 0.0;
183    }
184
185    dot / magnitude
186}
187
188/// Identify semantically duplicate blocks across files.
189/// IDF is computed over the full file corpus for accurate weighting.
190pub fn find_semantic_duplicates(
191    files: &[(String, String)],
192    threshold: f64,
193) -> Vec<(String, String, f64)> {
194    let corpus: Vec<&str> = files.iter().map(|(_, c)| c.as_str()).collect();
195    let idf = compute_idf(&corpus);
196    let vectors: Vec<HashMap<&str, f64>> =
197        files.iter().map(|(_, c)| tfidf_vector(c, &idf)).collect();
198
199    let mut duplicates = Vec::new();
200
201    for i in 0..files.len() {
202        for j in (i + 1)..files.len() {
203            let sim = cosine_from_vectors(&vectors[i], &vectors[j]);
204            if sim >= threshold {
205                duplicates.push((files[i].0.clone(), files[j].0.clone(), sim));
206            }
207        }
208    }
209
210    duplicates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
211    duplicates
212}
213
214fn compute_idf<'a>(corpus: &[&'a str]) -> HashMap<&'a str, f64> {
215    let n = corpus.len() as f64;
216    if n == 0.0 {
217        return HashMap::new();
218    }
219
220    let mut doc_freq: HashMap<&str, usize> = HashMap::new();
221    for doc in corpus {
222        let mut seen: std::collections::HashSet<&str> = std::collections::HashSet::new();
223        for word in doc.split_whitespace() {
224            if seen.insert(word) {
225                *doc_freq.entry(word).or_insert(0) += 1;
226            }
227        }
228    }
229
230    doc_freq
231        .into_iter()
232        .map(|(term, df)| (term, (n / (1.0 + df as f64)).ln() + 1.0))
233        .collect()
234}
235
236fn tfidf_vector<'a>(doc: &'a str, idf: &HashMap<&str, f64>) -> HashMap<&'a str, f64> {
237    let words: Vec<&str> = doc.split_whitespace().collect();
238    let total = words.len() as f64;
239    if total == 0.0 {
240        return HashMap::new();
241    }
242
243    let mut tf: HashMap<&str, f64> = HashMap::new();
244    for word in &words {
245        *tf.entry(word).or_insert(0.0) += 1.0;
246    }
247    for val in tf.values_mut() {
248        *val /= total;
249    }
250
251    tf.into_iter()
252        .map(|(term, tf_val)| {
253            let idf_val = idf.get(term).copied().unwrap_or(1.0);
254            (term, tf_val * idf_val)
255        })
256        .collect()
257}
258
259fn cosine_from_vectors(a: &HashMap<&str, f64>, b: &HashMap<&str, f64>) -> f64 {
260    let all_terms: std::collections::HashSet<&&str> = a.keys().chain(b.keys()).collect();
261    if all_terms.is_empty() {
262        return 0.0;
263    }
264
265    let mut dot = 0.0;
266    let mut mag_a = 0.0;
267    let mut mag_b = 0.0;
268
269    for term in &all_terms {
270        let va = a.get(*term).copied().unwrap_or(0.0);
271        let vb = b.get(*term).copied().unwrap_or(0.0);
272        dot += va * vb;
273        mag_a += va * va;
274        mag_b += vb * vb;
275    }
276
277    let magnitude = (mag_a * mag_b).sqrt();
278    if magnitude < f64::EPSILON {
279        return 0.0;
280    }
281
282    dot / magnitude
283}
284
285fn normalize_line(line: &str) -> String {
286    line.split_whitespace().collect::<Vec<&str>>().join(" ")
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn codebook_identifies_common_patterns() {
295        let files: Vec<(&str, &str)> = vec![
296            (
297                "a.rs",
298                "use std::io;\nuse std::collections::HashMap;\nfn main() {}\n",
299            ),
300            (
301                "b.rs",
302                "use std::io;\nuse std::collections::HashMap;\nfn helper() {}\n",
303            ),
304            (
305                "c.rs",
306                "use std::io;\nuse std::collections::HashMap;\nfn other() {}\n",
307            ),
308            ("d.rs", "use std::io;\nfn unique() {}\n"),
309        ];
310
311        let mut cb = Codebook::new();
312        cb.build_from_files(&files);
313        assert!(!cb.is_empty(), "should find common patterns");
314    }
315
316    #[test]
317    fn cosine_identical_is_one() {
318        let sim = tfidf_cosine_similarity("hello world foo", "hello world foo");
319        assert!((sim - 1.0).abs() < 0.01);
320    }
321
322    #[test]
323    fn cosine_disjoint_is_zero() {
324        let sim = tfidf_cosine_similarity("alpha beta gamma", "delta epsilon zeta");
325        assert!(sim < 0.01);
326    }
327
328    #[test]
329    fn cosine_partial_overlap() {
330        let sim = tfidf_cosine_similarity("hello world foo bar", "hello world baz qux");
331        assert!(sim > 0.0 && sim < 1.0);
332    }
333
334    #[test]
335    fn find_duplicates_detects_similar_files() {
336        let files = vec![
337            (
338                "a.rs".to_string(),
339                "fn main() { let x = 1; let y = 2; println!(x + y); }".to_string(),
340            ),
341            (
342                "b.rs".to_string(),
343                "fn main() { let x = 1; let y = 2; println!(x + y); }".to_string(),
344            ),
345            (
346                "c.rs".to_string(),
347                "completely different content here with no overlap at all".to_string(),
348            ),
349        ];
350
351        let dups = find_semantic_duplicates(&files, 0.8);
352        assert_eq!(dups.len(), 1);
353        assert!(dups[0].2 > 0.99);
354    }
355}