1use anyhow::Result;
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use sled::Db;
5use std::path::Path;
6use std::time::Duration;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
9pub enum PredicateResult {
10 True,
11 False,
12}
13
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15struct StateEntry {
16 result: PredicateResult,
17 timestamp: DateTime<Utc>,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum RuleTransition {
22 None,
23 Activated, Deactivated, }
26
27#[async_trait]
28pub trait StateStore: Send + Sync {
29 async fn get_last_result(&self, rule_id: &str) -> Result<PredicateResult>;
30 async fn update_result(
31 &self,
32 rule_id: &str,
33 current: PredicateResult,
34 ) -> Result<RuleTransition>;
35 async fn cleanup_expired(&self, rule_id: &str, ttl_seconds: u64) -> Result<bool>;
36 async fn get_last_transition_time(&self, rule_id: &str) -> Result<Option<DateTime<Utc>>>;
37}
38
39pub struct SledStateStore {
40 db: Db,
41}
42
43impl SledStateStore {
44 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
45 let db = sled::open(path)?;
46 Ok(Self { db })
47 }
48
49 pub fn start_cleanup_task(&self, rules_ttl: std::collections::HashMap<String, u64>) {
51 let db = self.db.clone();
52 tokio::spawn(async move {
53 let mut interval = tokio::time::interval(Duration::from_secs(60)); loop {
55 interval.tick().await;
56 for (rule_id, ttl_seconds) in &rules_ttl {
57 if let Err(e) = Self::cleanup_expired_static(&db, rule_id, *ttl_seconds) {
58 tracing::warn!(error = %e, rule_id = %rule_id, "Failed to cleanup expired state");
59 }
60 }
61 }
62 });
63 }
64
65 fn cleanup_expired_static(db: &Db, rule_id: &str, ttl_seconds: u64) -> Result<bool> {
66 let key = format!("rule_state:{}", rule_id);
67 if let Some(bytes) = db.get(&key)? {
68 if let Ok(entry) = serde_json::from_slice::<StateEntry>(&bytes) {
69 let age = Utc::now().signed_duration_since(entry.timestamp);
70 if age.num_seconds() > ttl_seconds as i64 {
71 db.remove(&key)?;
72 return Ok(true);
73 }
74 }
75 }
76 Ok(false)
77 }
78}
79
80#[async_trait]
81impl StateStore for SledStateStore {
82 async fn get_last_result(&self, rule_id: &str) -> Result<PredicateResult> {
83 let key = format!("rule_state:{}", rule_id);
84 let last_bytes = self.db.get(&key)?;
85
86 if let Some(bytes) = last_bytes {
87 if let Ok(entry) = serde_json::from_slice::<StateEntry>(&bytes) {
89 Ok(entry.result)
90 } else {
91 Ok(serde_json::from_slice(&bytes)?)
93 }
94 } else {
95 Ok(PredicateResult::False)
96 }
97 }
98
99 async fn update_result(
100 &self,
101 rule_id: &str,
102 current: PredicateResult,
103 ) -> Result<RuleTransition> {
104 let last_result = self.get_last_result(rule_id).await?;
105
106 let transition = match (last_result, current) {
107 (PredicateResult::False, PredicateResult::True) => RuleTransition::Activated,
108 (PredicateResult::True, PredicateResult::False) => RuleTransition::Deactivated,
109 _ => RuleTransition::None,
110 };
111
112 let key = format!("rule_state:{}", rule_id);
113 let entry = StateEntry {
114 result: current,
115 timestamp: Utc::now(),
116 };
117 let current_bytes = serde_json::to_vec(&entry)?;
118 self.db.insert(key, current_bytes)?;
119
120 Ok(transition)
121 }
122
123 async fn cleanup_expired(&self, rule_id: &str, ttl_seconds: u64) -> Result<bool> {
124 Ok(Self::cleanup_expired_static(
125 &self.db,
126 rule_id,
127 ttl_seconds,
128 )?)
129 }
130
131 async fn get_last_transition_time(&self, rule_id: &str) -> Result<Option<DateTime<Utc>>> {
132 let key = format!("rule_state:{}", rule_id);
133 if let Some(bytes) = self.db.get(&key)? {
134 if let Ok(entry) = serde_json::from_slice::<StateEntry>(&bytes) {
135 return Ok(Some(entry.timestamp));
136 }
137 }
138 Ok(None)
139 }
140}