lean_ctx/core/
gamma_cover.rs1use std::collections::{HashMap, HashSet};
4
5fn entropy_bits(tokens: &[String]) -> f64 {
6 if tokens.is_empty() {
7 return 0.0;
8 }
9 let mut freq: HashMap<&str, usize> = HashMap::new();
10 for t in tokens {
11 *freq.entry(t.as_str()).or_default() += 1;
12 }
13 let n = tokens.len() as f64;
14 freq.values().fold(0.0_f64, |acc, &c| {
15 let p = c as f64 / n;
16 acc - p * p.log2()
17 })
18}
19
20fn tf_idf_vectors(chunks: &[(String, Vec<String>)]) -> Vec<HashMap<usize, f64>> {
21 let n_docs = chunks.len();
22 let mut df: HashMap<&str, usize> = HashMap::new();
23
24 for (_, tokens) in chunks {
25 let mut seen = HashSet::new();
26 for tok in tokens {
27 let prev = seen.insert(tok.as_str());
28 if prev {
29 *df.entry(tok.as_str()).or_default() += 1;
30 }
31 }
32 }
33
34 let mut vocab: HashMap<&str, usize> = HashMap::new();
35 let mut next_id = 0usize;
36 for (_, tokens) in chunks {
37 for tok in tokens {
38 vocab.entry(tok.as_str()).or_insert_with(|| {
39 let id = next_id;
40 next_id += 1;
41 id
42 });
43 }
44 }
45
46 chunks
47 .iter()
48 .map(|(_, tokens)| {
49 if tokens.is_empty() {
50 return HashMap::new();
51 }
52 let mut tf: HashMap<&str, usize> = HashMap::new();
53 let mut max_tf = 1usize;
54 for tok in tokens {
55 let e = tf.entry(tok.as_str()).or_default();
56 *e += 1;
57 max_tf = max_tf.max(*e);
58 }
59 let mut v = HashMap::new();
60 for (term, &c) in &tf {
61 let Some(&tid) = vocab.get(term) else {
62 continue;
63 };
64 let tf_norm = c as f64 / max_tf as f64;
65 let dfi = df.get(term).copied().unwrap_or(1).max(1);
66 let idf = ((n_docs as f64 + 1.0) / (dfi as f64 + 1.0)).ln() + 1.0;
67 v.insert(tid, tf_norm * idf);
68 }
69 v
70 })
71 .collect()
72}
73
74fn cosine_sparse(a: &HashMap<usize, f64>, b: &HashMap<usize, f64>) -> f64 {
75 if a.is_empty() || b.is_empty() {
76 return 0.0;
77 }
78 let mut dot = 0.0_f64;
79 let (small, large) = if a.len() <= b.len() { (a, b) } else { (b, a) };
80 for (&k, &va) in small {
81 if let Some(&vb) = large.get(&k) {
82 dot += va * vb;
83 }
84 }
85 let na: f64 = a.values().map(|x| x * x).sum::<f64>().sqrt();
86 let nb: f64 = b.values().map(|x| x * x).sum::<f64>().sqrt();
87 if na <= f64::EPSILON || nb <= f64::EPSILON {
88 return 0.0;
89 }
90 (dot / (na * nb)).clamp(0.0, 1.0)
91}
92
93fn covers(vecs: &[HashMap<usize, f64>], entropy: &[f64], i: usize, j: usize, gamma: f64) -> bool {
94 if i == j {
95 return true;
96 }
97 let h = entropy[j];
98 if h <= f64::EPSILON {
99 return true;
100 }
101 let sim = cosine_sparse(&vecs[i], &vecs[j]);
102 let residual = (1.0 - sim) * h;
103 residual <= gamma + 1e-12
104}
105
106pub fn compute_cover(chunks: &[(String, Vec<String>)], gamma: f64) -> Vec<usize> {
109 let n = chunks.len();
110 if n == 0 {
111 return Vec::new();
112 }
113
114 let vecs = tf_idf_vectors(chunks);
115 let entropy: Vec<f64> = chunks.iter().map(|(_, t)| entropy_bits(t)).collect();
116
117 let mut picked = Vec::new();
118 let mut covered = vec![false; n];
119
120 loop {
121 if covered.iter().all(|&c| c) {
122 break;
123 }
124
125 let mut best_i = usize::MAX;
126 let mut best_gain = 0usize;
127
128 for i in 0..n {
129 let gain = (0..n)
130 .filter(|&j| !covered[j] && covers(&vecs, &entropy, i, j, gamma))
131 .count();
132 if gain > best_gain {
133 best_gain = gain;
134 best_i = i;
135 } else if gain == best_gain && gain > 0 && i < best_i {
136 best_i = i;
137 }
138 }
139
140 if best_gain == 0 {
141 if let Some(j) = (0..n).find(|&j| !covered[j]) {
143 picked.push(j);
144 covered[j] = true;
145 } else {
146 break;
147 }
148 continue;
149 }
150
151 picked.push(best_i);
152 for (j, cov) in covered.iter_mut().enumerate().take(n) {
153 if covers(&vecs, &entropy, best_i, j, gamma) {
154 *cov = true;
155 }
156 }
157 }
158
159 picked
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 fn chunk(content: &str, tokens: &[&str]) -> (String, Vec<String>) {
167 (content.into(), tokens.iter().map(|s| (*s).into()).collect())
168 }
169
170 #[test]
171 fn empty_input() {
172 assert!(compute_cover(&[], 1.0).is_empty());
173 }
174
175 #[test]
176 fn duplicate_chunks_one_covers_other() {
177 let c = vec![
178 chunk(
179 "alpha beta gamma delta",
180 &["alpha", "beta", "gamma", "delta"],
181 ),
182 chunk(
183 "alpha beta gamma delta",
184 &["alpha", "beta", "gamma", "delta"],
185 ),
186 ];
187 let cov = compute_cover(&c, 0.01);
188 assert_eq!(cov.len(), 1);
189 }
190
191 #[test]
192 fn hub_chunk_covers_spokes() {
193 let hub_toks = vec!["fn", "parse", "emit", "error", "ok", "ctx"];
194 let mut chunks = vec![chunk("hub impl", &hub_toks)];
195 for i in 0..4 {
196 let mut toks = hub_toks.clone();
197 toks.push(["extra_a", "extra_b", "noise_z"][i % 3]);
198 chunks.push(chunk("spoke", &toks));
199 }
200 let cov = compute_cover(&chunks, 2.5);
201 assert!(cov.len() <= 3);
202 assert!(cov.contains(&0));
203 }
204
205 #[test]
206 fn orthogonal_chunks_need_multiple_picks() {
207 let c = vec![
208 chunk("u", &["u1", "u2", "u3"]),
209 chunk("v", &["v1", "v2", "v3"]),
210 chunk("w", &["w1", "w2", "w3"]),
211 ];
212 let cov = compute_cover(&c, 0.01);
213 assert!(cov.len() >= 2);
214 }
215}