remcached 0.3.0

Caching system designed for efficient storage and retrieval of entities from remote repositories (REST APIs, Database, ...etc)
Documentation
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::collections::hash_set::Iter;
use std::sync::{Arc, mpsc};
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering::Relaxed;
use std::thread;
use std::time::Instant;

use crate::cache_manager_config::CacheManagerConfig;
use crate::cache_task::CacheTask;
use crate::errors::ManualOperationError;
use crate::async_executor::{AsyncExecutor, AsyncTask};
use crate::metrics::measure;
use crate::r_cache::{CacheTaskProcessor, InputValidator};

#[derive(Default, Debug)]
pub struct CacheManagerStatistics {
    tasks_total: AtomicU64,
    merged_tasks_total: AtomicU64,
    pending_tasks_total: AtomicU64,
    pending_async_tasks_total: AtomicU64,
    expired_tasks_total: AtomicU64,
    expired_async_tasks_total: AtomicU64,
    cycles_total: AtomicU64,
    cycles_time_ms: AtomicU64,
}

impl CacheManagerStatistics {
    fn init() -> Self {
        Self {
            tasks_total: Default::default(),
            merged_tasks_total: Default::default(),
            pending_tasks_total: Default::default(),
            pending_async_tasks_total: Default::default(),
            expired_tasks_total: Default::default(),
            expired_async_tasks_total: Default::default(),
            cycles_total: Default::default(),
            cycles_time_ms: Default::default(),
        }
    }

    pub fn share_init() -> Arc<Self> {
        Arc::new(Self::init())
    }
    pub fn tasks_total(&self) -> u64 {
        self.tasks_total.load(Relaxed)
    }
    pub fn merged_tasks_total(&self) -> u64 {
        self.merged_tasks_total.load(Relaxed)
    }
    pub fn pending_tasks_total(&self) -> u64 {
        self.pending_tasks_total.load(Relaxed)
    }
    pub fn pending_async_tasks_total(&self) -> u64 {
        self.pending_async_tasks_total.load(Relaxed)
    }
    pub fn expired_async_tasks_total(&self) -> u64 {
        self.expired_async_tasks_total.load(Relaxed)
    }
    pub fn expired_tasks_total(&self) -> u64 {
        self.expired_tasks_total.load(Relaxed)
    }
    pub fn cycles_total(&self) -> u64 {
        self.cycles_total.load(Relaxed)
    }
    pub fn cycles_time_ms(&self) -> u64 {
        self.cycles_time_ms.load(Relaxed)
    }
}


pub struct CacheManagerInner<E>
where
    E: AsyncExecutor,
{
    tasks: BinaryHeap<CacheTask>,
    async_tasks: Vec<E::Task>,
    task_processors: HashMap<&'static str, Box<dyn CacheTaskProcessor>>,
    config: CacheManagerConfig,
    async_executor: E,
    rx: mpsc::Receiver<CacheTask>,
}

impl<E> CacheManagerInner<E>

