use std::path::PathBuf;
use anyhow::{Error, Result};
use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use serde_json::{json, Value};
use std::pin::Pin;
use std::future::Future;
use log::info;
use uuid;
use chrono;
use crate::ChatMessage;
use std::sync::Arc;
use crate::memory::message_history::{FileChatMessageHistory, ChatMessageRecord, MessageHistoryMemory};
use crate::memory::utils::estimate_text_tokens;
use crate::{ChatModel, OpenAIChatModel, ModelChatMessage, ChatMessageContent};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SummaryData {
pub session_id: String,
pub sequence_number: u64,
pub summary: Option<String>,
pub token_count: usize,
pub last_updated: String,
}
impl Default for SummaryData {
fn default() -> Self {
Self {
session_id: String::new(),
sequence_number: 0,
summary: None,
token_count: 0,
last_updated: chrono::Utc::now().to_rfc3339(),
}
}
}
#[derive(Debug)]
pub struct SummaryMemory {
session_id: String,
data_dir: PathBuf,
summary_threshold: usize,
summary_prompt_template: String,
recent_messages_count: usize,
message_history: Option<Arc<MessageHistoryMemory>>,
}
impl Clone for SummaryMemory {
fn clone(&self) -> Self {
Self {
session_id: self.session_id.clone(),
data_dir: self.data_dir.clone(),
summary_threshold: self.summary_threshold,
summary_prompt_template: self.summary_prompt_template.clone(),
recent_messages_count: self.recent_messages_count,
message_history: self.message_history.clone(),
}
}
}
impl SummaryMemory {
pub async fn new(session_id: String, data_dir: PathBuf, summary_threshold: usize) -> Result<Self> {
tokio::fs::create_dir_all(&data_dir).await?;
Ok(Self {
session_id,
data_dir,
summary_threshold,
summary_prompt_template: "Please provide a concise summary of the following conversation. Focus on the main topics discussed, key decisions made, and any important outcomes.\n\nConversation:\n{chat_history}\n\nSummary:".to_string(),
recent_messages_count: crate::memory::utils::get_recent_messages_count_from_env(),
message_history: None,
})
}
pub async fn new_with_shared_history(
session_id: String,
data_dir: PathBuf,
summary_threshold: usize,
message_history: Arc<MessageHistoryMemory>
) -> Result<Self> {
tokio::fs::create_dir_all(&data_dir).await?;
Ok(Self {
session_id,
data_dir,
summary_threshold,
summary_prompt_template: "Please provide a concise summary of the following conversation. Focus on the main topics discussed, key decisions made, and any important outcomes.\n\nConversation:\n{chat_history}\n\nSummary:".to_string(),
recent_messages_count: crate::memory::utils::get_recent_messages_count_from_env(),
message_history: Some(message_history),
})
}
pub fn with_summary_prompt_template(mut self, template: String) -> Self {
self.summary_prompt_template = template;
self
}
pub fn with_recent_messages_count(mut self, count: usize) -> Self {
self.recent_messages_count = count;
self
}
fn get_summary_file_path(&self) -> PathBuf {
self.data_dir.join(format!("{}_summary.json", self.session_id))
}
pub async fn load_context(&self) -> Result<Vec<String>> {
let summary_data = self.load_summary().await?;
let messages = if let Some(ref history) = self.message_history {
history.get_recent_messages(self.recent_messages_count).await?
} else {
let file_path = self.data_dir.join(format!("{}_history.jsonl", self.session_id));
let chat_history = FileChatMessageHistory::new(self.session_id.clone(), file_path).await?;
chat_history.get_messages().await?
};
let mut context = Vec::new();
if let Some(summary) = summary_data.summary {
context.push(format!("Previous conversation summary: {}", summary));
}
for msg in messages {
context.push(format!("{}: {}", msg.role, msg.content));
}
Ok(context)
}
pub async fn load_summary(&self) -> Result<SummaryData> {
let file_path = self.get_summary_file_path();
if !tokio::fs::metadata(&file_path).await.is_ok() {
return Ok(SummaryData {
session_id: self.session_id.clone(),
sequence_number: 0,
summary: None,
token_count: 0,
last_updated: chrono::Utc::now().to_rfc3339(),
});
}
let contents = tokio::fs::read_to_string(&file_path).await?;
let summary_data: SummaryData = serde_json::from_str(&contents)?;
Ok(summary_data)
}
async fn save_summary(&self, summary: &str, sequence_number: u64) -> Result<()> {
let file_path = self.get_summary_file_path();
let token_count = estimate_text_tokens(summary);
let summary_data = SummaryData {
session_id: self.session_id.clone(),
sequence_number,
summary: Some(summary.to_string()),
token_count,
last_updated: chrono::Utc::now().to_rfc3339(),
};
let json = serde_json::to_string(&summary_data)?;
tokio::fs::write(&file_path, json).await?;
Ok(())
}
async fn generate_summary(&self, messages: &[ChatMessageRecord]) -> Result<(String, u64)> {
info!("Generating summary for {} messages", messages.len());
let mut chat_text = String::new();
for msg in messages {
let role = if msg.role == "user" { "User" } else { "Assistant" };
chat_text.push_str(&format!("{}: {}\n", role, msg.content));
}
let summary_prompt = self.summary_prompt_template.replace("{chat_history}", &chat_text);
let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "OPENAI_API_KEY".to_string());
let base_url = std::env::var("OPENAI_API_URL").ok();
let model = crate::OpenAIChatModel::new(api_key.clone(), base_url)
.with_model(std::env::var("OPENAI_API_MODEL").unwrap_or_else(|_| "gpt-3.5-turbo".to_string()))
.with_temperature(0.3)
.with_max_tokens(1024);
let model_messages = vec![
crate::ModelChatMessage::System(crate::ChatMessageContent {
content: "You are a helpful assistant that creates concise summaries of conversations.".to_string(),
name: None,
additional_kwargs: std::collections::HashMap::new(),
}),
crate::ModelChatMessage::Human(crate::ChatMessageContent {
content: summary_prompt,
name: None,
additional_kwargs: std::collections::HashMap::new(),
}),
];
let response = model.invoke(model_messages).await?;
let summary = match response.message {
crate::ModelChatMessage::AIMessage(content) => content.content,
_ => return Err(anyhow::anyhow!("Expected AI message response")),
};
let last_sequence_number = messages.last()
.map(|msg| msg.sequence_number)
.unwrap_or(0);
self.save_summary(&summary, last_sequence_number).await?;
Ok((summary, last_sequence_number))
}
pub async fn check_and_generate_summary(&self) -> Result<bool> {
let summary_data = self.load_summary().await?;
let last_summary_sequence = summary_data.sequence_number;
let messages = if let Some(ref message_history) = self.message_history {
message_history.get_recent_messages(usize::MAX).await?
} else {
return Ok(false);
};
if messages.is_empty() {
return Ok(false);
}
let messages_to_summarize: Vec<ChatMessageRecord> = messages
.into_iter()
.filter(|msg| msg.sequence_number > last_summary_sequence)
.collect();
if messages_to_summarize.is_empty() {
return Ok(false);
}
let mut chat_text = String::new();
for msg in &messages_to_summarize {
chat_text.push_str(&format!("{}: {}\n", msg.role, msg.content));
}
let total_tokens = estimate_text_tokens(&chat_text);
if total_tokens > self.summary_threshold {
info!("[SummaryMemory] Generating summary... ({} new messages, {} tokens)", messages_to_summarize.len(), total_tokens);
let (summary, _) = self.generate_summary(&messages_to_summarize).await?;
let last_sequence = messages_to_summarize.last().map(|m| m.sequence_number).unwrap_or(0);
self.save_summary(&summary, last_sequence).await?;
if let Some(ref message_history) = self.message_history {
message_history.keep_recent_messages(self.recent_messages_count).await?;
}
Ok(true)
} else {
Ok(false)
}
}
pub fn get_session_id(&self) -> &str {
&self.session_id
}
pub async fn get_memory_stats(&self) -> Result<Value> {
let summary_data = self.load_summary().await?;
let file_path = self.data_dir.join(format!("{}_history.jsonl", self.session_id.clone()));
let chat_history = FileChatMessageHistory::new(self.session_id.clone(), file_path).await?;
let messages = chat_history.get_messages().await?;
let mut chat_text = String::new();
for msg in &messages {
chat_text.push_str(&format!("{}: {}\n", msg.role, msg.content));
}
let token_count = estimate_text_tokens(&chat_text);
let stats = json!({
"session_id": self.session_id,
"summary_threshold": self.summary_threshold,
"recent_messages_count": self.recent_messages_count,
"message_count": messages.len(),
"token_count": token_count,
"has_summary": summary_data.summary.is_some(),
"summary_token_count": summary_data.token_count,
"last_updated": summary_data.last_updated
});
Ok(stats)
}
}
use crate::memory::base::BaseMemory;
impl BaseMemory for SummaryMemory {
fn memory_variables(&self) -> Vec<String> {
vec!["chat_history".to_string()]
}
fn load_memory_variables<'a>(&'a self, _inputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<HashMap<String, Value>, Error>> + Send + 'a>> {
let session_id = self.session_id.clone();
let data_dir = self.data_dir.clone();
let summary_threshold = self.summary_threshold;
let recent_messages_count = self.recent_messages_count;
let use_shared_history = self.message_history.is_some();
Box::pin(async move {
let summary_memory = SummaryMemory {
session_id: session_id.clone(),
data_dir: data_dir.clone(),
summary_threshold,
summary_prompt_template: String::new(),
recent_messages_count,
message_history: None, };
let summary_data = summary_memory.load_summary().await?;
let messages = if use_shared_history {
let file_path = data_dir.join(format!("{}_history.jsonl", session_id.clone()));
let chat_history = FileChatMessageHistory::new(session_id.clone(), file_path).await?;
chat_history.get_messages().await?
} else {
let file_path = data_dir.join(format!("{}_history.jsonl", session_id.clone()));
let chat_history = FileChatMessageHistory::new(session_id.clone(), file_path).await?;
chat_history.get_messages().await?
};
let mut history_array = Vec::new();
let mut system_prompt_parts = Vec::new();
system_prompt_parts.push("You are a helpful assistant that provides accurate and concise answers.".to_string());
if let Some(user_system_prompt) = std::env::var("USER_SYSTEM_PROMPT").ok() {
system_prompt_parts.push(user_system_prompt);
}
if let Some(summary) = summary_data.summary {
system_prompt_parts.push(format!("Previous conversation summary: {}", summary));
}
let combined_system_prompt = system_prompt_parts.join("\n\n");
let mut system_msg_obj = serde_json::Map::new();
system_msg_obj.insert("role".to_string(), serde_json::Value::String("system".to_string()));
system_msg_obj.insert("content".to_string(), serde_json::Value::String(combined_system_prompt));
history_array.push(serde_json::Value::Object(system_msg_obj));
let len = messages.len();
let start = if len > recent_messages_count {
len - recent_messages_count
} else {
0
};
for msg in &messages[start..] {
let mut msg_obj = serde_json::Map::new();
msg_obj.insert("role".to_string(), serde_json::Value::String(msg.role.clone()));
msg_obj.insert("content".to_string(), serde_json::Value::String(msg.content.clone()));
if let Some(name) = &msg.name {
msg_obj.insert("name".to_string(), serde_json::Value::String(name.clone()));
}
if let Some(kwargs) = &msg.additional_kwargs {
for (k, v) in kwargs {
msg_obj.insert(k.clone(), v.clone());
}
}
history_array.push(serde_json::Value::Object(msg_obj));
}
let mut result = HashMap::new();
result.insert("chat_history".to_string(), serde_json::Value::Array(history_array));
Ok(result)
})
}
fn save_context<'a>(&'a self, inputs: &'a HashMap<String, Value>, outputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
Box::pin(async move {
let mut user_message: Option<String> = None;
let mut assistant_message: Option<String> = None;
if let Some(input_value) = inputs.get("input") {
if let Some(s) = input_value.as_str() {
user_message = Some(s.to_string());
}
}
if let Some(output_value) = outputs.get("output") {
if let Some(s) = output_value.as_str() {
assistant_message = Some(s.to_string());
}
}
if let Some(ref message_history) = self.message_history {
if let Some(user_msg) = user_message {
let chat_msg = ChatMessage {
id: uuid::Uuid::new_v4().to_string(),
role: "user".to_string(),
content: user_msg,
timestamp: chrono::Utc::now().to_rfc3339(),
metadata: None,
};
message_history.add_message(&chat_msg).await?;
}
if let Some(assistant_msg) = assistant_message {
let chat_msg = ChatMessage {
id: uuid::Uuid::new_v4().to_string(),
role: "assistant".to_string(),
content: assistant_msg,
timestamp: chrono::Utc::now().to_rfc3339(),
metadata: None,
};
message_history.add_message(&chat_msg).await?;
}
info!("save_context");
}
Ok(())
})
}
fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
let session_id = self.session_id.clone();
let data_dir = self.data_dir.clone();
Box::pin(async move {
let file_path = data_dir.join(format!("{}_history.jsonl", session_id.clone()));
let chat_history = FileChatMessageHistory::new(session_id.clone(), file_path).await?;
chat_history.clear().await?;
let summary_path = data_dir.join(format!("{}_summary.json", session_id.clone()));
if tokio::fs::metadata(&summary_path).await.is_ok() {
tokio::fs::remove_file(&summary_path).await?;
}
Ok(())
})
}
fn clone_box(&self) -> Box<dyn BaseMemory> {
Box::new(self.clone())
}
fn get_session_id(&self) -> Option<&str> {
Some(&self.session_id)
}
fn set_session_id(&mut self, session_id: String) {
self.session_id = session_id;
}
fn get_token_count(&self) -> Result<usize, Error> {
let text = format!("{}:{}", self.session_id, self.data_dir.to_string_lossy());
Ok(estimate_text_tokens(&text))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}