Skip to main content

rig_memory_policy/
inmem.rs

1//! In-memory append-only episode store with lexical retrieval.
2//!
3//! [`InMemoryStore`] is a no-disk reference memory backend: it keeps
4//! episodes in a process-local `Vec` and ranks them with a deterministic
5//! token-overlap score. It is intended for tests, examples, offline modes,
6//! and adapters that need an [`Episode`]-shaped retrieval surface without
7//! standing up a persistent store.
8//!
9//! Compared to durable backends, this store
10//!
11//! - has no I/O, no file lock, and no embedder;
12//! - returns at most `k` hits sorted by descending lexical score; and
13//! - is generic over a user-defined [`Episode`] payload, so callers keep
14//!   their own domain types without forcing them through a backend-specific
15//!   serialization format.
16//!
17//! ```no_run
18//! use rig_memory_policy::inmem::{Episode, InMemoryStore};
19//!
20//! #[derive(Clone)]
21//! struct MyEpisode {
22//!     summary: String,
23//! }
24//!
25//! impl Episode for MyEpisode {
26//!     fn summary(&self) -> &str { &self.summary }
27//! }
28//!
29//! # async fn run() -> Result<(), Box<dyn std::error::Error>> {
30//! let store = InMemoryStore::<MyEpisode>::new();
31//! let key = store
32//!     .append(MyEpisode { summary: "scheduled maintenance".into() })
33//!     .await?;
34//! let hits = store.retrieve_similar("maintenance", 5).await?;
35//! assert_eq!(hits.first().map(|h| h.episode.summary.as_str()),
36//!            Some("scheduled maintenance"));
37//! let _ = (key, hits);
38//! # Ok(()) }
39//! ```
40
41use std::collections::HashSet;
42use std::sync::{Arc, Mutex};
43
44use crate::PolicyError;
45
46/// User-defined episode payload with a searchable natural-language
47/// summary.
48///
49/// Implementors keep their own domain shape; only the summary string is
50/// observed by the lexical scorer.
51pub trait Episode: Clone + Send + Sync + 'static {
52    /// The natural-language summary used to rank hits against a query.
53    /// Tokens are split on whitespace and scored by the fraction of
54    /// distinct normalized query tokens present in the summary.
55    fn summary(&self) -> &str;
56}
57
58/// A retrieval hit returned by [`InMemoryStore::retrieve_similar`].
59#[derive(Debug, Clone)]
60pub struct InMemoryHit<E> {
61    /// The stored episode.
62    pub episode: E,
63    /// Lexical similarity score in `[0, 1]`. Higher is more similar.
64    pub score: f32,
65    /// Stable storage key assigned by [`InMemoryStore::append`].
66    pub key: String,
67}
68
69/// Append-only in-memory episode store with deterministic lexical
70/// retrieval.
71///
72/// `E` is the caller's episode payload; only [`Episode::summary`] is
73/// inspected during retrieval.
74#[derive(Debug)]
75pub struct InMemoryStore<E: Episode> {
76    inner: Mutex<Inner<E>>,
77}
78
79#[derive(Debug)]
80struct Inner<E: Episode> {
81    next_key: u64,
82    episodes: Vec<(String, E)>,
83}
84
85impl<E: Episode> Default for InMemoryStore<E> {
86    fn default() -> Self {
87        Self {
88            inner: Mutex::new(Inner {
89                next_key: 0,
90                episodes: Vec::new(),
91            }),
92        }
93    }
94}
95
96impl<E: Episode> InMemoryStore<E> {
97    /// Create a fresh store wrapped in [`Arc`] for cheap cloning across
98    /// agent tasks.
99    #[must_use]
100    pub fn new() -> Arc<Self> {
101        Arc::new(Self::default())
102    }
103
104    /// Append a new episode and return its storage key.
105    ///
106    /// Keys are stable across the lifetime of the store and follow the
107    /// `ep-{:016x}` template so they sort lexicographically by insertion
108    /// order. If the process appends more than `u64::MAX` episodes, the
109    /// key counter saturates and subsequent appends reuse the terminal key.
110    pub async fn append(&self, episode: E) -> Result<String, PolicyError> {
111        let mut inner = self.inner.lock().map_err(|_| PolicyError::Poisoned)?;
112        inner.next_key = inner.next_key.saturating_add(1);
113        let key = format!("ep-{:016x}", inner.next_key);
114        inner.episodes.push((key.clone(), episode));
115        Ok(key)
116    }
117
118    /// Return up to `k` episodes most similar to `query`, sorted by
119    /// descending lexical score. Hits with score `0.0` are skipped.
120    pub async fn retrieve_similar(
121        &self,
122        query: &str,
123        k: usize,
124    ) -> Result<Vec<InMemoryHit<E>>, PolicyError> {
125        if k == 0 {
126            return Ok(Vec::new());
127        }
128        let snapshot: Vec<(String, E)> = {
129            let inner = self.inner.lock().map_err(|_| PolicyError::Poisoned)?;
130            inner.episodes.clone()
131        };
132        let mut hits: Vec<InMemoryHit<E>> = snapshot
133            .into_iter()
134            .filter_map(|(key, episode)| {
135                let score = lexical_score(query, episode.summary());
136                if score > 0.0 {
137                    Some(InMemoryHit {
138                        episode,
139                        score,
140                        key,
141                    })
142                } else {
143                    None
144                }
145            })
146            .collect();
147        hits.sort_by(|a, b| {
148            b.score
149                .partial_cmp(&a.score)
150                .unwrap_or(std::cmp::Ordering::Equal)
151        });
152        hits.truncate(k);
153        Ok(hits)
154    }
155
156    /// Direct lookup by storage key.
157    pub async fn get(&self, key: &str) -> Result<E, PolicyError> {
158        let inner = self.inner.lock().map_err(|_| PolicyError::Poisoned)?;
159        inner
160            .episodes
161            .iter()
162            .find(|(stored_key, _)| stored_key == key)
163            .map(|(_, episode)| episode.clone())
164            .ok_or_else(|| PolicyError::NotFound(key.to_string()))
165    }
166
167    /// Number of episodes currently stored.
168    pub async fn len(&self) -> Result<usize, PolicyError> {
169        let inner = self.inner.lock().map_err(|_| PolicyError::Poisoned)?;
170        Ok(inner.episodes.len())
171    }
172
173    /// Whether the store is empty.
174    pub async fn is_empty(&self) -> Result<bool, PolicyError> {
175        Ok(self.len().await? == 0)
176    }
177}
178
179/// Token-overlap score in `[0, 1]`.
180///
181/// Returns the fraction of distinct normalized query tokens that appear
182/// in `summary`. Normalization is intentionally simple and deterministic:
183/// Unicode-aware lowercase via [`str::to_lowercase`] on each
184/// whitespace-delimited token, and trim leading/trailing non-alphanumeric
185/// characters. An empty query yields `0.0`.
186fn lexical_score(query: &str, summary: &str) -> f32 {
187    let query_tokens = normalized_tokens(query);
188    if query_tokens.is_empty() {
189        return 0.0;
190    }
191    let summary_tokens = normalized_tokens(summary);
192    let intersection = query_tokens.intersection(&summary_tokens).count() as f32;
193    intersection / query_tokens.len() as f32
194}
195
196fn normalized_tokens(input: &str) -> HashSet<String> {
197    input
198        .split_whitespace()
199        .map(|token| token.trim_matches(|ch: char| !ch.is_alphanumeric()))
200        .filter(|token| !token.is_empty())
201        .map(str::to_lowercase)
202        .collect()
203}
204
205#[cfg(test)]
206#[allow(clippy::unwrap_used, clippy::panic)]
207mod tests {
208    use super::*;
209
210    #[derive(Debug, Clone)]
211    struct TestEpisode(&'static str);
212
213    impl Episode for TestEpisode {
214        fn summary(&self) -> &str {
215            self.0
216        }
217    }
218
219    #[test]
220    fn append_assigns_unique_ordered_keys() {
221        let store = InMemoryStore::<TestEpisode>::new();
222        let key_one = pollster::block_on(store.append(TestEpisode("a"))).unwrap();
223        let key_two = pollster::block_on(store.append(TestEpisode("b"))).unwrap();
224        assert!(key_one < key_two);
225    }
226
227    #[test]
228    fn retrieve_returns_top_k_by_score() {
229        let store = InMemoryStore::<TestEpisode>::new();
230        pollster::block_on(store.append(TestEpisode("powershell maintenance scheduled task")))
231            .unwrap();
232        pollster::block_on(store.append(TestEpisode("ddos amplification spike"))).unwrap();
233        let hits = pollster::block_on(store.retrieve_similar("powershell scheduled", 5)).unwrap();
234        assert_eq!(hits.len(), 1);
235        assert!(hits.first().unwrap().episode.0.contains("powershell"));
236    }
237
238    #[test]
239    fn retrieve_skips_zero_score_hits() {
240        let store = InMemoryStore::<TestEpisode>::new();
241        pollster::block_on(store.append(TestEpisode("alpha bravo"))).unwrap();
242        let hits = pollster::block_on(store.retrieve_similar("zulu", 5)).unwrap();
243        assert!(hits.is_empty());
244    }
245
246    #[test]
247    fn retrieve_matches_case_insensitively() {
248        let store = InMemoryStore::<TestEpisode>::new();
249        pollster::block_on(store.append(TestEpisode("PowerShell scheduled task"))).unwrap();
250        let hits = pollster::block_on(store.retrieve_similar("powershell", 5)).unwrap();
251        assert_eq!(hits.len(), 1);
252    }
253
254    #[test]
255    fn retrieve_trims_simple_punctuation() {
256        let store = InMemoryStore::<TestEpisode>::new();
257        pollster::block_on(store.append(TestEpisode("powershell, scheduled-task beacon"))).unwrap();
258        let hits =
259            pollster::block_on(store.retrieve_similar("powershell scheduled-task", 5)).unwrap();
260        assert_eq!(hits.len(), 1);
261    }
262
263    #[test]
264    fn retrieve_handles_unicode_case_folding() {
265        let store = InMemoryStore::<TestEpisode>::new();
266        pollster::block_on(store.append(TestEpisode("ПОЛЬЗОВАТЕЛЬ logged in"))).unwrap();
267        let hits = pollster::block_on(store.retrieve_similar("пользователь", 5)).unwrap();
268        assert_eq!(hits.len(), 1);
269    }
270
271    #[test]
272    fn retrieve_trims_unicode_punctuation() {
273        let store = InMemoryStore::<TestEpisode>::new();
274        pollster::block_on(store.append(TestEpisode("「scheduled-task」 beacon"))).unwrap();
275        let hits = pollster::block_on(store.retrieve_similar("scheduled-task", 5)).unwrap();
276        assert_eq!(hits.len(), 1);
277    }
278
279    #[test]
280    fn get_returns_not_found_for_unknown_key() {
281        let store = InMemoryStore::<TestEpisode>::new();
282        let err = pollster::block_on(store.get("nope")).unwrap_err();
283        assert!(matches!(err, PolicyError::NotFound(_)));
284    }
285
286    #[test]
287    fn len_and_is_empty_track_inserts() {
288        let store = InMemoryStore::<TestEpisode>::new();
289        assert!(pollster::block_on(store.is_empty()).unwrap());
290        pollster::block_on(store.append(TestEpisode("x"))).unwrap();
291        assert_eq!(pollster::block_on(store.len()).unwrap(), 1);
292        assert!(!pollster::block_on(store.is_empty()).unwrap());
293    }
294
295    #[test]
296    fn k_zero_returns_empty() {
297        let store = InMemoryStore::<TestEpisode>::new();
298        pollster::block_on(store.append(TestEpisode("alpha"))).unwrap();
299        assert!(
300            pollster::block_on(store.retrieve_similar("alpha", 0))
301                .unwrap()
302                .is_empty()
303        );
304    }
305}