use async_trait::async_trait;
use crate::schema::Message;
use std::collections::HashMap;
#[derive(Debug)]
pub enum MemoryError {
LoadError(String),
SaveError(String),
ClearError(String),
Other(String),
}
impl std::fmt::Display for MemoryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MemoryError::LoadError(msg) => write!(f, "加载记忆失败: {}", msg),
MemoryError::SaveError(msg) => write!(f, "保存记忆失败: {}", msg),
MemoryError::ClearError(msg) => write!(f, "清空记忆失败: {}", msg),
MemoryError::Other(msg) => write!(f, "Memory 错误: {}", msg),
}
}
}
impl std::error::Error for MemoryError {}
#[async_trait]
pub trait BaseMemory: Send + Sync {
fn memory_variables(&self) -> Vec<&str>;
async fn load_memory_variables(
&self,
inputs: &HashMap<String, String>,
) -> Result<HashMap<String, serde_json::Value>, MemoryError>;
async fn save_context(
&mut self,
inputs: &HashMap<String, String>,
outputs: &HashMap<String, String>,
) -> Result<(), MemoryError>;
async fn clear(&mut self) -> Result<(), MemoryError>;
}
pub trait BaseChatMemory: BaseMemory {
fn messages(&self) -> &Vec<Message>;
fn add_message(&mut self, message: Message);
fn add_user_message(&mut self, content: &str) {
self.add_message(Message::human(content));
}
fn add_ai_message(&mut self, content: &str) {
self.add_message(Message::ai(content));
}
}
#[derive(Debug, Clone)]
pub struct ChatMessageHistory {
messages: Vec<Message>,
}
impl ChatMessageHistory {
pub fn new() -> Self {
Self {
messages: Vec::new(),
}
}
pub fn from_messages(messages: Vec<Message>) -> Self {
Self { messages }
}
pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
}
pub fn add_user_message(&mut self, content: &str) {
self.add_message(Message::human(content));
}
pub fn add_ai_message(&mut self, content: &str) {
self.add_message(Message::ai(content));
}
pub fn add_system_message(&mut self, content: &str) {
self.add_message(Message::system(content));
}
pub fn messages(&self) -> &[Message] {
&self.messages
}
pub fn clear(&mut self) {
self.messages.clear();
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
}
impl std::fmt::Display for ChatMessageHistory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let formatted: String = self.messages
.iter()
.map(|msg| {
let role = match msg.message_type {
crate::schema::MessageType::Human => "Human",
crate::schema::MessageType::AI => "AI",
crate::schema::MessageType::System => "System",
crate::schema::MessageType::Tool { .. } => "Tool",
};
format!("{}: {}", role, msg.content)
})
.collect::<Vec<_>>()
.join("\n");
write!(f, "{}", formatted)
}
}
impl Default for ChatMessageHistory {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_message_history() {
let mut history = ChatMessageHistory::new();
history.add_user_message("你好");
history.add_ai_message("你好!有什么我可以帮助你的吗?");
history.add_user_message("介绍一下自己");
assert_eq!(history.len(), 3);
assert!(!history.is_empty());
}
#[test]
fn test_chat_message_history_to_string() {
let mut history = ChatMessageHistory::new();
history.add_user_message("你好");
history.add_ai_message("你好!");
let str = history.to_string();
assert!(str.contains("Human: 你好"));
assert!(str.contains("AI: 你好!"));
}
#[test]
fn test_chat_message_history_clear() {
let mut history = ChatMessageHistory::new();
history.add_user_message("测试");
assert_eq!(history.len(), 1);
history.clear();
assert_eq!(history.len(), 0);
assert!(history.is_empty());
}
}