mocra 0.3.0

A distributed, event-driven crawling and data collection framework
use super::backend::CacheBackend;
use crate::errors::CacheError;
use dashmap::DashMap;
use std::time::{Duration, Instant};

pub struct LocalBackend {
    pub(crate) store: DashMap<String, (Vec<u8>, Option<Instant>)>,
    zstore: DashMap<String, Vec<(f64, Vec<u8>)>>,
}

impl LocalBackend {
    pub fn new() -> Self {
        Self {
            store: DashMap::new(),
            zstore: DashMap::new(),
        }
    }
}

#[async_trait::async_trait]
impl CacheBackend for LocalBackend {
    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, CacheError> {
        if let Some(entry) = self.store.get(key) {
            let (val, expires_at) = entry.value();
            if let Some(exp) = expires_at
                && Instant::now() > *exp
            {
                drop(entry);
                self.store.remove(key);
                return Ok(None);
            }
            return Ok(Some(val.clone()));
        }
        Ok(None)
    }

    async fn set(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<(), CacheError> {
        let expires_at = ttl.map(|d| Instant::now() + d);
        self.store
            .insert(key.to_string(), (value.to_vec(), expires_at));
        Ok(())
    }

    async fn del(&self, key: &str) -> Result<(), CacheError> {
        self.store.remove(key);
        Ok(())
    }

    async fn del_batch(&self, keys: &[&str]) -> Result<u64, CacheError> {
        let mut count = 0u64;
        for key in keys {
            if self.store.remove(*key).is_some() {
                count += 1;
            }
        }
        Ok(count)
    }

    async fn keys(&self, pattern: &str) -> Result<Vec<String>, CacheError> {
        self.keys_with_limit(pattern, usize::MAX).await
    }

    async fn keys_with_limit(
        &self,
        pattern: &str,
        limit: usize,
    ) -> Result<Vec<String>, CacheError> {
        let prefix = pattern.trim_end_matches('*');
        let now = Instant::now();
        let mut keys = Vec::new();

        for entry in self.store.iter() {
            if keys.len() >= limit {
                break;
            }
            let key = entry.key();
            let (_, expires_at) = entry.value();
            if let Some(exp) = expires_at
                && now > *exp
            {
                let key_to_remove = key.clone();
                drop(entry);
                self.store.remove(&key_to_remove);
                continue;
            }
            if key.starts_with(prefix) {
                keys.push(key.clone());
            }
        }
        Ok(keys)
    }

    async fn set_nx(
        &self,
        key: &str,
        value: &[u8],
        ttl: Option<Duration>,
    ) -> Result<bool, CacheError> {
        let now = Instant::now();
        let expires_at = ttl.map(|d| now + d);

        let entry = self.store.entry(key.to_string());
        match entry {
            dashmap::mapref::entry::Entry::Occupied(mut occupied) => {
                let (_, old_expires_at) = occupied.get();
                if let Some(exp) = old_expires_at {
                    if now < *exp {
                        return Ok(false);
                    }
                } else {
                    return Ok(false);
                }

                occupied.insert((value.to_vec(), expires_at));
                Ok(true)
            }
            dashmap::mapref::entry::Entry::Vacant(vacant) => {
                vacant.insert((value.to_vec(), expires_at));
                Ok(true)
            }
        }
    }

    async fn set_nx_batch(
        &self,
        keys: &[&str],
        value: &[u8],
        ttl: Option<Duration>,
    ) -> Result<Vec<bool>, CacheError> {
        let mut results = Vec::with_capacity(keys.len());
        for key in keys {
            results.push(self.set_nx(key, value, ttl).await?);
        }
        Ok(results)
    }

    async fn mget(&self, keys: &[&str]) -> Result<Vec<Option<Vec<u8>>>, CacheError> {
        let mut results = Vec::with_capacity(keys.len());
        for key in keys {
            results.push(self.get(key).await?);
        }
        Ok(results)
    }

    async fn incr(&self, key: &str, delta: i64) -> Result<i64, CacheError> {
        let entry = self.store.entry(key.to_string());
        match entry {
            dashmap::mapref::entry::Entry::Occupied(mut occupied) => {
                let (val, _) = occupied.get();
                let s = String::from_utf8(val.clone()).unwrap_or_default();
                let current = s.parse::<i64>().unwrap_or(0);
                let new_val = current + delta;
                occupied.insert((new_val.to_string().into_bytes(), None));
                Ok(new_val)
            }
            dashmap::mapref::entry::Entry::Vacant(vacant) => {
                vacant.insert((delta.to_string().into_bytes(), None));
                Ok(delta)
            }
        }
    }

    async fn ping(&self) -> Result<(), CacheError> {
        Ok(())
    }

    async fn zadd(&self, key: &str, score: f64, member: &[u8]) -> Result<i64, CacheError> {
        let mut entry = self.zstore.entry(key.to_string()).or_default();
        let vec = entry.value_mut();

        let mut removed = false;
        if let Some(pos) = vec.iter().position(|(_, m)| m == member) {
            vec.remove(pos);
            removed = true;
        }

        vec.push((score, member.to_vec()));
        vec.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));

        Ok(if removed { 0 } else { 1 })
    }

    async fn zrangebyscore(
        &self,
        key: &str,
        min: f64,
        max: f64,
    ) -> Result<Vec<Vec<u8>>, CacheError> {
        if let Some(entry) = self.zstore.get(key) {
            let vec = entry.value();
            let res = vec
                .iter()
                .filter(|(s, _)| *s >= min && *s <= max)
                .map(|(_, m)| m.clone())
                .collect();
            return Ok(res);
        }
        Ok(Vec::new())
    }

    async fn zremrangebyscore(&self, key: &str, min: f64, max: f64) -> Result<i64, CacheError> {
        if let Some(mut entry) = self.zstore.get_mut(key) {
            let vec = entry.value_mut();
            let len_before = vec.len();
            vec.retain(|(s, _)| *s < min || *s > max);
            return Ok((len_before - vec.len()) as i64);
        }
        Ok(0)
    }
}