where
    E: AsyncExecutor,
{
    pub fn new(async_executor: E, rx: mpsc::Receiver<CacheTask>, config: CacheManagerConfig) -> Self {
        Self { tasks: Default::default(), async_tasks: vec![], task_processors: Default::default(), config, async_executor, rx }
    }
    pub fn register<T>(&mut self, task_processor: T)
    where
        T: CacheTaskProcessor + 'static,
    {
        if self.task_processors.contains_key(task_processor.cache_id()) {
            panic!("#{} cache currently registered!", task_processor.cache_id());
        }

        self.task_processors.insert(task_processor.cache_id(), Box::new(task_processor));
    }
    fn push_tasks(&mut self) -> (u64, u64) {
        let mut pending_tasks = HashSet::new();
        let mut merged_tasks = 0;
        let now = Instant::now();
        if let Ok(task) = self.rx.recv_timeout(self.config.max_pending_first_poll_ms_await()) {
            if pending_tasks.replace(task).is_some() {
                merged_tasks += 1;
            }

            if now.elapsed() < self.config.max_pending_ms_await() {
                while pending_tasks.len() < self.config.max_task_drain_size() as usize {
                    match self.rx.recv_timeout(self.config.max_pending_bulk_poll_ms_await()) {
                        Ok(task) => {
                            if pending_tasks.replace(task).is_some() {
                                merged_tasks += 1;
                            }
                        }
                        Err(_) => break,
                    }
                    if now.elapsed() >= self.config.max_pending_ms_await() {
                        break;
                    }
                }
            }
        }

        let tasks_len = self.tasks.len();
        self.tasks.retain(|task| !pending_tasks.contains(task));
        let merged_tasks = merged_tasks + (tasks_len - self.tasks.len());

        let tasks_pushed = pending_tasks.len();
        self.tasks.extend(pending_tasks);
        (merged_tasks as u64, tasks_pushed as u64)
    }

    fn check_async_tasks(&mut self) -> (u64, u64) {
        let mut processed_async_tasks = 0;
        let mut expired_async_tasks = 0;

        self.async_tasks.retain(|task| {
            if task.is_finished() {
                processed_async_tasks += 1;
                return false;
            }

            if task.is_expired() {
                expired_async_tasks += 1;
                task.abort();
                return false;
            }

            true
        });

        (expired_async_tasks, processed_async_tasks)
    }

    fn process_tasks(&mut self) -> (u64, u64) {
        let mut processed_tasks = 0;
        let mut expired_tasks = 0;

        loop {
            match self.tasks.peek() {
                Some(val) if !val.is_executable() => break,
                Some(val) if val.is_expired() => {
                    self.tasks.pop().expect("Task Found");
                    expired_tasks += 1;
                    continue;
                }
                Some(_) => {
                    let task = self.tasks.pop().expect("Task Found");

                    if task.is_async() {
                        self.async_execute(task);
                    } else {
                        self.execute(task);
                        processed_tasks += 1;
                    }
                }
                None => break,
            }
        }

        (expired_tasks, processed_tasks)
    }


    fn execute(&mut self, task: CacheTask) {
        let task_processor = self.task_processors.get(task.cache_id()).expect("Task processor found");
        match task {
            CacheTask::InvalidationWithProperties { invalidation, .. } => task_processor.invalidate_with_properties(invalidation),
            CacheTask::EntryExpiration { key, .. } => task_processor.expire(key),
            CacheTask::Init { .. } => task_processor.init(),
            CacheTask::Stop { .. } => task_processor.stop(),
            CacheTask::FlushAll { .. } => task_processor.flush_all(),
            _ => {}
        }
    }

    fn async_execute(&mut self, task: CacheTask) {
        let task_processor = self.task_processors.get(task.cache_id()).expect("Task processor found");
        let async_task = match task {
            CacheTask::Invalidation { exp_time, key, .. } => self.async_executor.execute(exp_time, task_processor.invalidate(key)),
            _ => {
                return;
            }
        };
        self.async_tasks.push(async_task);
    }
}

pub struct CacheManager<E>
where
    E: AsyncExecutor,
{
    inner: Option<CacheManagerInner<E>>,
    cache_ids: HashSet<&'static str>,
    input_validators: HashMap<&'static str, Box<dyn InputValidator>>,
    statistics: Arc<CacheManagerStatistics>,
    tx: mpsc::Sender<CacheTask>,
}


