Skip to main content

nuro_memory/
store.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3
4use async_trait::async_trait;
5use nuro_core::{message::Message, Result};
6
7/// 抽象的记忆存储接口。
8///
9/// 为了保持 MVP 简洁,所有方法的语义都尽量宽松,仅用于占位和后续扩展。
10#[async_trait]
11pub trait MemoryStore: Send + Sync {
12    /// 追加一条消息到某个会话中。
13    async fn add(&self, conversation_id: &str, message: Message) -> Result<()>;
14
15    /// 按简单字符串 query 查询相关消息。
16    ///
17    /// 约定:实现可以根据 `conversation_id` 做范围限定,也可以忽略它并全局搜索。
18    async fn query(&self, _conversation_id: &str, _query: &str) -> Result<Vec<Message>> {
19        Ok(Vec::new())
20    }
21
22    /// 获取某个会话的完整消息列表。
23    async fn get_conversation(&self, conversation_id: &str) -> Result<Vec<Message>>;
24
25    /// 覆盖保存整个会话的消息列表。
26    async fn save_conversation(
27        &self,
28        conversation_id: &str,
29        messages: &[Message],
30    ) -> Result<()>;
31}
32
33/// 纯内存版的实现:
34///
35/// - 使用 `HashMap<conversation_id, Vec<Message>>` 保存消息;
36/// - 不做容量控制与持久化,仅用于开发与测试;
37/// - 线程安全,但不保证高并发场景下的性能。
38#[derive(Default)]
39pub struct InMemoryMemoryStore {
40    inner: Mutex<HashMap<String, Vec<Message>>>,
41}
42
43impl InMemoryMemoryStore {
44    pub fn new() -> Self {
45        Self {
46            inner: Mutex::new(HashMap::new()),
47        }
48    }
49}
50
51#[async_trait]
52impl MemoryStore for InMemoryMemoryStore {
53    async fn add(&self, conversation_id: &str, message: Message) -> Result<()> {
54        let mut guard = self.inner.lock().unwrap();
55        guard
56            .entry(conversation_id.to_string())
57            .or_default()
58            .push(message);
59        Ok(())
60    }
61
62    /// 一个极简的包含匹配查询实现:
63    /// - 仅在指定 `conversation_id` 的会话内搜索;
64    /// - 使用不区分大小写的子串匹配;
65    /// - 若 query 为空字符串,则返回整个会话的消息列表。
66    async fn query(&self, conversation_id: &str, query: &str) -> Result<Vec<Message>> {
67        let guard = self.inner.lock().unwrap();
68        let messages = guard.get(conversation_id).cloned().unwrap_or_default();
69
70        if query.trim().is_empty() {
71            return Ok(messages);
72        }
73
74        let q = query.to_lowercase();
75        let filtered = messages
76            .into_iter()
77            .filter(|m| {
78                m.text_content()
79                    .map(|t| t.to_lowercase().contains(&q))
80                    .unwrap_or(false)
81            })
82            .collect();
83
84        Ok(filtered)
85    }
86
87    async fn get_conversation(&self, conversation_id: &str) -> Result<Vec<Message>> {
88        let guard = self.inner.lock().unwrap();
89        Ok(guard.get(conversation_id).cloned().unwrap_or_default())
90    }
91
92    async fn save_conversation(
93        &self,
94        conversation_id: &str,
95        messages: &[Message],
96    ) -> Result<()> {
97        let mut guard = self.inner.lock().unwrap();
98        guard.insert(conversation_id.to_string(), messages.to_vec());
99        Ok(())
100    }
101}