use crate::Result;
use async_trait::async_trait;
use std::time::Duration;
#[async_trait]
pub trait StateManager: Send + Sync {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>>;
async fn set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()>;
async fn delete(&self, key: &str) -> Result<()>;
async fn exists(&self, key: &str) -> Result<bool>;
}
struct StateEntry {
value: Vec<u8>,
expires_at: Option<tokio::time::Instant>,
}
pub struct MemoryStateManager {
store: dashmap::DashMap<String, StateEntry>,
}
impl MemoryStateManager {
#[must_use]
pub fn new() -> Self {
Self {
store: dashmap::DashMap::new(),
}
}
}
impl Default for MemoryStateManager {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl StateManager for MemoryStateManager {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
if let Some(entry) = self.store.get(key) {
if let Some(expires_at) = entry.expires_at {
if tokio::time::Instant::now() >= expires_at {
drop(entry); self.store.remove(key);
return Ok(None);
}
}
Ok(Some(entry.value.clone()))
} else {
Ok(None)
}
}
async fn set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()> {
let expires_at = ttl.map(|d| tokio::time::Instant::now() + d);
self.store
.insert(key.to_string(), StateEntry { value, expires_at });
Ok(())
}
async fn delete(&self, key: &str) -> Result<()> {
self.store.remove(key);
Ok(())
}
async fn exists(&self, key: &str) -> Result<bool> {
if let Some(entry) = self.store.get(key) {
if let Some(expires_at) = entry.expires_at {
if tokio::time::Instant::now() >= expires_at {
drop(entry);
self.store.remove(key);
return Ok(false);
}
}
Ok(true)
} else {
Ok(false)
}
}
}
#[cfg(feature = "persistence")]
pub use trueno_kv::TruenoKvStateManager;
#[cfg(feature = "persistence")]
mod trueno_kv {
use super::*;
use crate::Error;
use tokio::time::Instant;
use trueno_db::kv::{KvStore, MemoryKvStore};
pub struct TruenoKvStateManager {
store: MemoryKvStore,
expirations: dashmap::DashMap<String, Instant>,
}
impl TruenoKvStateManager {
#[must_use]
pub fn new() -> Self {
Self {
store: MemoryKvStore::new(),
expirations: dashmap::DashMap::new(),
}
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
store: MemoryKvStore::with_capacity(capacity),
expirations: dashmap::DashMap::new(),
}
}
fn is_expired(&self, key: &str) -> bool {
let expired = if let Some(expires_at) = self.expirations.get(key) {
Instant::now() >= *expires_at
} else {
return false;
};
if expired {
self.expirations.remove(key);
}
expired
}
#[must_use]
pub fn len(&self) -> usize {
self.store.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.store.is_empty()
}
pub fn clear(&self) {
self.store.clear();
}
#[cfg(test)]
pub(crate) fn set_expiration_for_test(&self, key: &str, expires_at: Instant) {
self.expirations.insert(key.to_string(), expires_at);
}
}
impl Default for TruenoKvStateManager {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl StateManager for TruenoKvStateManager {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
if self.is_expired(key) {
return Ok(None);
}
self.store
.get(key)
.await
.map_err(|e| Error::StateError(e.to_string()))
}
async fn set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()> {
if let Some(duration) = ttl {
let expires_at = Instant::now() + duration;
self.expirations.insert(key.to_string(), expires_at);
} else {
self.expirations.remove(key);
}
self.store
.set(key, value)
.await
.map_err(|e| Error::StateError(e.to_string()))
}
async fn delete(&self, key: &str) -> Result<()> {
self.expirations.remove(key);
self.store
.delete(key)
.await
.map_err(|e| Error::StateError(e.to_string()))
}
async fn exists(&self, key: &str) -> Result<bool> {
if self.is_expired(key) {
return Ok(false);
}
self.store
.exists(key)
.await
.map_err(|e| Error::StateError(e.to_string()))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_state_basic() {
let state = MemoryStateManager::new();
state.set("key1", b"value1".to_vec(), None).await.unwrap();
let value = state.get("key1").await.unwrap();
assert_eq!(value, Some(b"value1".to_vec()));
assert!(state.exists("key1").await.unwrap());
assert!(!state.exists("key2").await.unwrap());
state.delete("key1").await.unwrap();
assert!(!state.exists("key1").await.unwrap());
}
#[tokio::test]
async fn test_memory_state_overwrite() {
let state = MemoryStateManager::new();
state.set("key", b"value1".to_vec(), None).await.unwrap();
state.set("key", b"value2".to_vec(), None).await.unwrap();
let value = state.get("key").await.unwrap();
assert_eq!(value, Some(b"value2".to_vec()));
}
#[tokio::test]
async fn test_memory_state_concurrent() {
use std::sync::Arc;
let state = Arc::new(MemoryStateManager::new());
let mut handles = vec![];
for i in 0..10 {
let state = Arc::clone(&state);
handles.push(tokio::spawn(async move {
let key = format!("key{i}");
let value = format!("value{i}").into_bytes();
state.set(&key, value, None).await.unwrap();
}));
}
for handle in handles {
handle.await.unwrap();
}
for i in 0..10 {
let key = format!("key{i}");
assert!(state.exists(&key).await.unwrap());
}
}
#[tokio::test(start_paused = true)]
async fn test_memory_state_ttl_expiration() {
let state = MemoryStateManager::new();
state
.set(
"ttl_key",
b"value".to_vec(),
Some(Duration::from_millis(50)),
)
.await
.unwrap();
assert!(state.exists("ttl_key").await.unwrap());
assert_eq!(state.get("ttl_key").await.unwrap(), Some(b"value".to_vec()));
tokio::time::advance(Duration::from_millis(60)).await;
assert!(!state.exists("ttl_key").await.unwrap());
assert_eq!(state.get("ttl_key").await.unwrap(), None);
}
#[tokio::test(start_paused = true)]
async fn test_memory_state_ttl_no_expiration() {
let state = MemoryStateManager::new();
state.set("no_ttl", b"value".to_vec(), None).await.unwrap();
tokio::time::advance(Duration::from_millis(10)).await;
assert!(state.exists("no_ttl").await.unwrap());
assert_eq!(state.get("no_ttl").await.unwrap(), Some(b"value".to_vec()));
}
#[tokio::test(start_paused = true)]
async fn test_memory_state_ttl_overwrite_extends() {
let state = MemoryStateManager::new();
state
.set("key", b"v1".to_vec(), Some(Duration::from_millis(30)))
.await
.unwrap();
tokio::time::advance(Duration::from_millis(20)).await;
state
.set("key", b"v2".to_vec(), Some(Duration::from_millis(100)))
.await
.unwrap();
tokio::time::advance(Duration::from_millis(20)).await;
assert_eq!(state.get("key").await.unwrap(), Some(b"v2".to_vec()));
}
#[cfg(feature = "persistence")]
mod trueno_kv_tests {
use super::*;
use crate::state::TruenoKvStateManager;
#[tokio::test]
async fn test_trueno_kv_basic() {
let state = TruenoKvStateManager::new();
state.set("key1", b"value1".to_vec(), None).await.unwrap();
let value = state.get("key1").await.unwrap();
assert_eq!(value, Some(b"value1".to_vec()));
assert!(state.exists("key1").await.unwrap());
assert!(!state.exists("key2").await.unwrap());
state.delete("key1").await.unwrap();
assert!(!state.exists("key1").await.unwrap());
}
#[tokio::test]
async fn test_trueno_kv_overwrite() {
let state = TruenoKvStateManager::new();
state.set("key", b"value1".to_vec(), None).await.unwrap();
state.set("key", b"value2".to_vec(), None).await.unwrap();
let value = state.get("key").await.unwrap();
assert_eq!(value, Some(b"value2".to_vec()));
}
#[tokio::test]
async fn test_trueno_kv_with_capacity() {
let state = TruenoKvStateManager::with_capacity(100);
state.set("key", b"value".to_vec(), None).await.unwrap();
assert_eq!(state.get("key").await.unwrap(), Some(b"value".to_vec()));
}
#[tokio::test]
async fn test_trueno_kv_len_and_clear() {
let state = TruenoKvStateManager::new();
assert!(state.is_empty());
assert_eq!(state.len(), 0);
state.set("key1", b"value1".to_vec(), None).await.unwrap();
assert!(!state.is_empty());
assert_eq!(state.len(), 1);
state.set("key2", b"value2".to_vec(), None).await.unwrap();
assert_eq!(state.len(), 2);
state.clear();
assert!(state.is_empty());
}
#[test]
fn test_trueno_kv_default() {
let state: TruenoKvStateManager = Default::default();
assert!(state.is_empty());
}
#[tokio::test]
async fn test_trueno_kv_ttl_expiration() {
use tokio::time::Instant;
let state = TruenoKvStateManager::new();
state
.set("ttl_key", b"value".to_vec(), None)
.await
.expect("set should succeed");
assert!(state
.exists("ttl_key")
.await
.expect("exists check should succeed"));
state.set_expiration_for_test("ttl_key", Instant::now());
tokio::task::yield_now().await;
assert!(!state
.exists("ttl_key")
.await
.expect("exists check should succeed"));
}
#[tokio::test]
async fn test_trueno_kv_ttl_no_expiration() {
use tokio::time::Instant;
let state = TruenoKvStateManager::new();
state
.set("no_ttl", b"value".to_vec(), None)
.await
.expect("set should succeed");
let future = Instant::now() + Duration::from_secs(3600);
state.set_expiration_for_test("no_ttl", future);
assert!(state
.exists("no_ttl")
.await
.expect("exists check should succeed"));
assert_eq!(
state.get("no_ttl").await.expect("get should succeed"),
Some(b"value".to_vec())
);
}
}
}