use std::collections::HashSet;
use std::sync::{Arc, Mutex};
#[derive(Debug, thiserror::Error)]
pub enum InMemoryError {
#[error("episode not found: {0}")]
NotFound(String),
#[error("in-memory store mutex poisoned")]
Poisoned,
}
pub trait Episode: Clone + Send + Sync + 'static {
fn summary(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct InMemoryHit<E> {
pub episode: E,
pub score: f32,
pub key: String,
}
#[derive(Debug)]
pub struct InMemoryStore<E: Episode> {
inner: Mutex<Inner<E>>,
}
#[derive(Debug)]
struct Inner<E: Episode> {
next_key: u64,
episodes: Vec<(String, E)>,
}
impl<E: Episode> Default for InMemoryStore<E> {
fn default() -> Self {
Self {
inner: Mutex::new(Inner {
next_key: 0,
episodes: Vec::new(),
}),
}
}
}
impl<E: Episode> InMemoryStore<E> {
pub fn new() -> Arc<Self> {
Arc::new(Self::default())
}
pub async fn append(&self, episode: E) -> Result<String, InMemoryError> {
let mut inner = self.inner.lock().map_err(|_| InMemoryError::Poisoned)?;
inner.next_key = inner.next_key.saturating_add(1);
let key = format!("ep-{:016x}", inner.next_key);
inner.episodes.push((key.clone(), episode));
Ok(key)
}
pub async fn retrieve_similar(
&self,
query: &str,
k: usize,
) -> Result<Vec<InMemoryHit<E>>, InMemoryError> {
if k == 0 {
return Ok(Vec::new());
}
let snapshot: Vec<(String, E)> = {
let inner = self.inner.lock().map_err(|_| InMemoryError::Poisoned)?;
inner.episodes.clone()
};
let mut hits: Vec<InMemoryHit<E>> = snapshot
.into_iter()
.filter_map(|(key, ep)| {
let score = lexical_score(query, ep.summary());
if score > 0.0 {
Some(InMemoryHit {
episode: ep,
score,
key,
})
} else {
None
}
})
.collect();
hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
hits.truncate(k);
Ok(hits)
}
pub async fn get(&self, key: &str) -> Result<E, InMemoryError> {
let inner = self.inner.lock().map_err(|_| InMemoryError::Poisoned)?;
inner
.episodes
.iter()
.find(|(k, _)| k == key)
.map(|(_, ep)| ep.clone())
.ok_or_else(|| InMemoryError::NotFound(key.to_string()))
}
pub async fn len(&self) -> Result<usize, InMemoryError> {
let inner = self.inner.lock().map_err(|_| InMemoryError::Poisoned)?;
Ok(inner.episodes.len())
}
pub async fn is_empty(&self) -> Result<bool, InMemoryError> {
Ok(self.len().await? == 0)
}
}
fn lexical_score(query: &str, summary: &str) -> f32 {
let query_tokens = normalized_tokens(query);
if query_tokens.is_empty() {
return 0.0;
}
let summary_tokens = normalized_tokens(summary);
let intersection = query_tokens.intersection(&summary_tokens).count() as f32;
intersection / query_tokens.len() as f32
}
fn normalized_tokens(input: &str) -> HashSet<String> {
input
.split_whitespace()
.map(|token| token.trim_matches(|ch: char| !ch.is_alphanumeric()))
.filter(|token| !token.is_empty())
.map(str::to_lowercase)
.collect()
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct E(&'static str);
impl Episode for E {
fn summary(&self) -> &str {
self.0
}
}
#[tokio::test]
async fn append_assigns_unique_ordered_keys() {
let s = InMemoryStore::<E>::new();
let k1 = s.append(E("a")).await.unwrap();
let k2 = s.append(E("b")).await.unwrap();
assert!(k1 < k2);
}
#[tokio::test]
async fn retrieve_returns_top_k_by_score() {
let s = InMemoryStore::<E>::new();
s.append(E("powershell maintenance scheduled task"))
.await
.unwrap();
s.append(E("ddos amplification spike")).await.unwrap();
let hits = s.retrieve_similar("powershell scheduled", 5).await.unwrap();
assert_eq!(hits.len(), 1);
assert!(hits.first().unwrap().episode.0.contains("powershell"));
}
#[tokio::test]
async fn retrieve_skips_zero_score_hits() {
let s = InMemoryStore::<E>::new();
s.append(E("alpha bravo")).await.unwrap();
let hits = s.retrieve_similar("zulu", 5).await.unwrap();
assert!(hits.is_empty());
}
#[tokio::test]
async fn retrieve_matches_case_insensitively() {
let s = InMemoryStore::<E>::new();
s.append(E("PowerShell scheduled task")).await.unwrap();
let hits = s.retrieve_similar("powershell", 5).await.unwrap();
assert_eq!(hits.len(), 1);
}
#[tokio::test]
async fn retrieve_trims_simple_punctuation() {
let s = InMemoryStore::<E>::new();
s.append(E("powershell, scheduled-task beacon"))
.await
.unwrap();
let hits = s
.retrieve_similar("powershell scheduled-task", 5)
.await
.unwrap();
assert_eq!(hits.len(), 1);
}
#[tokio::test]
async fn retrieve_handles_unicode_case_folding() {
let s = InMemoryStore::<E>::new();
s.append(E("ПОЛЬЗОВАТЕЛЬ logged in")).await.unwrap();
let hits = s.retrieve_similar("пользователь", 5).await.unwrap();
assert_eq!(hits.len(), 1);
}
#[tokio::test]
async fn retrieve_trims_unicode_punctuation() {
let s = InMemoryStore::<E>::new();
s.append(E("「scheduled-task」 beacon")).await.unwrap();
let hits = s.retrieve_similar("scheduled-task", 5).await.unwrap();
assert_eq!(hits.len(), 1);
}
#[tokio::test]
async fn get_returns_not_found_for_unknown_key() {
let s = InMemoryStore::<E>::new();
let err = s.get("nope").await.unwrap_err();
assert!(matches!(err, InMemoryError::NotFound(_)));
}
#[tokio::test]
async fn len_and_is_empty_track_inserts() {
let s = InMemoryStore::<E>::new();
assert!(s.is_empty().await.unwrap());
s.append(E("x")).await.unwrap();
assert_eq!(s.len().await.unwrap(), 1);
assert!(!s.is_empty().await.unwrap());
}
#[tokio::test]
async fn k_zero_returns_empty() {
let s = InMemoryStore::<E>::new();
s.append(E("alpha")).await.unwrap();
assert!(s.retrieve_similar("alpha", 0).await.unwrap().is_empty());
}
}