1use std::collections::HashMap;
2use std::sync::Mutex;
3
4use async_trait::async_trait;
5use nuro_core::{message::Message, Result};
6
7#[async_trait]
11pub trait MemoryStore: Send + Sync {
12 async fn add(&self, conversation_id: &str, message: Message) -> Result<()>;
14
15 async fn query(&self, _conversation_id: &str, _query: &str) -> Result<Vec<Message>> {
19 Ok(Vec::new())
20 }
21
22 async fn get_conversation(&self, conversation_id: &str) -> Result<Vec<Message>>;
24
25 async fn save_conversation(
27 &self,
28 conversation_id: &str,
29 messages: &[Message],
30 ) -> Result<()>;
31}
32
33#[derive(Default)]
39pub struct InMemoryMemoryStore {
40 inner: Mutex<HashMap<String, Vec<Message>>>,
41}
42
43impl InMemoryMemoryStore {
44 pub fn new() -> Self {
45 Self {
46 inner: Mutex::new(HashMap::new()),
47 }
48 }
49}
50
51#[async_trait]
52impl MemoryStore for InMemoryMemoryStore {
53 async fn add(&self, conversation_id: &str, message: Message) -> Result<()> {
54 let mut guard = self.inner.lock().unwrap();
55 guard
56 .entry(conversation_id.to_string())
57 .or_default()
58 .push(message);
59 Ok(())
60 }
61
62 async fn query(&self, conversation_id: &str, query: &str) -> Result<Vec<Message>> {
67 let guard = self.inner.lock().unwrap();
68 let messages = guard.get(conversation_id).cloned().unwrap_or_default();
69
70 if query.trim().is_empty() {
71 return Ok(messages);
72 }
73
74 let q = query.to_lowercase();
75 let filtered = messages
76 .into_iter()
77 .filter(|m| {
78 m.text_content()
79 .map(|t| t.to_lowercase().contains(&q))
80 .unwrap_or(false)
81 })
82 .collect();
83
84 Ok(filtered)
85 }
86
87 async fn get_conversation(&self, conversation_id: &str) -> Result<Vec<Message>> {
88 let guard = self.inner.lock().unwrap();
89 Ok(guard.get(conversation_id).cloned().unwrap_or_default())
90 }
91
92 async fn save_conversation(
93 &self,
94 conversation_id: &str,
95 messages: &[Message],
96 ) -> Result<()> {
97 let mut guard = self.inner.lock().unwrap();
98 guard.insert(conversation_id.to_string(), messages.to_vec());
99 Ok(())
100 }
101}