use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::pin::Pin;
use anyhow::{Error, Result};
use serde::{Serialize, Deserialize};
use serde_json::{json, Value};
use tokio::sync::RwLock;
use async_trait::async_trait;
use log::{info, warn, error};
use std::future::Future;
use crate::memory::base::{BaseMemory, MemoryVariables};
use crate::memory::message_history::{MessageHistoryMemory, ChatMessage};
use crate::memory::summary::SummaryMemory;
use crate::memory::utils::{
ensure_data_dir_exists, get_data_dir_from_env, get_summary_threshold_from_env,
get_recent_messages_count_from_env, generate_session_id
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompositeMemoryConfig {
pub data_dir: PathBuf,
pub session_id: Option<String>,
pub summary_threshold: usize,
pub recent_messages_count: usize,
pub auto_generate_summary: bool,
}
impl Default for CompositeMemoryConfig {
fn default() -> Self {
Self {
data_dir: get_data_dir_from_env(),
session_id: None, summary_threshold: get_summary_threshold_from_env(),
recent_messages_count: get_recent_messages_count_from_env(),
auto_generate_summary: true,
}
}
}
#[derive(Debug, Clone)]
pub struct CompositeMemory {
config: CompositeMemoryConfig,
message_history: Option<Arc<MessageHistoryMemory>>,
summary_memory: Option<Arc<SummaryMemory>>,
memory_variables: Arc<RwLock<MemoryVariables>>,
}
impl CompositeMemory {
pub async fn new() -> Result<Self> {
Self::with_config(CompositeMemoryConfig::default()).await
}
pub async fn with_basic_params(
data_dir: PathBuf,
summary_threshold: usize,
recent_messages_count: usize,
) -> Result<Self> {
let config = CompositeMemoryConfig {
data_dir,
session_id: None, summary_threshold,
recent_messages_count,
auto_generate_summary: true,
};
Self::with_config(config).await
}
pub async fn with_config(config: CompositeMemoryConfig) -> Result<Self> {
ensure_data_dir_exists(&config.data_dir).await?;
let session_id = config.session_id.clone()
.unwrap_or_else(|| generate_session_id());
let history = MessageHistoryMemory::new_with_recent_count(
session_id.clone(),
config.data_dir.clone(),
config.recent_messages_count
).await?;
let message_history = Some(Arc::new(history));
let summary = SummaryMemory::new_with_shared_history(
session_id.clone(),
config.data_dir.clone(),
config.summary_threshold,
message_history.clone().unwrap() ).await?;
let summary_memory = Some(Arc::new(summary));
Ok(Self {
config,
message_history,
summary_memory,
memory_variables: Arc::new(RwLock::new(HashMap::new())),
})
}
pub async fn with_session_id(session_id: String) -> Result<Self> {
let mut config = CompositeMemoryConfig::default();
config.session_id = Some(session_id);
Self::with_config(config).await
}
pub async fn add_message(&self, message: ChatMessage) -> Result<()> {
if let Some(ref history) = self.message_history {
history.add_message(&message).await?;
}
if self.config.auto_generate_summary {
info!("Checking if summary generation is needed...");
if let Some(ref summary) = self.summary_memory {
summary.check_and_generate_summary().await?;
if let Some(ref history) = self.message_history {
let keep_count = self.config.recent_messages_count;
history.keep_recent_messages(keep_count).await?;
}
}
}
Ok(())
}
pub async fn get_message_count(&self) -> Result<usize> {
if let Some(ref history) = self.message_history {
history.get_message_count().await
} else {
Ok(0)
}
}
pub async fn get_recent_messages(&self, count: usize) -> Result<Vec<ChatMessage>> {
if let Some(ref history) = self.message_history {
history.get_recent_chat_messages(count).await
} else {
Ok(Vec::new())
}
}
pub async fn cleanup_old_messages(&self) -> Result<()> {
if let Some(ref history) = self.message_history {
history.keep_recent_messages(self.config.recent_messages_count).await?;
}
Ok(())
}
pub async fn get_memory_stats(&self) -> Result<Value> {
let mut stats = json!({
"config": {
"summary_threshold": self.config.summary_threshold,
"recent_messages_count": self.config.recent_messages_count,
"auto_generate_summary": self.config.auto_generate_summary,
}
});
if let Some(ref history) = self.message_history {
let message_count: usize = history.get_message_count().await?;
stats["message_history"] = json!({
"enabled": true,
"message_count": message_count,
});
}
if let Some(ref summary) = self.summary_memory {
let summary_data = summary.load_summary().await?;
stats["summary_memory"] = json!({
"enabled": true,
"has_summary": summary_data.summary.is_some(),
"token_count": summary_data.token_count,
"last_updated": summary_data.last_updated,
});
}
Ok(stats)
}
pub async fn get_summary(&self) -> Result<Option<String>> {
if let Some(ref summary) = self.summary_memory {
let summary_data = summary.load_summary().await?;
Ok(summary_data.summary)
} else {
Ok(None)
}
}
}
impl CompositeMemory {
pub fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[async_trait]
impl BaseMemory for CompositeMemory {
fn memory_variables(&self) -> Vec<String> {
let mut vars = Vec::new();
vars.extend_from_slice(&["chat_history".to_string(), "summary".to_string(), "input".to_string(), "output".to_string()]);
vars.push("config".to_string());
vars
}
fn load_memory_variables<'a>(&'a self, inputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<HashMap<String, Value>>> + Send + 'a>> {
Box::pin(async move {
let mut result = HashMap::new();
if let Some(ref history) = self.message_history {
let messages = history.get_recent_chat_messages(
self.config.recent_messages_count
).await?;
let history_json = serde_json::to_value(&messages)?;
result.insert("chat_history".to_string(), history_json);
}
if let Some(ref summary) = self.summary_memory {
let summary_data = summary.load_summary().await?;
if let Some(summary_text) = summary_data.summary {
result.insert("summary".to_string(), json!(summary_text));
}
}
if let Some(input) = inputs.get("input") {
result.insert("input".to_string(), input.clone());
}
if let Some(output) = inputs.get("output") {
result.insert("output".to_string(), output.clone());
}
result.insert("config".to_string(), serde_json::to_value(&self.config)?);
*self.memory_variables.write().await = result.clone();
Ok(result)
})
}
fn save_context<'a>(&'a self, inputs: &'a HashMap<String, Value>, outputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
Box::pin(async move {
let input = inputs.get("input")
.and_then(|v| v.as_str())
.unwrap_or("");
let output = outputs.get("output")
.and_then(|v| v.as_str())
.unwrap_or("");
if !input.is_empty() {
let user_message = ChatMessage {
id: uuid::Uuid::new_v4().to_string(),
role: "user".to_string(),
content: input.to_string(),
timestamp: chrono::Utc::now().to_rfc3339(),
metadata: None,
};
if let Some(ref history) = self.message_history {
history.add_message(&user_message).await?;
}
}
if !output.is_empty() {
let assistant_message = ChatMessage {
id: uuid::Uuid::new_v4().to_string(),
role: "assistant".to_string(),
content: output.to_string(),
timestamp: chrono::Utc::now().to_rfc3339(),
metadata: None,
};
if let Some(ref history) = self.message_history {
history.add_message(&assistant_message).await?;
}
}
if self.config.auto_generate_summary {
info!("Checking if summary generation is needed...");
if let Some(ref summary) = self.summary_memory {
summary.check_and_generate_summary().await?;
if let Some(ref history) = self.message_history {
let keep_count = self.config.recent_messages_count;
history.keep_recent_messages(keep_count).await?;
}
}
}
let mut memory_vars = self.memory_variables.write().await;
if let Some(input_val) = inputs.get("input") {
memory_vars.insert("input".to_string(), input_val.clone());
}
if let Some(output_val) = outputs.get("output") {
memory_vars.insert("output".to_string(), output_val.clone());
}
Ok(())
})
}
fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
Box::pin(async move {
if let Some(ref history) = self.message_history {
history.clear().await?;
}
if let Some(ref summary) = self.summary_memory {
summary.clear().await?;
}
self.memory_variables.write().await.clear();
Ok(())
})
}
fn clone_box(&self) -> Box<dyn BaseMemory> {
Box::new(self.clone())
}
fn get_session_id(&self) -> Option<&str> {
self.config.session_id.as_deref()
}
fn set_session_id(&mut self, session_id: String) {
self.config.session_id = Some(session_id);
}
fn get_token_count(&self) -> Result<usize, Error> {
let mut count = 0;
if let Ok(config_json) = serde_json::to_value(&self.config) {
count += crate::memory::utils::estimate_json_token_count(&config_json);
}
if let Ok(memory_vars) = self.memory_variables.try_read() {
if let Ok(vars_json) = serde_json::to_value(&*memory_vars) {
count += crate::memory::utils::estimate_json_token_count(&vars_json);
}
}
Ok(count)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
use crate::memory::message_history::ChatMessage;
#[tokio::test]
async fn test_composite_memory_new() {
let memory = CompositeMemory::new().await;
assert!(memory.is_ok());
}
#[tokio::test]
async fn test_composite_memory_with_session_id() {
let session_id = "test_session";
let memory = CompositeMemory::with_session_id(session_id.to_string()).await;
assert!(memory.is_ok());
let memory = memory.unwrap();
assert_eq!(memory.get_session_id(), Some(session_id));
}
#[tokio::test]
async fn test_add_message() {
let temp_dir = TempDir::new().unwrap();
let mut config = CompositeMemoryConfig::default();
config.data_dir = temp_dir.path().to_path_buf();
config.auto_generate_summary = false;
let memory = CompositeMemory::with_config(config).await.unwrap();
let message = ChatMessage {
id: "test_id".to_string(),
role: "user".to_string(),
content: "Hello, world!".to_string(),
timestamp: chrono::Utc::now().to_rfc3339(),
metadata: None,
};
let result = memory.add_message(message).await;
assert!(result.is_ok());
let count = memory.get_message_count().await.unwrap();
assert_eq!(count, 1);
}
#[tokio::test]
async fn test_save_context() {
let temp_dir = TempDir::new().unwrap();
let mut config = CompositeMemoryConfig::default();
config.data_dir = temp_dir.path().to_path_buf();
config.auto_generate_summary = false;
let memory = CompositeMemory::with_config(config).await.unwrap();
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), json!("Hello"));
let mut outputs = HashMap::new();
outputs.insert("output".to_string(), json!("Hi there!"));
let result = memory.save_context(&inputs, &outputs).await;
assert!(result.is_ok());
let count = memory.get_message_count().await.unwrap();
assert_eq!(count, 2); }
#[tokio::test]
async fn test_clear() {
let temp_dir = TempDir::new().unwrap();
let mut config = CompositeMemoryConfig::default();
config.data_dir = temp_dir.path().to_path_buf();
config.auto_generate_summary = false;
let memory = CompositeMemory::with_config(config).await.unwrap();
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), json!("Hello"));
let mut outputs = HashMap::new();
outputs.insert("output".to_string(), json!("Hi there!"));
memory.save_context(&inputs, &outputs).await.unwrap();
let count = memory.get_message_count().await.unwrap();
assert_eq!(count, 2);
let result = memory.clear().await;
assert!(result.is_ok());
let count = memory.get_message_count().await.unwrap();
assert_eq!(count, 0);
}
}