use rucora_core::provider::LlmProvider;
use rucora_core::provider::types::{ChatMessage, Role};
use serde::{Deserialize, Serialize};
use crate::compact::generate_compact_prompt;
use crate::compact::{CompactConfig, TokenCounter};
use crate::compact::{group_messages_by_api_round, groups_to_text, select_groups_to_compact};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationManager {
system_prompt: Option<String>,
messages: Vec<ChatMessage>,
max_messages: usize,
max_tokens: usize,
auto_compress: bool,
compact_config: CompactConfig,
token_counter: TokenCounter,
token_count: u32,
compact_boundary: Option<usize>,
}
impl Default for ConversationManager {
fn default() -> Self {
Self::new()
}
}
impl ConversationManager {
pub fn new() -> Self {
Self {
system_prompt: None,
messages: Vec::new(),
max_messages: 0,
max_tokens: 0,
auto_compress: false,
compact_config: CompactConfig::default(),
token_counter: TokenCounter::new(),
token_count: 0,
compact_boundary: None,
}
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn with_max_messages(mut self, max: usize) -> Self {
self.max_messages = max;
self
}
pub fn with_max_tokens(mut self, max: usize) -> Self {
self.max_tokens = max;
self
}
pub fn with_auto_compress(mut self, enable: bool) -> Self {
self.auto_compress = enable;
self
}
pub fn with_compact_config(mut self, config: CompactConfig) -> Self {
self.compact_config = config;
self
}
pub fn with_auto_compact(mut self, enabled: bool) -> Self {
self.compact_config.auto_compact_enabled = enabled;
self
}
pub fn with_compact_buffer_tokens(mut self, tokens: u32) -> Self {
self.compact_config.auto_compact_buffer_tokens = tokens;
self
}
pub fn ensure_system_prompt(&mut self, prompt: impl Into<String>) {
if self.system_prompt.is_none() {
self.system_prompt = Some(prompt.into());
}
}
pub fn add_message(&mut self, message: ChatMessage) {
if self.messages.is_empty()
&& let Some(prompt) = &self.system_prompt
{
let system_message = ChatMessage {
role: Role::System,
content: prompt.clone(),
name: None,
};
let system_tokens = self.estimate_message_tokens(&system_message);
self.token_count = self.token_count.saturating_add(system_tokens);
self.messages.push(system_message);
}
let tokens = self.estimate_message_tokens(&message);
self.token_count = self.token_count.saturating_add(tokens);
self.messages.push(message);
self.enforce_limits();
}
fn estimate_message_tokens(&self, message: &ChatMessage) -> u32 {
let role_str = match message.role {
Role::User => "user",
Role::Assistant => "assistant",
Role::System => "system",
Role::Tool => "tool",
};
self.token_counter
.estimate_message(&message.content, role_str)
}
pub fn add_user_message(&mut self, content: impl Into<String>) {
self.add_message(ChatMessage {
role: Role::User,
content: content.into(),
name: None,
});
}
pub fn add_assistant_message(&mut self, content: impl Into<String>) {
self.add_message(ChatMessage {
role: Role::Assistant,
content: content.into(),
name: None,
});
}
pub fn add_tool_result(&mut self, tool_call_id: impl Into<String>, content: impl Into<String>) {
self.add_message(ChatMessage {
role: Role::Tool,
content: content.into(),
name: Some(tool_call_id.into()),
});
}
pub fn get_messages(&self) -> &[ChatMessage] {
&self.messages
}
pub fn get_recent_messages(&self, limit: usize) -> &[ChatMessage] {
if limit >= self.messages.len() {
&self.messages
} else {
&self.messages[self.messages.len() - limit..]
}
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
pub fn clear(&mut self) {
self.messages.clear();
self.token_count = 0;
if let Some(prompt) = &self.system_prompt {
let system_message = ChatMessage {
role: Role::System,
content: prompt.clone(),
name: None,
};
self.token_count = self.estimate_message_tokens(&system_message);
self.messages.push(system_message);
}
}
fn enforce_limits(&mut self) {
if self.max_messages > 0 && self.messages.len() > self.max_messages {
let has_system = self
.messages
.first()
.is_some_and(|m| m.role == Role::System);
let skip = if has_system { 1 } else { 0 };
let _keep_count = self.max_messages - skip;
if self.messages.len() > self.max_messages {
let drain_count = self.messages.len() - self.max_messages;
self.messages.drain(skip..skip + drain_count);
self.recalculate_token_count();
}
}
if self.max_tokens > 0 {
let preserve_system = self
.messages
.first()
.is_some_and(|m| m.role == Role::System);
let min_len = usize::from(preserve_system);
while self.token_count as usize > self.max_tokens && self.messages.len() > min_len {
let remove_idx = usize::from(preserve_system);
let removed = self.messages.remove(remove_idx);
let removed_tokens = self.estimate_message_tokens(&removed);
self.token_count = self.token_count.saturating_sub(removed_tokens);
}
}
}
pub fn compress(&mut self, summary: impl Into<String>) {
let has_system = self
.messages
.first()
.is_some_and(|m| m.role == Role::System);
let summary_message = ChatMessage {
role: Role::System,
content: format!("对话历史摘要:{}", summary.into()),
name: None,
};
let mut new_messages = Vec::new();
if has_system {
new_messages.push(self.messages[0].clone());
}
new_messages.push(summary_message);
if self.messages.len() > 2 {
new_messages.extend_from_slice(&self.messages[self.messages.len() - 2..]);
}
self.messages = new_messages;
}
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string_pretty(&self.messages)
}
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
let messages: Vec<ChatMessage> = serde_json::from_str(json)?;
let mut manager = Self {
messages,
..Default::default()
};
manager.recalculate_token_count();
Ok(manager)
}
pub fn token_count(&self) -> u32 {
self.token_count
}
pub fn should_compact(&self, model: &str) -> bool {
let context_window = get_context_window_for_model(model);
self.compact_config
.should_compact(self.token_count, context_window)
}
pub async fn compact(
&mut self,
provider: &dyn LlmProvider,
_model: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let groups = group_messages_by_api_round(&self.messages);
let groups_to_compact = select_groups_to_compact(&groups, 3);
if groups_to_compact.is_empty() {
return Ok(String::new());
}
let summary: String = self
.generate_compact_summary(provider, &groups_to_compact)
.await?;
let boundary_message = self.create_compact_boundary(&summary);
self.replace_compacted_messages(boundary_message, groups_to_compact.len());
self.recalculate_token_count();
Ok(summary)
}
async fn generate_compact_summary(
&self,
provider: &dyn LlmProvider,
messages: &[Vec<ChatMessage>],
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let prompt = generate_compact_prompt(None);
let context_text = groups_to_text(messages);
let request = rucora_core::provider::types::ChatRequest::from_user_text(format!(
"{prompt}\n\n{context_text}"
));
let response = provider.chat(request).await?;
Ok(response.message.content)
}
fn create_compact_boundary(&self, summary: &str) -> ChatMessage {
ChatMessage::system(format!(
"<conversation_summary>\n{summary}\n</conversation_summary>\n\n\
以上是之前对话的摘要。请基于此摘要继续对话。"
))
}
fn replace_compacted_messages(&mut self, boundary_message: ChatMessage, groups_count: usize) {
let messages_to_remove = groups_count * 2;
if messages_to_remove < self.messages.len() {
self.messages.drain(0..messages_to_remove);
self.messages.insert(0, boundary_message);
self.compact_boundary = Some(0);
}
}
fn recalculate_token_count(&mut self) {
self.token_count = self
.messages
.iter()
.map(|m| self.estimate_message_tokens(m))
.sum();
}
}
fn get_context_window_for_model(model: &str) -> u32 {
match model {
m if m.contains("claude-3-5-sonnet") => 200_000,
m if m.contains("claude-3-opus") => 200_000,
m if m.contains("claude-3-sonnet") => 200_000,
m if m.contains("claude-3-haiku") => 200_000,
m if m.contains("gpt-4o") => 128_000,
m if m.contains("gpt-4-turbo") => 128_000,
m if m.contains("gpt-4") => 8_192,
m if m.contains("gpt-3.5-turbo") => 16_385,
_ => 32_000,
}
}
pub fn estimate_tokens(text: &str) -> usize {
let chars = text.chars().count();
chars / 2 + 1
}
pub fn estimate_messages_tokens(messages: &[ChatMessage]) -> usize {
messages.iter().map(|m| estimate_tokens(&m.content)).sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_conversation_manager_basic() {
let mut manager = ConversationManager::new();
manager.add_user_message("你好");
manager.add_assistant_message("你好!有什么可以帮助你的?");
assert_eq!(manager.len(), 2);
assert!(!manager.is_empty());
}
#[test]
fn test_conversation_manager_system_prompt() {
let mut manager = ConversationManager::new().with_system_prompt("你是助手");
manager.add_user_message("你好");
assert_eq!(manager.len(), 2);
assert_eq!(manager.messages[0].role, Role::System);
}
#[test]
fn test_conversation_manager_max_messages() {
let mut manager = ConversationManager::new()
.with_system_prompt("系统")
.with_max_messages(5);
for i in 0..10 {
manager.add_user_message(format!("消息 {i}"));
}
assert_eq!(manager.len(), 5);
assert_eq!(manager.messages[0].role, Role::System);
}
#[test]
fn test_conversation_manager_clear() {
let mut manager = ConversationManager::new().with_system_prompt("系统");
manager.add_user_message("你好");
manager.clear();
assert_eq!(manager.len(), 1);
assert_eq!(manager.messages[0].content, "系统");
}
#[test]
fn test_conversation_manager_max_tokens() {
let mut manager = ConversationManager::new()
.with_system_prompt("系统提示词")
.with_max_tokens(12);
manager.add_user_message("第一条很长的用户消息");
manager.add_assistant_message("第一条很长的助手回复");
manager.add_user_message("第二条很长的用户消息");
assert_eq!(manager.messages[0].role, Role::System);
assert!(manager.token_count() as usize <= 12 || manager.len() == 1);
assert!(manager.len() <= 2);
}
#[test]
fn test_estimate_tokens() {
assert!(estimate_tokens("Hello World") > 0);
assert!(estimate_tokens("你好世界") > 0);
}
}