use lru::LruCache;
use rusmes_proto::MessageId;
use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Mutex;
pub const DEFAULT_CAPACITY: usize = 256;
pub type CacheKey = (String, String);
#[derive(Clone, Debug)]
struct CacheValue {
ids: Vec<MessageId>,
version: u64,
}
pub struct ResultCache {
inner: Mutex<LruCache<CacheKey, CacheValue>>,
version: AtomicU64,
}
impl ResultCache {
pub fn new_default() -> Self {
let cap = NonZeroUsize::new(DEFAULT_CAPACITY).unwrap_or(NonZeroUsize::MIN);
Self::with_capacity(cap)
}
pub fn with_capacity(cap: NonZeroUsize) -> Self {
Self {
inner: Mutex::new(LruCache::new(cap)),
version: AtomicU64::new(0),
}
}
pub fn normalize_query(query: &str) -> String {
let lower = query.to_lowercase();
lower.split_whitespace().collect::<Vec<_>>().join(" ")
}
pub fn make_key(query: &str, user: Option<&str>) -> CacheKey {
(Self::normalize_query(query), user.unwrap_or("").to_string())
}
pub fn get(&self, key: &CacheKey) -> Option<Vec<MessageId>> {
let current = self.version.load(Ordering::Acquire);
let mut guard = match self.inner.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
};
let value = guard.get(key)?;
if value.version == current {
Some(value.ids.clone())
} else {
guard.pop(key);
None
}
}
pub fn put(&self, key: CacheKey, ids: Vec<MessageId>) {
let current = self.version.load(Ordering::Acquire);
let mut guard = match self.inner.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
};
guard.put(
key,
CacheValue {
ids,
version: current,
},
);
}
pub fn invalidate_all(&self) {
self.version.fetch_add(1, Ordering::AcqRel);
}
pub fn len(&self) -> usize {
let guard = match self.inner.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
};
guard.len()
}
pub fn version(&self) -> u64 {
self.version.load(Ordering::Acquire)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for ResultCache {
fn default() -> Self {
Self::new_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
use rusmes_proto::MessageId;
#[test]
fn normalize_lowercases_and_collapses_whitespace() {
let n = ResultCache::normalize_query(" Hello WORLD\t\nfoo ");
assert_eq!(n, "hello world foo");
}
#[test]
fn put_get_roundtrip_returns_ids() {
let cache = ResultCache::new_default();
let key = ResultCache::make_key("hello world", Some("alice"));
let id1 = MessageId::new();
let id2 = MessageId::new();
cache.put(key.clone(), vec![id1, id2]);
let hit = cache.get(&key).expect("entry should be present");
assert_eq!(hit, vec![id1, id2]);
}
#[test]
fn invalidate_all_makes_existing_entries_stale() {
let cache = ResultCache::new_default();
let key = ResultCache::make_key("q", None);
cache.put(key.clone(), vec![MessageId::new()]);
assert!(cache.get(&key).is_some());
cache.invalidate_all();
assert!(cache.get(&key).is_none());
}
#[test]
fn key_is_user_aware() {
let cache = ResultCache::new_default();
let k_alice = ResultCache::make_key("foo", Some("alice"));
let k_bob = ResultCache::make_key("foo", Some("bob"));
let id = MessageId::new();
cache.put(k_alice.clone(), vec![id]);
assert!(cache.get(&k_alice).is_some());
assert!(cache.get(&k_bob).is_none());
}
#[test]
fn make_key_normalizes_query_text() {
let k1 = ResultCache::make_key("Hello World", Some("u"));
let k2 = ResultCache::make_key("hello world", Some("u"));
assert_eq!(k1, k2);
}
}