Skip to main content

relay_core_lib/rule/engine/
state.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::Mutex;
6use tokio::time::Instant;
7
8#[async_trait]
9pub trait RuleStateStore: Send + Sync + std::fmt::Debug {
10    /// Increment a counter for the given key.
11    /// Returns the new value.
12    /// The window parameter suggests a time window for rate limiting,
13    /// but in this simple interface it might just set the TTL for the key if it's new.
14    async fn increment_counter(&self, key: &str, window: Duration) -> u64;
15
16    async fn get_variable(&self, key: &str) -> Option<String>;
17
18    async fn set_variable(&self, key: &str, value: String, ttl: Option<Duration>);
19}
20
21#[derive(Clone, Debug)]
22pub struct InMemoryRuleStateStore {
23    variables: Arc<Mutex<HashMap<String, VariableState>>>,
24    counters: Arc<Mutex<HashMap<String, CounterState>>>,
25}
26
27#[derive(Clone, Debug)]
28struct CounterState {
29    count: u64,
30    expires_at: Instant,
31}
32
33#[derive(Clone, Debug)]
34struct VariableState {
35    value: String,
36    expires_at: Option<Instant>,
37}
38
39impl Default for InMemoryRuleStateStore {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl InMemoryRuleStateStore {
46    pub fn new() -> Self {
47        Self {
48            variables: Arc::new(Mutex::new(HashMap::new())),
49            counters: Arc::new(Mutex::new(HashMap::new())),
50        }
51    }
52}
53
54#[async_trait]
55impl RuleStateStore for InMemoryRuleStateStore {
56    async fn increment_counter(&self, key: &str, window: Duration) -> u64 {
57        let now = Instant::now();
58        let mut counters = self.counters.lock().await;
59        let window = if window.is_zero() {
60            Duration::from_millis(1)
61        } else {
62            window
63        };
64
65        let entry = counters
66            .entry(key.to_string())
67            .or_insert_with(|| CounterState {
68                count: 0,
69                expires_at: now + window,
70            });
71
72        if now >= entry.expires_at {
73            entry.count = 0;
74            entry.expires_at = now + window;
75        }
76
77        entry.count = entry.count.saturating_add(1);
78        entry.count
79    }
80
81    async fn get_variable(&self, key: &str) -> Option<String> {
82        let mut variables = self.variables.lock().await;
83        if let Some(v) = variables.get(key) {
84            if let Some(exp) = v.expires_at
85                && Instant::now() >= exp
86            {
87                variables.remove(key);
88                return None;
89            }
90            return Some(v.value.clone());
91        }
92        None
93    }
94
95    async fn set_variable(&self, key: &str, value: String, ttl: Option<Duration>) {
96        let expires_at = ttl.and_then(|d| {
97            if d.is_zero() {
98                None
99            } else {
100                Some(Instant::now() + d)
101            }
102        });
103        let mut variables = self.variables.lock().await;
104        variables.insert(key.to_string(), VariableState { value, expires_at });
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::{InMemoryRuleStateStore, RuleStateStore};
111    use std::time::Duration;
112
113    #[tokio::test]
114    async fn test_increment_counter_respects_window() {
115        let store = InMemoryRuleStateStore::new();
116        let window = Duration::from_millis(30);
117        let key = "rate:k1";
118
119        let c1 = store.increment_counter(key, window).await;
120        let c2 = store.increment_counter(key, window).await;
121        assert_eq!(c1, 1);
122        assert_eq!(c2, 2);
123
124        tokio::time::sleep(Duration::from_millis(40)).await;
125        let c3 = store.increment_counter(key, window).await;
126        assert_eq!(c3, 1, "counter should reset after window expires");
127    }
128
129    #[tokio::test]
130    async fn test_increment_counter_isolated_by_key() {
131        let store = InMemoryRuleStateStore::new();
132        let window = Duration::from_millis(100);
133
134        let a1 = store.increment_counter("a", window).await;
135        let b1 = store.increment_counter("b", window).await;
136        let a2 = store.increment_counter("a", window).await;
137
138        assert_eq!(a1, 1);
139        assert_eq!(b1, 1);
140        assert_eq!(a2, 2);
141    }
142
143    #[tokio::test]
144    async fn test_variable_ttl_expires() {
145        let store = InMemoryRuleStateStore::new();
146        store
147            .set_variable("k1", "v1".to_string(), Some(Duration::from_millis(30)))
148            .await;
149        assert_eq!(store.get_variable("k1").await.as_deref(), Some("v1"));
150        tokio::time::sleep(Duration::from_millis(40)).await;
151        assert_eq!(store.get_variable("k1").await, None);
152    }
153
154    #[tokio::test]
155    async fn test_variable_without_ttl_persists() {
156        let store = InMemoryRuleStateStore::new();
157        store.set_variable("k2", "v2".to_string(), None).await;
158        tokio::time::sleep(Duration::from_millis(40)).await;
159        assert_eq!(store.get_variable("k2").await.as_deref(), Some("v2"));
160    }
161
162    #[tokio::test]
163    async fn test_variable_zero_ttl_treated_as_no_expiry() {
164        let store = InMemoryRuleStateStore::new();
165        store
166            .set_variable("k3", "v3".to_string(), Some(Duration::ZERO))
167            .await;
168        tokio::time::sleep(Duration::from_millis(40)).await;
169        assert_eq!(store.get_variable("k3").await.as_deref(), Some("v3"));
170    }
171
172    #[tokio::test]
173    async fn test_variable_overwrite_resets_expiry_policy() {
174        let store = InMemoryRuleStateStore::new();
175        store
176            .set_variable("k4", "short".to_string(), Some(Duration::from_millis(20)))
177            .await;
178        tokio::time::sleep(Duration::from_millis(10)).await;
179        store.set_variable("k4", "stable".to_string(), None).await;
180
181        tokio::time::sleep(Duration::from_millis(30)).await;
182        assert_eq!(store.get_variable("k4").await.as_deref(), Some("stable"));
183    }
184}