1use crate::StreamEvent;
10use anyhow::{anyhow, Result};
11use chrono::{DateTime, Duration, Utc};
12use serde::{Deserialize, Serialize};
13use std::collections::{BTreeMap, HashMap};
14use std::path::PathBuf;
15use std::sync::Arc;
16use tokio::fs;
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::sync::RwLock;
19use tracing::{error, info};
20use uuid::Uuid;
21
22#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
24pub enum StateBackend {
25 Memory,
27 File,
29 RocksDB,
31 Redis,
33 Custom,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "type", content = "value")]
40pub enum StateValue {
41 String(String),
42 Integer(i64),
43 Float(f64),
44 Boolean(bool),
45 Binary(Vec<u8>),
46 List(Vec<StateValue>),
47 Map(HashMap<String, StateValue>),
48 Counter(i64),
49 Timestamp(DateTime<Utc>),
50}
51
52impl StateValue {
53 pub fn merge(&self, other: &StateValue) -> Result<StateValue> {
55 match (self, other) {
56 (StateValue::Integer(a), StateValue::Integer(b)) => Ok(StateValue::Integer(a + b)),
57 (StateValue::Float(a), StateValue::Float(b)) => Ok(StateValue::Float(a + b)),
58 (StateValue::Counter(a), StateValue::Counter(b)) => Ok(StateValue::Counter(a + b)),
59 (StateValue::List(a), StateValue::List(b)) => {
60 let mut merged = a.clone();
61 merged.extend(b.clone());
62 Ok(StateValue::List(merged))
63 }
64 (StateValue::Map(a), StateValue::Map(b)) => {
65 let mut merged = a.clone();
66 for (k, v) in b {
67 merged.insert(k.clone(), v.clone());
68 }
69 Ok(StateValue::Map(merged))
70 }
71 _ => Err(anyhow!("Cannot merge incompatible state value types")),
72 }
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct StateConfig {
79 pub backend: StateBackend,
80 pub checkpoint_interval: Duration,
81 pub checkpoint_path: Option<PathBuf>,
82 pub compaction_interval: Duration,
83 pub ttl: Option<Duration>,
84 pub max_size: Option<usize>,
85 pub enable_changelog: bool,
86 pub enable_metrics: bool,
87}
88
89impl Default for StateConfig {
90 fn default() -> Self {
91 Self {
92 backend: StateBackend::Memory,
93 checkpoint_interval: Duration::minutes(5),
94 checkpoint_path: None,
95 compaction_interval: Duration::hours(1),
96 ttl: None,
97 max_size: Some(1_000_000),
98 enable_changelog: true,
99 enable_metrics: true,
100 }
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct StateOperation {
107 pub timestamp: DateTime<Utc>,
108 pub key: String,
109 pub operation: StateOperationType,
110 pub value: Option<StateValue>,
111 pub metadata: HashMap<String, String>,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub enum StateOperationType {
116 Put,
117 Delete,
118 Merge,
119 Clear,
120}
121
122#[derive(Debug, Clone, Default, Serialize, Deserialize)]
124pub struct StateStatistics {
125 pub total_keys: usize,
126 pub total_size_bytes: usize,
127 pub reads: u64,
128 pub writes: u64,
129 pub deletes: u64,
130 pub checkpoints: u64,
131 pub last_checkpoint: Option<DateTime<Utc>>,
132 pub last_compaction: Option<DateTime<Utc>>,
133}
134
135#[async_trait::async_trait]
137pub trait StateStore: Send + Sync {
138 async fn get(&self, key: &str) -> Result<Option<StateValue>>;
140
141 async fn put(&self, key: &str, value: StateValue) -> Result<()>;
143
144 async fn delete(&self, key: &str) -> Result<()>;
146
147 async fn multi_get(&self, keys: &[String]) -> Result<HashMap<String, StateValue>>;
149
150 async fn scan(&self, prefix: &str, limit: Option<usize>) -> Result<Vec<(String, StateValue)>>;
152
153 async fn clear(&self) -> Result<()>;
155
156 async fn checkpoint(&self) -> Result<String>;
158
159 async fn restore(&self, checkpoint_id: &str) -> Result<()>;
161
162 async fn statistics(&self) -> Result<StateStatistics>;
164}
165
166pub struct MemoryStateStore {
168 data: Arc<RwLock<BTreeMap<String, StateValue>>>,
169 changelog: Arc<RwLock<Vec<StateOperation>>>,
170 statistics: Arc<RwLock<StateStatistics>>,
171 config: StateConfig,
172}
173
174impl MemoryStateStore {
175 pub fn new(config: StateConfig) -> Self {
176 Self {
177 data: Arc::new(RwLock::new(BTreeMap::new())),
178 changelog: Arc::new(RwLock::new(Vec::new())),
179 statistics: Arc::new(RwLock::new(StateStatistics::default())),
180 config,
181 }
182 }
183
184 async fn add_to_changelog(&self, operation: StateOperation) {
185 if self.config.enable_changelog {
186 self.changelog.write().await.push(operation);
187 }
188 }
189
190 async fn apply_ttl(&self) {
191 if let Some(ttl) = self.config.ttl {
192 let cutoff = Utc::now() - ttl;
193 let mut data = self.data.write().await;
194 let keys_to_remove: Vec<String> = data
195 .iter()
196 .filter_map(|(k, v)| {
197 if let StateValue::Map(map) = v {
198 if let Some(StateValue::Timestamp(ts)) = map.get("_timestamp") {
199 if *ts < cutoff {
200 return Some(k.clone());
201 }
202 }
203 }
204 None
205 })
206 .collect();
207
208 for key in keys_to_remove {
209 data.remove(&key);
210 }
211 }
212 }
213}
214
215#[async_trait::async_trait]
216impl StateStore for MemoryStateStore {
217 async fn get(&self, key: &str) -> Result<Option<StateValue>> {
218 self.statistics.write().await.reads += 1;
219 let data = self.data.read().await;
220 Ok(data.get(key).cloned())
221 }
222
223 async fn put(&self, key: &str, value: StateValue) -> Result<()> {
224 self.statistics.write().await.writes += 1;
225
226 let mut value_with_ts = value;
228 if self.config.ttl.is_some() {
229 if let StateValue::Map(ref mut map) = value_with_ts {
230 map.insert("_timestamp".to_string(), StateValue::Timestamp(Utc::now()));
231 }
232 }
233
234 self.data
235 .write()
236 .await
237 .insert(key.to_string(), value_with_ts.clone());
238
239 self.add_to_changelog(StateOperation {
240 timestamp: Utc::now(),
241 key: key.to_string(),
242 operation: StateOperationType::Put,
243 value: Some(value_with_ts),
244 metadata: HashMap::new(),
245 })
246 .await;
247
248 if let Some(max_size) = self.config.max_size {
250 let data = self.data.read().await;
251 if data.len() > max_size {
252 drop(data);
253 let mut data = self.data.write().await;
255 let to_remove = data.len() - max_size;
256 let keys_to_remove: Vec<String> = data.keys().take(to_remove).cloned().collect();
257 for key in keys_to_remove {
258 data.remove(&key);
259 }
260 }
261 }
262
263 Ok(())
264 }
265
266 async fn delete(&self, key: &str) -> Result<()> {
267 self.statistics.write().await.deletes += 1;
268 self.data.write().await.remove(key);
269
270 self.add_to_changelog(StateOperation {
271 timestamp: Utc::now(),
272 key: key.to_string(),
273 operation: StateOperationType::Delete,
274 value: None,
275 metadata: HashMap::new(),
276 })
277 .await;
278
279 Ok(())
280 }
281
282 async fn multi_get(&self, keys: &[String]) -> Result<HashMap<String, StateValue>> {
283 let mut stats = self.statistics.write().await;
284 stats.reads += keys.len() as u64;
285 drop(stats);
286
287 let data = self.data.read().await;
288 let mut result = HashMap::new();
289
290 for key in keys {
291 if let Some(value) = data.get(key) {
292 result.insert(key.clone(), value.clone());
293 }
294 }
295
296 Ok(result)
297 }
298
299 async fn scan(&self, prefix: &str, limit: Option<usize>) -> Result<Vec<(String, StateValue)>> {
300 self.statistics.write().await.reads += 1;
301
302 let data = self.data.read().await;
303 let iter = data
304 .range(prefix.to_string()..)
305 .take_while(|(k, _)| k.starts_with(prefix));
306
307 let result: Vec<(String, StateValue)> = match limit {
308 Some(n) => iter.take(n).map(|(k, v)| (k.clone(), v.clone())).collect(),
309 None => iter.map(|(k, v)| (k.clone(), v.clone())).collect(),
310 };
311
312 Ok(result)
313 }
314
315 async fn clear(&self) -> Result<()> {
316 self.data.write().await.clear();
317
318 self.add_to_changelog(StateOperation {
319 timestamp: Utc::now(),
320 key: String::new(),
321 operation: StateOperationType::Clear,
322 value: None,
323 metadata: HashMap::new(),
324 })
325 .await;
326
327 Ok(())
328 }
329
330 async fn checkpoint(&self) -> Result<String> {
331 let checkpoint_id = Uuid::new_v4().to_string();
332
333 if let Some(ref checkpoint_path) = self.config.checkpoint_path {
334 let checkpoint_file = checkpoint_path.join(format!("{checkpoint_id}.checkpoint"));
335
336 let data = self.data.read().await;
338 let checkpoint_data = serde_json::to_vec(&*data)?;
339
340 let mut file = fs::File::create(&checkpoint_file).await?;
342 file.write_all(&checkpoint_data).await?;
343 file.sync_all().await?;
344
345 info!(
346 "Created checkpoint {} at {:?}",
347 checkpoint_id, checkpoint_file
348 );
349 }
350
351 let mut stats = self.statistics.write().await;
352 stats.checkpoints += 1;
353 stats.last_checkpoint = Some(Utc::now());
354
355 Ok(checkpoint_id)
356 }
357
358 async fn restore(&self, checkpoint_id: &str) -> Result<()> {
359 if let Some(ref checkpoint_path) = self.config.checkpoint_path {
360 let checkpoint_file = checkpoint_path.join(format!("{checkpoint_id}.checkpoint"));
361
362 let mut file = fs::File::open(&checkpoint_file).await?;
364 let mut checkpoint_data = Vec::new();
365 file.read_to_end(&mut checkpoint_data).await?;
366
367 let restored_data: BTreeMap<String, StateValue> =
369 serde_json::from_slice(&checkpoint_data)?;
370 *self.data.write().await = restored_data;
371
372 info!("Restored from checkpoint {}", checkpoint_id);
373 } else {
374 return Err(anyhow!("No checkpoint path configured"));
375 }
376
377 Ok(())
378 }
379
380 async fn statistics(&self) -> Result<StateStatistics> {
381 self.apply_ttl().await;
382
383 let mut stats = self.statistics.write().await.clone();
384 let data = self.data.read().await;
385 stats.total_keys = data.len();
386
387 stats.total_size_bytes = data
389 .values()
390 .map(|v| serde_json::to_vec(v).map(|vec| vec.len()).unwrap_or(0))
391 .sum();
392
393 Ok(stats)
394 }
395}
396
397pub struct StateProcessor {
399 stores: HashMap<String, Arc<dyn StateStore>>,
400 default_store: Arc<dyn StateStore>,
401 config: StateConfig,
402 checkpoint_task: Option<tokio::task::JoinHandle<()>>,
403}
404
405impl StateProcessor {
406 pub fn new(config: StateConfig) -> Self {
407 let default_store = Arc::new(MemoryStateStore::new(config.clone())) as Arc<dyn StateStore>;
408
409 Self {
410 stores: HashMap::new(),
411 default_store: default_store.clone(),
412 config,
413 checkpoint_task: None,
414 }
415 }
416
417 pub async fn start_checkpointing(&mut self) {
419 let store = self.default_store.clone();
420 let interval = self.config.checkpoint_interval;
421
422 let task = tokio::spawn(async move {
423 let mut interval_timer = tokio::time::interval(
424 interval
425 .to_std()
426 .expect("checkpoint interval should be valid std Duration"),
427 );
428
429 loop {
430 interval_timer.tick().await;
431
432 match store.checkpoint().await {
433 Ok(checkpoint_id) => {
434 info!("Automatic checkpoint created: {}", checkpoint_id);
435 }
436 Err(e) => {
437 error!("Failed to create checkpoint: {}", e);
438 }
439 }
440 }
441 });
442
443 self.checkpoint_task = Some(task);
444 }
445
446 pub fn stop_checkpointing(&mut self) {
448 if let Some(task) = self.checkpoint_task.take() {
449 task.abort();
450 }
451 }
452
453 pub fn register_store(&mut self, name: String, store: Arc<dyn StateStore>) {
455 self.stores.insert(name, store);
456 }
457
458 pub fn get_store(&self, name: &str) -> Option<Arc<dyn StateStore>> {
460 self.stores.get(name).cloned()
461 }
462
463 pub fn default_store(&self) -> Arc<dyn StateStore> {
465 self.default_store.clone()
466 }
467
468 pub async fn process_with_state<F, R>(
470 &self,
471 event: &StreamEvent,
472 state_key: &str,
473 processor: F,
474 ) -> Result<R>
475 where
476 F: FnOnce(&StreamEvent, Option<StateValue>) -> Result<(R, Option<StateValue>)>,
477 {
478 let current_state = self.default_store.get(state_key).await?;
480
481 let (result, new_state) = processor(event, current_state)?;
483
484 if let Some(state) = new_state {
486 self.default_store.put(state_key, state).await?;
487 }
488
489 Ok(result)
490 }
491}
492
493pub struct StateProcessorBuilder {
495 config: StateConfig,
496 stores: HashMap<String, Arc<dyn StateStore>>,
497}
498
499impl Default for StateProcessorBuilder {
500 fn default() -> Self {
501 Self::new()
502 }
503}
504
505impl StateProcessorBuilder {
506 pub fn new() -> Self {
507 Self {
508 config: StateConfig::default(),
509 stores: HashMap::new(),
510 }
511 }
512
513 pub fn with_backend(mut self, backend: StateBackend) -> Self {
514 self.config.backend = backend;
515 self
516 }
517
518 pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
519 self.config.checkpoint_interval = interval;
520 self
521 }
522
523 pub fn with_checkpoint_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
524 self.config.checkpoint_path = Some(path.into());
525 self
526 }
527
528 pub fn with_ttl(mut self, ttl: Duration) -> Self {
529 self.config.ttl = Some(ttl);
530 self
531 }
532
533 pub fn with_max_size(mut self, max_size: usize) -> Self {
534 self.config.max_size = Some(max_size);
535 self
536 }
537
538 pub fn add_store(mut self, name: String, store: Arc<dyn StateStore>) -> Self {
539 self.stores.insert(name, store);
540 self
541 }
542
543 pub fn build(self) -> StateProcessor {
544 let mut processor = StateProcessor::new(self.config);
545
546 for (name, store) in self.stores {
547 processor.register_store(name, store);
548 }
549
550 processor
551 }
552}
553
554pub mod patterns {
556 use super::*;
557
558 pub async fn increment_counter(
560 store: &dyn StateStore,
561 key: &str,
562 increment: i64,
563 ) -> Result<i64> {
564 let current = store.get(key).await?;
565 let new_value = match current {
566 Some(StateValue::Counter(n)) => n + increment,
567 _ => increment,
568 };
569
570 store.put(key, StateValue::Counter(new_value)).await?;
571 Ok(new_value)
572 }
573
574 pub async fn append_to_list(
576 store: &dyn StateStore,
577 key: &str,
578 value: StateValue,
579 ) -> Result<()> {
580 let current = store.get(key).await?;
581 let mut list = match current {
582 Some(StateValue::List(l)) => l,
583 _ => Vec::new(),
584 };
585
586 list.push(value);
587 store.put(key, StateValue::List(list)).await?;
588 Ok(())
589 }
590
591 pub async fn merge_map(
593 store: &dyn StateStore,
594 key: &str,
595 updates: HashMap<String, StateValue>,
596 ) -> Result<()> {
597 let current = store.get(key).await?;
598 let mut map = match current {
599 Some(StateValue::Map(m)) => m,
600 _ => HashMap::new(),
601 };
602
603 for (k, v) in updates {
604 map.insert(k, v);
605 }
606
607 store.put(key, StateValue::Map(map)).await?;
608 Ok(())
609 }
610
611 pub async fn update_time_window(
613 store: &dyn StateStore,
614 key: &str,
615 value: StateValue,
616 window_duration: Duration,
617 ) -> Result<Vec<StateValue>> {
618 let current = store.get(key).await?;
619 let mut window_data = match current {
620 Some(StateValue::List(l)) => l,
621 _ => Vec::new(),
622 };
623
624 let mut value_with_time = HashMap::new();
626 value_with_time.insert("value".to_string(), value);
627 value_with_time.insert("timestamp".to_string(), StateValue::Timestamp(Utc::now()));
628 window_data.push(StateValue::Map(value_with_time));
629
630 let cutoff = Utc::now() - window_duration;
632 window_data.retain(|v| {
633 if let StateValue::Map(m) = v {
634 if let Some(StateValue::Timestamp(ts)) = m.get("timestamp") {
635 return *ts >= cutoff;
636 }
637 }
638 false
639 });
640
641 store
642 .put(key, StateValue::List(window_data.clone()))
643 .await?;
644 Ok(window_data)
645 }
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651 use crate::event::EventMetadata;
652 use tempfile::TempDir;
653
654 #[tokio::test]
655 async fn test_memory_state_store() {
656 let config = StateConfig::default();
657 let store = MemoryStateStore::new(config);
658
659 store
661 .put("key1", StateValue::String("value1".to_string()))
662 .await
663 .unwrap();
664 let value = store.get("key1").await.unwrap();
665 assert!(matches!(value, Some(StateValue::String(s)) if s == "value1"));
666
667 store.delete("key1").await.unwrap();
669 let value = store.get("key1").await.unwrap();
670 assert!(value.is_none());
671
672 let stats = store.statistics().await.unwrap();
674 assert_eq!(stats.writes, 1);
675 assert_eq!(stats.deletes, 1);
676 }
677
678 #[tokio::test]
679 async fn test_state_ttl() {
680 let config = StateConfig {
681 ttl: Some(Duration::milliseconds(100)),
682 ..Default::default()
683 };
684 let store = MemoryStateStore::new(config);
685
686 let mut map = HashMap::new();
688 map.insert("data".to_string(), StateValue::String("test".to_string()));
689 store.put("key1", StateValue::Map(map)).await.unwrap();
690
691 assert!(store.get("key1").await.unwrap().is_some());
693
694 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
696
697 let _ = store.statistics().await.unwrap();
699
700 assert!(store.get("key1").await.unwrap().is_none());
702 }
703
704 #[tokio::test]
705 async fn test_checkpoint_restore() {
706 let temp_dir = TempDir::new().unwrap();
707 let config = StateConfig {
708 checkpoint_path: Some(temp_dir.path().to_path_buf()),
709 ..Default::default()
710 };
711
712 let store = MemoryStateStore::new(config.clone());
713
714 store
716 .put("key1", StateValue::String("value1".to_string()))
717 .await
718 .unwrap();
719 store.put("key2", StateValue::Integer(42)).await.unwrap();
720
721 let checkpoint_id = store.checkpoint().await.unwrap();
723
724 store.clear().await.unwrap();
726 assert!(store.get("key1").await.unwrap().is_none());
727
728 store.restore(&checkpoint_id).await.unwrap();
730
731 let value1 = store.get("key1").await.unwrap();
733 assert!(matches!(value1, Some(StateValue::String(s)) if s == "value1"));
734
735 let value2 = store.get("key2").await.unwrap();
736 assert!(matches!(value2, Some(StateValue::Integer(i)) if i == 42));
737 }
738
739 #[tokio::test]
740 async fn test_state_processor() {
741 let processor = StateProcessorBuilder::new()
742 .with_backend(StateBackend::Memory)
743 .build();
744
745 let event = StreamEvent::TripleAdded {
746 subject: "http://example.org/s".to_string(),
747 predicate: "http://example.org/p".to_string(),
748 object: "http://example.org/o".to_string(),
749 graph: None,
750 metadata: EventMetadata::default(),
751 };
752
753 let result = processor
755 .process_with_state(&event, "counter", |_event, state| {
756 let count = match state {
757 Some(StateValue::Counter(n)) => n + 1,
758 _ => 1,
759 };
760 Ok((count, Some(StateValue::Counter(count))))
761 })
762 .await
763 .unwrap();
764
765 assert_eq!(result, 1);
766
767 let result = processor
769 .process_with_state(&event, "counter", |_event, state| {
770 let count = match state {
771 Some(StateValue::Counter(n)) => n + 1,
772 _ => 1,
773 };
774 Ok((count, Some(StateValue::Counter(count))))
775 })
776 .await
777 .unwrap();
778
779 assert_eq!(result, 2);
780 }
781
782 #[tokio::test]
783 async fn test_state_patterns() {
784 let config = StateConfig::default();
785 let store = MemoryStateStore::new(config);
786
787 let count = patterns::increment_counter(&store, "counter1", 5)
789 .await
790 .unwrap();
791 assert_eq!(count, 5);
792
793 let count = patterns::increment_counter(&store, "counter1", 3)
794 .await
795 .unwrap();
796 assert_eq!(count, 8);
797
798 patterns::append_to_list(&store, "list1", StateValue::String("item1".to_string()))
800 .await
801 .unwrap();
802 patterns::append_to_list(&store, "list1", StateValue::String("item2".to_string()))
803 .await
804 .unwrap();
805
806 let list = store.get("list1").await.unwrap();
807 if let Some(StateValue::List(items)) = list {
808 assert_eq!(items.len(), 2);
809 } else {
810 panic!("Expected list");
811 }
812
813 let mut updates = HashMap::new();
815 updates.insert(
816 "field1".to_string(),
817 StateValue::String("value1".to_string()),
818 );
819 updates.insert("field2".to_string(), StateValue::Integer(42));
820
821 patterns::merge_map(&store, "map1", updates).await.unwrap();
822
823 let map = store.get("map1").await.unwrap();
824 if let Some(StateValue::Map(m)) = map {
825 assert_eq!(m.len(), 2);
826 assert!(matches!(m.get("field1"), Some(StateValue::String(s)) if s == "value1"));
827 assert!(matches!(m.get("field2"), Some(StateValue::Integer(i)) if *i == 42));
828 } else {
829 panic!("Expected map");
830 }
831 }
832}