1use std::sync::{
4 Arc,
5 atomic::{AtomicU64, Ordering},
6};
7
8use crate::memory::{
9 domain::{HybridWeights, MemoryEntry, MemoryImportance, RecallHit, RecallQuery},
10 error::{MemoryServiceError, MemoryValidationError},
11 hybrid::simple_text_embedding,
12 ports::MemoryStorePort,
13};
14
15static MEMORY_COUNTER: AtomicU64 = AtomicU64::new(1);
16
17pub struct MemoryService {
19 store: Arc<dyn MemoryStorePort>,
20 recall_limit: usize,
21 weights: HybridWeights,
22 vector_dimensions: usize,
23 enable_vector: bool,
24}
25
26impl std::fmt::Debug for MemoryService {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("MemoryService")
29 .field("recall_limit", &self.recall_limit)
30 .field("weights", &self.weights)
31 .field("vector_dimensions", &self.vector_dimensions)
32 .field("enable_vector", &self.enable_vector)
33 .finish_non_exhaustive()
34 }
35}
36
37impl MemoryService {
38 pub fn new(
40 store: Arc<dyn MemoryStorePort>,
41 recall_limit: usize,
42 weights: HybridWeights,
43 vector_dimensions: usize,
44 enable_vector: bool,
45 ) -> Result<Self, MemoryServiceError> {
46 if recall_limit == 0 {
47 return Err(MemoryValidationError::InvalidRecallLimit.into());
48 }
49 store.init_schema()?;
50 Ok(Self {
51 store,
52 recall_limit,
53 weights,
54 vector_dimensions: vector_dimensions.max(1),
55 enable_vector,
56 })
57 }
58
59 pub fn recall_for_turn(
61 &self,
62 session_id: &str,
63 input: &str,
64 ) -> Result<Vec<RecallHit>, MemoryServiceError> {
65 let query_embedding =
66 self.enable_vector.then(|| simple_text_embedding(input, self.vector_dimensions));
67
68 let query = RecallQuery {
69 session_id: Some(session_id.to_string()),
70 text: input.to_string(),
71 query_embedding,
72 limit: self.recall_limit,
73 };
74
75 self.store.recall_hybrid(&query, self.weights).map_err(MemoryServiceError::from)
76 }
77
78 #[must_use]
80 pub fn render_recall_context(hits: &[RecallHit]) -> Option<String> {
81 if hits.is_empty() {
82 return None;
83 }
84
85 let mut output = String::from("Relevant prior memory:\n");
86 for (index, hit) in hits.iter().enumerate() {
87 let number = index + 1;
88 output.push_str(&format!(
89 "{number}. [{}] {}\n",
90 hit.entry.topic,
91 hit.entry.summary.trim()
92 ));
93 }
94
95 Some(output)
96 }
97
98 pub fn persist_turn(
100 &self,
101 session_id: &str,
102 user_input: &str,
103 assistant_output: &str,
104 ) -> Result<(), MemoryServiceError> {
105 let now_ms = current_time_millis();
106 let id = format!("mem-{now_ms}-{}", MEMORY_COUNTER.fetch_add(1, Ordering::Relaxed));
107
108 let summary = truncate(assistant_output.trim(), 300);
109 let raw_excerpt =
110 format!("user: {}\nassistant: {}", user_input.trim(), assistant_output.trim());
111
112 let embedding = self.enable_vector.then(|| {
113 simple_text_embedding(
114 &format!("{} {}", user_input.trim(), assistant_output.trim()),
115 self.vector_dimensions,
116 )
117 });
118
119 let entry = MemoryEntry {
120 id,
121 session_id: session_id.to_string(),
122 topic: session_id.to_string(),
123 summary,
124 raw_excerpt,
125 keywords: extract_keywords(user_input, assistant_output),
126 importance: MemoryImportance::Medium,
127 embedding,
128 created_at_epoch_ms: now_ms,
129 };
130
131 self.store.insert(&entry)?;
132 Ok(())
133 }
134}
135
136fn current_time_millis() -> i64 {
137 use std::time::{SystemTime, UNIX_EPOCH};
138
139 let duration = match SystemTime::now().duration_since(UNIX_EPOCH) {
140 Ok(value) => value,
141 Err(_) => return 0,
142 };
143 i64::try_from(duration.as_millis()).unwrap_or(i64::MAX)
144}
145
146fn truncate(input: &str, max_chars: usize) -> String {
147 let mut output = String::new();
148 for (idx, ch) in input.chars().enumerate() {
149 if idx >= max_chars {
150 break;
151 }
152 output.push(ch);
153 }
154 output
155}
156
157fn extract_keywords(user_input: &str, assistant_output: &str) -> Vec<String> {
158 let mut keywords = Vec::new();
159 for token in user_input
160 .split_whitespace()
161 .chain(assistant_output.split_whitespace())
162 .map(|token| token.trim_matches(|ch: char| !ch.is_ascii_alphanumeric()).to_lowercase())
163 .filter(|token| token.len() >= 4)
164 {
165 if keywords.iter().any(|existing| existing == &token) {
166 continue;
167 }
168 keywords.push(token);
169 if keywords.len() >= 12 {
170 break;
171 }
172 }
173 if keywords.is_empty() {
174 keywords.push("conversation".to_string());
175 }
176 keywords
177}
178
179#[cfg(test)]
180mod tests {
181 use std::sync::Arc;
182
183 use parking_lot::Mutex;
184
185 use super::MemoryService;
186 use crate::memory::{
187 domain::{HybridWeights, MemoryEntry, MemoryImportance, RecallHit, RecallQuery},
188 error::MemoryStoreError,
189 ports::MemoryStorePort,
190 };
191
192 #[derive(Debug, Default)]
193 struct MockStore {
194 rows: Mutex<Vec<MemoryEntry>>,
195 }
196
197 impl MemoryStorePort for MockStore {
198 fn init_schema(&self) -> Result<(), MemoryStoreError> {
199 Ok(())
200 }
201
202 fn insert(&self, entry: &MemoryEntry) -> Result<(), MemoryStoreError> {
203 self.rows.lock().push(entry.clone());
204 Ok(())
205 }
206
207 fn recall_hybrid(
208 &self,
209 query: &RecallQuery,
210 _weights: HybridWeights,
211 ) -> Result<Vec<RecallHit>, MemoryStoreError> {
212 let rows = self
213 .rows
214 .lock()
215 .iter()
216 .filter(|row| {
217 query.session_id.as_ref().is_none_or(|session_id| &row.session_id == session_id)
218 })
219 .cloned()
220 .collect::<Vec<_>>();
221
222 Ok(rows
223 .into_iter()
224 .map(|entry| RecallHit {
225 entry,
226 bm25_score: 0.5,
227 vector_score: Some(0.5),
228 final_score: 0.5,
229 })
230 .collect())
231 }
232 }
233
234 #[test]
235 fn render_empty_hits_returns_none() {
236 assert!(MemoryService::render_recall_context(&[]).is_none());
237 }
238
239 #[test]
240 fn persist_then_recall_roundtrip() {
241 let store: Arc<dyn MemoryStorePort> = Arc::new(MockStore::default());
242 let service = MemoryService::new(store, 5, HybridWeights::default(), 128, false);
243 assert!(service.is_ok(), "service construction should succeed");
244 let Ok(service) = service else {
245 return;
246 };
247
248 assert!(service.persist_turn("s1", "user asks", "assistant answers").is_ok());
249 let hits = service.recall_for_turn("s1", "asks");
250 assert!(hits.is_ok(), "recall should succeed");
251 let Ok(hits) = hits else {
252 return;
253 };
254
255 assert_eq!(hits.len(), 1);
256 assert_eq!(hits[0].entry.importance, MemoryImportance::Medium);
257 }
258
259 #[test]
261 fn recall_for_turn_uses_mock_store() {
262 let mock = Arc::new(MockStore::default());
263 let store: Arc<dyn MemoryStorePort> = Arc::clone(&mock) as _;
264 let Ok(service) = MemoryService::new(store, 3, HybridWeights::default(), 32, false) else {
265 return;
266 };
267
268 assert!(service.persist_turn("s-a", "hi", "hello").is_ok());
270 assert!(service.persist_turn("s-b", "bye", "farewell").is_ok());
271
272 let Ok(hits) = service.recall_for_turn("s-a", "hi") else {
273 return;
274 };
275 assert_eq!(hits.len(), 1);
277 assert_eq!(hits[0].entry.session_id, "s-a");
278 }
279
280 #[test]
282 fn render_recall_context_with_hits() {
283 let entry = MemoryEntry {
284 id: "m1".to_string(),
285 session_id: "s1".to_string(),
286 topic: "rust".to_string(),
287 summary: "ownership rules".to_string(),
288 raw_excerpt: String::new(),
289 keywords: vec![],
290 importance: MemoryImportance::Medium,
291 embedding: None,
292 created_at_epoch_ms: 0,
293 };
294 let hit = RecallHit { entry, bm25_score: 0.5, vector_score: Some(0.5), final_score: 0.5 };
295 let rendered = MemoryService::render_recall_context(&[hit]);
296 assert!(rendered.is_some());
297 let Ok(text) = rendered.ok_or("none") else {
298 return;
299 };
300 assert!(text.contains("1."));
301 assert!(text.contains("[rust]"));
302 assert!(text.contains("ownership rules"));
303 }
304
305 #[test]
307 fn recall_limit_must_be_positive() {
308 let store: Arc<dyn MemoryStorePort> = Arc::new(MockStore::default());
309 let result = MemoryService::new(store, 0, HybridWeights::default(), 128, false);
310 assert!(result.is_err());
311 }
312}