use async_trait::async_trait;
use rucora_core::{
error::{MemoryError, ToolError},
memory::{Memory, MemoryItem, MemoryQuery},
tool::{Tool, ToolCategory},
};
use serde_json::{Value, json};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
struct SimpleMemory {
records: Mutex<HashMap<String, MemoryItem>>,
}
impl SimpleMemory {
fn new() -> Self {
Self {
records: Mutex::new(HashMap::new()),
}
}
}
#[async_trait]
impl Memory for SimpleMemory {
async fn add(&self, item: MemoryItem) -> Result<(), MemoryError> {
self.records.lock().unwrap().insert(item.id.clone(), item);
Ok(())
}
async fn query(&self, _query: MemoryQuery) -> Result<Vec<MemoryItem>, MemoryError> {
Ok(self.records.lock().unwrap().values().cloned().collect())
}
}
pub struct MemoryStoreTool {
memory: Arc<dyn Memory>,
}
impl MemoryStoreTool {
pub fn new() -> Self {
Self::from_memory(Arc::new(SimpleMemory::new()))
}
pub fn from_memory(memory: Arc<dyn Memory>) -> Self {
Self { memory }
}
}
impl Default for MemoryStoreTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for MemoryStoreTool {
fn name(&self) -> &str {
"memory_store"
}
fn description(&self) -> Option<&str> {
Some(
"存储事实、偏好或笔记到长期记忆。使用 category 'core' 表示永久记忆,'daily' 表示会话笔记,'conversation' 表示对话上下文",
)
}
fn categories(&self) -> &'static [ToolCategory] {
&[ToolCategory::Memory]
}
fn input_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"key": {
"type": "string",
"description": "记忆的唯一键(如 'user_lang', 'project_stack')"
},
"content": {
"type": "string",
"description": "要记忆的信息"
},
"category": {
"type": "string",
"description": "记忆类别: 'core' (永久), 'daily' (会话), 'conversation' (对话), 或自定义类别。默认为 'core'"
}
},
"required": ["key", "content"]
})
}
async fn call(&self, input: Value) -> Result<Value, ToolError> {
let key = input
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::Message("缺少必需的 'key' 字段".to_string()))?;
let content = input
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::Message("缺少必需的 'content' 字段".to_string()))?;
let category = input
.get("category")
.and_then(|v| v.as_str())
.unwrap_or("core");
let full_key = format!("{category}:{key}");
self.memory
.add(MemoryItem {
id: full_key,
content: content.to_string(),
metadata: None,
})
.await
.map_err(|e| ToolError::Message(e.to_string()))?;
Ok(json!({
"success": true,
"key": key,
"category": category,
"message": format!("已存储记忆: {}", key)
}))
}
}
pub struct MemoryRecallTool {
memory: Arc<dyn Memory>,
}
impl MemoryRecallTool {
pub fn new() -> Self {
Self::from_memory(Arc::new(SimpleMemory::new()))
}
pub fn from_store(store: &MemoryStoreTool) -> Self {
Self {
memory: store.memory.clone(),
}
}
pub fn from_memory(memory: Arc<dyn Memory>) -> Self {
Self { memory }
}
}
impl Default for MemoryRecallTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for MemoryRecallTool {
fn name(&self) -> &str {
"memory_recall"
}
fn description(&self) -> Option<&str> {
Some("从长期记忆中检索存储的信息")
}
fn categories(&self) -> &'static [ToolCategory] {
&[ToolCategory::Memory]
}
fn input_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"key": {
"type": "string",
"description": "要检索的记忆键"
},
"category": {
"type": "string",
"description": "记忆类别,默认为 'core'"
}
},
"required": ["key"]
})
}
async fn call(&self, input: Value) -> Result<Value, ToolError> {
let key = input
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::Message("缺少必需的 'key' 字段".to_string()))?;
let category = input
.get("category")
.and_then(|v| v.as_str())
.unwrap_or("core");
let full_key = format!("{category}:{key}");
let mut results = self
.memory
.query(MemoryQuery {
text: full_key,
limit: 1,
})
.await
.map_err(|e| ToolError::Message(e.to_string()))?;
if let Some(item) = results.pop() {
Ok(json!({
"found": true,
"key": key,
"category": category,
"content": item.content
}))
} else {
Ok(json!({
"found": false,
"key": key,
"category": category,
"content": null
}))
}
}
}