rig_memory_policy/
inmem.rs1use std::collections::HashSet;
42use std::sync::{Arc, Mutex};
43
44use crate::PolicyError;
45
46pub trait Episode: Clone + Send + Sync + 'static {
52 fn summary(&self) -> &str;
56}
57
58#[derive(Debug, Clone)]
60pub struct InMemoryHit<E> {
61 pub episode: E,
63 pub score: f32,
65 pub key: String,
67}
68
69#[derive(Debug)]
75pub struct InMemoryStore<E: Episode> {
76 inner: Mutex<Inner<E>>,
77}
78
79#[derive(Debug)]
80struct Inner<E: Episode> {
81 next_key: u64,
82 episodes: Vec<(String, E)>,
83}
84
85impl<E: Episode> Default for InMemoryStore<E> {
86 fn default() -> Self {
87 Self {
88 inner: Mutex::new(Inner {
89 next_key: 0,
90 episodes: Vec::new(),
91 }),
92 }
93 }
94}
95
96impl<E: Episode> InMemoryStore<E> {
97 #[must_use]
100 pub fn new() -> Arc<Self> {
101 Arc::new(Self::default())
102 }
103
104 pub async fn append(&self, episode: E) -> Result<String, PolicyError> {
111 let mut inner = self.inner.lock().map_err(|_| PolicyError::Poisoned)?;
112 inner.next_key = inner.next_key.saturating_add(1);
113 let key = format!("ep-{:016x}", inner.next_key);
114 inner.episodes.push((key.clone(), episode));
115 Ok(key)
116 }
117
118 pub async fn retrieve_similar(
121 &self,
122 query: &str,
123 k: usize,
124 ) -> Result<Vec<InMemoryHit<E>>, PolicyError> {
125 if k == 0 {
126 return Ok(Vec::new());
127 }
128 let snapshot: Vec<(String, E)> = {
129 let inner = self.inner.lock().map_err(|_| PolicyError::Poisoned)?;
130 inner.episodes.clone()
131 };
132 let mut hits: Vec<InMemoryHit<E>> = snapshot
133 .into_iter()
134 .filter_map(|(key, episode)| {
135 let score = lexical_score(query, episode.summary());
136 if score > 0.0 {
137 Some(InMemoryHit {
138 episode,
139 score,
140 key,
141 })
142 } else {
143 None
144 }
145 })
146 .collect();
147 hits.sort_by(|a, b| {
148 b.score
149 .partial_cmp(&a.score)
150 .unwrap_or(std::cmp::Ordering::Equal)
151 });
152 hits.truncate(k);
153 Ok(hits)
154 }
155
156 pub async fn get(&self, key: &str) -> Result<E, PolicyError> {
158 let inner = self.inner.lock().map_err(|_| PolicyError::Poisoned)?;
159 inner
160 .episodes
161 .iter()
162 .find(|(stored_key, _)| stored_key == key)
163 .map(|(_, episode)| episode.clone())
164 .ok_or_else(|| PolicyError::NotFound(key.to_string()))
165 }
166
167 pub async fn len(&self) -> Result<usize, PolicyError> {
169 let inner = self.inner.lock().map_err(|_| PolicyError::Poisoned)?;
170 Ok(inner.episodes.len())
171 }
172
173 pub async fn is_empty(&self) -> Result<bool, PolicyError> {
175 Ok(self.len().await? == 0)
176 }
177}
178
179fn lexical_score(query: &str, summary: &str) -> f32 {
187 let query_tokens = normalized_tokens(query);
188 if query_tokens.is_empty() {
189 return 0.0;
190 }
191 let summary_tokens = normalized_tokens(summary);
192 let intersection = query_tokens.intersection(&summary_tokens).count() as f32;
193 intersection / query_tokens.len() as f32
194}
195
196fn normalized_tokens(input: &str) -> HashSet<String> {
197 input
198 .split_whitespace()
199 .map(|token| token.trim_matches(|ch: char| !ch.is_alphanumeric()))
200 .filter(|token| !token.is_empty())
201 .map(str::to_lowercase)
202 .collect()
203}
204
205#[cfg(test)]
206#[allow(clippy::unwrap_used, clippy::panic)]
207mod tests {
208 use super::*;
209
210 #[derive(Debug, Clone)]
211 struct TestEpisode(&'static str);
212
213 impl Episode for TestEpisode {
214 fn summary(&self) -> &str {
215 self.0
216 }
217 }
218
219 #[test]
220 fn append_assigns_unique_ordered_keys() {
221 let store = InMemoryStore::<TestEpisode>::new();
222 let key_one = pollster::block_on(store.append(TestEpisode("a"))).unwrap();
223 let key_two = pollster::block_on(store.append(TestEpisode("b"))).unwrap();
224 assert!(key_one < key_two);
225 }
226
227 #[test]
228 fn retrieve_returns_top_k_by_score() {
229 let store = InMemoryStore::<TestEpisode>::new();
230 pollster::block_on(store.append(TestEpisode("powershell maintenance scheduled task")))
231 .unwrap();
232 pollster::block_on(store.append(TestEpisode("ddos amplification spike"))).unwrap();
233 let hits = pollster::block_on(store.retrieve_similar("powershell scheduled", 5)).unwrap();
234 assert_eq!(hits.len(), 1);
235 assert!(hits.first().unwrap().episode.0.contains("powershell"));
236 }
237
238 #[test]
239 fn retrieve_skips_zero_score_hits() {
240 let store = InMemoryStore::<TestEpisode>::new();
241 pollster::block_on(store.append(TestEpisode("alpha bravo"))).unwrap();
242 let hits = pollster::block_on(store.retrieve_similar("zulu", 5)).unwrap();
243 assert!(hits.is_empty());
244 }
245
246 #[test]
247 fn retrieve_matches_case_insensitively() {
248 let store = InMemoryStore::<TestEpisode>::new();
249 pollster::block_on(store.append(TestEpisode("PowerShell scheduled task"))).unwrap();
250 let hits = pollster::block_on(store.retrieve_similar("powershell", 5)).unwrap();
251 assert_eq!(hits.len(), 1);
252 }
253
254 #[test]
255 fn retrieve_trims_simple_punctuation() {
256 let store = InMemoryStore::<TestEpisode>::new();
257 pollster::block_on(store.append(TestEpisode("powershell, scheduled-task beacon"))).unwrap();
258 let hits =
259 pollster::block_on(store.retrieve_similar("powershell scheduled-task", 5)).unwrap();
260 assert_eq!(hits.len(), 1);
261 }
262
263 #[test]
264 fn retrieve_handles_unicode_case_folding() {
265 let store = InMemoryStore::<TestEpisode>::new();
266 pollster::block_on(store.append(TestEpisode("ПОЛЬЗОВАТЕЛЬ logged in"))).unwrap();
267 let hits = pollster::block_on(store.retrieve_similar("пользователь", 5)).unwrap();
268 assert_eq!(hits.len(), 1);
269 }
270
271 #[test]
272 fn retrieve_trims_unicode_punctuation() {
273 let store = InMemoryStore::<TestEpisode>::new();
274 pollster::block_on(store.append(TestEpisode("「scheduled-task」 beacon"))).unwrap();
275 let hits = pollster::block_on(store.retrieve_similar("scheduled-task", 5)).unwrap();
276 assert_eq!(hits.len(), 1);
277 }
278
279 #[test]
280 fn get_returns_not_found_for_unknown_key() {
281 let store = InMemoryStore::<TestEpisode>::new();
282 let err = pollster::block_on(store.get("nope")).unwrap_err();
283 assert!(matches!(err, PolicyError::NotFound(_)));
284 }
285
286 #[test]
287 fn len_and_is_empty_track_inserts() {
288 let store = InMemoryStore::<TestEpisode>::new();
289 assert!(pollster::block_on(store.is_empty()).unwrap());
290 pollster::block_on(store.append(TestEpisode("x"))).unwrap();
291 assert_eq!(pollster::block_on(store.len()).unwrap(), 1);
292 assert!(!pollster::block_on(store.is_empty()).unwrap());
293 }
294
295 #[test]
296 fn k_zero_returns_empty() {
297 let store = InMemoryStore::<TestEpisode>::new();
298 pollster::block_on(store.append(TestEpisode("alpha"))).unwrap();
299 assert!(
300 pollster::block_on(store.retrieve_similar("alpha", 0))
301 .unwrap()
302 .is_empty()
303 );
304 }
305}