use super::{Lens, LensContext, LensId, LensOutput};
use forge::budget::estimator::TokenEstimator;
use std::collections::HashMap;
pub struct HarmonicLens;
impl Lens for HarmonicLens {
fn id(&self) -> LensId {
LensId::Harmonic
}
fn apply(&self, input: &str, ctx: &LensContext) -> LensOutput {
let tokens_before = TokenEstimator::count_nonblocking(input);
let chunks: Vec<String> = split_chunks(input);
if chunks.len() <= 1 {
return LensOutput {
content: input.to_string(),
tokens_before,
tokens_after: tokens_before,
applied: vec!["harmonic:passthrough".into()],
};
}
let scored = if let Some(hint) = &ctx.task_hint {
bm25_score(&chunks, hint)
} else {
entropy_score(&chunks)
};
let budget = ctx.budget.limit;
let mut kept: Vec<(usize, f64)> = scored;
kept.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let keep_count = ((chunks.len() as f64 * 0.6).ceil() as usize).max(2);
let mut keep_indices: std::collections::HashSet<usize> =
kept.iter().take(keep_count).map(|(i, _)| *i).collect();
keep_indices.insert(0);
keep_indices.insert(chunks.len() - 1);
let result_chunks: Vec<&str> = (0..chunks.len())
.filter(|i| keep_indices.contains(i))
.map(|i| chunks[i].as_str())
.collect();
let content = result_chunks.join("\n\n");
let _ = budget;
let tokens_after = TokenEstimator::count_nonblocking(&content);
if tokens_after < tokens_before {
LensOutput {
content,
tokens_before,
tokens_after,
applied: vec!["harmonic".into()],
}
} else {
LensOutput {
content: input.to_string(),
tokens_before,
tokens_after: tokens_before,
applied: vec!["harmonic:no-gain".into()],
}
}
}
}
fn split_chunks(text: &str) -> Vec<String> {
let mut chunks = Vec::new();
let mut current = String::new();
for line in text.lines() {
if line.trim().is_empty() {
if !current.is_empty() {
chunks.push(std::mem::take(&mut current));
}
} else {
if !current.is_empty() {
current.push('\n');
}
current.push_str(line);
}
}
if !current.is_empty() {
chunks.push(current);
}
if chunks.len() >= 4 {
return chunks;
}
let all_lines: Vec<&str> = text.lines().filter(|l| !l.trim().is_empty()).collect();
if all_lines.len() < 8 {
return vec![text.to_string()];
}
all_lines.chunks(8).map(|g| g.join("\n")).collect()
}
fn bm25_score(chunks: &[String], query: &str) -> Vec<(usize, f64)> {
let q_terms = tokenise(query);
let n = chunks.len() as f64;
let k1 = 1.5_f64;
let b = 0.75_f64;
let avg_len = {
let total: usize = chunks.iter().map(|c| tokenise(c).len()).sum();
if chunks.is_empty() {
1.0
} else {
total as f64 / n
}
};
let df: HashMap<String, usize> = q_terms
.iter()
.map(|t| {
let count = chunks.iter().filter(|c| tokenise(c).contains(t)).count();
(t.clone(), count)
})
.collect();
chunks
.iter()
.enumerate()
.map(|(i, chunk)| {
let terms = tokenise(chunk);
let dl = terms.len() as f64;
let tf_map: HashMap<&str, usize> = terms.iter().fold(HashMap::new(), |mut m, t| {
*m.entry(t.as_str()).or_insert(0) += 1;
m
});
let score: f64 = q_terms
.iter()
.map(|t| {
let tf = *tf_map.get(t.as_str()).unwrap_or(&0) as f64;
let d = *df.get(t).unwrap_or(&0) as f64;
if d == 0.0 {
return 0.0;
}
let idf = ((n - d + 0.5) / (d + 0.5) + 1.0).ln();
let num = tf * (k1 + 1.0);
let denom = tf + k1 * (1.0 - b + b * dl / avg_len);
idf * (num / denom)
})
.sum();
(i, score)
})
.collect()
}
fn entropy_score(chunks: &[String]) -> Vec<(usize, f64)> {
let mut global_freq: HashMap<String, usize> = HashMap::new();
for chunk in chunks {
for t in tokenise(chunk) {
*global_freq.entry(t).or_insert(0) += 1;
}
}
let total_tokens: usize = global_freq.values().sum();
if total_tokens == 0 {
return chunks.iter().enumerate().map(|(i, _)| (i, 0.0)).collect();
}
chunks
.iter()
.enumerate()
.map(|(i, chunk)| {
let terms = tokenise(chunk);
if terms.is_empty() {
return (i, 0.0);
}
let score: f64 = terms
.iter()
.map(|t| {
let freq = *global_freq.get(t).unwrap_or(&1);
let p = freq as f64 / total_tokens as f64;
-p.log2()
})
.sum::<f64>()
/ terms.len() as f64;
(i, score)
})
.collect()
}
fn tokenise(text: &str) -> Vec<String> {
text.split(|c: char| !c.is_alphanumeric() && c != '_')
.filter(|t| !t.is_empty() && t.len() > 1)
.map(|t| t.to_lowercase())
.collect()
}