use crate::truncate::truncate_chars;
use anyhow::Result;
use serde::Deserialize;
use super::config::*;
use super::keywords_config::KeywordsConfig;
use super::types::{AutoMemory, MemoryCategory, MemoryEntry};
#[async_trait::async_trait]
pub trait MemoryExtractor: Send + Sync {
async fn extract(&self, text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>>;
fn model_name(&self) -> &str;
}
pub struct AiMemoryExtractor {
provider: Box<dyn crate::providers::Provider>,
model: String,
}
impl AiMemoryExtractor {
pub fn new(provider: Box<dyn crate::providers::Provider>, model: String) -> Self {
Self { provider, model }
}
}
const MEMORY_EXTRACT_SYSTEM_PROMPT: &str = r#"你是一个记忆提取助手。从对话中提取值得长期记忆的关键信息。
记忆类型:
- decision: 项目或技术选型的决定
- preference: 用户习惯或偏好
- solution: 解决问题的具体方法
- finding: 重要发现或信息
- technical: 技术栈或框架信息
- structure: 项目结构信息
输出格式(严格 JSON):
{
"memories": [
{
"category": "decision",
"content": "采用 PostgreSQL 作为主数据库",
"importance": 85,
"keywords": ["PostgreSQL", "数据库", "database"],
"tags": ["backend", "storage"]
}
]
}
关键词提取要求:
- 提取 3-5 个核心关键词(技术名词、项目名、关键概念)
- 中英文关键词都提取
- 用于后续记忆检索匹配
标签提取要求:
- 提取 1-3 个分类标签(如 backend、frontend、config、auth 等)
- 用于记忆分类筛选
只返回 JSON,不要其他解释。"#;
#[async_trait::async_trait]
impl MemoryExtractor for AiMemoryExtractor {
async fn extract(&self, text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>> {
use crate::providers::{ChatRequest, Message, MessageContent, Role};
let truncated = truncate_chars(text, 4000);
let request = ChatRequest {
messages: vec![Message {
role: Role::User,
content: MessageContent::Text(format!(
"请从以下对话中提取值得记忆的关键信息:\n\n{}",
truncated
)),
}],
tools: vec![],
system: Some(MEMORY_EXTRACT_SYSTEM_PROMPT.to_string()),
think: false,
max_tokens: 512,
server_tools: vec![],
enable_caching: false,
};
let response = self.provider.chat(request).await?;
let response_text = response
.content
.iter()
.filter_map(|b| {
if let crate::providers::ContentBlock::Text { text } = b {
Some(text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("");
parse_memory_response(&response_text, session_id)
}
fn model_name(&self) -> &str {
&self.model
}
}
fn parse_memory_response(json_text: &str, session_id: Option<&str>) -> Result<Vec<MemoryEntry>> {
let cleaned = json_text
.trim()
.trim_start_matches("```json")
.trim_start_matches("```")
.trim_end_matches("```")
.trim();
#[derive(Deserialize)]
struct MemoryResponse {
memories: Vec<MemoryItem>,
}
#[derive(Deserialize)]
struct MemoryItem {
category: String,
content: String,
#[serde(default)]
importance: f64,
#[serde(default)]
keywords: Vec<String>,
#[serde(default)]
tags: Vec<String>,
}
let parsed: MemoryResponse = serde_json::from_str(cleaned)?;
let entries = parsed
.memories
.into_iter()
.filter_map(|item| {
let category = match item.category.to_lowercase().as_str() {
"decision" => MemoryCategory::Decision,
"preference" => MemoryCategory::Preference,
"solution" => MemoryCategory::Solution,
"finding" => MemoryCategory::Finding,
"technical" => MemoryCategory::Technical,
"structure" => MemoryCategory::Structure,
_ => return None,
};
if item.content.len() < MIN_MEMORY_CONTENT_LENGTH {
return None;
}
let mut entry =
MemoryEntry::new(category, item.content, session_id.map(|s| s.to_string()));
if item.importance > 0.0 {
entry.importance = item.importance.clamp(0.0, 100.0);
}
if !item.keywords.is_empty() {
entry.tags.extend(item.keywords);
}
if !item.tags.is_empty() {
entry.tags.extend(item.tags);
}
entry.tags.dedup();
Some(entry)
})
.collect();
Ok(deduplicate_entries(entries))
}
fn deduplicate_entries(entries: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
let mut seen: Vec<String> = Vec::new();
entries
.into_iter()
.filter(|e| {
let content_lower = e.content.to_lowercase();
if seen.iter().any(|s| {
AutoMemory::calculate_similarity(s, &content_lower) >= SIMILARITY_THRESHOLD
}) {
false
} else {
seen.push(content_lower);
true
}
})
.take(MAX_DETECTED_ENTRIES)
.collect()
}
pub fn detect_memories_fallback(text: &str, session_id: Option<&str>) -> Vec<MemoryEntry> {
let config = KeywordsConfig::load();
let mut entries = Vec::new();
let text_lower = text.to_lowercase();
let categories = [
(MemoryCategory::Decision, "decision"),
(MemoryCategory::Preference, "preference"),
(MemoryCategory::Solution, "solution"),
(MemoryCategory::Finding, "finding"),
(MemoryCategory::Technical, "technical"),
(MemoryCategory::Structure, "structure"),
];
for (category, key) in categories {
let patterns = config
.patterns
.get(key)
.map(|v| v.as_slice())
.unwrap_or(&[]);
for keyword in patterns {
if text_lower.contains(&keyword.to_lowercase()) {
let content = extract_memory_content(text, keyword);
if !content.is_empty() && content.len() >= MIN_MEMORY_CONTENT_LENGTH {
entries.push(MemoryEntry::new(
category,
content,
session_id.map(|s| s.to_string()),
));
}
}
}
}
deduplicate_entries(entries)
}
pub fn detect_memories_from_text(text: &str, session_id: Option<&str>) -> Vec<MemoryEntry> {
detect_memories_fallback(text, session_id)
}
pub async fn detect_memories_smart(
text: &str,
session_id: Option<&str>,
extractor: Option<&AiMemoryExtractor>,
) -> Vec<MemoryEntry> {
let mode = AiDetectionMode::from_env();
let text_len = text.len();
let should_try_ai = mode != AiDetectionMode::Never && extractor.is_some() && text_len > 200;
if should_try_ai && let Some(ex) = extractor {
if let Ok(ai_entries) = ex.extract(text, session_id).await {
return deduplicate_entries(ai_entries);
}
log::warn!("AI memory extraction failed, falling back to rule-based");
}
detect_memories_fallback(text, session_id)
}
fn extract_memory_content(text: &str, keyword: &str) -> String {
let text_lower = text.to_lowercase();
let keyword_lower = keyword.to_lowercase();
let pos = match text_lower.find(&keyword_lower) {
Some(p) => p,
None => return String::new(),
};
let start = text[..pos]
.rfind(['.', '。', '\n'])
.map(|i| i + 1)
.unwrap_or(0);
let end = text[pos..]
.find(['.', '。', '\n'])
.map(|i| pos + i + 1)
.unwrap_or(text.len());
let sentence = text[start..end].trim();
if sentence.len() > MAX_MEMORY_CONTENT_LENGTH {
sentence[..MAX_MEMORY_CONTENT_LENGTH].to_string()
} else {
sentence.to_string()
}
}
pub fn infer_category_from_content(content: &str) -> MemoryCategory {
let lower = content.to_lowercase();
if lower.contains("决定")
|| lower.contains("选择")
|| lower.contains("采用")
|| lower.contains("decided")
{
return MemoryCategory::Decision;
}
if lower.contains("喜欢")
|| lower.contains("偏好")
|| lower.contains("习惯")
|| lower.contains("prefer")
{
return MemoryCategory::Preference;
}
if lower.contains("解决")
|| lower.contains("修复")
|| lower.contains("搞定")
|| lower.contains("fixed")
{
return MemoryCategory::Solution;
}
if lower.contains("发现")
|| lower.contains("原因")
|| lower.contains("原来")
|| lower.contains("found")
{
return MemoryCategory::Finding;
}
if lower.contains("技术")
|| lower.contains("框架")
|| lower.contains("库")
|| lower.contains("tech")
{
return MemoryCategory::Technical;
}
if lower.contains("文件")
|| lower.contains("目录")
|| lower.contains("入口")
|| lower.contains("file")
{
return MemoryCategory::Structure;
}
MemoryCategory::Finding }