impl<E> CacheManager<E>
where
    E: AsyncExecutor + 'static,
{
    pub fn new(async_executor: E, config: CacheManagerConfig) -> Self {
        let (tx, rx) = mpsc::channel::<CacheTask>();
        let inner = Some(CacheManagerInner::new(async_executor, rx, config));
        Self { inner, cache_ids: Default::default(), input_validators: Default::default(), statistics: CacheManagerStatistics::share_init(), tx }
    }
    pub fn register<I, T>(&mut self, input_validator: I, task_processor: T)
    where
        I: InputValidator + 'static,
        T: CacheTaskProcessor + 'static,
    {
        self.cache_ids.insert(task_processor.cache_id());
        self.input_validators.insert(task_processor.cache_id(), Box::new(input_validator));
        self.inner.as_mut()
            .expect("Cache manager inner found").register(task_processor);
    }

    pub fn sender(&self) -> mpsc::Sender<CacheTask> {
        self.tx.clone()
    }

    pub fn start(&mut self) {
        let mut inner = self.inner.take().expect("Cache manager inner found");

        inner.task_processors
            .values()
            .for_each(|val| val.init());

        thread::spawn({
            let statistics = Arc::clone(&self.statistics);
            move || {
                loop {
                    measure(&statistics.cycles_total, &statistics.cycles_time_ms, || {
                        let (merged, pushed) = inner.push_tasks();
                        log::debug!("#{pushed} pushed cache tasks");
                        log::debug!("#{merged} merged cache tasks");
                        statistics.tasks_total.fetch_add(pushed, Relaxed);
                        statistics.merged_tasks_total.fetch_add(merged, Relaxed);

                        let (expired_tasks, processed_tasks) = inner.process_tasks();
                        log::debug!("#{processed_tasks} processed sync cache tasks");
                        log::debug!("#{expired_tasks} expired non executed cache tasks");
                        statistics.expired_tasks_total.fetch_add(expired_tasks, Relaxed);
                        statistics.pending_tasks_total.store(inner.tasks.len() as u64, Relaxed);
                        log::debug!("#{} remaining sync cache tasks to be processed", inner.tasks.len());

                        let (expired_async_tasks, processed_async_tasks) = inner.check_async_tasks();
                        log::debug!("#{processed_async_tasks} processed async cache tasks");
                        log::debug!("#{expired_async_tasks} expired executed async cache tasks");
                        statistics.pending_async_tasks_total.store(inner.async_tasks.len() as u64, Relaxed);
                        log::debug!("#{} remaining async cache tasks to be processed", inner.async_tasks.len());
                        statistics.expired_async_tasks_total.fetch_add(expired_async_tasks, Relaxed);
                    });
                }
            }
        });
    }


    pub fn init(&self, cache_id: &'static str) -> anyhow::Result<(), ManualOperationError> {
        let task = CacheTask::init(cache_id);
        self.send(task)?;
        Ok(())
    }

    pub fn stop(&self, cache_id: &'static str) -> anyhow::Result<(), ManualOperationError> {
        let task = CacheTask::stop(cache_id);
        self.send(task)?;
        Ok(())
    }
    pub fn flush_all(&self, cache_id: &'static str) -> anyhow::Result<(), ManualOperationError> {
        let task = CacheTask::flush_all(cache_id);
        self.send(task)?;
        Ok(())
    }

    pub fn force_expiration<K>(&self, cache_id: &'static str, key: K) -> anyhow::Result<(), ManualOperationError>
    where
        K: Send + ToString + 'static,
    {
        let task = CacheTask::entry_expiration(0, cache_id, key);
        self.send(task)?;
        Ok(())
    }
    pub fn invalidate_with_properties<K, V>(&self, cache_id: &'static str, key: K, value: V) -> anyhow::Result<(), ManualOperationError>
    where
        K: Send + ToString + 'static,
        V: Send + 'static,
    {
        let task = CacheTask::invalidation_with_properties(cache_id, key, value);
        self.send(task)?;
        Ok(())
    }

    pub fn invalidate<K>(&self, expires_in: u64, cache_id: &'static str, key: K) -> anyhow::Result<(), ManualOperationError>
    where
        K: Send + ToString + 'static,
    {
        let task = CacheTask::invalidation(expires_in, cache_id, key);
        self.send(task)?;
        Ok(())
    }
    fn send(&self, task: CacheTask) -> anyhow::Result<(), ManualOperationError> {
        self.validate_cache_id(task.cache_id())?;
        self.validate_task(&task)?;
        let result = self.tx.send(task);
        if result.is_err() {
            return Err(ManualOperationError::SendError);
        }

        Ok(())
    }

    fn validate_task(&self, task: &CacheTask) -> anyhow::Result<(), ManualOperationError> {
        let result = match task {
            CacheTask::EntryExpiration { cache_id, key, .. } => {
                self.input_validators.get(cache_id)
                    .expect("Input validator found").validate(key, false)
            }
            CacheTask::Invalidation { cache_id, key, .. } => {
                self.input_validators.get(cache_id)
                    .expect("Input validator found").validate(key, false)
            }
            CacheTask::InvalidationWithProperties { cache_id, invalidation, .. } => {
                self.input_validators.get(cache_id)
                    .expect("Input validator found").validate(invalidation, true)
            }
            _ => { true }
        };

        if !result {
            return Err(ManualOperationError::InvalidInput);
        }

        Ok(())
    }
    fn validate_cache_id(&self, cache_id: &'static str) -> anyhow::Result<(), ManualOperationError> {
        let result = self.cache_ids.contains(cache_id);
        if !result {
            return Err(ManualOperationError::CacheNotFound { cache_id });
        }
        Ok(())
    }


    pub fn cache_ids(&self) -> Iter<&'static str> {
        self.cache_ids.iter()
    }
    pub fn statistics(&self) -> &Arc<CacheManagerStatistics> {
        &self.statistics
    }
}