fuse_rule/
state.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use sled::Db;
5use std::path::Path;
6use std::time::Duration;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
9pub enum PredicateResult {
10    True,
11    False,
12}
13
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15struct StateEntry {
16    result: PredicateResult,
17    timestamp: DateTime<Utc>,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum RuleTransition {
22    None,
23    Activated,   // False -> True
24    Deactivated, // True -> False
25}
26
27#[async_trait]
28pub trait StateStore: Send + Sync {
29    async fn get_last_result(&self, rule_id: &str) -> Result<PredicateResult>;
30    async fn update_result(
31        &self,
32        rule_id: &str,
33        current: PredicateResult,
34    ) -> Result<RuleTransition>;
35    async fn cleanup_expired(&self, rule_id: &str, ttl_seconds: u64) -> Result<bool>;
36    async fn get_last_transition_time(&self, rule_id: &str) -> Result<Option<DateTime<Utc>>>;
37}
38
39pub struct SledStateStore {
40    db: Db,
41}
42
43impl SledStateStore {
44    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
45        let db = sled::open(path)?;
46        Ok(Self { db })
47    }
48
49    /// Background task to clean up expired state entries
50    pub fn start_cleanup_task(&self, rules_ttl: std::collections::HashMap<String, u64>) {
51        let db = self.db.clone();
52        tokio::spawn(async move {
53            let mut interval = tokio::time::interval(Duration::from_secs(60)); // Run every minute
54            loop {
55                interval.tick().await;
56                for (rule_id, ttl_seconds) in &rules_ttl {
57                    if let Err(e) = Self::cleanup_expired_static(&db, rule_id, *ttl_seconds) {
58                        tracing::warn!(error = %e, rule_id = %rule_id, "Failed to cleanup expired state");
59                    }
60                }
61            }
62        });
63    }
64
65    fn cleanup_expired_static(db: &Db, rule_id: &str, ttl_seconds: u64) -> Result<bool> {
66        let key = format!("rule_state:{}", rule_id);
67        if let Some(bytes) = db.get(&key)? {
68            if let Ok(entry) = serde_json::from_slice::<StateEntry>(&bytes) {
69                let age = Utc::now().signed_duration_since(entry.timestamp);
70                if age.num_seconds() > ttl_seconds as i64 {
71                    db.remove(&key)?;
72                    return Ok(true);
73                }
74            }
75        }
76        Ok(false)
77    }
78}
79
80#[async_trait]
81impl StateStore for SledStateStore {
82    async fn get_last_result(&self, rule_id: &str) -> Result<PredicateResult> {
83        let key = format!("rule_state:{}", rule_id);
84        let last_bytes = self.db.get(&key)?;
85
86        if let Some(bytes) = last_bytes {
87            // Try to deserialize as StateEntry (new format with timestamp)
88            if let Ok(entry) = serde_json::from_slice::<StateEntry>(&bytes) {
89                Ok(entry.result)
90            } else {
91                // Fallback to old format (just PredicateResult)
92                Ok(serde_json::from_slice(&bytes)?)
93            }
94        } else {
95            Ok(PredicateResult::False)
96        }
97    }
98
99    async fn update_result(
100        &self,
101        rule_id: &str,
102        current: PredicateResult,
103    ) -> Result<RuleTransition> {
104        let last_result = self.get_last_result(rule_id).await?;
105
106        let transition = match (last_result, current) {
107            (PredicateResult::False, PredicateResult::True) => RuleTransition::Activated,
108            (PredicateResult::True, PredicateResult::False) => RuleTransition::Deactivated,
109            _ => RuleTransition::None,
110        };
111
112        let key = format!("rule_state:{}", rule_id);
113        let entry = StateEntry {
114            result: current,
115            timestamp: Utc::now(),
116        };
117        let current_bytes = serde_json::to_vec(&entry)?;
118        self.db.insert(key, current_bytes)?;
119
120        Ok(transition)
121    }
122
123    async fn cleanup_expired(&self, rule_id: &str, ttl_seconds: u64) -> Result<bool> {
124        Ok(Self::cleanup_expired_static(
125            &self.db,
126            rule_id,
127            ttl_seconds,
128        )?)
129    }
130
131    async fn get_last_transition_time(&self, rule_id: &str) -> Result<Option<DateTime<Utc>>> {
132        let key = format!("rule_state:{}", rule_id);
133        if let Some(bytes) = self.db.get(&key)? {
134            if let Ok(entry) = serde_json::from_slice::<StateEntry>(&bytes) {
135                return Ok(Some(entry.timestamp));
136            }
137        }
138        Ok(None)
139    }
140}