use rucora_core::provider::LlmProvider;
use rucora_core::provider::types::{ChatMessage, Role};
use serde::{Deserialize, Serialize};
use tracing::{debug, info};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum CompressionStrategy {
Aggressive,
#[default]
Balanced,
Conservative,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressionConfig {
pub strategy: CompressionStrategy,
pub protect_head_count: usize,
pub protect_tail_tokens: usize,
pub compression_threshold: f64,
pub target_usage_ratio: f64,
pub max_iterations: usize,
pub summary_cooldown_seconds: u64,
}
impl Default for CompressionConfig {
fn default() -> Self {
Self {
strategy: CompressionStrategy::Balanced,
protect_head_count: 3,
protect_tail_tokens: 20_000,
compression_threshold: 0.85,
target_usage_ratio: 0.60,
max_iterations: 3,
summary_cooldown_seconds: 600,
}
}
}
impl CompressionConfig {
pub fn aggressive() -> Self {
Self {
strategy: CompressionStrategy::Aggressive,
protect_head_count: 2,
protect_tail_tokens: 15_000,
compression_threshold: 0.80,
target_usage_ratio: 0.50,
..Default::default()
}
}
pub fn conservative() -> Self {
Self {
strategy: CompressionStrategy::Conservative,
protect_head_count: 5,
protect_tail_tokens: 25_000,
compression_threshold: 0.90,
target_usage_ratio: 0.70,
..Default::default()
}
}
}
const STRUCTURED_SUMMARY_TEMPLATE: &str = r#"请对以下对话进行结构化摘要,以便后续继续工作而不丢失关键上下文。
## Goal — 用户试图完成什么
[描述用户的主要目标和任务]
## Constraints & Preferences — 用户偏好、编码风格
[记录用户的特殊要求、偏好、编码风格等]
## Progress — Done / In Progress / Blocked
- **Done**: [已完成的工作]
- **In Progress**: [正在进行的工作]
- **Blocked**: [阻塞的问题]
## Key Decisions — 重要技术决策
[记录重要的技术决策及其原因]
## Resolved Questions — 已回答的问题
[已解决的问题,防止重新回答]
## Pending User Asks — 未回答的问题
[用户提出但尚未回答的问题]
## Relevant Files — 读取/修改/创建的文件
[相关文件列表]
## Remaining Work — 剩余工作
[还需要完成的工作]
## Critical Context — 不能丢失的具体值
[重要的代码片段、配置值、URL 等]
## Tools & Patterns — 使用过的工具及有效用法
[使用过的工具和有效的工作模式]
---
请基于以上模板对对话进行摘要,保持简洁但完整。"#;
pub struct LayeredCompressor {
config: CompressionConfig,
last_summary_timestamp: Option<u64>,
last_summary_content: Option<String>,
}
impl LayeredCompressor {
pub fn new(config: CompressionConfig) -> Self {
Self {
config,
last_summary_timestamp: None,
last_summary_content: None,
}
}
pub fn default_engine() -> Self {
Self::new(CompressionConfig::default())
}
pub fn should_compress(&self, current_tokens: usize, context_window: usize) -> bool {
if context_window == 0 {
return false;
}
let usage_ratio = current_tokens as f64 / context_window as f64;
if usage_ratio < self.config.compression_threshold {
return false;
}
if let Some(last_ts) = self.last_summary_timestamp {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| d.as_secs());
if now - last_ts < self.config.summary_cooldown_seconds {
debug!(
elapsed = now - last_ts,
cooldown = self.config.summary_cooldown_seconds,
"压缩冷却期,跳过压缩"
);
return false;
}
}
true
}
pub async fn compress(
&mut self,
provider: &dyn LlmProvider,
messages: Vec<ChatMessage>,
context_window: usize,
) -> Result<Vec<ChatMessage>, Box<dyn std::error::Error + Send + Sync>> {
info!(
original_count = messages.len(),
context_window = context_window,
"开始分层压缩"
);
let original_count = messages.len();
let messages = self.trim_old_tool_results(messages);
let (head, middle, tail) = self.split_messages(messages);
debug!(
head_count = head.len(),
middle_count = middle.len(),
tail_count = tail.len(),
"消息分层完成"
);
if middle.is_empty() {
info!("无中间消息,跳过压缩");
return Ok([head, middle, tail].concat());
}
let summary = self.generate_structured_summary(provider, &middle).await?;
self.last_summary_timestamp = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| d.as_secs()),
);
self.last_summary_content = Some(summary.clone());
let summary_message = ChatMessage::system(format!(
"<conversation-summary>\n{summary}\n</conversation-summary>\n\n\
以上是之前对话的结构化摘要。请基于此摘要和后续对话继续工作。"
));
let compressed = [head, vec![summary_message], tail].concat();
info!(
compressed_count = compressed.len(),
compression_ratio = format!(
"{:.1}%",
(1.0 - compressed.len() as f64 / original_count as f64) * 100.0
),
"压缩完成"
);
Ok(compressed)
}
fn trim_old_tool_results(&self, messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
let mut trimmed = Vec::new();
let mut tool_result_count = 0;
let max_tool_results = match self.config.strategy {
CompressionStrategy::Aggressive => 2,
CompressionStrategy::Balanced => 4,
CompressionStrategy::Conservative => 6,
};
let mut messages_reversed = messages;
messages_reversed.reverse();
for msg in messages_reversed {
if msg.role == Role::Tool {
if tool_result_count < max_tool_results {
trimmed.push(msg);
tool_result_count += 1;
}
} else {
trimmed.push(msg);
}
}
trimmed.reverse();
trimmed
}
#[allow(clippy::needless_pass_by_value)]
fn split_messages(
&self,
messages: Vec<ChatMessage>,
) -> (Vec<ChatMessage>, Vec<ChatMessage>, Vec<ChatMessage>) {
let head_count = self.config.protect_head_count.min(messages.len());
let head: Vec<ChatMessage> = messages[..head_count].to_vec();
let mut tail_count = 0;
let mut tail_tokens = 0;
let token_counter = TokenCounter::new();
for msg in messages.iter().rev() {
let role_str = match msg.role {
Role::User => "user",
Role::Assistant => "assistant",
Role::System => "system",
Role::Tool => "tool",
};
let tokens = token_counter.estimate_message(&msg.content, role_str);
if tail_tokens + tokens > self.config.protect_tail_tokens {
break;
}
tail_tokens += tokens;
tail_count += 1;
}
let tail_start = messages.len().saturating_sub(tail_count);
let tail_start = tail_start.max(head_count);
let tail: Vec<ChatMessage> = messages[tail_start..].to_vec();
let middle: Vec<ChatMessage> = messages[head_count..tail_start].to_vec();
(head, middle, tail)
}
async fn generate_structured_summary(
&self,
provider: &dyn LlmProvider,
messages: &[ChatMessage],
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let context_text = messages
.iter()
.map(|m| format!("[{}]: {}", Self::role_name(&m.role), m.content))
.collect::<Vec<_>>()
.join("\n\n");
let prompt = if let Some(previous_summary) = &self.last_summary_content {
format!(
"这是之前的对话摘要:\n{previous_summary}\n\n---\n\n这是新的对话内容:\n{context_text}\n\n\
请更新之前的摘要以反映新的进展,保持结构化格式。"
)
} else {
format!("{STRUCTURED_SUMMARY_TEMPLATE}\n\n---\n\n对话内容:\n{context_text}")
};
let request = rucora_core::provider::types::ChatRequest::from_user_text(prompt);
let response = provider.chat(request).await?;
Ok(response.message.content)
}
fn role_name(role: &Role) -> &'static str {
match role {
Role::User => "用户",
Role::Assistant => "助手",
Role::System => "系统",
Role::Tool => "工具",
}
}
pub fn last_summary(&self) -> Option<&String> {
self.last_summary_content.as_ref()
}
}
struct TokenCounter {
avg_chars_per_token: f64,
}
impl TokenCounter {
fn new() -> Self {
Self {
avg_chars_per_token: 4.0, }
}
fn estimate_message(&self, content: &str, _role: &str) -> usize {
let char_count = content.chars().count() as f64;
let base_tokens = (char_count / self.avg_chars_per_token) as usize;
base_tokens + 4
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_compress() {
let engine = LayeredCompressor::default_engine();
assert!(!engine.should_compress(10_000, 128_000));
assert!(engine.should_compress(110_000, 128_000));
}
#[test]
fn test_trim_tool_results() {
let engine = LayeredCompressor::new(CompressionConfig::aggressive());
let messages = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi"),
ChatMessage::tool("tool1".to_string(), "result1".to_string()),
ChatMessage::assistant("Done"),
ChatMessage::tool("tool2".to_string(), "result2".to_string()),
ChatMessage::assistant("Done2"),
ChatMessage::tool("tool3".to_string(), "result3".to_string()),
ChatMessage::tool("tool4".to_string(), "result4".to_string()),
ChatMessage::tool("tool5".to_string(), "result5".to_string()),
];
let trimmed = engine.trim_old_tool_results(messages);
let tool_count = trimmed.iter().filter(|m| m.role == Role::Tool).count();
assert!(tool_count <= 2);
}
#[test]
fn test_split_messages() {
let engine = LayeredCompressor::default_engine();
let messages: Vec<ChatMessage> = (0..20)
.map(|i| {
if i % 2 == 0 {
ChatMessage::user(format!("User message {i}"))
} else {
ChatMessage::assistant(format!("Assistant message {i}"))
}
})
.collect();
let (head, middle, tail) = engine.split_messages(messages);
assert_eq!(head.len(), 3);
assert!(!tail.is_empty());
assert!(middle.len() < 17);
}
}