use std::{fmt, sync::Arc};
use moka::sync::Cache;
use crate::{data::error_model::AssayErrorModels, simulator::likelihood::SubjectPredictions};
pub const DEFAULT_CACHE_SIZE: u64 = 100_000;
pub const DEFAULT_BOUND_ERROR_MODEL_CACHE_SIZE: u64 = 32;
pub(crate) type PredictionKey = (u64, u64);
pub(crate) type SdeKey = (u64, u64, u64);
pub(crate) type BoundErrorModelKey = u64;
#[derive(Clone)]
pub struct PredictionCache(Cache<PredictionKey, SubjectPredictions>);
impl PredictionCache {
pub fn new(size: u64) -> Self {
Self(Cache::new(size))
}
#[inline]
pub fn get(&self, key: &PredictionKey) -> Option<SubjectPredictions> {
self.0.get(key)
}
#[inline]
pub fn insert(&self, key: PredictionKey, value: SubjectPredictions) {
self.0.insert(key, value);
}
pub fn invalidate_all(&self) {
self.0.invalidate_all();
}
pub fn entry_count(&self) -> u64 {
self.0.entry_count()
}
}
impl fmt::Debug for PredictionCache {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PredictionCache")
.field("entry_count", &self.0.entry_count())
.finish()
}
}
#[derive(Clone)]
pub struct BoundErrorModelCache(Cache<BoundErrorModelKey, Arc<AssayErrorModels>>);
impl BoundErrorModelCache {
pub fn new(size: u64) -> Self {
Self(Cache::new(size))
}
#[inline]
pub fn get(&self, key: &BoundErrorModelKey) -> Option<Arc<AssayErrorModels>> {
self.0.get(key)
}
#[inline]
pub fn insert(&self, key: BoundErrorModelKey, value: Arc<AssayErrorModels>) {
self.0.insert(key, value);
}
pub fn invalidate_all(&self) {
self.0.invalidate_all();
}
pub fn entry_count(&self) -> u64 {
self.0.entry_count()
}
}
impl fmt::Debug for BoundErrorModelCache {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BoundErrorModelCache")
.field("entry_count", &self.0.entry_count())
.finish()
}
}
#[derive(Clone)]
pub struct SdeLikelihoodCache(Cache<SdeKey, f64>);
impl SdeLikelihoodCache {
pub fn new(size: u64) -> Self {
Self(Cache::new(size))
}
#[inline]
pub fn get(&self, key: &SdeKey) -> Option<f64> {
self.0.get(key)
}
#[inline]
pub fn insert(&self, key: SdeKey, value: f64) {
self.0.insert(key, value);
}
pub fn invalidate_all(&self) {
self.0.invalidate_all();
}
pub fn entry_count(&self) -> u64 {
self.0.entry_count()
}
}
impl fmt::Debug for SdeLikelihoodCache {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SdeLikelihoodCache")
.field("entry_count", &self.0.entry_count())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prediction_cache_miss_returns_none() {
let cache = PredictionCache::new(10);
assert!(cache.get(&(1, 2)).is_none());
}
#[test]
fn prediction_cache_hit_returns_value() {
let cache = PredictionCache::new(10);
let key: PredictionKey = (42, 99);
let preds = SubjectPredictions::default();
cache.insert(key, preds.clone());
assert!(cache.get(&key).is_some());
}
#[test]
fn prediction_cache_entry_count() {
let cache = PredictionCache::new(10);
assert_eq!(cache.entry_count(), 0);
cache.insert((1, 1), SubjectPredictions::default());
cache.insert((2, 2), SubjectPredictions::default());
cache.0.run_pending_tasks();
assert_eq!(cache.entry_count(), 2);
}
#[test]
fn prediction_cache_invalidate_all_clears_entries() {
let cache = PredictionCache::new(10);
cache.insert((1, 1), SubjectPredictions::default());
cache.insert((2, 2), SubjectPredictions::default());
cache.0.run_pending_tasks();
assert_eq!(cache.entry_count(), 2);
cache.invalidate_all();
cache.0.run_pending_tasks();
assert_eq!(cache.entry_count(), 0);
assert!(cache.get(&(1, 1)).is_none());
}
#[test]
fn prediction_cache_overwrite_same_key() {
let cache = PredictionCache::new(10);
let key: PredictionKey = (1, 1);
cache.insert(key, SubjectPredictions::default());
cache.insert(key, SubjectPredictions::default());
cache.0.run_pending_tasks();
assert_eq!(cache.entry_count(), 1);
}
#[test]
fn prediction_cache_clone_shares_data() {
let cache = PredictionCache::new(10);
cache.insert((1, 1), SubjectPredictions::default());
let clone = cache.clone();
assert!(clone.get(&(1, 1)).is_some());
clone.insert((2, 2), SubjectPredictions::default());
assert!(cache.get(&(2, 2)).is_some());
}
#[test]
fn prediction_cache_debug_format() {
let cache = PredictionCache::new(10);
let dbg = format!("{:?}", cache);
assert!(dbg.contains("PredictionCache"));
assert!(dbg.contains("entry_count"));
}
#[test]
fn sde_cache_miss_returns_none() {
let cache = SdeLikelihoodCache::new(10);
assert!(cache.get(&(1, 2, 3)).is_none());
}
#[test]
fn sde_cache_hit_returns_value() {
let cache = SdeLikelihoodCache::new(10);
let key: SdeKey = (10, 20, 30);
cache.insert(key, -42.5);
assert_eq!(cache.get(&key), Some(-42.5));
}
#[test]
fn bound_error_model_cache_clone_shares_data() {
let cache = BoundErrorModelCache::new(10);
let models = Arc::new(AssayErrorModels::empty());
cache.insert(7, Arc::clone(&models));
let clone = cache.clone();
assert!(clone.get(&7).is_some());
assert!(Arc::ptr_eq(&clone.get(&7).unwrap(), &models));
}
#[test]
fn sde_cache_entry_count() {
let cache = SdeLikelihoodCache::new(10);
cache.insert((1, 1, 1), 0.0);
cache.insert((2, 2, 2), 1.0);
cache.0.run_pending_tasks();
assert_eq!(cache.entry_count(), 2);
}
#[test]
fn sde_cache_invalidate_all_clears_entries() {
let cache = SdeLikelihoodCache::new(10);
cache.insert((1, 1, 1), 0.0);
cache.insert((2, 2, 2), 1.0);
cache.0.run_pending_tasks();
assert_eq!(cache.entry_count(), 2);
cache.invalidate_all();
cache.0.run_pending_tasks();
assert_eq!(cache.entry_count(), 0);
assert!(cache.get(&(1, 1, 1)).is_none());
}
#[test]
fn sde_cache_overwrite_same_key() {
let cache = SdeLikelihoodCache::new(10);
let key: SdeKey = (1, 1, 1);
cache.insert(key, 1.0);
cache.insert(key, 2.0);
cache.0.run_pending_tasks();
assert_eq!(cache.entry_count(), 1);
assert_eq!(cache.get(&key), Some(2.0));
}
#[test]
fn sde_cache_clone_shares_data() {
let cache = SdeLikelihoodCache::new(10);
cache.insert((1, 1, 1), 5.0);
let clone = cache.clone();
assert_eq!(clone.get(&(1, 1, 1)), Some(5.0));
clone.insert((2, 2, 2), 10.0);
assert_eq!(cache.get(&(2, 2, 2)), Some(10.0));
}
#[test]
fn sde_cache_debug_format() {
let cache = SdeLikelihoodCache::new(10);
let dbg = format!("{:?}", cache);
assert!(dbg.contains("SdeLikelihoodCache"));
assert!(dbg.contains("entry_count"));
}
}