use super::{Lens, LensContext, LensId, LensOutput};
use forge::budget::estimator::TokenEstimator;
use prism::extract;
use prism::parse::{Language, ParsedUnit};
pub struct RefractLens;
impl Lens for RefractLens {
fn id(&self) -> LensId {
LensId::Refract
}
fn apply(&self, input: &str, _ctx: &LensContext) -> LensOutput {
let tokens_before = TokenEstimator::count_nonblocking(input);
let lang = detect_language(input);
if lang == Language::Unknown {
return LensOutput {
content: input.to_string(),
tokens_before,
tokens_after: tokens_before,
applied: vec!["refract:passthrough".into()],
};
}
let fake_path = match lang {
Language::Rust => "__refract__.rs",
Language::TypeScript => "__refract__.ts",
Language::JavaScript => "__refract__.js",
Language::Python => "__refract__.py",
Language::Go => "__refract__.go",
_ => "__refract__.txt",
};
let unit = ParsedUnit::new(fake_path.to_string(), input.to_string());
let map = extract::extract(&unit);
if map.symbols.is_empty() {
return LensOutput {
content: input.to_string(),
tokens_before,
tokens_after: tokens_before,
applied: vec!["refract:no-symbols".into()],
};
}
let mut out = String::new();
for imp in &map.imports {
out.push_str(&format!("// import: {}\n", imp.import_path));
}
if !map.imports.is_empty() {
out.push('\n');
}
for sym in &map.symbols {
let kind_tag = match sym.kind {
prism::extract::SymbolKind::Function => "fn",
prism::extract::SymbolKind::Method => "fn",
prism::extract::SymbolKind::Struct => "struct",
prism::extract::SymbolKind::Class => "class",
prism::extract::SymbolKind::Trait => "trait",
prism::extract::SymbolKind::Interface => "interface",
prism::extract::SymbolKind::Enum => "enum",
prism::extract::SymbolKind::Type => "type",
prism::extract::SymbolKind::Constant => "const",
prism::extract::SymbolKind::Module => "mod",
};
let pub_prefix = if sym.is_exported { "pub " } else { "" };
if let Some(doc) = &sym.docstring {
out.push_str(&format!("// {doc}\n"));
}
out.push_str(&format!(
"{pub_prefix}{kind_tag} {}; // line {}\n",
sym.name, sym.line
));
}
let tokens_after = TokenEstimator::count_nonblocking(&out);
if tokens_after < tokens_before {
LensOutput {
content: out,
tokens_before,
tokens_after,
applied: vec!["refract".into()],
}
} else {
LensOutput {
content: input.to_string(),
tokens_before,
tokens_after: tokens_before,
applied: vec!["refract:no-gain".into()],
}
}
}
}
fn detect_language(src: &str) -> Language {
let mut end = src.len().min(2048);
while end > 0 && !src.is_char_boundary(end) {
end -= 1;
}
let sample = &src[..end];
let rust_score = count_hits(
sample,
&[
"fn ",
"struct ",
"impl ",
"pub ",
"use crate",
"mod ",
"-> Result",
"anyhow",
],
);
let ts_score = count_hits(
sample,
&[
"interface ",
"const ",
": string",
": number",
"export ",
"import {",
"tsx",
": void",
],
);
let py_score = count_hits(
sample,
&[
"def ",
"class ",
"import ",
" pass",
"self.",
"elif ",
"__init__",
"async def",
],
);
let go_score = count_hits(
sample,
&[
"func ",
"package ",
"import (\n",
"var ",
"type ",
":=",
"chan ",
"go func",
],
);
let js_score = count_hits(
sample,
&[
"function ",
"const ",
"var ",
"let ",
"require(",
"module.exports",
"=>",
],
);
let max = [
(rust_score, Language::Rust),
(ts_score, Language::TypeScript),
(py_score, Language::Python),
(go_score, Language::Go),
(js_score, Language::JavaScript),
]
.into_iter()
.max_by_key(|(score, _)| *score);
match max {
Some((score, lang)) if score >= 2 => lang,
_ => Language::Unknown,
}
}
fn count_hits(src: &str, patterns: &[&str]) -> usize {
patterns.iter().filter(|&&p| src.contains(p)).count()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn handles_multibyte_char_straddling_sample_boundary() {
let mut input = "a".repeat(2047);
input.push('─'); input.push_str("\nfn main() {}\n");
let ctx = LensContext::new(2000);
let _ = RefractLens.apply(&input, &ctx);
}
}