relay_core_lib/rule/engine/
state.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::Mutex;
6use tokio::time::Instant;
7
8#[async_trait]
9pub trait RuleStateStore: Send + Sync + std::fmt::Debug {
10 async fn increment_counter(&self, key: &str, window: Duration) -> u64;
15
16 async fn get_variable(&self, key: &str) -> Option<String>;
17
18 async fn set_variable(&self, key: &str, value: String, ttl: Option<Duration>);
19}
20
21#[derive(Clone, Debug)]
22pub struct InMemoryRuleStateStore {
23 variables: Arc<Mutex<HashMap<String, VariableState>>>,
24 counters: Arc<Mutex<HashMap<String, CounterState>>>,
25}
26
27#[derive(Clone, Debug)]
28struct CounterState {
29 count: u64,
30 expires_at: Instant,
31}
32
33#[derive(Clone, Debug)]
34struct VariableState {
35 value: String,
36 expires_at: Option<Instant>,
37}
38
39impl Default for InMemoryRuleStateStore {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl InMemoryRuleStateStore {
46 pub fn new() -> Self {
47 Self {
48 variables: Arc::new(Mutex::new(HashMap::new())),
49 counters: Arc::new(Mutex::new(HashMap::new())),
50 }
51 }
52}
53
54#[async_trait]
55impl RuleStateStore for InMemoryRuleStateStore {
56 async fn increment_counter(&self, key: &str, window: Duration) -> u64 {
57 let now = Instant::now();
58 let mut counters = self.counters.lock().await;
59 let window = if window.is_zero() {
60 Duration::from_millis(1)
61 } else {
62 window
63 };
64
65 let entry = counters
66 .entry(key.to_string())
67 .or_insert_with(|| CounterState {
68 count: 0,
69 expires_at: now + window,
70 });
71
72 if now >= entry.expires_at {
73 entry.count = 0;
74 entry.expires_at = now + window;
75 }
76
77 entry.count = entry.count.saturating_add(1);
78 entry.count
79 }
80
81 async fn get_variable(&self, key: &str) -> Option<String> {
82 let mut variables = self.variables.lock().await;
83 if let Some(v) = variables.get(key) {
84 if let Some(exp) = v.expires_at
85 && Instant::now() >= exp
86 {
87 variables.remove(key);
88 return None;
89 }
90 return Some(v.value.clone());
91 }
92 None
93 }
94
95 async fn set_variable(&self, key: &str, value: String, ttl: Option<Duration>) {
96 let expires_at = ttl.and_then(|d| {
97 if d.is_zero() {
98 None
99 } else {
100 Some(Instant::now() + d)
101 }
102 });
103 let mut variables = self.variables.lock().await;
104 variables.insert(key.to_string(), VariableState { value, expires_at });
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::{InMemoryRuleStateStore, RuleStateStore};
111 use std::time::Duration;
112
113 #[tokio::test]
114 async fn test_increment_counter_respects_window() {
115 let store = InMemoryRuleStateStore::new();
116 let window = Duration::from_millis(30);
117 let key = "rate:k1";
118
119 let c1 = store.increment_counter(key, window).await;
120 let c2 = store.increment_counter(key, window).await;
121 assert_eq!(c1, 1);
122 assert_eq!(c2, 2);
123
124 tokio::time::sleep(Duration::from_millis(40)).await;
125 let c3 = store.increment_counter(key, window).await;
126 assert_eq!(c3, 1, "counter should reset after window expires");
127 }
128
129 #[tokio::test]
130 async fn test_increment_counter_isolated_by_key() {
131 let store = InMemoryRuleStateStore::new();
132 let window = Duration::from_millis(100);
133
134 let a1 = store.increment_counter("a", window).await;
135 let b1 = store.increment_counter("b", window).await;
136 let a2 = store.increment_counter("a", window).await;
137
138 assert_eq!(a1, 1);
139 assert_eq!(b1, 1);
140 assert_eq!(a2, 2);
141 }
142
143 #[tokio::test]
144 async fn test_variable_ttl_expires() {
145 let store = InMemoryRuleStateStore::new();
146 store
147 .set_variable("k1", "v1".to_string(), Some(Duration::from_millis(30)))
148 .await;
149 assert_eq!(store.get_variable("k1").await.as_deref(), Some("v1"));
150 tokio::time::sleep(Duration::from_millis(40)).await;
151 assert_eq!(store.get_variable("k1").await, None);
152 }
153
154 #[tokio::test]
155 async fn test_variable_without_ttl_persists() {
156 let store = InMemoryRuleStateStore::new();
157 store.set_variable("k2", "v2".to_string(), None).await;
158 tokio::time::sleep(Duration::from_millis(40)).await;
159 assert_eq!(store.get_variable("k2").await.as_deref(), Some("v2"));
160 }
161
162 #[tokio::test]
163 async fn test_variable_zero_ttl_treated_as_no_expiry() {
164 let store = InMemoryRuleStateStore::new();
165 store
166 .set_variable("k3", "v3".to_string(), Some(Duration::ZERO))
167 .await;
168 tokio::time::sleep(Duration::from_millis(40)).await;
169 assert_eq!(store.get_variable("k3").await.as_deref(), Some("v3"));
170 }
171
172 #[tokio::test]
173 async fn test_variable_overwrite_resets_expiry_policy() {
174 let store = InMemoryRuleStateStore::new();
175 store
176 .set_variable("k4", "short".to_string(), Some(Duration::from_millis(20)))
177 .await;
178 tokio::time::sleep(Duration::from_millis(10)).await;
179 store.set_variable("k4", "stable".to_string(), None).await;
180
181 tokio::time::sleep(Duration::from_millis(30)).await;
182 assert_eq!(store.get_variable("k4").await.as_deref(), Some("stable"));
183 }
184}