Skip to main content

lean_ctx/core/
codebook.rs

1use std::collections::HashMap;
2
3/// Cross-file semantic deduplication via TF-IDF codebook.
4///
5/// Identifies patterns that appear frequently across files (high TF, low IDF)
6/// and creates short references for them. This avoids sending the same
7/// boilerplate to the LLM multiple times across different file reads.
8
9#[derive(Debug, Clone)]
10pub struct CodebookEntry {
11    pub id: String,
12    pub pattern: String,
13    pub frequency: usize,
14    pub idf: f64,
15}
16
17#[derive(Debug, Default)]
18pub struct Codebook {
19    entries: Vec<CodebookEntry>,
20    pattern_to_id: HashMap<String, String>,
21    next_id: usize,
22}
23
24impl Codebook {
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    /// Build codebook from multiple file contents.
30    /// Identifies lines that appear in 3+ files and creates short references.
31    pub fn build_from_files(&mut self, files: &[(String, String)]) {
32        let total_docs = files.len() as f64;
33        if total_docs < 2.0 {
34            return;
35        }
36
37        // Count document frequency for each normalized line
38        let mut doc_freq: HashMap<String, usize> = HashMap::new();
39        let mut term_freq: HashMap<String, usize> = HashMap::new();
40
41        for (_, content) in files {
42            let mut seen_in_doc: std::collections::HashSet<String> =
43                std::collections::HashSet::new();
44            for line in content.lines() {
45                let normalized = normalize_line(line);
46                if normalized.len() < 10 {
47                    continue;
48                }
49
50                *term_freq.entry(normalized.clone()).or_insert(0) += 1;
51
52                if seen_in_doc.insert(normalized.clone()) {
53                    *doc_freq.entry(normalized).or_insert(0) += 1;
54                }
55            }
56        }
57
58        // Select patterns with high DF (appear in many files) — these are boilerplate
59        let mut candidates: Vec<(String, usize, f64)> = doc_freq
60            .into_iter()
61            .filter(|(_, df)| *df >= 3) // appears in 3+ files
62            .map(|(pattern, df)| {
63                let idf = (total_docs / df as f64).ln();
64                let tf = *term_freq.get(&pattern).unwrap_or(&0);
65                (pattern, tf, idf)
66            })
67            .collect();
68
69        // Sort by frequency descending (most common boilerplate first)
70        candidates.sort_by_key(|x| std::cmp::Reverse(x.1));
71
72        // Take top 50 patterns to keep codebook compact
73        for (pattern, freq, idf) in candidates.into_iter().take(50) {
74            let id = format!("§{}", self.next_id);
75            self.next_id += 1;
76            self.pattern_to_id.insert(pattern.clone(), id.clone());
77            self.entries.push(CodebookEntry {
78                id,
79                pattern,
80                frequency: freq,
81                idf,
82            });
83        }
84    }
85
86    /// Apply codebook to content: replace known patterns with short references.
87    /// Returns (compressed content, references used).
88    pub fn compress(&self, content: &str) -> (String, Vec<String>) {
89        if self.entries.is_empty() {
90            return (content.to_string(), vec![]);
91        }
92
93        let mut result = Vec::new();
94        let mut refs_used = Vec::new();
95
96        for line in content.lines() {
97            let normalized = normalize_line(line);
98            if let Some(id) = self.pattern_to_id.get(&normalized) {
99                if !refs_used.contains(id) {
100                    refs_used.push(id.clone());
101                }
102                result.push(format!("[{id}]"));
103            } else {
104                result.push(line.to_string());
105            }
106        }
107
108        (result.join("\n"), refs_used)
109    }
110
111    /// Format the codebook legend for lines that were referenced.
112    pub fn format_legend(&self, refs_used: &[String]) -> String {
113        if refs_used.is_empty() {
114            return String::new();
115        }
116
117        let mut lines = vec!["§CODEBOOK:".to_string()];
118        for entry in &self.entries {
119            if refs_used.contains(&entry.id) {
120                let short = if entry.pattern.len() > 60 {
121                    format!("{}...", &entry.pattern[..57])
122                } else {
123                    entry.pattern.clone()
124                };
125                lines.push(format!("  {}={}", entry.id, short));
126            }
127        }
128        lines.join("\n")
129    }
130
131    pub fn len(&self) -> usize {
132        self.entries.len()
133    }
134
135    pub fn is_empty(&self) -> bool {
136        self.entries.is_empty()
137    }
138}
139
140/// Cosine similarity between two documents using TF-IDF vectors.
141/// IDF is computed over the two-document corpus to down-weight common terms
142/// like `fn`, `let`, `return` and up-weight domain-specific identifiers.
143pub fn tfidf_cosine_similarity(doc_a: &str, doc_b: &str) -> f64 {
144    tfidf_cosine_similarity_with_corpus(&[doc_a, doc_b], doc_a, doc_b)
145}
146
147/// TF-IDF cosine similarity with IDF computed over a larger corpus.
148pub fn tfidf_cosine_similarity_with_corpus(corpus: &[&str], doc_a: &str, doc_b: &str) -> f64 {
149    let idf = compute_idf(corpus);
150    let tfidf_a = tfidf_vector(doc_a, &idf);
151    let tfidf_b = tfidf_vector(doc_b, &idf);
152
153    let all_terms: std::collections::HashSet<&str> =
154        tfidf_a.keys().chain(tfidf_b.keys()).copied().collect();
155    if all_terms.is_empty() {
156        return 0.0;
157    }
158
159    let mut dot = 0.0;
160    let mut mag_a = 0.0;
161    let mut mag_b = 0.0;
162
163    for term in &all_terms {
164        let a = *tfidf_a.get(term).unwrap_or(&0.0);
165        let b = *tfidf_b.get(term).unwrap_or(&0.0);
166        dot += a * b;
167        mag_a += a * a;
168        mag_b += b * b;
169    }
170
171    let magnitude = (mag_a * mag_b).sqrt();
172    if magnitude < f64::EPSILON {
173        return 0.0;
174    }
175
176    dot / magnitude
177}
178
179/// Identify semantically duplicate blocks across files.
180/// IDF is computed over the full file corpus for accurate weighting.
181pub fn find_semantic_duplicates(
182    files: &[(String, String)],
183    threshold: f64,
184) -> Vec<(String, String, f64)> {
185    let corpus: Vec<&str> = files.iter().map(|(_, c)| c.as_str()).collect();
186    let idf = compute_idf(&corpus);
187    let vectors: Vec<HashMap<&str, f64>> =
188        files.iter().map(|(_, c)| tfidf_vector(c, &idf)).collect();
189
190    let mut duplicates = Vec::new();
191
192    for i in 0..files.len() {
193        for j in (i + 1)..files.len() {
194            let sim = cosine_from_vectors(&vectors[i], &vectors[j]);
195            if sim >= threshold {
196                duplicates.push((files[i].0.clone(), files[j].0.clone(), sim));
197            }
198        }
199    }
200
201    duplicates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
202    duplicates
203}
204
205fn compute_idf<'a>(corpus: &[&'a str]) -> HashMap<&'a str, f64> {
206    let n = corpus.len() as f64;
207    if n == 0.0 {
208        return HashMap::new();
209    }
210
211    let mut doc_freq: HashMap<&str, usize> = HashMap::new();
212    for doc in corpus {
213        let mut seen: std::collections::HashSet<&str> = std::collections::HashSet::new();
214        for word in doc.split_whitespace() {
215            if seen.insert(word) {
216                *doc_freq.entry(word).or_insert(0) += 1;
217            }
218        }
219    }
220
221    doc_freq
222        .into_iter()
223        .map(|(term, df)| (term, (n / (1.0 + df as f64)).ln() + 1.0))
224        .collect()
225}
226
227fn tfidf_vector<'a>(doc: &'a str, idf: &HashMap<&str, f64>) -> HashMap<&'a str, f64> {
228    let words: Vec<&str> = doc.split_whitespace().collect();
229    let total = words.len() as f64;
230    if total == 0.0 {
231        return HashMap::new();
232    }
233
234    let mut tf: HashMap<&str, f64> = HashMap::new();
235    for word in &words {
236        *tf.entry(word).or_insert(0.0) += 1.0;
237    }
238    for val in tf.values_mut() {
239        *val /= total;
240    }
241
242    tf.into_iter()
243        .map(|(term, tf_val)| {
244            let idf_val = idf.get(term).copied().unwrap_or(1.0);
245            (term, tf_val * idf_val)
246        })
247        .collect()
248}
249
250fn cosine_from_vectors(a: &HashMap<&str, f64>, b: &HashMap<&str, f64>) -> f64 {
251    let all_terms: std::collections::HashSet<&&str> = a.keys().chain(b.keys()).collect();
252    if all_terms.is_empty() {
253        return 0.0;
254    }
255
256    let mut dot = 0.0;
257    let mut mag_a = 0.0;
258    let mut mag_b = 0.0;
259
260    for term in &all_terms {
261        let va = a.get(*term).copied().unwrap_or(0.0);
262        let vb = b.get(*term).copied().unwrap_or(0.0);
263        dot += va * vb;
264        mag_a += va * va;
265        mag_b += vb * vb;
266    }
267
268    let magnitude = (mag_a * mag_b).sqrt();
269    if magnitude < f64::EPSILON {
270        return 0.0;
271    }
272
273    dot / magnitude
274}
275
276fn normalize_line(line: &str) -> String {
277    line.split_whitespace().collect::<Vec<&str>>().join(" ")
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn codebook_identifies_common_patterns() {
286        let files = vec![
287            (
288                "a.rs".to_string(),
289                "use std::io;\nuse std::collections::HashMap;\nfn main() {}\n".to_string(),
290            ),
291            (
292                "b.rs".to_string(),
293                "use std::io;\nuse std::collections::HashMap;\nfn helper() {}\n".to_string(),
294            ),
295            (
296                "c.rs".to_string(),
297                "use std::io;\nuse std::collections::HashMap;\nfn other() {}\n".to_string(),
298            ),
299            (
300                "d.rs".to_string(),
301                "use std::io;\nfn unique() {}\n".to_string(),
302            ),
303        ];
304
305        let mut cb = Codebook::new();
306        cb.build_from_files(&files);
307        assert!(!cb.is_empty(), "should find common patterns");
308    }
309
310    #[test]
311    fn cosine_identical_is_one() {
312        let sim = tfidf_cosine_similarity("hello world foo", "hello world foo");
313        assert!((sim - 1.0).abs() < 0.01);
314    }
315
316    #[test]
317    fn cosine_disjoint_is_zero() {
318        let sim = tfidf_cosine_similarity("alpha beta gamma", "delta epsilon zeta");
319        assert!(sim < 0.01);
320    }
321
322    #[test]
323    fn cosine_partial_overlap() {
324        let sim = tfidf_cosine_similarity("hello world foo bar", "hello world baz qux");
325        assert!(sim > 0.0 && sim < 1.0);
326    }
327
328    #[test]
329    fn find_duplicates_detects_similar_files() {
330        let files = vec![
331            (
332                "a.rs".to_string(),
333                "fn main() { let x = 1; let y = 2; println!(x + y); }".to_string(),
334            ),
335            (
336                "b.rs".to_string(),
337                "fn main() { let x = 1; let y = 2; println!(x + y); }".to_string(),
338            ),
339            (
340                "c.rs".to_string(),
341                "completely different content here with no overlap at all".to_string(),
342            ),
343        ];
344
345        let dups = find_semantic_duplicates(&files, 0.8);
346        assert_eq!(dups.len(), 1);
347        assert!(dups[0].2 > 0.99);
348    }
349}