pub mod config;
pub mod engine;
pub mod grouping;
pub mod prompt;
pub mod token_counter;
pub use config::{CompactConfig, CompactStrategy};
pub use engine::{CompressionConfig, CompressionStrategy, LayeredCompressor};
pub use grouping::{group_messages_by_api_round, groups_to_text, select_groups_to_compact};
pub use prompt::{
BASE_COMPACT_PROMPT, PARTIAL_COMPACT_PROMPT, generate_compact_prompt,
generate_partial_compact_prompt,
};
pub use token_counter::{ContextWindowManager, TokenCounter};
use rucora_core::provider::LlmProvider;
use rucora_core::provider::types::{ChatMessage, Role};
use std::result::Result;
pub struct ContextManager {
messages: Vec<ChatMessage>,
token_count: u32,
config: CompactConfig,
token_counter: TokenCounter,
compact_boundary: Option<usize>,
}
impl ContextManager {
pub fn new(config: CompactConfig) -> Self {
Self {
messages: Vec::new(),
token_count: 0,
config,
token_counter: TokenCounter::new(),
compact_boundary: None,
}
}
pub fn add_message(&mut self, message: ChatMessage) {
let tokens = self.estimate_message_tokens(&message);
self.token_count = self.token_count.saturating_add(tokens);
self.messages.push(message);
}
pub fn messages(&self) -> &[ChatMessage] {
&self.messages
}
pub fn token_count(&self) -> u32 {
self.token_count
}
fn estimate_message_tokens(&self, message: &ChatMessage) -> u32 {
let role_str = match message.role {
Role::User => "user",
Role::Assistant => "assistant",
Role::System => "system",
Role::Tool => "tool",
};
self.token_counter
.estimate_message(&message.content, role_str)
}
pub fn should_compact(&self, model: &str) -> bool {
let context_window = get_context_window_for_model(model);
self.config.should_compact(self.token_count, context_window)
}
pub async fn compact(
&mut self,
provider: &dyn LlmProvider,
_model: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let groups = self.group_messages_by_api_round();
let groups_to_compact = self.select_groups_to_compact(&groups);
if groups_to_compact.is_empty() {
return Ok(String::new());
}
let summary: String = self
.generate_compact_summary(provider, &groups_to_compact)
.await?;
let boundary_message = self.create_compact_boundary(&summary);
self.replace_compacted_messages(boundary_message, groups_to_compact.len());
self.recalculate_token_count();
Ok(summary)
}
fn group_messages_by_api_round(&self) -> Vec<Vec<ChatMessage>> {
group_messages_by_api_round(&self.messages)
}
fn select_groups_to_compact(&self, groups: &[Vec<ChatMessage>]) -> Vec<Vec<ChatMessage>> {
select_groups_to_compact(groups, 3)
}
async fn generate_compact_summary(
&self,
provider: &dyn LlmProvider,
messages: &[Vec<ChatMessage>],
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let prompt = generate_compact_prompt(None);
let context_text = groups_to_text(messages);
let request = rucora_core::provider::types::ChatRequest::from_user_text(format!(
"{prompt}\n\n{context_text}"
));
let response = provider.chat(request).await?;
Ok(response.message.content)
}
fn create_compact_boundary(&self, summary: &str) -> ChatMessage {
ChatMessage::system(format!(
"<conversation_summary>\n{summary}\n</conversation_summary>\n\n\
以上是之前对话的摘要。请基于此摘要继续对话。"
))
}
fn replace_compacted_messages(&mut self, boundary_message: ChatMessage, groups_count: usize) {
let messages_to_remove = groups_count * 2;
if messages_to_remove < self.messages.len() {
self.messages.drain(0..messages_to_remove);
self.messages.insert(0, boundary_message);
self.compact_boundary = Some(0);
}
}
fn recalculate_token_count(&mut self) {
self.token_count = self
.messages
.iter()
.map(|m| self.estimate_message_tokens(m))
.sum();
}
}
fn get_context_window_for_model(model: &str) -> u32 {
match model {
m if m.contains("claude-3-5-sonnet") => 200_000,
m if m.contains("claude-3-opus") => 200_000,
m if m.contains("claude-3-sonnet") => 200_000,
m if m.contains("claude-3-haiku") => 200_000,
m if m.contains("gpt-4o") => 128_000,
m if m.contains("gpt-4-turbo") => 128_000,
m if m.contains("gpt-4") => 8_192,
m if m.contains("gpt-3.5-turbo") => 16_385,
_ => 32_000,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_manager_creation() {
let config = CompactConfig::default();
let manager = ContextManager::new(config);
assert_eq!(manager.token_count(), 0);
assert_eq!(manager.messages().len(), 0);
}
#[test]
fn test_add_message() {
let mut manager = ContextManager::new(CompactConfig::default());
manager.add_message(ChatMessage::user("你好"));
assert_eq!(manager.messages().len(), 1);
assert!(manager.token_count() > 0);
}
#[test]
fn test_should_compact() {
let config = CompactConfig::default().with_buffer_tokens(1000);
let mut manager = ContextManager::new(config);
for i in 0..500 {
manager.add_message(ChatMessage::user(format!(
"这是第 {i} 条测试消息,包含一些额外的内容来增加 token 数量"
)));
manager.add_message(ChatMessage::assistant(format!(
"这是第 {i} 条回复,同样包含一些额外的内容来增加 token 数量"
)));
}
assert!(manager.should_compact("gpt-4"));
}
}