Skip to main content

rig_memvid/
inmem.rs

1//! In-memory append-only episode store with lexical retrieval.
2//!
3//! [`InMemoryStore`] is the no-disk companion to [`crate::MemvidStore`]:
4//! it keeps episodes in a process-local `Vec` and ranks them with a
5//! deterministic token-overlap score. It is intended for tests,
6//! examples, and offline modes that don't want to spin up a memvid
7//! `.mv2` archive but still need an [`Episode`]-shaped retrieval
8//! surface.
9//!
10//! Compared to [`crate::MemvidStore`], this store
11//!
12//! - has no I/O, no file lock, and no embedder;
13//! - returns at most `k` hits sorted by descending lexical score; and
14//! - is generic over a user-defined [`Episode`] payload, so callers
15//!   keep their own domain types (alert envelopes, hunter findings,
16//!   etc.) without forcing them through memvid's serialisation.
17//!
18//! ```no_run
19//! use rig_memvid::inmem::{Episode, InMemoryStore};
20//!
21//! #[derive(Clone)]
22//! struct MyEpisode {
23//!     summary: String,
24//! }
25//!
26//! impl Episode for MyEpisode {
27//!     fn summary(&self) -> &str { &self.summary }
28//! }
29//!
30//! # async fn run() -> Result<(), Box<dyn std::error::Error>> {
31//! let store = InMemoryStore::<MyEpisode>::new();
32//! let key = store
33//!     .append(MyEpisode { summary: "scheduled maintenance".into() })
34//!     .await?;
35//! let hits = store.retrieve_similar("maintenance", 5).await?;
36//! assert_eq!(hits.first().map(|h| h.episode.summary.as_str()),
37//!            Some("scheduled maintenance"));
38//! let _ = (key, hits);
39//! # Ok(()) }
40//! ```
41
42use std::collections::HashSet;
43use std::sync::{Arc, Mutex};
44
45/// Errors returned by [`InMemoryStore`] operations.
46#[derive(Debug, thiserror::Error)]
47pub enum InMemoryError {
48    /// Direct lookup against an unknown key.
49    #[error("episode not found: {0}")]
50    NotFound(String),
51
52    /// The store's mutex was poisoned by a previous panic.
53    #[error("in-memory store mutex poisoned")]
54    Poisoned,
55}
56
57/// User-defined episode payload with a searchable natural-language
58/// summary.
59///
60/// Implementors keep their own domain shape — only the summary string
61/// is observed by the lexical scorer.
62pub trait Episode: Clone + Send + Sync + 'static {
63    /// The natural-language summary used to rank hits against a query.
64    /// Tokens are split on whitespace; longer summaries score higher
65    /// for queries with broader vocabulary overlap.
66    fn summary(&self) -> &str;
67}
68
69/// A retrieval hit returned by [`InMemoryStore::retrieve_similar`].
70#[derive(Debug, Clone)]
71pub struct InMemoryHit<E> {
72    /// The stored episode.
73    pub episode: E,
74    /// Lexical similarity score in `[0, 1]`. Higher is more similar.
75    pub score: f32,
76    /// Stable storage key assigned by [`InMemoryStore::append`].
77    pub key: String,
78}
79
80/// Append-only in-memory episode store with deterministic lexical
81/// retrieval.
82///
83/// `E` is the caller's episode payload; only [`Episode::summary`] is
84/// inspected during retrieval.
85#[derive(Debug)]
86pub struct InMemoryStore<E: Episode> {
87    inner: Mutex<Inner<E>>,
88}
89
90#[derive(Debug)]
91struct Inner<E: Episode> {
92    next_key: u64,
93    /// Insertion-ordered (key, episode) pairs. Order is preserved so
94    /// retrieval ties break deterministically by insertion time.
95    episodes: Vec<(String, E)>,
96}
97
98impl<E: Episode> Default for InMemoryStore<E> {
99    fn default() -> Self {
100        Self {
101            inner: Mutex::new(Inner {
102                next_key: 0,
103                episodes: Vec::new(),
104            }),
105        }
106    }
107}
108
109impl<E: Episode> InMemoryStore<E> {
110    /// Create a fresh store wrapped in [`Arc`] for cheap cloning across
111    /// hunter / agent tasks.
112    pub fn new() -> Arc<Self> {
113        Arc::new(Self::default())
114    }
115
116    /// Append a new episode and return its storage key.
117    ///
118    /// Keys are stable across the lifetime of the store and follow the
119    /// `ep-{:016x}` template so they sort lexicographically by insertion
120    /// order.
121    pub async fn append(&self, episode: E) -> Result<String, InMemoryError> {
122        let mut inner = self.inner.lock().map_err(|_| InMemoryError::Poisoned)?;
123        // L5: `saturating_add` is correct but worth a comment — once
124        // `next_key` reaches `u64::MAX` every subsequent append produces
125        // the same key, and `episodes.push` happily accepts duplicates.
126        // The in-memory store is a test fake (see the module doc), so
127        // hitting that ceiling means the test loop has run >= 2^64
128        // turns and the duplicate-key behaviour is the smallest of the
129        // caller's problems. The real `MemvidStore` uses memvid's
130        // frame-id allocator, which has its own overflow handling.
131        inner.next_key = inner.next_key.saturating_add(1);
132        let key = format!("ep-{:016x}", inner.next_key);
133        inner.episodes.push((key.clone(), episode));
134        Ok(key)
135    }
136
137    /// Return up to `k` episodes most similar to `query`, sorted by
138    /// descending lexical score. Hits with score `0.0` are skipped.
139    pub async fn retrieve_similar(
140        &self,
141        query: &str,
142        k: usize,
143    ) -> Result<Vec<InMemoryHit<E>>, InMemoryError> {
144        if k == 0 {
145            return Ok(Vec::new());
146        }
147        let snapshot: Vec<(String, E)> = {
148            let inner = self.inner.lock().map_err(|_| InMemoryError::Poisoned)?;
149            inner.episodes.clone()
150        };
151        let mut hits: Vec<InMemoryHit<E>> = snapshot
152            .into_iter()
153            .filter_map(|(key, ep)| {
154                let score = lexical_score(query, ep.summary());
155                if score > 0.0 {
156                    Some(InMemoryHit {
157                        episode: ep,
158                        score,
159                        key,
160                    })
161                } else {
162                    None
163                }
164            })
165            .collect();
166        hits.sort_by(|a, b| {
167            b.score
168                .partial_cmp(&a.score)
169                .unwrap_or(std::cmp::Ordering::Equal)
170        });
171        hits.truncate(k);
172        Ok(hits)
173    }
174
175    /// Direct lookup by storage key.
176    pub async fn get(&self, key: &str) -> Result<E, InMemoryError> {
177        let inner = self.inner.lock().map_err(|_| InMemoryError::Poisoned)?;
178        inner
179            .episodes
180            .iter()
181            .find(|(k, _)| k == key)
182            .map(|(_, ep)| ep.clone())
183            .ok_or_else(|| InMemoryError::NotFound(key.to_string()))
184    }
185
186    /// Number of episodes currently stored.
187    pub async fn len(&self) -> Result<usize, InMemoryError> {
188        let inner = self.inner.lock().map_err(|_| InMemoryError::Poisoned)?;
189        Ok(inner.episodes.len())
190    }
191
192    /// Whether the store is empty.
193    pub async fn is_empty(&self) -> Result<bool, InMemoryError> {
194        Ok(self.len().await? == 0)
195    }
196}
197
198/// Token-overlap score in `[0, 1]`.
199///
200/// Returns the fraction of distinct normalized query tokens that appear
201/// in `summary`. Normalization is intentionally simple and
202/// deterministic: Unicode-aware lowercase via [`str::to_lowercase`] on
203/// each whitespace-delimited token, and trim leading/trailing
204/// non-alphanumeric characters (a Unicode-aware superset of ASCII
205/// punctuation). An empty query yields `0.0`.
206fn lexical_score(query: &str, summary: &str) -> f32 {
207    let query_tokens = normalized_tokens(query);
208    if query_tokens.is_empty() {
209        return 0.0;
210    }
211    let summary_tokens = normalized_tokens(summary);
212    let intersection = query_tokens.intersection(&summary_tokens).count() as f32;
213    intersection / query_tokens.len() as f32
214}
215
216fn normalized_tokens(input: &str) -> HashSet<String> {
217    input
218        .split_whitespace()
219        .map(|token| token.trim_matches(|ch: char| !ch.is_alphanumeric()))
220        .filter(|token| !token.is_empty())
221        .map(str::to_lowercase)
222        .collect()
223}
224
225#[cfg(test)]
226#[allow(clippy::unwrap_used)]
227mod tests {
228    use super::*;
229
230    #[derive(Debug, Clone)]
231    struct E(&'static str);
232
233    impl Episode for E {
234        fn summary(&self) -> &str {
235            self.0
236        }
237    }
238
239    #[tokio::test]
240    async fn append_assigns_unique_ordered_keys() {
241        let s = InMemoryStore::<E>::new();
242        let k1 = s.append(E("a")).await.unwrap();
243        let k2 = s.append(E("b")).await.unwrap();
244        assert!(k1 < k2);
245    }
246
247    #[tokio::test]
248    async fn retrieve_returns_top_k_by_score() {
249        let s = InMemoryStore::<E>::new();
250        s.append(E("powershell maintenance scheduled task"))
251            .await
252            .unwrap();
253        s.append(E("ddos amplification spike")).await.unwrap();
254        let hits = s.retrieve_similar("powershell scheduled", 5).await.unwrap();
255        assert_eq!(hits.len(), 1);
256        assert!(hits.first().unwrap().episode.0.contains("powershell"));
257    }
258
259    #[tokio::test]
260    async fn retrieve_skips_zero_score_hits() {
261        let s = InMemoryStore::<E>::new();
262        s.append(E("alpha bravo")).await.unwrap();
263        let hits = s.retrieve_similar("zulu", 5).await.unwrap();
264        assert!(hits.is_empty());
265    }
266
267    #[tokio::test]
268    async fn retrieve_matches_case_insensitively() {
269        let s = InMemoryStore::<E>::new();
270        s.append(E("PowerShell scheduled task")).await.unwrap();
271        let hits = s.retrieve_similar("powershell", 5).await.unwrap();
272        assert_eq!(hits.len(), 1);
273    }
274
275    #[tokio::test]
276    async fn retrieve_trims_simple_punctuation() {
277        let s = InMemoryStore::<E>::new();
278        s.append(E("powershell, scheduled-task beacon"))
279            .await
280            .unwrap();
281        let hits = s
282            .retrieve_similar("powershell scheduled-task", 5)
283            .await
284            .unwrap();
285        assert_eq!(hits.len(), 1);
286    }
287
288    #[tokio::test]
289    async fn retrieve_handles_unicode_case_folding() {
290        // Cyrillic case folding requires Unicode-aware lowercase.
291        let s = InMemoryStore::<E>::new();
292        s.append(E("ПОЛЬЗОВАТЕЛЬ logged in")).await.unwrap();
293        let hits = s.retrieve_similar("пользователь", 5).await.unwrap();
294        assert_eq!(hits.len(), 1);
295    }
296
297    #[tokio::test]
298    async fn retrieve_trims_unicode_punctuation() {
299        // The trailing 」 is Unicode punctuation, not ASCII; an
300        // ASCII-only trim would leave it attached and miss the match.
301        let s = InMemoryStore::<E>::new();
302        s.append(E("「scheduled-task」 beacon")).await.unwrap();
303        let hits = s.retrieve_similar("scheduled-task", 5).await.unwrap();
304        assert_eq!(hits.len(), 1);
305    }
306
307    #[tokio::test]
308    async fn get_returns_not_found_for_unknown_key() {
309        let s = InMemoryStore::<E>::new();
310        let err = s.get("nope").await.unwrap_err();
311        assert!(matches!(err, InMemoryError::NotFound(_)));
312    }
313
314    #[tokio::test]
315    async fn len_and_is_empty_track_inserts() {
316        let s = InMemoryStore::<E>::new();
317        assert!(s.is_empty().await.unwrap());
318        s.append(E("x")).await.unwrap();
319        assert_eq!(s.len().await.unwrap(), 1);
320        assert!(!s.is_empty().await.unwrap());
321    }
322
323    #[tokio::test]
324    async fn k_zero_returns_empty() {
325        let s = InMemoryStore::<E>::new();
326        s.append(E("alpha")).await.unwrap();
327        assert!(s.retrieve_similar("alpha", 0).await.unwrap().is_empty());
328    }
329}