use crate::intelligence::context_grouper::MemoryGroup;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum TruncationStrategy {
Simple,
Smart,
PreserveRecent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TruncationConfig {
pub max_tokens: usize,
pub preserve_recent: usize,
pub strategy: TruncationStrategy,
}
impl Default for TruncationConfig {
fn default() -> Self {
Self {
max_tokens: 2000,
preserve_recent: 1000,
strategy: TruncationStrategy::Smart,
}
}
}
pub struct TruncationEngine {
config: TruncationConfig,
}
impl TruncationEngine {
pub fn with_config(config: TruncationConfig) -> Self {
Self { config }
}
pub fn truncate_to_budget(&self, content: &str, budget_tokens: usize) -> String {
let budget_chars = budget_tokens * 4;
if content.len() <= budget_chars {
return content.to_string();
}
match self.config.strategy {
TruncationStrategy::Simple => {
let truncated = content.chars().take(budget_chars).collect::<String>();
format!("{}...", truncated)
}
TruncationStrategy::Smart => {
let mut result = String::new();
let mut char_count = 0;
for sentence in content.split('.') {
let sentence_with_period = format!("{}.", sentence);
if char_count + sentence_with_period.len() <= budget_chars {
result.push_str(&sentence_with_period);
char_count += sentence_with_period.len();
} else {
break;
}
}
if result.is_empty() {
let truncated = content.chars().take(budget_chars - 3).collect::<String>();
format!("{}...", truncated)
} else {
result
}
}
TruncationStrategy::PreserveRecent => {
let raw_start = content.len().saturating_sub(budget_chars);
let start = content
.char_indices()
.map(|(i, _)| i)
.find(|&i| i >= raw_start)
.unwrap_or(content.len());
let recent = &content[start..];
format!("...{}", recent)
}
}
}
pub fn truncate_groups(
&self,
groups: &[MemoryGroup],
budget_tokens: usize,
) -> Vec<MemoryGroup> {
let budget_chars = budget_tokens * 4; let mut result = Vec::new();
let mut used_chars = 0;
for group in groups {
let group_chars = group.topic.len() + group.summary.len() + 10;
if used_chars + group_chars <= budget_chars {
result.push(group.clone());
used_chars += group_chars;
} else {
let remaining = budget_chars.saturating_sub(used_chars);
if remaining > 50 {
let mut truncated_group = group.clone();
let max_summary_chars = remaining.saturating_sub(group.topic.len() + 10);
if group.summary.len() > max_summary_chars {
truncated_group.summary = format!(
"{}...",
group
.summary
.chars()
.take(max_summary_chars - 3)
.collect::<String>()
);
}
result.push(truncated_group);
}
break; }
}
result
}
pub fn estimate_tokens(&self, text: &str) -> usize {
text.len() / 4 }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_truncation() {
let engine = TruncationEngine::with_config(TruncationConfig {
max_tokens: 10,
preserve_recent: 5,
strategy: TruncationStrategy::Simple,
});
let content = "This is a long piece of content that needs to be truncated";
let result = engine.truncate_to_budget(content, 10);
assert!(result.len() <= 43); assert!(result.ends_with("..."));
}
#[test]
fn test_smart_truncation() {
let engine = TruncationEngine::with_config(TruncationConfig {
max_tokens: 20,
preserve_recent: 10,
strategy: TruncationStrategy::Smart,
});
let content = "First sentence. Second sentence. Third sentence. Fourth sentence.";
let result = engine.truncate_to_budget(content, 20);
assert!(result.contains("First sentence"));
}
#[test]
fn preserve_recent_keeps_tail() {
let engine = TruncationEngine::with_config(TruncationConfig {
max_tokens: 5,
preserve_recent: 5,
strategy: TruncationStrategy::PreserveRecent,
});
let content = "ABCDEFGHIJ".repeat(5); let result = engine.truncate_to_budget(&content, 5);
assert!(result.starts_with("..."));
assert!(result.ends_with("ABCDEFGHIJ"));
assert!(result.len() <= content.len() + 3);
}
#[test]
fn preserve_recent_handles_multibyte_boundary() {
let engine = TruncationEngine::with_config(TruncationConfig {
max_tokens: 2,
preserve_recent: 2,
strategy: TruncationStrategy::PreserveRecent,
});
let content = "café résumé naïve crème brûlée";
let result = engine.truncate_to_budget(content, 2);
assert!(result.starts_with("..."));
assert!(content.ends_with(&result[3..]));
}
#[test]
fn preserve_recent_returns_full_content_when_under_budget() {
let engine = TruncationEngine::with_config(TruncationConfig {
max_tokens: 100,
preserve_recent: 100,
strategy: TruncationStrategy::PreserveRecent,
});
let content = "short";
let result = engine.truncate_to_budget(content, 100);
assert_eq!(result, "short");
}
#[test]
fn test_estimate_tokens() {
let engine = TruncationEngine::with_config(Default::default());
let text = "Hello world";
let tokens = engine.estimate_tokens(text);
assert_eq!(tokens, 2); }
}