use anyhow::Error;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use serde_json::Value;
use std::pin::Pin;
use std::future::Future;
use log::info;
pub type MemoryVariables = HashMap<String, Value>;
pub trait BaseMemory: Send + Sync {
fn memory_variables(&self) -> Vec<String>;
fn load_memory_variables<'a>(&'a self, inputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<HashMap<String, Value>, Error>> + Send + 'a>>;
fn save_context<'a>(&'a self, inputs: &'a HashMap<String, Value>, outputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>>;
fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>>;
fn clone_box(&self) -> Box<dyn BaseMemory>;
fn get_session_id(&self) -> Option<&str>;
fn set_session_id(&mut self, session_id: String);
fn get_token_count(&self) -> Result<usize, Error>;
fn as_any(&self) -> &dyn std::any::Any;
}
#[derive(Debug)]
pub struct SimpleMemory {
memories: Arc<RwLock<HashMap<String, Value>>>,
memory_key: String,
session_id: Option<String>,
}
impl Clone for SimpleMemory {
fn clone(&self) -> Self {
Self {
memories: Arc::clone(&self.memories),
memory_key: self.memory_key.clone(),
session_id: self.session_id.clone(),
}
}
}
impl SimpleMemory {
pub fn new() -> Self {
Self {
memories: Arc::new(RwLock::new(HashMap::new())),
memory_key: "chat_history".to_string(),
session_id: None,
}
}
pub fn with_memory_key(memory_key: String) -> Self {
Self {
memories: Arc::new(RwLock::new(HashMap::new())),
memory_key,
session_id: None,
}
}
pub fn with_memories(memories: HashMap<String, Value>) -> Self {
Self {
memories: Arc::new(RwLock::new(memories)),
memory_key: "chat_history".to_string(),
session_id: None,
}
}
pub async fn add_message(&self, message: Value) -> Result<(), Error> {
let mut memories = self.memories.write().await;
let chat_history = memories.entry(self.memory_key.clone()).or_insert_with(|| Value::Array(vec![]));
if let Value::Array(ref mut arr) = chat_history {
arr.push(message);
} else {
*chat_history = Value::Array(vec![message]);
}
Ok(())
}
pub fn get_memory_key(&self) -> String {
self.memory_key.clone()
}
}
impl Default for SimpleMemory {
fn default() -> Self {
Self::new()
}
}
impl BaseMemory for SimpleMemory {
fn memory_variables(&self) -> Vec<String> {
vec![self.memory_key.clone()]
}
fn load_memory_variables<'a>(&'a self, _inputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<HashMap<String, Value>, Error>> + Send + 'a>> {
let memories = Arc::clone(&self.memories);
Box::pin(async move {
let memories = memories.read().await;
Ok(memories.clone())
})
}
fn save_context<'a>(&'a self, inputs: &'a HashMap<String, Value>, outputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
let memories = Arc::clone(&self.memories);
let input_clone = inputs.clone();
let output_clone = outputs.clone();
let memory_key = self.memory_key.clone();
Box::pin(async move {
let mut memories = memories.write().await;
let chat_history = memories.entry(memory_key.clone()).or_insert_with(|| Value::Array(vec![]));
if !chat_history.is_array() {
*chat_history = Value::Array(vec![]);
}
if let Some(input_value) = input_clone.get("input") {
let user_message = serde_json::json!({
"role": "human",
"content": input_value
});
if let Value::Array(ref mut arr) = chat_history {
info!("Adding to chat history: {:?}", user_message);
arr.push(user_message);
}
}
if let Some(output_value) = output_clone.get("output") {
let ai_message = serde_json::json!({
"role": "ai",
"content": output_value
});
if let Value::Array(ref mut arr) = chat_history {
arr.push(ai_message);
}
}
Ok(())
})
}
fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
let memories = Arc::clone(&self.memories);
Box::pin(async move {
let mut memories = memories.write().await;
memories.clear();
Ok(())
})
}
fn clone_box(&self) -> Box<dyn BaseMemory> {
Box::new(self.clone())
}
fn get_session_id(&self) -> Option<&str> {
self.session_id.as_deref()
}
fn set_session_id(&mut self, session_id: String) {
self.session_id = Some(session_id);
}
fn get_token_count(&self) -> Result<usize, Error> {
let count = self.memory_key.len() + self.session_id.as_ref().map(|s| s.len()).unwrap_or(0);
Ok(count)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl Clone for Box<dyn BaseMemory> {
fn clone(&self) -> Self {
self.as_ref().clone_box()
}
}