use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum CompactDirection {
Head,
Tail,
#[default]
Smart,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompactResult {
pub success: bool,
pub messages_removed: usize,
pub tokens_before: u64,
pub tokens_after: u64,
pub direction: CompactDirection,
pub error: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct CompactOptions {
pub max_tokens: Option<u64>,
pub direction: CompactDirection,
pub create_boundary: bool,
pub system_prompt: Option<String>,
}
pub async fn compact_messages(
_messages: &[impl AsRef<dyn std::any::Any>],
options: CompactOptions,
) -> Result<CompactResult, String> {
Ok(CompactResult {
success: true,
messages_removed: 0,
tokens_before: 0,
tokens_after: 0,
direction: options.direction,
error: None,
})
}
pub fn get_recommended_direction(
message_count: usize,
total_tokens: u64,
max_tokens: u64,
) -> CompactDirection {
if total_tokens <= max_tokens {
return CompactDirection::Smart;
}
if message_count > 10 {
CompactDirection::Head
} else {
CompactDirection::Smart
}
}
pub fn calculate_messages_to_remove(
current_tokens: u64,
target_tokens: u64,
avg_tokens_per_message: u64,
) -> usize {
if current_tokens <= target_tokens {
return 0;
}
let tokens_to_remove = current_tokens - target_tokens;
(tokens_to_remove / avg_tokens_per_message) as usize
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compact_direction_default() {
let options = CompactOptions::default();
assert_eq!(options.direction, CompactDirection::Smart);
}
#[test]
fn test_get_recommended_direction_no_compact() {
let dir = get_recommended_direction(5, 1000, 2000);
assert_eq!(dir, CompactDirection::Smart);
}
#[test]
fn test_calculate_messages_to_remove() {
let count = calculate_messages_to_remove(5000, 2000, 500);
assert_eq!(count, 6);
}
#[test]
fn test_calculate_messages_to_remove_no_need() {
let count = calculate_messages_to_remove(1000, 2000, 500);
assert_eq!(count, 0);
}
}