Skip to main content

robinpath_modules/modules/
ratelimit_mod.rs

1use robinpath::{RobinPath, Value};
2use std::sync::{Arc, Mutex};
3use std::collections::HashMap;
4
5enum Limiter {
6    TokenBucket {
7        tokens: f64,
8        max_tokens: f64,
9        refill_rate: f64,
10        last_refill: u64,
11    },
12    SlidingWindow {
13        requests: Vec<u64>,
14        window_ms: u64,
15        max_requests: u64,
16    },
17    FixedWindow {
18        count: u64,
19        window_ms: u64,
20        max_requests: u64,
21        window_start: u64,
22    },
23}
24
25fn now_ms() -> u64 {
26    std::time::SystemTime::now()
27        .duration_since(std::time::UNIX_EPOCH)
28        .unwrap_or_default()
29        .as_millis() as u64
30}
31
32pub fn register(rp: &mut RobinPath) {
33    let state = Arc::new(Mutex::new(HashMap::<String, Limiter>::new()));
34
35    // ratelimit.create name type options → config
36    let s = state.clone();
37    rp.register_builtin("ratelimit.create", move |args, _| {
38        let name = args.first().map(|v| v.to_display_string()).unwrap_or_else(|| "default".to_string());
39        let limiter_type = args.get(1).map(|v| v.to_display_string()).unwrap_or_else(|| "token-bucket".to_string());
40        let opts = args.get(2).cloned().unwrap_or(Value::Null);
41        let now = now_ms();
42        let limiter = match limiter_type.as_str() {
43            "sliding-window" => {
44                let window_ms = get_opt_num(&opts, "windowMs", 60000.0) as u64;
45                let max_requests = get_opt_num(&opts, "maxRequests", 100.0) as u64;
46                Limiter::SlidingWindow { requests: Vec::new(), window_ms, max_requests }
47            }
48            "fixed-window" => {
49                let window_ms = get_opt_num(&opts, "windowMs", 60000.0) as u64;
50                let max_requests = get_opt_num(&opts, "maxRequests", 100.0) as u64;
51                Limiter::FixedWindow { count: 0, window_ms, max_requests, window_start: now }
52            }
53            _ => {
54                let max_tokens = get_opt_num(&opts, "maxTokens", 10.0);
55                let refill_rate = get_opt_num(&opts, "refillRate", 1.0);
56                Limiter::TokenBucket { tokens: max_tokens, max_tokens, refill_rate, last_refill: now }
57            }
58        };
59        s.lock().unwrap().insert(name.clone(), limiter);
60        let mut obj = indexmap::IndexMap::new();
61        obj.insert("name".to_string(), Value::String(name));
62        obj.insert("type".to_string(), Value::String(limiter_type));
63        Ok(Value::Object(obj))
64    });
65
66    // ratelimit.acquire name count? → bool
67    let s = state.clone();
68    rp.register_builtin("ratelimit.acquire", move |args, _| {
69        let name = args.first().map(|v| v.to_display_string()).unwrap_or_else(|| "default".to_string());
70        let count = args.get(1).map(|v| v.to_number()).unwrap_or(1.0);
71        let mut limiters = s.lock().unwrap();
72        let limiter = limiters.get_mut(&name)
73            .ok_or_else(|| format!("Rate limiter \"{}\" not found", name))?;
74        let now = now_ms();
75        let allowed = match limiter {
76            Limiter::TokenBucket { tokens, max_tokens, refill_rate, last_refill } => {
77                let elapsed = (now - *last_refill) as f64 / 1000.0;
78                *tokens = (*tokens + elapsed * *refill_rate).min(*max_tokens);
79                *last_refill = now;
80                if *tokens >= count {
81                    *tokens -= count;
82                    true
83                } else {
84                    false
85                }
86            }
87            Limiter::SlidingWindow { requests, window_ms, max_requests } => {
88                requests.retain(|&t| now - t < *window_ms);
89                if (requests.len() as u64) < *max_requests {
90                    for _ in 0..count as u64 {
91                        requests.push(now);
92                    }
93                    true
94                } else {
95                    false
96                }
97            }
98            Limiter::FixedWindow { count: c, window_ms, max_requests, window_start } => {
99                if now - *window_start >= *window_ms {
100                    *c = 0;
101                    *window_start = now;
102                }
103                if *c + count as u64 <= *max_requests {
104                    *c += count as u64;
105                    true
106                } else {
107                    false
108                }
109            }
110        };
111        Ok(Value::Bool(allowed))
112    });
113
114    // ratelimit.check name → bool
115    let s = state.clone();
116    rp.register_builtin("ratelimit.check", move |args, _| {
117        let name = args.first().map(|v| v.to_display_string()).unwrap_or_else(|| "default".to_string());
118        let mut limiters = s.lock().unwrap();
119        let limiter = limiters.get_mut(&name)
120            .ok_or_else(|| format!("Rate limiter \"{}\" not found", name))?;
121        let now = now_ms();
122        let allowed = match limiter {
123            Limiter::TokenBucket { tokens, max_tokens, refill_rate, last_refill } => {
124                let elapsed = (now - *last_refill) as f64 / 1000.0;
125                let current = (*tokens + elapsed * *refill_rate).min(*max_tokens);
126                current >= 1.0
127            }
128            Limiter::SlidingWindow { requests, window_ms, max_requests } => {
129                let active = requests.iter().filter(|&&t| now - t < *window_ms).count() as u64;
130                active < *max_requests
131            }
132            Limiter::FixedWindow { count, window_ms, max_requests, window_start } => {
133                (if now - *window_start >= *window_ms { 0 } else { *count }) < *max_requests
134            }
135        };
136        Ok(Value::Bool(allowed))
137    });
138
139    // ratelimit.remaining name → number
140    let s = state.clone();
141    rp.register_builtin("ratelimit.remaining", move |args, _| {
142        let name = args.first().map(|v| v.to_display_string()).unwrap_or_else(|| "default".to_string());
143        let mut limiters = s.lock().unwrap();
144        let limiter = limiters.get_mut(&name)
145            .ok_or_else(|| format!("Rate limiter \"{}\" not found", name))?;
146        let now = now_ms();
147        let remaining = match limiter {
148            Limiter::TokenBucket { tokens, max_tokens, refill_rate, last_refill } => {
149                let elapsed = (now - *last_refill) as f64 / 1000.0;
150                (*tokens + elapsed * *refill_rate).min(*max_tokens).floor()
151            }
152            Limiter::SlidingWindow { requests, window_ms, max_requests } => {
153                let active = requests.iter().filter(|&&t| now - t < *window_ms).count() as f64;
154                (*max_requests as f64 - active).max(0.0)
155            }
156            Limiter::FixedWindow { count, window_ms, max_requests, window_start } => {
157                let c = if now - *window_start >= *window_ms { 0 } else { *count };
158                (*max_requests - c) as f64
159            }
160        };
161        Ok(Value::Number(remaining))
162    });
163
164    // ratelimit.reset name → bool
165    let s = state.clone();
166    rp.register_builtin("ratelimit.reset", move |args, _| {
167        let name = args.first().map(|v| v.to_display_string()).unwrap_or_else(|| "default".to_string());
168        let mut limiters = s.lock().unwrap();
169        let limiter = limiters.get_mut(&name)
170            .ok_or_else(|| format!("Rate limiter \"{}\" not found", name))?;
171        let now = now_ms();
172        match limiter {
173            Limiter::TokenBucket { tokens, max_tokens, last_refill, .. } => {
174                *tokens = *max_tokens;
175                *last_refill = now;
176            }
177            Limiter::SlidingWindow { requests, .. } => { requests.clear(); }
178            Limiter::FixedWindow { count, window_start, .. } => {
179                *count = 0;
180                *window_start = now;
181            }
182        }
183        Ok(Value::Bool(true))
184    });
185
186    // ratelimit.status name → {name, remaining, ...}
187    let s = state.clone();
188    rp.register_builtin("ratelimit.status", move |args, _| {
189        let name = args.first().map(|v| v.to_display_string()).unwrap_or_else(|| "default".to_string());
190        let mut limiters = s.lock().unwrap();
191        let limiter = limiters.get_mut(&name)
192            .ok_or_else(|| format!("Rate limiter \"{}\" not found", name))?;
193        let now = now_ms();
194        let mut obj = indexmap::IndexMap::new();
195        obj.insert("name".to_string(), Value::String(name));
196        match limiter {
197            Limiter::TokenBucket { tokens, max_tokens, refill_rate, last_refill } => {
198                let elapsed = (now - *last_refill) as f64 / 1000.0;
199                let current = (*tokens + elapsed * *refill_rate).min(*max_tokens);
200                obj.insert("type".to_string(), Value::String("token-bucket".to_string()));
201                obj.insert("tokens".to_string(), Value::Number(current.floor()));
202                obj.insert("maxTokens".to_string(), Value::Number(*max_tokens));
203                obj.insert("refillRate".to_string(), Value::Number(*refill_rate));
204            }
205            Limiter::SlidingWindow { requests, window_ms, max_requests } => {
206                let active = requests.iter().filter(|&&t| now - t < *window_ms).count() as f64;
207                obj.insert("type".to_string(), Value::String("sliding-window".to_string()));
208                obj.insert("used".to_string(), Value::Number(active));
209                obj.insert("remaining".to_string(), Value::Number((*max_requests as f64 - active).max(0.0)));
210                obj.insert("maxRequests".to_string(), Value::Number(*max_requests as f64));
211            }
212            Limiter::FixedWindow { count, window_ms, max_requests, window_start } => {
213                let c = if now - *window_start >= *window_ms { 0 } else { *count };
214                obj.insert("type".to_string(), Value::String("fixed-window".to_string()));
215                obj.insert("used".to_string(), Value::Number(c as f64));
216                obj.insert("remaining".to_string(), Value::Number((*max_requests - c) as f64));
217                obj.insert("maxRequests".to_string(), Value::Number(*max_requests as f64));
218            }
219        }
220        Ok(Value::Object(obj))
221    });
222
223    // ratelimit.destroy name → bool
224    let s = state.clone();
225    rp.register_builtin("ratelimit.destroy", move |args, _| {
226        let name = args.first().map(|v| v.to_display_string()).unwrap_or_else(|| "default".to_string());
227        Ok(Value::Bool(s.lock().unwrap().remove(&name).is_some()))
228    });
229}
230
231fn get_opt_num(opts: &Value, key: &str, default: f64) -> f64 {
232    if let Value::Object(obj) = opts {
233        obj.get(key).map(|v| v.to_number()).unwrap_or(default)
234    } else {
235        default
236    }
237}