statsig-rust 0.19.2

Statsig Rust SDK for usage in multi-user server environments.
Documentation
use std::sync::{
    atomic::{AtomicUsize, Ordering},
    Arc, Mutex,
};

use async_trait::async_trait;
use statsig_rust::{
    data_store_interface::{
        DataStoreBytesResponse, DataStoreResponse, DataStoreTrait, RequestPath,
    },
    StatsigErr,
};

#[derive(Default)]
struct MockDataStoreByteCache {
    proto: Option<Vec<u8>>,
    json: Option<Vec<u8>>,
}

pub struct MockDataStore {
    response: Mutex<Option<DataStoreResponse>>,
    byte_cache: Mutex<MockDataStoreByteCache>,
    get_bytes_error: Mutex<Option<String>>,
    supports_polling: bool,
    byte_cache_enabled: bool,
    get_call_count: Arc<AtomicUsize>,
    get_bytes_call_count: Arc<AtomicUsize>,
    set_call_count: Arc<AtomicUsize>,
    set_bytes_call_count: Arc<AtomicUsize>,
}

impl MockDataStore {
    pub fn new(supports_polling: bool) -> Self {
        Self {
            response: Mutex::new(None),
            byte_cache: Mutex::new(MockDataStoreByteCache::default()),
            get_bytes_error: Mutex::new(None),
            supports_polling,
            byte_cache_enabled: false,
            get_call_count: Arc::new(AtomicUsize::new(0)),
            get_bytes_call_count: Arc::new(AtomicUsize::new(0)),
            set_call_count: Arc::new(AtomicUsize::new(0)),
            set_bytes_call_count: Arc::new(AtomicUsize::new(0)),
        }
    }

    pub fn new_with_byte_cache(supports_polling: bool) -> Self {
        Self {
            byte_cache_enabled: true,
            ..Self::new(supports_polling)
        }
    }

    pub fn with_proto_cache(proto: &[u8]) -> Self {
        let store = Self::new_with_byte_cache(false);
        store.mock_proto_bytes(proto);
        store
    }

    pub fn with_json_cache(json: &str) -> Self {
        let store = Self::new_with_byte_cache(false);
        store.mock_json_bytes(json);
        store
    }

    pub async fn mock_response(&self, response: DataStoreResponse) {
        let mut lock = self.response.lock().unwrap();
        *lock = Some(response);
    }

    pub fn mock_proto_bytes(&self, proto: &[u8]) {
        self.byte_cache.lock().unwrap().proto = Some(proto.to_vec());
    }

    pub fn mock_json_bytes(&self, json: &str) {
        self.byte_cache.lock().unwrap().json = Some(json.as_bytes().to_vec());
    }

    pub fn mock_get_bytes_error(&self, message: &str) {
        *self.get_bytes_error.lock().unwrap() = Some(message.to_string());
    }

    pub fn stored_proto_bytes(&self) -> Option<Vec<u8>> {
        self.byte_cache.lock().unwrap().proto.clone()
    }

    pub fn stored_json_bytes(&self) -> Option<Vec<u8>> {
        self.byte_cache.lock().unwrap().json.clone()
    }

    pub fn num_get_calls(&self) -> usize {
        self.get_call_count.load(Ordering::SeqCst)
    }

    pub fn num_get_bytes_calls(&self) -> usize {
        self.get_bytes_call_count.load(Ordering::SeqCst)
    }

    pub fn num_set_calls(&self) -> usize {
        self.set_call_count.load(Ordering::SeqCst)
    }

    pub fn num_set_bytes_calls(&self) -> usize {
        self.set_bytes_call_count.load(Ordering::SeqCst)
    }

    fn get_bytes_cache_for_key(&self, key: &str) -> Option<Vec<u8>> {
        let cache = self.byte_cache.lock().unwrap();
        if is_proto_cache_key(key) {
            cache.proto.clone()
        } else {
            cache.json.clone()
        }
    }
}

#[async_trait]
impl DataStoreTrait for MockDataStore {
    async fn initialize(&self) -> Result<(), StatsigErr> {
        Ok(())
    }

    async fn shutdown(&self) -> Result<(), StatsigErr> {
        Ok(())
    }

    async fn get(&self, key: &str) -> Result<DataStoreResponse, StatsigErr> {
        self.get_call_count.fetch_add(1, Ordering::SeqCst);
        let response = self.response.lock().unwrap().take();
        if let Some(response) = response {
            return Ok(response);
        }

        let Some(bytes) = self.get_bytes_cache_for_key(key) else {
            return Err(StatsigErr::DataStoreFailure("Failed to get".to_string()));
        };

        Ok(DataStoreResponse {
            result: Some(String::from_utf8(bytes).map_err(|e| {
                StatsigErr::DataStoreFailure(format!("Cached value is not UTF-8: {e}"))
            })?),
            time: Some(1),
        })
    }

    async fn set(&self, key: &str, value: &str, _time: Option<u64>) -> Result<(), StatsigErr> {
        self.set_call_count.fetch_add(1, Ordering::SeqCst);
        if self.byte_cache_enabled && !is_proto_cache_key(key) {
            self.byte_cache.lock().unwrap().json = Some(value.as_bytes().to_vec());
        }
        Ok(())
    }

    async fn get_bytes(&self, key: &str) -> Result<DataStoreBytesResponse, StatsigErr> {
        self.get_bytes_call_count.fetch_add(1, Ordering::SeqCst);
        if !self.byte_cache_enabled {
            return Err(StatsigErr::BytesNotImplemented);
        }

        if let Some(message) = self.get_bytes_error.lock().unwrap().as_ref() {
            return Err(StatsigErr::DataStoreFailure(message.clone()));
        }

        Ok(DataStoreBytesResponse {
            result: self.get_bytes_cache_for_key(key),
            time: Some(1),
        })
    }

    async fn set_bytes(
        &self,
        key: &str,
        value: &[u8],
        _time: Option<u64>,
    ) -> Result<(), StatsigErr> {
        self.set_bytes_call_count.fetch_add(1, Ordering::SeqCst);
        if !self.byte_cache_enabled {
            return Err(StatsigErr::BytesNotImplemented);
        }

        let mut cache = self.byte_cache.lock().unwrap();
        if is_proto_cache_key(key) {
            cache.proto = Some(value.to_vec());
        } else {
            cache.json = Some(value.to_vec());
        }

        Ok(())
    }

    async fn support_polling_updates_for(&self, _path: RequestPath) -> bool {
        self.supports_polling
    }
}

fn is_proto_cache_key(key: &str) -> bool {
    key.contains("|statsig-br|")
}