use kovan_map::HopscotchMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicI64, Ordering};
pub trait StateStore: Send + Sync {
fn get(&self, key: &str) -> Option<Vec<u8>>;
fn set(&self, key: &str, value: Vec<u8>);
fn delete(&self, key: &str);
fn increment_i64(&self, key: &str, delta: i64) -> i64;
fn list_keys(&self, _prefix: &str) -> Vec<String> {
Vec::new()
}
}
pub type SharedStateStore = Arc<dyn StateStore>;
#[derive(Default)]
pub struct InMemoryStateStore {
bytes: HopscotchMap<String, Vec<u8>>,
counters: HopscotchMap<String, Arc<AtomicI64>>,
}
impl InMemoryStateStore {
pub fn new() -> Self {
Self::default()
}
pub fn shared() -> SharedStateStore {
Arc::new(Self::new())
}
}
impl StateStore for InMemoryStateStore {
fn get(&self, key: &str) -> Option<Vec<u8>> {
if let Some(counter) = self.counters.get(key) {
return Some(counter.load(Ordering::Acquire).to_string().into_bytes());
}
self.bytes.get(key)
}
fn set(&self, key: &str, value: Vec<u8>) {
self.counters.remove(key);
self.bytes.insert(key.to_string(), value);
}
fn delete(&self, key: &str) {
self.bytes.remove(key);
self.counters.remove(key);
}
fn increment_i64(&self, key: &str, delta: i64) -> i64 {
if let Some(counter) = self.counters.get(key) {
return counter.fetch_add(delta, Ordering::AcqRel) + delta;
}
let seed = self
.bytes
.get(key)
.and_then(|v| {
std::str::from_utf8(&v)
.ok()
.and_then(|s| s.parse::<i64>().ok())
})
.unwrap_or(0);
let fresh = Arc::new(AtomicI64::new(seed));
let counter = match self
.counters
.insert_if_absent(key.to_string(), fresh.clone())
{
None => fresh,
Some(existing) => existing,
};
counter.fetch_add(delta, Ordering::AcqRel) + delta
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn in_memory_get_set_delete() {
let s = InMemoryStateStore::new();
assert!(s.get("k").is_none());
s.set("k", b"v1".to_vec());
assert_eq!(s.get("k"), Some(b"v1".to_vec()));
s.set("k", b"v2".to_vec());
assert_eq!(s.get("k"), Some(b"v2".to_vec()));
s.delete("k");
assert!(s.get("k").is_none());
}
#[test]
fn increment_is_atomic_under_concurrency() {
use std::thread;
let store = InMemoryStateStore::shared();
let mut handles = Vec::new();
for _ in 0..16 {
let s = store.clone();
handles.push(thread::spawn(move || {
for _ in 0..1000 {
s.increment_i64("hits", 1);
}
}));
}
for h in handles {
h.join().unwrap();
}
let final_val: i64 = std::str::from_utf8(&store.get("hits").unwrap())
.unwrap()
.parse()
.unwrap();
assert_eq!(final_val, 16_000);
}
#[test]
fn increment_seeds_from_set_value() {
let s = InMemoryStateStore::new();
s.set("k", b"5".to_vec());
assert_eq!(s.increment_i64("k", 3), 8);
assert_eq!(s.get("k"), Some(b"8".to_vec()));
}
#[test]
fn set_after_increment_clears_counter() {
let s = InMemoryStateStore::new();
s.increment_i64("k", 10);
s.set("k", b"fresh".to_vec());
assert_eq!(s.get("k"), Some(b"fresh".to_vec()));
}
}