1use crate::Result;
2use async_trait::async_trait;
3use std::time::Duration;
4
5#[async_trait]
16pub trait StateManager: Send + Sync {
17 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>>;
19
20 async fn set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()>;
22
23 async fn delete(&self, key: &str) -> Result<()>;
25
26 async fn exists(&self, key: &str) -> Result<bool>;
28}
29
30struct StateEntry {
32 value: Vec<u8>,
33 expires_at: Option<tokio::time::Instant>,
34}
35
36pub struct MemoryStateManager {
41 store: dashmap::DashMap<String, StateEntry>,
42}
43
44impl MemoryStateManager {
45 #[must_use]
47 pub fn new() -> Self {
48 Self {
49 store: dashmap::DashMap::new(),
50 }
51 }
52}
53
54impl Default for MemoryStateManager {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60#[async_trait]
61impl StateManager for MemoryStateManager {
62 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
63 if let Some(entry) = self.store.get(key) {
64 if let Some(expires_at) = entry.expires_at {
66 if tokio::time::Instant::now() >= expires_at {
67 drop(entry); self.store.remove(key);
70 return Ok(None);
71 }
72 }
73 Ok(Some(entry.value.clone()))
74 } else {
75 Ok(None)
76 }
77 }
78
79 async fn set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()> {
80 let expires_at = ttl.map(|d| tokio::time::Instant::now() + d);
81 self.store
82 .insert(key.to_string(), StateEntry { value, expires_at });
83 Ok(())
84 }
85
86 async fn delete(&self, key: &str) -> Result<()> {
87 self.store.remove(key);
88 Ok(())
89 }
90
91 async fn exists(&self, key: &str) -> Result<bool> {
92 if let Some(entry) = self.store.get(key) {
93 if let Some(expires_at) = entry.expires_at {
95 if tokio::time::Instant::now() >= expires_at {
96 drop(entry);
97 self.store.remove(key);
98 return Ok(false);
99 }
100 }
101 Ok(true)
102 } else {
103 Ok(false)
104 }
105 }
106}
107
108#[cfg(feature = "persistence")]
110pub use trueno_kv::TruenoKvStateManager;
111
112#[cfg(feature = "persistence")]
113mod trueno_kv {
114 use super::*;
115 use crate::Error;
116 use tokio::time::Instant;
117 use trueno_db::kv::{KvStore, MemoryKvStore};
118
119 pub struct TruenoKvStateManager {
127 store: MemoryKvStore,
128 expirations: dashmap::DashMap<String, Instant>,
130 }
131
132 impl TruenoKvStateManager {
133 #[must_use]
135 pub fn new() -> Self {
136 Self {
137 store: MemoryKvStore::new(),
138 expirations: dashmap::DashMap::new(),
139 }
140 }
141
142 #[must_use]
144 pub fn with_capacity(capacity: usize) -> Self {
145 Self {
146 store: MemoryKvStore::with_capacity(capacity),
147 expirations: dashmap::DashMap::new(),
148 }
149 }
150
151 fn is_expired(&self, key: &str) -> bool {
153 let expired = if let Some(expires_at) = self.expirations.get(key) {
155 Instant::now() >= *expires_at
156 } else {
157 return false;
158 };
159 if expired {
161 self.expirations.remove(key);
162 }
163 expired
164 }
165
166 #[must_use]
168 pub fn len(&self) -> usize {
169 self.store.len()
170 }
171
172 #[must_use]
174 pub fn is_empty(&self) -> bool {
175 self.store.is_empty()
176 }
177
178 pub fn clear(&self) {
180 self.store.clear();
181 }
182
183 #[cfg(test)]
186 pub(crate) fn set_expiration_for_test(&self, key: &str, expires_at: Instant) {
187 self.expirations.insert(key.to_string(), expires_at);
188 }
189 }
190
191 impl Default for TruenoKvStateManager {
192 fn default() -> Self {
193 Self::new()
194 }
195 }
196
197 #[async_trait]
198 impl StateManager for TruenoKvStateManager {
199 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
200 if self.is_expired(key) {
202 return Ok(None);
205 }
206
207 self.store
208 .get(key)
209 .await
210 .map_err(|e| Error::StateError(e.to_string()))
211 }
212
213 async fn set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()> {
214 if let Some(duration) = ttl {
216 let expires_at = Instant::now() + duration;
217 self.expirations.insert(key.to_string(), expires_at);
218 } else {
219 self.expirations.remove(key);
221 }
222
223 self.store
224 .set(key, value)
225 .await
226 .map_err(|e| Error::StateError(e.to_string()))
227 }
228
229 async fn delete(&self, key: &str) -> Result<()> {
230 self.expirations.remove(key);
232
233 self.store
234 .delete(key)
235 .await
236 .map_err(|e| Error::StateError(e.to_string()))
237 }
238
239 async fn exists(&self, key: &str) -> Result<bool> {
240 if self.is_expired(key) {
242 return Ok(false);
245 }
246
247 self.store
248 .exists(key)
249 .await
250 .map_err(|e| Error::StateError(e.to_string()))
251 }
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[tokio::test]
260 async fn test_memory_state_basic() {
261 let state = MemoryStateManager::new();
262
263 state.set("key1", b"value1".to_vec(), None).await.unwrap();
265 let value = state.get("key1").await.unwrap();
266 assert_eq!(value, Some(b"value1".to_vec()));
267
268 assert!(state.exists("key1").await.unwrap());
270 assert!(!state.exists("key2").await.unwrap());
271
272 state.delete("key1").await.unwrap();
274 assert!(!state.exists("key1").await.unwrap());
275 }
276
277 #[tokio::test]
278 async fn test_memory_state_overwrite() {
279 let state = MemoryStateManager::new();
280
281 state.set("key", b"value1".to_vec(), None).await.unwrap();
282 state.set("key", b"value2".to_vec(), None).await.unwrap();
283
284 let value = state.get("key").await.unwrap();
285 assert_eq!(value, Some(b"value2".to_vec()));
286 }
287
288 #[tokio::test]
289 async fn test_memory_state_concurrent() {
290 use std::sync::Arc;
291
292 let state = Arc::new(MemoryStateManager::new());
293 let mut handles = vec![];
294
295 for i in 0..10 {
296 let state = Arc::clone(&state);
297 handles.push(tokio::spawn(async move {
298 let key = format!("key{i}");
299 let value = format!("value{i}").into_bytes();
300 state.set(&key, value, None).await.unwrap();
301 }));
302 }
303
304 for handle in handles {
305 handle.await.unwrap();
306 }
307
308 for i in 0..10 {
309 let key = format!("key{i}");
310 assert!(state.exists(&key).await.unwrap());
311 }
312 }
313
314 #[tokio::test(start_paused = true)]
315 async fn test_memory_state_ttl_expiration() {
316 let state = MemoryStateManager::new();
317
318 state
320 .set(
321 "ttl_key",
322 b"value".to_vec(),
323 Some(Duration::from_millis(50)),
324 )
325 .await
326 .unwrap();
327
328 assert!(state.exists("ttl_key").await.unwrap());
330 assert_eq!(state.get("ttl_key").await.unwrap(), Some(b"value".to_vec()));
331
332 tokio::time::advance(Duration::from_millis(60)).await;
334
335 assert!(!state.exists("ttl_key").await.unwrap());
337 assert_eq!(state.get("ttl_key").await.unwrap(), None);
338 }
339
340 #[tokio::test(start_paused = true)]
341 async fn test_memory_state_ttl_no_expiration() {
342 let state = MemoryStateManager::new();
343
344 state.set("no_ttl", b"value".to_vec(), None).await.unwrap();
346
347 tokio::time::advance(Duration::from_millis(10)).await;
349
350 assert!(state.exists("no_ttl").await.unwrap());
352 assert_eq!(state.get("no_ttl").await.unwrap(), Some(b"value".to_vec()));
353 }
354
355 #[tokio::test(start_paused = true)]
356 async fn test_memory_state_ttl_overwrite_extends() {
357 let state = MemoryStateManager::new();
358
359 state
361 .set("key", b"v1".to_vec(), Some(Duration::from_millis(30)))
362 .await
363 .unwrap();
364
365 tokio::time::advance(Duration::from_millis(20)).await;
367
368 state
370 .set("key", b"v2".to_vec(), Some(Duration::from_millis(100)))
371 .await
372 .unwrap();
373
374 tokio::time::advance(Duration::from_millis(20)).await;
376
377 assert_eq!(state.get("key").await.unwrap(), Some(b"v2".to_vec()));
379 }
380
381 #[cfg(feature = "persistence")]
383 mod trueno_kv_tests {
384 use super::*;
385 use crate::state::TruenoKvStateManager;
386
387 #[tokio::test]
388 async fn test_trueno_kv_basic() {
389 let state = TruenoKvStateManager::new();
390
391 state.set("key1", b"value1".to_vec(), None).await.unwrap();
393 let value = state.get("key1").await.unwrap();
394 assert_eq!(value, Some(b"value1".to_vec()));
395
396 assert!(state.exists("key1").await.unwrap());
398 assert!(!state.exists("key2").await.unwrap());
399
400 state.delete("key1").await.unwrap();
402 assert!(!state.exists("key1").await.unwrap());
403 }
404
405 #[tokio::test]
406 async fn test_trueno_kv_overwrite() {
407 let state = TruenoKvStateManager::new();
408
409 state.set("key", b"value1".to_vec(), None).await.unwrap();
410 state.set("key", b"value2".to_vec(), None).await.unwrap();
411
412 let value = state.get("key").await.unwrap();
413 assert_eq!(value, Some(b"value2".to_vec()));
414 }
415
416 #[tokio::test]
417 async fn test_trueno_kv_with_capacity() {
418 let state = TruenoKvStateManager::with_capacity(100);
419 state.set("key", b"value".to_vec(), None).await.unwrap();
420 assert_eq!(state.get("key").await.unwrap(), Some(b"value".to_vec()));
421 }
422
423 #[tokio::test]
424 async fn test_trueno_kv_len_and_clear() {
425 let state = TruenoKvStateManager::new();
426
427 assert!(state.is_empty());
428 assert_eq!(state.len(), 0);
429
430 state.set("key1", b"value1".to_vec(), None).await.unwrap();
431 assert!(!state.is_empty());
432 assert_eq!(state.len(), 1);
433
434 state.set("key2", b"value2".to_vec(), None).await.unwrap();
435 assert_eq!(state.len(), 2);
436
437 state.clear();
438 assert!(state.is_empty());
439 }
440
441 #[test]
442 fn test_trueno_kv_default() {
443 let state: TruenoKvStateManager = Default::default();
444 assert!(state.is_empty());
445 }
446
447 #[tokio::test]
448 async fn test_trueno_kv_ttl_expiration() {
449 use tokio::time::Instant;
450
451 let state = TruenoKvStateManager::new();
452
453 state
455 .set("ttl_key", b"value".to_vec(), None)
456 .await
457 .expect("set should succeed");
458
459 assert!(state
461 .exists("ttl_key")
462 .await
463 .expect("exists check should succeed"));
464
465 state.set_expiration_for_test("ttl_key", Instant::now());
468
469 tokio::task::yield_now().await;
471
472 assert!(!state
475 .exists("ttl_key")
476 .await
477 .expect("exists check should succeed"));
478 }
479
480 #[tokio::test]
481 async fn test_trueno_kv_ttl_no_expiration() {
482 use tokio::time::Instant;
483
484 let state = TruenoKvStateManager::new();
485
486 state
488 .set("no_ttl", b"value".to_vec(), None)
489 .await
490 .expect("set should succeed");
491
492 let future = Instant::now() + Duration::from_secs(3600);
494 state.set_expiration_for_test("no_ttl", future);
495
496 assert!(state
498 .exists("no_ttl")
499 .await
500 .expect("exists check should succeed"));
501 assert_eq!(
502 state.get("no_ttl").await.expect("get should succeed"),
503 Some(b"value".to_vec())
504 );
505 }
506 }
507}