streamling-state 0.1.0

State management and persistence for Streamling.
Documentation
use std::fmt::Debug;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};

use async_trait::async_trait;
use serde::{Deserialize, Serialize};

use crate::{StateBackendError, StateBackendErrorKind, StateKey, StateOperatorBackend};

pub enum FailCondition {
    Never,
    Always,
    OnKeyPrefix(String),
    FailOnNthCall(AtomicUsize, usize),
}

impl FailCondition {
    pub fn should_fail(&self, key: &str) -> bool {
        match self {
            FailCondition::Never => false,
            FailCondition::Always => true,
            FailCondition::OnKeyPrefix(prefix) => key.starts_with(prefix),
            FailCondition::FailOnNthCall(counter, n) => {
                let call = counter.fetch_add(1, Ordering::SeqCst) + 1;
                call == *n
            }
        }
    }
}

pub struct FailableStateBackend<V>
where
    V: Serialize + for<'de> Deserialize<'de>,
{
    inner: Arc<dyn StateOperatorBackend<V>>,
    put_condition: Arc<RwLock<FailCondition>>,
}

impl<V> Debug for FailableStateBackend<V>
where
    V: Serialize + for<'de> Deserialize<'de>,
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("FailableStateBackend").finish()
    }
}

impl<V> FailableStateBackend<V>
where
    V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Unpin + Clone + Debug + 'static,
{
    pub fn new(inner: Arc<dyn StateOperatorBackend<V>>) -> Self {
        Self {
            inner,
            put_condition: Arc::new(RwLock::new(FailCondition::Never)),
        }
    }

    pub fn with_put_condition(mut self, condition: FailCondition) -> Self {
        self.put_condition = Arc::new(RwLock::new(condition));
        self
    }
}

#[async_trait]
impl<V> StateOperatorBackend<V> for FailableStateBackend<V>
where
    V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Unpin + Clone + Debug + 'static,
{
    async fn get(&self, key: StateKey) -> Result<Option<V>, StateBackendError> {
        self.inner.get(key).await
    }

    async fn put(&self, key: StateKey, value: V) -> Result<(), StateBackendError> {
        {
            let condition = self.put_condition.read().unwrap();
            if condition.should_fail(&key.0) {
                return Err(StateBackendError::new(
                    StateBackendErrorKind::Query,
                    "injected failure",
                ));
            }
        }
        self.inner.put(key, value).await
    }

    async fn remove(&self, key: StateKey) -> Result<(), StateBackendError> {
        self.inner.remove(key).await
    }

    async fn clear(&self) -> Result<(), StateBackendError> {
        self.inner.clear().await
    }
}