tonbo 0.3.2

An embedded persistent KV database in Rust.
Documentation
use std::{
    marker::PhantomData,
    sync::{
        atomic::{AtomicUsize, Ordering},
        Arc,
    },
};

use crate::record::Record;

pub trait Trigger<R: Record>: Send + Sync {
    fn check_if_exceed(&self, item: &Option<R>) -> bool;

    fn reset(&self);
}

#[derive(Debug)]
pub struct SizeOfMemTrigger<R> {
    threshold: usize,
    current_size: AtomicUsize,
    _p: PhantomData<R>,
}

impl<T> SizeOfMemTrigger<T> {
    pub fn new(max_size: usize) -> Self {
        Self {
            threshold: max_size,
            current_size: AtomicUsize::new(0),
            _p: Default::default(),
        }
    }
}

impl<R: Record> Trigger<R> for SizeOfMemTrigger<R> {
    fn check_if_exceed(&self, item: &Option<R>) -> bool {
        let size = item.as_ref().map_or(0, R::size);
        self.current_size.fetch_add(size, Ordering::SeqCst) + size >= self.threshold
    }

    fn reset(&self) {
        self.current_size.store(0, Ordering::SeqCst);
    }
}

#[derive(Debug)]
pub struct LengthTrigger<R> {
    threshold: usize,
    count: AtomicUsize,
    _p: PhantomData<R>,
}

impl<T> LengthTrigger<T> {
    pub fn new(threshold: usize) -> Self {
        Self {
            threshold,
            count: AtomicUsize::new(0),
            _p: Default::default(),
        }
    }
}

impl<R: Record> Trigger<R> for LengthTrigger<R> {
    fn check_if_exceed(&self, _: &Option<R>) -> bool {
        self.count.fetch_add(1, Ordering::SeqCst) + 1 >= self.threshold
    }

    fn reset(&self) {
        self.count.store(0, Ordering::SeqCst);
    }
}

#[derive(Copy, Clone, Debug)]
pub enum TriggerType {
    SizeOfMem(usize),
    #[allow(unused)]
    Length(usize),
}
pub(crate) struct TriggerFactory<R> {
    _p: PhantomData<R>,
}

impl<R: Record> TriggerFactory<R> {
    pub fn create(trigger_type: TriggerType) -> Arc<dyn Trigger<R>> {
        match trigger_type {
            TriggerType::SizeOfMem(threshold) => Arc::new(SizeOfMemTrigger::new(threshold)),
            TriggerType::Length(threshold) => Arc::new(LengthTrigger::new(threshold)),
        }
    }
}

#[cfg(all(test, feature = "tokio"))]
mod tests {
    use super::*;
    use crate::tests::Test;

    #[tokio::test]
    async fn test_size_of_mem_trigger() {
        let threshold = 10;
        let trigger = SizeOfMemTrigger::new(threshold);

        let record = Some(Test {
            vstring: "test".to_string(),
            vu32: 0,
            vbool: None,
        });

        let record_size = record.clone().unwrap().size();
        assert_eq!(record_size, 8);

        assert!(
            !trigger.check_if_exceed(&record),
            "Trigger should not be exceeded after 1 record"
        );

        trigger.check_if_exceed(&record);
        assert!(
            trigger.check_if_exceed(&record),
            "Trigger should be exceeded after 2 records"
        );

        trigger.reset();
        assert!(
            !trigger.check_if_exceed(&record),
            "Trigger should not be exceeded after reset"
        );
    }

    #[tokio::test]
    async fn test_length_trigger() {
        let threshold = 2;
        let trigger = LengthTrigger::new(threshold);

        let record = Some(Test {
            vstring: "test".to_string(),
            vu32: 0,
            vbool: None,
        });

        assert!(
            !trigger.check_if_exceed(&record),
            "Trigger should not be exceeded after 1 record"
        );

        trigger.check_if_exceed(&record);
        assert!(
            trigger.check_if_exceed(&record),
            "Trigger should be exceeded after 2 records"
        );

        trigger.reset();
        assert!(
            !trigger.check_if_exceed(&record),
            "Trigger should not be exceeded after reset"
        );
    }
    #[tokio::test]
    async fn test_trigger_factory() {
        let size_of_mem_trigger = TriggerFactory::<Test>::create(TriggerType::SizeOfMem(9));
        let length_trigger = TriggerFactory::<Test>::create(TriggerType::Length(2));

        assert!(!size_of_mem_trigger.check_if_exceed(&Some(Test {
            vstring: "test".to_string(),
            vu32: 0,
            vbool: None
        })));
        assert!(size_of_mem_trigger.check_if_exceed(&Some(Test {
            vstring: "test".to_string(),
            vu32: 0,
            vbool: None
        })));

        assert!(!length_trigger.check_if_exceed(&Some(Test {
            vstring: "test".to_string(),
            vu32: 1,
            vbool: Some(true)
        })));
        assert!(length_trigger.check_if_exceed(&Some(Test {
            vstring: "test".to_string(),
            vu32: 1,
            vbool: Some(true)
        })));
    }
}