relay_core_lib/rule/engine/
state.rs1use async_trait::async_trait;
2use std::time::Duration;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::Mutex;
6use tokio::time::Instant;
7
8#[async_trait]
9pub trait RuleStateStore: Send + Sync + std::fmt::Debug {
10 async fn increment_counter(&self, key: &str, window: Duration) -> u64;
15
16 async fn get_variable(&self, key: &str) -> Option<String>;
17
18 async fn set_variable(&self, key: &str, value: String, ttl: Option<Duration>);
19}
20
21#[derive(Clone, Debug)]
22pub struct InMemoryRuleStateStore {
23 variables: Arc<Mutex<HashMap<String, VariableState>>>,
24 counters: Arc<Mutex<HashMap<String, CounterState>>>,
25}
26
27#[derive(Clone, Debug)]
28struct CounterState {
29 count: u64,
30 expires_at: Instant,
31}
32
33#[derive(Clone, Debug)]
34struct VariableState {
35 value: String,
36 expires_at: Option<Instant>,
37}
38
39impl Default for InMemoryRuleStateStore {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl InMemoryRuleStateStore {
46 pub fn new() -> Self {
47 Self {
48 variables: Arc::new(Mutex::new(HashMap::new())),
49 counters: Arc::new(Mutex::new(HashMap::new())),
50 }
51 }
52}
53
54#[async_trait]
55impl RuleStateStore for InMemoryRuleStateStore {
56 async fn increment_counter(&self, key: &str, window: Duration) -> u64 {
57 let now = Instant::now();
58 let mut counters = self.counters.lock().await;
59 let window = if window.is_zero() {
60 Duration::from_millis(1)
61 } else {
62 window
63 };
64
65 let entry = counters.entry(key.to_string()).or_insert_with(|| CounterState {
66 count: 0,
67 expires_at: now + window,
68 });
69
70 if now >= entry.expires_at {
71 entry.count = 0;
72 entry.expires_at = now + window;
73 }
74
75 entry.count = entry.count.saturating_add(1);
76 entry.count
77 }
78
79 async fn get_variable(&self, key: &str) -> Option<String> {
80 let mut variables = self.variables.lock().await;
81 if let Some(v) = variables.get(key) {
82 if let Some(exp) = v.expires_at
83 && Instant::now() >= exp {
84 variables.remove(key);
85 return None;
86 }
87 return Some(v.value.clone());
88 }
89 None
90 }
91
92 async fn set_variable(&self, key: &str, value: String, ttl: Option<Duration>) {
93 let expires_at = ttl.and_then(|d| {
94 if d.is_zero() {
95 None
96 } else {
97 Some(Instant::now() + d)
98 }
99 });
100 let mut variables = self.variables.lock().await;
101 variables.insert(
102 key.to_string(),
103 VariableState {
104 value,
105 expires_at,
106 },
107 );
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::{InMemoryRuleStateStore, RuleStateStore};
114 use std::time::Duration;
115
116 #[tokio::test]
117 async fn test_increment_counter_respects_window() {
118 let store = InMemoryRuleStateStore::new();
119 let window = Duration::from_millis(30);
120 let key = "rate:k1";
121
122 let c1 = store.increment_counter(key, window).await;
123 let c2 = store.increment_counter(key, window).await;
124 assert_eq!(c1, 1);
125 assert_eq!(c2, 2);
126
127 tokio::time::sleep(Duration::from_millis(40)).await;
128 let c3 = store.increment_counter(key, window).await;
129 assert_eq!(c3, 1, "counter should reset after window expires");
130 }
131
132 #[tokio::test]
133 async fn test_increment_counter_isolated_by_key() {
134 let store = InMemoryRuleStateStore::new();
135 let window = Duration::from_millis(100);
136
137 let a1 = store.increment_counter("a", window).await;
138 let b1 = store.increment_counter("b", window).await;
139 let a2 = store.increment_counter("a", window).await;
140
141 assert_eq!(a1, 1);
142 assert_eq!(b1, 1);
143 assert_eq!(a2, 2);
144 }
145
146 #[tokio::test]
147 async fn test_variable_ttl_expires() {
148 let store = InMemoryRuleStateStore::new();
149 store
150 .set_variable("k1", "v1".to_string(), Some(Duration::from_millis(30)))
151 .await;
152 assert_eq!(store.get_variable("k1").await.as_deref(), Some("v1"));
153 tokio::time::sleep(Duration::from_millis(40)).await;
154 assert_eq!(store.get_variable("k1").await, None);
155 }
156
157 #[tokio::test]
158 async fn test_variable_without_ttl_persists() {
159 let store = InMemoryRuleStateStore::new();
160 store.set_variable("k2", "v2".to_string(), None).await;
161 tokio::time::sleep(Duration::from_millis(40)).await;
162 assert_eq!(store.get_variable("k2").await.as_deref(), Some("v2"));
163 }
164
165 #[tokio::test]
166 async fn test_variable_zero_ttl_treated_as_no_expiry() {
167 let store = InMemoryRuleStateStore::new();
168 store
169 .set_variable("k3", "v3".to_string(), Some(Duration::ZERO))
170 .await;
171 tokio::time::sleep(Duration::from_millis(40)).await;
172 assert_eq!(store.get_variable("k3").await.as_deref(), Some("v3"));
173 }
174
175 #[tokio::test]
176 async fn test_variable_overwrite_resets_expiry_policy() {
177 let store = InMemoryRuleStateStore::new();
178 store
179 .set_variable("k4", "short".to_string(), Some(Duration::from_millis(20)))
180 .await;
181 tokio::time::sleep(Duration::from_millis(10)).await;
182 store.set_variable("k4", "stable".to_string(), None).await;
183
184 tokio::time::sleep(Duration::from_millis(30)).await;
185 assert_eq!(store.get_variable("k4").await.as_deref(), Some("stable"));
186 }
187}