Skip to main content

alice_core/memory/
service.rs

1//! Application service wrapping memory recall and persistence workflows.
2
3use std::sync::{
4    Arc,
5    atomic::{AtomicU64, Ordering},
6};
7
8use crate::memory::{
9    domain::{HybridWeights, MemoryEntry, MemoryImportance, RecallHit, RecallQuery},
10    error::{MemoryServiceError, MemoryValidationError},
11    hybrid::simple_text_embedding,
12    ports::MemoryStorePort,
13};
14
15static MEMORY_COUNTER: AtomicU64 = AtomicU64::new(1);
16
17/// High-level memory use-cases for Alice runtime integration.
18pub struct MemoryService {
19    store: Arc<dyn MemoryStorePort>,
20    recall_limit: usize,
21    weights: HybridWeights,
22    vector_dimensions: usize,
23    enable_vector: bool,
24}
25
26impl std::fmt::Debug for MemoryService {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("MemoryService")
29            .field("recall_limit", &self.recall_limit)
30            .field("weights", &self.weights)
31            .field("vector_dimensions", &self.vector_dimensions)
32            .field("enable_vector", &self.enable_vector)
33            .finish_non_exhaustive()
34    }
35}
36
37impl MemoryService {
38    /// Construct a memory service and initialize store schema.
39    pub fn new(
40        store: Arc<dyn MemoryStorePort>,
41        recall_limit: usize,
42        weights: HybridWeights,
43        vector_dimensions: usize,
44        enable_vector: bool,
45    ) -> Result<Self, MemoryServiceError> {
46        if recall_limit == 0 {
47            return Err(MemoryValidationError::InvalidRecallLimit.into());
48        }
49        store.init_schema()?;
50        Ok(Self {
51            store,
52            recall_limit,
53            weights,
54            vector_dimensions: vector_dimensions.max(1),
55            enable_vector,
56        })
57    }
58
59    /// Recall memory hits relevant to current input.
60    pub fn recall_for_turn(
61        &self,
62        session_id: &str,
63        input: &str,
64    ) -> Result<Vec<RecallHit>, MemoryServiceError> {
65        let query_embedding =
66            self.enable_vector.then(|| simple_text_embedding(input, self.vector_dimensions));
67
68        let query = RecallQuery {
69            session_id: Some(session_id.to_string()),
70            text: input.to_string(),
71            query_embedding,
72            limit: self.recall_limit,
73        };
74
75        self.store.recall_hybrid(&query, self.weights).map_err(MemoryServiceError::from)
76    }
77
78    /// Render recalled memory as prompt context.
79    #[must_use]
80    pub fn render_recall_context(hits: &[RecallHit]) -> Option<String> {
81        if hits.is_empty() {
82            return None;
83        }
84
85        let mut output = String::from("Relevant prior memory:\n");
86        for (index, hit) in hits.iter().enumerate() {
87            let number = index + 1;
88            output.push_str(&format!(
89                "{number}. [{}] {}\n",
90                hit.entry.topic,
91                hit.entry.summary.trim()
92            ));
93        }
94
95        Some(output)
96    }
97
98    /// Persist one turn as a memory entry.
99    pub fn persist_turn(
100        &self,
101        session_id: &str,
102        user_input: &str,
103        assistant_output: &str,
104    ) -> Result<(), MemoryServiceError> {
105        let now_ms = current_time_millis();
106        let id = format!("mem-{now_ms}-{}", MEMORY_COUNTER.fetch_add(1, Ordering::Relaxed));
107
108        let summary = truncate(assistant_output.trim(), 300);
109        let raw_excerpt =
110            format!("user: {}\nassistant: {}", user_input.trim(), assistant_output.trim());
111
112        let embedding = self.enable_vector.then(|| {
113            simple_text_embedding(
114                &format!("{} {}", user_input.trim(), assistant_output.trim()),
115                self.vector_dimensions,
116            )
117        });
118
119        let entry = MemoryEntry {
120            id,
121            session_id: session_id.to_string(),
122            topic: session_id.to_string(),
123            summary,
124            raw_excerpt,
125            keywords: extract_keywords(user_input, assistant_output),
126            importance: MemoryImportance::Medium,
127            embedding,
128            created_at_epoch_ms: now_ms,
129        };
130
131        self.store.insert(&entry)?;
132        Ok(())
133    }
134}
135
136fn current_time_millis() -> i64 {
137    use std::time::{SystemTime, UNIX_EPOCH};
138
139    let duration = match SystemTime::now().duration_since(UNIX_EPOCH) {
140        Ok(value) => value,
141        Err(_) => return 0,
142    };
143    i64::try_from(duration.as_millis()).unwrap_or(i64::MAX)
144}
145
146fn truncate(input: &str, max_chars: usize) -> String {
147    let mut output = String::new();
148    for (idx, ch) in input.chars().enumerate() {
149        if idx >= max_chars {
150            break;
151        }
152        output.push(ch);
153    }
154    output
155}
156
157fn extract_keywords(user_input: &str, assistant_output: &str) -> Vec<String> {
158    let mut keywords = Vec::new();
159    for token in user_input
160        .split_whitespace()
161        .chain(assistant_output.split_whitespace())
162        .map(|token| token.trim_matches(|ch: char| !ch.is_ascii_alphanumeric()).to_lowercase())
163        .filter(|token| token.len() >= 4)
164    {
165        if keywords.iter().any(|existing| existing == &token) {
166            continue;
167        }
168        keywords.push(token);
169        if keywords.len() >= 12 {
170            break;
171        }
172    }
173    if keywords.is_empty() {
174        keywords.push("conversation".to_string());
175    }
176    keywords
177}
178
179#[cfg(test)]
180mod tests {
181    use std::sync::Arc;
182
183    use parking_lot::Mutex;
184
185    use super::MemoryService;
186    use crate::memory::{
187        domain::{HybridWeights, MemoryEntry, MemoryImportance, RecallHit, RecallQuery},
188        error::MemoryStoreError,
189        ports::MemoryStorePort,
190    };
191
192    #[derive(Debug, Default)]
193    struct MockStore {
194        rows: Mutex<Vec<MemoryEntry>>,
195    }
196
197    impl MemoryStorePort for MockStore {
198        fn init_schema(&self) -> Result<(), MemoryStoreError> {
199            Ok(())
200        }
201
202        fn insert(&self, entry: &MemoryEntry) -> Result<(), MemoryStoreError> {
203            self.rows.lock().push(entry.clone());
204            Ok(())
205        }
206
207        fn recall_hybrid(
208            &self,
209            query: &RecallQuery,
210            _weights: HybridWeights,
211        ) -> Result<Vec<RecallHit>, MemoryStoreError> {
212            let rows = self
213                .rows
214                .lock()
215                .iter()
216                .filter(|row| {
217                    query.session_id.as_ref().is_none_or(|session_id| &row.session_id == session_id)
218                })
219                .cloned()
220                .collect::<Vec<_>>();
221
222            Ok(rows
223                .into_iter()
224                .map(|entry| RecallHit {
225                    entry,
226                    bm25_score: 0.5,
227                    vector_score: Some(0.5),
228                    final_score: 0.5,
229                })
230                .collect())
231        }
232    }
233
234    #[test]
235    fn render_empty_hits_returns_none() {
236        assert!(MemoryService::render_recall_context(&[]).is_none());
237    }
238
239    #[test]
240    fn persist_then_recall_roundtrip() {
241        let store: Arc<dyn MemoryStorePort> = Arc::new(MockStore::default());
242        let service = MemoryService::new(store, 5, HybridWeights::default(), 128, false);
243        assert!(service.is_ok(), "service construction should succeed");
244        let Ok(service) = service else {
245            return;
246        };
247
248        assert!(service.persist_turn("s1", "user asks", "assistant answers").is_ok());
249        let hits = service.recall_for_turn("s1", "asks");
250        assert!(hits.is_ok(), "recall should succeed");
251        let Ok(hits) = hits else {
252            return;
253        };
254
255        assert_eq!(hits.len(), 1);
256        assert_eq!(hits[0].entry.importance, MemoryImportance::Medium);
257    }
258
259    /// `recall_for_turn` populates the correct session filter in the query.
260    #[test]
261    fn recall_for_turn_uses_mock_store() {
262        let mock = Arc::new(MockStore::default());
263        let store: Arc<dyn MemoryStorePort> = Arc::clone(&mock) as _;
264        let Ok(service) = MemoryService::new(store, 3, HybridWeights::default(), 32, false) else {
265            return;
266        };
267
268        // Insert two entries for different sessions.
269        assert!(service.persist_turn("s-a", "hi", "hello").is_ok());
270        assert!(service.persist_turn("s-b", "bye", "farewell").is_ok());
271
272        let Ok(hits) = service.recall_for_turn("s-a", "hi") else {
273            return;
274        };
275        // Only the s-a entry should match.
276        assert_eq!(hits.len(), 1);
277        assert_eq!(hits[0].entry.session_id, "s-a");
278    }
279
280    /// `render_recall_context` formats hits with numbered topic/summary lines.
281    #[test]
282    fn render_recall_context_with_hits() {
283        let entry = MemoryEntry {
284            id: "m1".to_string(),
285            session_id: "s1".to_string(),
286            topic: "rust".to_string(),
287            summary: "ownership rules".to_string(),
288            raw_excerpt: String::new(),
289            keywords: vec![],
290            importance: MemoryImportance::Medium,
291            embedding: None,
292            created_at_epoch_ms: 0,
293        };
294        let hit = RecallHit { entry, bm25_score: 0.5, vector_score: Some(0.5), final_score: 0.5 };
295        let rendered = MemoryService::render_recall_context(&[hit]);
296        assert!(rendered.is_some());
297        let Ok(text) = rendered.ok_or("none") else {
298            return;
299        };
300        assert!(text.contains("1."));
301        assert!(text.contains("[rust]"));
302        assert!(text.contains("ownership rules"));
303    }
304
305    /// Service respects `recall_limit` — cannot be zero.
306    #[test]
307    fn recall_limit_must_be_positive() {
308        let store: Arc<dyn MemoryStorePort> = Arc::new(MockStore::default());
309        let result = MemoryService::new(store, 0, HybridWeights::default(), 128, false);
310        assert!(result.is_err());
311    }
312}