use crate::context::token_counter::TokenCounter;
use anyhow::Result;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TruncationStrategy {
PreserveSignature,
Head,
HeadAndTail,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TruncationResult {
pub content: String,
pub tokens: usize,
pub truncated: bool,
pub lines_removed: usize,
}
pub struct CodeTruncator {
counter: TokenCounter,
}
impl CodeTruncator {
pub fn new() -> Self {
Self {
counter: TokenCounter::new(),
}
}
pub fn truncate(
&self,
content: &str,
budget: usize,
strategy: TruncationStrategy,
) -> Result<TruncationResult> {
let current_tokens = self.counter.count(content)?;
if current_tokens <= budget {
return Ok(TruncationResult {
content: content.to_string(),
tokens: current_tokens,
truncated: false,
lines_removed: 0,
});
}
match strategy {
TruncationStrategy::PreserveSignature => {
self.truncate_preserve_signature(content, budget)
}
TruncationStrategy::Head => self.truncate_head(content, budget),
TruncationStrategy::HeadAndTail => self.truncate_head_and_tail(content, budget),
}
}
fn truncate_preserve_signature(
&self,
content: &str,
budget: usize,
) -> Result<TruncationResult> {
let lines: Vec<&str> = content.lines().collect();
if lines.is_empty() {
return Ok(TruncationResult {
content: String::new(),
tokens: 0,
truncated: false,
lines_removed: 0,
});
}
let signature_end = self.find_signature_end(&lines);
let docstring_end = self.find_docstring_end(&lines, signature_end);
let mut kept_lines = Vec::new();
for line in lines
.iter()
.take(docstring_end.min(lines.len().saturating_sub(1)) + 1)
{
kept_lines.push(*line);
}
let marker = "\n // ... [truncated] ...\n";
let marker_tokens = self.counter.count(marker)?;
let mut current_content = kept_lines.join("\n");
let mut current_tokens = self.counter.count(¤t_content)?;
let available = budget.saturating_sub(marker_tokens);
let mut body_start = docstring_end + 1;
let mut lines_added = 0;
while body_start < lines.len() && current_tokens < available {
let next_line = lines[body_start];
let test_content = format!("{}\n{}", current_content, next_line);
let test_tokens = self.counter.count(&test_content)?;
if test_tokens <= available {
current_content = test_content;
current_tokens = test_tokens;
body_start += 1;
lines_added += 1;
} else {
break;
}
}
let lines_removed = lines.len().saturating_sub(docstring_end + 1 + lines_added);
if lines_removed > 0 {
current_content.push_str(marker);
current_tokens = self.counter.count(¤t_content)?;
}
Ok(TruncationResult {
content: current_content,
tokens: current_tokens,
truncated: lines_removed > 0,
lines_removed,
})
}
fn truncate_head(&self, content: &str, budget: usize) -> Result<TruncationResult> {
let lines: Vec<&str> = content.lines().collect();
if lines.is_empty() {
return Ok(TruncationResult {
content: String::new(),
tokens: 0,
truncated: false,
lines_removed: 0,
});
}
let marker = "\n// ... [truncated] ...";
let marker_tokens = self.counter.count(marker)?;
let available = budget.saturating_sub(marker_tokens);
let mut kept_lines = Vec::new();
let mut current_content = String::new();
for (i, line) in lines.iter().enumerate() {
let test_content = if i == 0 {
line.to_string()
} else {
format!("{}\n{}", current_content, line)
};
let test_tokens = self.counter.count(&test_content)?;
if test_tokens <= available {
current_content = test_content;
kept_lines.push(*line);
} else {
break;
}
}
let lines_removed = lines.len().saturating_sub(kept_lines.len());
if lines_removed > 0 {
current_content.push_str(marker);
}
let tokens = self.counter.count(¤t_content)?;
Ok(TruncationResult {
content: current_content,
tokens,
truncated: lines_removed > 0,
lines_removed,
})
}
fn truncate_head_and_tail(&self, content: &str, budget: usize) -> Result<TruncationResult> {
let lines: Vec<&str> = content.lines().collect();
if lines.is_empty() {
return Ok(TruncationResult {
content: String::new(),
tokens: 0,
truncated: false,
lines_removed: 0,
});
}
let marker = "\n// ... [truncated] ...\n";
let marker_tokens = self.counter.count(marker)?;
let available = budget.saturating_sub(marker_tokens);
let head_budget = (available * 60) / 100;
let tail_budget = available - head_budget;
let mut head_lines = Vec::new();
let mut head_content = String::new();
for (i, line) in lines.iter().enumerate() {
let test_content = if i == 0 {
line.to_string()
} else {
format!("{}\n{}", head_content, line)
};
let test_tokens = self.counter.count(&test_content)?;
if test_tokens <= head_budget {
head_content = test_content;
head_lines.push(*line);
} else {
break;
}
}
let mut tail_lines = Vec::new();
let mut tail_content = String::new();
for line in lines.iter().rev() {
let test_content = if tail_content.is_empty() {
line.to_string()
} else {
format!("{}\n{}", line, tail_content)
};
let test_tokens = self.counter.count(&test_content)?;
if test_tokens <= tail_budget {
tail_content = test_content;
tail_lines.insert(0, *line);
} else {
break;
}
}
if head_lines.len() + tail_lines.len() >= lines.len() {
return Ok(TruncationResult {
content: content.to_string(),
tokens: self.counter.count(content)?,
truncated: false,
lines_removed: 0,
});
}
let final_content = format!("{}{}{}", head_content, marker, tail_content);
let tokens = self.counter.count(&final_content)?;
let lines_removed = lines.len() - head_lines.len() - tail_lines.len();
Ok(TruncationResult {
content: final_content,
tokens,
truncated: lines_removed > 0,
lines_removed,
})
}
fn find_signature_end(&self, lines: &[&str]) -> usize {
for (i, line) in lines.iter().enumerate() {
let trimmed = line.trim();
if trimmed.contains('{') || trimmed.contains("=>") || trimmed.ends_with(':') {
return i;
}
if !trimmed.is_empty()
&& !trimmed.starts_with("//")
&& !trimmed.starts_with("/*")
&& !trimmed.starts_with('*')
&& !trimmed.starts_with('#')
{
return i;
}
}
0
}
fn find_docstring_end(&self, lines: &[&str], start: usize) -> usize {
let mut end = start;
let mut i = start + 1;
while i < lines.len() {
let trimmed = lines[i].trim();
if trimmed.starts_with("//")
|| trimmed.starts_with("/*")
|| trimmed.starts_with('*')
|| trimmed.starts_with('#')
|| trimmed.starts_with("///")
{
end = i;
i += 1;
} else if trimmed.is_empty() {
if end > start {
end = i;
}
i += 1;
} else {
break;
}
}
end
}
}
impl Default for CodeTruncator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_no_truncation_needed() {
let truncator = CodeTruncator::new();
let code = "fn test() { println!(\"hello\"); }";
let result = truncator
.truncate(code, 1000, TruncationStrategy::PreserveSignature)
.unwrap();
assert!(!result.truncated);
assert_eq!(result.content, code);
assert_eq!(result.lines_removed, 0);
}
#[test]
fn test_truncate_long_function() {
let truncator = CodeTruncator::new();
let mut code = String::from("fn long_function() {\n");
for i in 0..100 {
code.push_str(&format!(" let x{} = {};\n", i, i));
}
code.push_str("}");
let result = truncator
.truncate(&code, 100, TruncationStrategy::PreserveSignature)
.unwrap();
assert!(result.truncated);
assert!(result.tokens <= 100);
assert!(result.content.contains("fn long_function()"));
assert!(result.content.contains("[truncated]"));
assert!(result.lines_removed > 0);
}
#[test]
fn test_preserve_signature_with_docstring() {
let truncator = CodeTruncator::new();
let code = r#"fn documented() {
// This is a docstring
// explaining the function
let x = 1;
let y = 2;
let z = 3;
// ... many more lines
}"#;
let result = truncator
.truncate(code, 50, TruncationStrategy::PreserveSignature)
.unwrap();
assert!(result.tokens <= 50);
assert!(result.content.contains("fn documented()"));
assert!(result.content.contains("docstring"));
}
#[test]
fn test_truncate_head_strategy() {
let truncator = CodeTruncator::new();
let code = "line 1\nline 2\nline 3\nline 4\nline 5";
let result = truncator
.truncate(code, 20, TruncationStrategy::Head)
.unwrap();
assert!(result.tokens <= 20);
if result.truncated {
assert!(result.content.contains("[truncated]"));
}
}
#[test]
fn test_truncate_head_and_tail_strategy() {
let truncator = CodeTruncator::new();
let mut code = String::new();
for i in 0..50 {
code.push_str(&format!("line {}\n", i));
}
let result = truncator
.truncate(&code, 50, TruncationStrategy::HeadAndTail)
.unwrap();
assert!(result.tokens <= 50);
if result.truncated {
assert!(result.content.contains("[truncated]"));
assert!(result.content.contains("line 0"));
}
}
#[test]
fn test_empty_content() {
let truncator = CodeTruncator::new();
let result = truncator
.truncate("", 100, TruncationStrategy::PreserveSignature)
.unwrap();
assert!(!result.truncated);
assert_eq!(result.content, "");
assert_eq!(result.tokens, 0);
}
#[test]
fn test_single_line() {
let truncator = CodeTruncator::new();
let code = "fn test();";
let result = truncator
.truncate(code, 100, TruncationStrategy::PreserveSignature)
.unwrap();
assert!(!result.truncated);
assert_eq!(result.content, code);
}
#[test]
fn test_very_small_budget() {
let truncator = CodeTruncator::new();
let code = "fn test() {\n println!(\"hello\");\n}";
let result = truncator
.truncate(code, 10, TruncationStrategy::PreserveSignature)
.unwrap();
assert!(result.tokens <= 10);
assert!(result.content.contains("fn"));
}
#[test]
fn test_find_signature_end() {
let truncator = CodeTruncator::new();
let lines1 = vec!["fn test() {", " body", "}"];
assert_eq!(truncator.find_signature_end(&lines1), 0);
let lines2 = vec!["fn test()", "{", " body"];
assert_eq!(truncator.find_signature_end(&lines2), 0);
let lines3 = vec!["const func = () => {", " body"];
assert_eq!(truncator.find_signature_end(&lines3), 0);
let lines4 = vec!["def func():", " body"];
assert_eq!(truncator.find_signature_end(&lines4), 0);
}
#[test]
fn test_find_docstring_end() {
let truncator = CodeTruncator::new();
let lines = vec![
"fn test() {",
" // Comment 1",
" // Comment 2",
" ",
" // Comment 3",
" let x = 1;",
];
let end = truncator.find_docstring_end(&lines, 0);
assert!(end >= 2); }
#[test]
fn test_preserve_signature_rust_function() {
let truncator = CodeTruncator::new();
let code = r#"pub fn calculate(x: i32, y: i32) -> i32 {
// Calculate the sum
let sum = x + y;
let doubled = sum * 2;
let tripled = doubled * 3;
let result = tripled * 4;
result
}"#;
let result = truncator
.truncate(code, 50, TruncationStrategy::PreserveSignature)
.unwrap();
assert!(result.tokens <= 50);
assert!(result.content.contains("pub fn calculate"));
}
#[test]
fn test_preserve_signature_typescript_function() {
let truncator = CodeTruncator::new();
let code = r#"function processData(data: string[]): Result {
// Process each item
const results = data.map(item => transform(item));
const filtered = results.filter(r => r.valid);
const sorted = filtered.sort((a, b) => a.score - b.score);
return { data: sorted };
}"#;
let result = truncator
.truncate(code, 60, TruncationStrategy::PreserveSignature)
.unwrap();
assert!(result.tokens <= 60);
assert!(result.content.contains("function processData"));
}
#[test]
fn test_truncation_strategies_comparison() {
let truncator = CodeTruncator::new();
let mut code = String::new();
for i in 0..30 {
code.push_str(&format!("line {}\n", i));
}
let budget = 40;
let preserve = truncator
.truncate(&code, budget, TruncationStrategy::PreserveSignature)
.unwrap();
let head = truncator
.truncate(&code, budget, TruncationStrategy::Head)
.unwrap();
let head_tail = truncator
.truncate(&code, budget, TruncationStrategy::HeadAndTail)
.unwrap();
assert!(preserve.tokens <= budget);
assert!(head.tokens <= budget);
assert!(head_tail.tokens <= budget);
if head_tail.truncated && head.truncated {
assert!(head_tail.content.len() >= head.content.len() / 2);
}
}
}