Skip to main content

relay_core_lib/rule/engine/
state.rs

1use async_trait::async_trait;
2use std::time::Duration;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::Mutex;
6use tokio::time::Instant;
7
8#[async_trait]
9pub trait RuleStateStore: Send + Sync + std::fmt::Debug {
10    /// Increment a counter for the given key.
11    /// Returns the new value.
12    /// The window parameter suggests a time window for rate limiting, 
13    /// but in this simple interface it might just set the TTL for the key if it's new.
14    async fn increment_counter(&self, key: &str, window: Duration) -> u64;
15    
16    async fn get_variable(&self, key: &str) -> Option<String>;
17    
18    async fn set_variable(&self, key: &str, value: String, ttl: Option<Duration>);
19}
20
21#[derive(Clone, Debug)]
22pub struct InMemoryRuleStateStore {
23    variables: Arc<Mutex<HashMap<String, VariableState>>>,
24    counters: Arc<Mutex<HashMap<String, CounterState>>>,
25}
26
27#[derive(Clone, Debug)]
28struct CounterState {
29    count: u64,
30    expires_at: Instant,
31}
32
33#[derive(Clone, Debug)]
34struct VariableState {
35    value: String,
36    expires_at: Option<Instant>,
37}
38
39impl Default for InMemoryRuleStateStore {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl InMemoryRuleStateStore {
46    pub fn new() -> Self {
47        Self {
48            variables: Arc::new(Mutex::new(HashMap::new())),
49            counters: Arc::new(Mutex::new(HashMap::new())),
50        }
51    }
52}
53
54#[async_trait]
55impl RuleStateStore for InMemoryRuleStateStore {
56    async fn increment_counter(&self, key: &str, window: Duration) -> u64 {
57        let now = Instant::now();
58        let mut counters = self.counters.lock().await;
59        let window = if window.is_zero() {
60            Duration::from_millis(1)
61        } else {
62            window
63        };
64
65        let entry = counters.entry(key.to_string()).or_insert_with(|| CounterState {
66            count: 0,
67            expires_at: now + window,
68        });
69
70        if now >= entry.expires_at {
71            entry.count = 0;
72            entry.expires_at = now + window;
73        }
74
75        entry.count = entry.count.saturating_add(1);
76        entry.count
77    }
78    
79    async fn get_variable(&self, key: &str) -> Option<String> {
80        let mut variables = self.variables.lock().await;
81        if let Some(v) = variables.get(key) {
82            if let Some(exp) = v.expires_at
83                && Instant::now() >= exp {
84                    variables.remove(key);
85                    return None;
86                }
87            return Some(v.value.clone());
88        }
89        None
90    }
91    
92    async fn set_variable(&self, key: &str, value: String, ttl: Option<Duration>) {
93        let expires_at = ttl.and_then(|d| {
94            if d.is_zero() {
95                None
96            } else {
97                Some(Instant::now() + d)
98            }
99        });
100        let mut variables = self.variables.lock().await;
101        variables.insert(
102            key.to_string(),
103            VariableState {
104                value,
105                expires_at,
106            },
107        );
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::{InMemoryRuleStateStore, RuleStateStore};
114    use std::time::Duration;
115
116    #[tokio::test]
117    async fn test_increment_counter_respects_window() {
118        let store = InMemoryRuleStateStore::new();
119        let window = Duration::from_millis(30);
120        let key = "rate:k1";
121
122        let c1 = store.increment_counter(key, window).await;
123        let c2 = store.increment_counter(key, window).await;
124        assert_eq!(c1, 1);
125        assert_eq!(c2, 2);
126
127        tokio::time::sleep(Duration::from_millis(40)).await;
128        let c3 = store.increment_counter(key, window).await;
129        assert_eq!(c3, 1, "counter should reset after window expires");
130    }
131
132    #[tokio::test]
133    async fn test_increment_counter_isolated_by_key() {
134        let store = InMemoryRuleStateStore::new();
135        let window = Duration::from_millis(100);
136
137        let a1 = store.increment_counter("a", window).await;
138        let b1 = store.increment_counter("b", window).await;
139        let a2 = store.increment_counter("a", window).await;
140
141        assert_eq!(a1, 1);
142        assert_eq!(b1, 1);
143        assert_eq!(a2, 2);
144    }
145
146    #[tokio::test]
147    async fn test_variable_ttl_expires() {
148        let store = InMemoryRuleStateStore::new();
149        store
150            .set_variable("k1", "v1".to_string(), Some(Duration::from_millis(30)))
151            .await;
152        assert_eq!(store.get_variable("k1").await.as_deref(), Some("v1"));
153        tokio::time::sleep(Duration::from_millis(40)).await;
154        assert_eq!(store.get_variable("k1").await, None);
155    }
156
157    #[tokio::test]
158    async fn test_variable_without_ttl_persists() {
159        let store = InMemoryRuleStateStore::new();
160        store.set_variable("k2", "v2".to_string(), None).await;
161        tokio::time::sleep(Duration::from_millis(40)).await;
162        assert_eq!(store.get_variable("k2").await.as_deref(), Some("v2"));
163    }
164
165    #[tokio::test]
166    async fn test_variable_zero_ttl_treated_as_no_expiry() {
167        let store = InMemoryRuleStateStore::new();
168        store
169            .set_variable("k3", "v3".to_string(), Some(Duration::ZERO))
170            .await;
171        tokio::time::sleep(Duration::from_millis(40)).await;
172        assert_eq!(store.get_variable("k3").await.as_deref(), Some("v3"));
173    }
174
175    #[tokio::test]
176    async fn test_variable_overwrite_resets_expiry_policy() {
177        let store = InMemoryRuleStateStore::new();
178        store
179            .set_variable("k4", "short".to_string(), Some(Duration::from_millis(20)))
180            .await;
181        tokio::time::sleep(Duration::from_millis(10)).await;
182        store.set_variable("k4", "stable".to_string(), None).await;
183
184        tokio::time::sleep(Duration::from_millis(30)).await;
185        assert_eq!(store.get_variable("k4").await.as_deref(), Some("stable"));
186    }
187}