Skip to main content

ati/core/
rate.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5use crate::core::dirs;
6use crate::core::scope::matches_wildcard;
7
8#[derive(Debug, Clone)]
9pub struct RateConfig {
10    /// Map from tool pattern (e.g. "tool:github:*") to rate limit
11    pub limits: HashMap<String, RateLimit>,
12}
13
14#[derive(Debug, Clone)]
15pub struct RateLimit {
16    pub count: u64,
17    pub window_secs: u64,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, Default)]
21pub struct RateState {
22    /// Map from tool pattern to list of call timestamps (Unix seconds)
23    pub calls: HashMap<String, Vec<u64>>,
24}
25
26#[derive(Debug, thiserror::Error)]
27pub enum RateError {
28    #[error("Rate limit exceeded for '{pattern}': {count}/{window} (limit: {limit}/{window})")]
29    Exceeded {
30        pattern: String,
31        count: u64,
32        limit: u64,
33        window: String,
34    },
35    #[error("Invalid rate spec '{0}': {1}")]
36    InvalidSpec(String, String),
37    #[error("Rate state I/O error: {0}")]
38    Io(#[from] std::io::Error),
39}
40
41/// Parse a rate spec like "10/hour" into a RateLimit.
42pub fn parse_rate_spec(spec: &str) -> Result<RateLimit, RateError> {
43    let parts: Vec<&str> = spec.split('/').collect();
44    if parts.len() != 2 {
45        return Err(RateError::InvalidSpec(
46            spec.to_string(),
47            "expected format: count/unit (e.g. 10/hour)".into(),
48        ));
49    }
50    let count: u64 = parts[0]
51        .parse()
52        .map_err(|_| RateError::InvalidSpec(spec.to_string(), "invalid count".into()))?;
53    let window_secs = dirs::unit_to_secs(parts[1].trim()).ok_or_else(|| {
54        RateError::InvalidSpec(spec.to_string(), format!("unknown unit: {}", parts[1]))
55    })?;
56    Ok(RateLimit { count, window_secs })
57}
58
59/// Parse rate claims from JWT AtiNamespace.rate HashMap.
60/// Format: {"tool:github:*": "10/hour", "tool:*": "100/hour"}
61pub fn parse_rate_config(rate_map: &HashMap<String, String>) -> Result<RateConfig, RateError> {
62    let mut limits = HashMap::new();
63    for (pattern, spec) in rate_map {
64        limits.insert(pattern.clone(), parse_rate_spec(spec)?);
65    }
66    Ok(RateConfig { limits })
67}
68
69/// Check if a tool call is within rate limits and record it.
70/// Returns Ok(()) if allowed, Err(RateError::Exceeded) if rate limited.
71pub fn check_and_record(tool_name: &str, config: &RateConfig) -> Result<(), RateError> {
72    let now = now_secs();
73    let mut state = load_state()?;
74
75    for (pattern, limit) in &config.limits {
76        // Prepend "tool:" to the tool name for pattern matching against rate patterns
77        let tool_scope = format!("tool:{}", tool_name);
78        if matches_wildcard(&tool_scope, pattern) {
79            let calls = state.calls.entry(pattern.clone()).or_default();
80
81            // Prune expired entries
82            let cutoff = now.saturating_sub(limit.window_secs);
83            calls.retain(|&ts| ts > cutoff);
84
85            // Check if over limit
86            if calls.len() as u64 >= limit.count {
87                let count = calls.len() as u64;
88                let limit_count = limit.count;
89                let window_str = format_window(limit.window_secs);
90                let pattern_clone = pattern.clone();
91                let _ = calls;
92                save_state(&state)?;
93                return Err(RateError::Exceeded {
94                    pattern: pattern_clone,
95                    count,
96                    limit: limit_count,
97                    window: window_str,
98                });
99            }
100
101            // Record this call
102            calls.push(now);
103        }
104    }
105
106    save_state(&state)?;
107    Ok(())
108}
109
110fn format_window(secs: u64) -> String {
111    match secs {
112        1 => "second".into(),
113        60 => "minute".into(),
114        3600 => "hour".into(),
115        86400 => "day".into(),
116        _ => format!("{secs}s"),
117    }
118}
119
120fn rate_state_path() -> PathBuf {
121    dirs::ati_dir().join("rate-state.json")
122}
123
124fn load_state() -> Result<RateState, RateError> {
125    let path = rate_state_path();
126    if !path.exists() {
127        return Ok(RateState::default());
128    }
129    let content = std::fs::read_to_string(&path)?;
130    match serde_json::from_str(&content) {
131        Ok(state) => Ok(state),
132        Err(_) => {
133            // Corrupted state file -- reset
134            let _ = std::fs::remove_file(&path);
135            Ok(RateState::default())
136        }
137    }
138}
139
140/// Save state atomically: write to a temp file, then rename into place.
141/// This prevents corruption from concurrent `ati run` invocations.
142fn save_state(state: &RateState) -> Result<(), RateError> {
143    let path = rate_state_path();
144    if let Some(parent) = path.parent() {
145        std::fs::create_dir_all(parent)?;
146    }
147    let content = serde_json::to_string(state).map_err(std::io::Error::other)?;
148    let tmp_path = path.with_extension("json.tmp");
149    std::fs::write(&tmp_path, content)?;
150    std::fs::rename(&tmp_path, &path)?;
151    Ok(())
152}
153
154fn now_secs() -> u64 {
155    std::time::SystemTime::now()
156        .duration_since(std::time::UNIX_EPOCH)
157        .unwrap_or_default()
158        .as_secs()
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn test_parse_rate_spec_hour() {
167        let rl = parse_rate_spec("10/hour").unwrap();
168        assert_eq!(rl.count, 10);
169        assert_eq!(rl.window_secs, 3600);
170    }
171
172    #[test]
173    fn test_parse_rate_spec_minute() {
174        let rl = parse_rate_spec("5/minute").unwrap();
175        assert_eq!(rl.count, 5);
176        assert_eq!(rl.window_secs, 60);
177    }
178
179    #[test]
180    fn test_parse_rate_spec_second() {
181        let rl = parse_rate_spec("1/second").unwrap();
182        assert_eq!(rl.count, 1);
183        assert_eq!(rl.window_secs, 1);
184    }
185
186    #[test]
187    fn test_parse_rate_spec_day() {
188        let rl = parse_rate_spec("100/day").unwrap();
189        assert_eq!(rl.count, 100);
190        assert_eq!(rl.window_secs, 86400);
191    }
192
193    #[test]
194    fn test_parse_rate_spec_short_units() {
195        assert_eq!(parse_rate_spec("1/s").unwrap().window_secs, 1);
196        assert_eq!(parse_rate_spec("1/m").unwrap().window_secs, 60);
197        assert_eq!(parse_rate_spec("1/h").unwrap().window_secs, 3600);
198        assert_eq!(parse_rate_spec("1/d").unwrap().window_secs, 86400);
199        assert_eq!(parse_rate_spec("1/sec").unwrap().window_secs, 1);
200        assert_eq!(parse_rate_spec("1/min").unwrap().window_secs, 60);
201        assert_eq!(parse_rate_spec("1/hr").unwrap().window_secs, 3600);
202    }
203
204    #[test]
205    fn test_parse_rate_spec_invalid() {
206        assert!(parse_rate_spec("abc/hour").is_err());
207        assert!(parse_rate_spec("10").is_err());
208        assert!(parse_rate_spec("10/week").is_err());
209        assert!(parse_rate_spec("").is_err());
210        assert!(parse_rate_spec("10/hour/extra").is_err());
211    }
212
213    #[test]
214    fn test_parse_rate_config() {
215        let mut map = HashMap::new();
216        map.insert("tool:github:*".to_string(), "10/hour".to_string());
217        map.insert("tool:*".to_string(), "100/hour".to_string());
218
219        let config = parse_rate_config(&map).unwrap();
220        assert_eq!(config.limits.len(), 2);
221        assert_eq!(config.limits["tool:github:*"].count, 10);
222        assert_eq!(config.limits["tool:*"].count, 100);
223    }
224
225    // Stateful tests (check_and_record, persistence) are in tests/rate_test.rs
226    // to avoid env var races with parallel unit tests.
227
228    #[test]
229    fn test_format_window() {
230        assert_eq!(format_window(1), "second");
231        assert_eq!(format_window(60), "minute");
232        assert_eq!(format_window(3600), "hour");
233        assert_eq!(format_window(86400), "day");
234        assert_eq!(format_window(7200), "7200s");
235    }
236}