use crate::types::api_types::{Message, MessageRole};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum CompactDirection {
Head,
Tail,
#[default]
Smart,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompactResult {
pub success: bool,
pub messages_removed: usize,
pub tokens_before: u64,
pub tokens_after: u64,
pub direction: CompactDirection,
pub summary: String,
pub messages_to_keep: Vec<Message>,
pub error: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct CompactOptions {
pub max_tokens: Option<u64>,
pub direction: CompactDirection,
pub create_boundary: bool,
pub system_prompt: Option<String>,
}
#[derive(Debug, Clone)]
struct MessageGroup {
start_index: usize,
messages: Vec<Message>,
token_count: u64,
is_boundary: bool,
}
fn group_messages(messages: &[Message]) -> Vec<MessageGroup> {
let mut groups = Vec::new();
let mut current_group = MessageGroup {
start_index: 0,
messages: Vec::new(),
token_count: 0,
is_boundary: false,
};
for (i, msg) in messages.iter().enumerate() {
match &msg.role {
MessageRole::User => {
if !current_group.messages.is_empty() {
groups.push(std::mem::replace(
&mut current_group,
MessageGroup {
start_index: i,
messages: Vec::new(),
token_count: 0,
is_boundary: false,
},
));
}
current_group.messages.push(msg.clone());
current_group.token_count += estimate_tokens_for_message(msg);
}
MessageRole::Assistant | MessageRole::Tool | MessageRole::System => {
current_group.messages.push(msg.clone());
current_group.token_count += estimate_tokens_for_message(msg);
}
}
}
if !current_group.messages.is_empty() {
groups.push(current_group);
}
if let Some(last) = groups.last_mut() {
last.is_boundary = true;
}
groups
}
fn estimate_tokens_for_message(msg: &Message) -> u64 {
let content_tokens = (msg.content.len() as u64 + 3) / 4;
let tool_call_tokens = msg
.tool_calls
.as_ref()
.map(|calls| {
calls
.iter()
.map(|tc| {
let name_tokens = (tc.name.len() as u64 + 3) / 4;
let args_tokens = (tc.arguments.to_string().len() as u64 + 3) / 4;
name_tokens + args_tokens + 2 })
.sum::<u64>()
})
.unwrap_or(0);
let role_overhead: u64 = 4;
content_tokens + tool_call_tokens + role_overhead
}
pub async fn compact_messages(
messages: &[Message],
options: CompactOptions,
) -> Result<CompactResult, String> {
if messages.is_empty() {
return Ok(CompactResult {
success: true,
messages_removed: 0,
tokens_before: 0,
tokens_after: 0,
direction: options.direction,
summary: String::new(),
messages_to_keep: Vec::new(),
error: None,
});
}
let tokens_before: u64 = messages.iter().map(estimate_tokens_for_message).sum();
let target_tokens = options.max_tokens.unwrap_or(tokens_before);
if tokens_before <= target_tokens {
return Ok(CompactResult {
success: true,
messages_removed: 0,
tokens_before,
tokens_after: tokens_before,
direction: options.direction,
summary: String::new(),
messages_to_keep: messages.to_vec(),
error: None,
});
}
let groups = group_messages(messages);
let direction = if options.direction == CompactDirection::Smart {
get_recommended_direction(messages.len(), tokens_before, target_tokens)
} else {
options.direction
};
let (kept_groups, compacted_groups) =
select_groups_to_compact(&groups, target_tokens, direction);
let messages_to_keep: Vec<Message> = kept_groups
.iter()
.flat_map(|g| g.messages.clone())
.collect();
let messages_removed: usize = compacted_groups.iter().map(|g| g.messages.len()).sum();
let summary = create_compact_summary(&compacted_groups);
let tokens_after: u64 = messages_to_keep
.iter()
.map(estimate_tokens_for_message)
.sum();
log::info!(
"[compact] Compacted {} messages: {} -> {} tokens (direction: {:?})",
messages_removed,
tokens_before,
tokens_after,
direction
);
Ok(CompactResult {
success: true,
messages_removed,
tokens_before,
tokens_after,
direction,
summary,
messages_to_keep,
error: None,
})
}
fn select_groups_to_compact(
groups: &[MessageGroup],
target_tokens: u64,
direction: CompactDirection,
) -> (Vec<&MessageGroup>, Vec<&MessageGroup>) {
let (boundary, non_boundary): (Vec<_>, Vec<_>) = groups.iter().partition(|g| g.is_boundary);
let boundary_tokens: u64 = boundary.iter().map(|g| g.token_count).sum();
let mut remaining_budget = target_tokens.saturating_sub(boundary_tokens);
let mut kept = boundary;
let mut compacted = Vec::new();
match direction {
CompactDirection::Head => {
let mut non_boundary_iter = non_boundary.into_iter().peekable();
while let Some(group) = non_boundary_iter.next() {
if remaining_budget >= group.token_count {
kept.push(group);
remaining_budget -= group.token_count;
} else {
compacted.push(group);
compacted.extend(non_boundary_iter);
break;
}
}
}
CompactDirection::Tail => {
let mut non_boundary_iter = non_boundary.into_iter().rev().peekable();
while let Some(group) = non_boundary_iter.next() {
if remaining_budget >= group.token_count {
kept.push(group);
remaining_budget -= group.token_count;
} else {
compacted.push(group);
compacted.extend(non_boundary_iter);
break;
}
}
}
CompactDirection::Smart => {
let mut non_boundary_iter = non_boundary.into_iter().peekable();
while let Some(group) = non_boundary_iter.next() {
if remaining_budget >= group.token_count {
kept.push(group);
remaining_budget -= group.token_count;
} else {
compacted.push(group);
compacted.extend(non_boundary_iter);
break;
}
}
}
}
kept.sort_by_key(|g| g.start_index);
(kept, compacted)
}
fn create_compact_summary(compacted_groups: &[&MessageGroup]) -> String {
if compacted_groups.is_empty() {
return String::new();
}
let mut summary = String::new();
let total_compacted: usize = compacted_groups.iter().map(|g| g.messages.len()).sum();
let total_tokens: u64 = compacted_groups.iter().map(|g| g.token_count).sum();
summary.push_str(&format!(
"Compacted {} messages (~{} tokens) from the conversation history.\n\n",
total_compacted, total_tokens
));
let mut user_messages = 0;
let mut assistant_messages = 0;
let mut tool_messages = 0;
for group in compacted_groups {
for msg in &group.messages {
match &msg.role {
MessageRole::User => user_messages += 1,
MessageRole::Assistant => assistant_messages += 1,
MessageRole::Tool => tool_messages += 1,
MessageRole::System => {}
}
}
}
if user_messages > 0 || assistant_messages > 0 {
summary.push_str(&format!(
"The compacted section contained {} user messages and {} assistant responses",
user_messages, assistant_messages
));
if tool_messages > 0 {
summary.push_str(&format!(" with {} tool results", tool_messages));
}
summary.push_str(".\n");
}
summary
}
pub fn get_recommended_direction(
message_count: usize,
total_tokens: u64,
max_tokens: u64,
) -> CompactDirection {
if total_tokens <= max_tokens {
return CompactDirection::Smart;
}
if message_count > 10 {
CompactDirection::Head
} else {
CompactDirection::Smart
}
}
pub fn calculate_messages_to_remove(
current_tokens: u64,
target_tokens: u64,
avg_tokens_per_message: u64,
) -> usize {
if current_tokens <= target_tokens {
return 0;
}
let tokens_to_remove = current_tokens - target_tokens;
(tokens_to_remove / avg_tokens_per_message) as usize
}
pub fn rough_token_estimation(text: &str) -> u64 {
(text.len() as u64 + 3) / 4
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compact_direction_default() {
let options = CompactOptions::default();
assert_eq!(options.direction, CompactDirection::Smart);
}
#[test]
fn test_get_recommended_direction_no_compact() {
let dir = get_recommended_direction(5, 1000, 2000);
assert_eq!(dir, CompactDirection::Smart);
}
#[test]
fn test_calculate_messages_to_remove() {
let count = calculate_messages_to_remove(5000, 2000, 500);
assert_eq!(count, 6);
}
#[test]
fn test_calculate_messages_to_remove_no_need() {
let count = calculate_messages_to_remove(1000, 2000, 500);
assert_eq!(count, 0);
}
#[test]
fn test_rough_token_estimation() {
let text = "Hello, this is a test message with some content.";
let tokens = rough_token_estimation(text);
assert!(tokens > 0);
assert!(tokens <= (text.len() as u64 + 3) / 4 + 1);
}
#[test]
fn test_estimate_tokens_for_message() {
let msg = Message {
role: MessageRole::User,
content: "Hello, how are you?".to_string(),
..Default::default()
};
let tokens = estimate_tokens_for_message(&msg);
assert!(tokens > 0);
}
#[test]
fn test_group_messages_basic() {
let messages = vec![
Message {
role: MessageRole::User,
content: "Question 1".to_string(),
..Default::default()
},
Message {
role: MessageRole::Assistant,
content: "Answer 1".to_string(),
..Default::default()
},
Message {
role: MessageRole::User,
content: "Question 2".to_string(),
..Default::default()
},
Message {
role: MessageRole::Assistant,
content: "Answer 2".to_string(),
..Default::default()
},
];
let groups = group_messages(&messages);
assert_eq!(groups.len(), 2);
assert!(!groups[0].is_boundary);
assert!(groups[1].is_boundary);
}
#[tokio::test]
async fn test_compact_messages_empty() {
let result = compact_messages(&[], CompactOptions::default())
.await
.unwrap();
assert!(result.success);
assert_eq!(result.messages_removed, 0);
}
#[tokio::test]
async fn test_compact_messages_within_budget() {
let messages = vec![Message {
role: MessageRole::User,
content: "Short message".to_string(),
..Default::default()
}];
let options = CompactOptions {
max_tokens: Some(1000000),
..Default::default()
};
let result = compact_messages(&messages, options).await.unwrap();
assert!(result.success);
assert_eq!(result.messages_removed, 0);
}
#[tokio::test]
async fn test_create_compact_summary() {
let msg1 = Message {
role: MessageRole::User,
content: "Hello".to_string(),
..Default::default()
};
let msg2 = Message {
role: MessageRole::Assistant,
content: "Hi there".to_string(),
..Default::default()
};
let g1 = MessageGroup {
start_index: 0,
messages: vec![msg1],
token_count: 10,
is_boundary: false,
};
let g2 = MessageGroup {
start_index: 1,
messages: vec![msg2],
token_count: 10,
is_boundary: false,
};
let groups = vec![&g1, &g2];
let summary = create_compact_summary(&groups);
assert!(summary.contains("2 messages"));
}
}