Skip to main content

agent_tools_interface/core/
rate.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5use crate::core::dirs;
6use crate::core::scope::matches_wildcard;
7
8#[derive(Debug, Clone)]
9pub struct RateConfig {
10    /// Map from tool pattern (e.g. "tool:github__*") to rate limit
11    pub limits: HashMap<String, RateLimit>,
12}
13
14#[derive(Debug, Clone)]
15pub struct RateLimit {
16    pub count: u64,
17    pub window_secs: u64,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, Default)]
21pub struct RateState {
22    /// Map from tool pattern to list of call timestamps (Unix seconds)
23    pub calls: HashMap<String, Vec<u64>>,
24}
25
26#[derive(Debug, thiserror::Error)]
27pub enum RateError {
28    #[error("Rate limit exceeded for '{pattern}': {count}/{window} (limit: {limit}/{window})")]
29    Exceeded {
30        pattern: String,
31        count: u64,
32        limit: u64,
33        window: String,
34    },
35    #[error("Invalid rate spec '{0}': {1}")]
36    InvalidSpec(String, String),
37    #[error("Rate state I/O error: {0}")]
38    Io(#[from] std::io::Error),
39}
40
41/// Parse a rate spec like "10/hour" into a RateLimit.
42pub fn parse_rate_spec(spec: &str) -> Result<RateLimit, RateError> {
43    let parts: Vec<&str> = spec.split('/').collect();
44    if parts.len() != 2 {
45        return Err(RateError::InvalidSpec(
46            spec.to_string(),
47            "expected format: count/unit (e.g. 10/hour)".into(),
48        ));
49    }
50    let count: u64 = parts[0]
51        .parse()
52        .map_err(|_| RateError::InvalidSpec(spec.to_string(), "invalid count".into()))?;
53    let window_secs = dirs::unit_to_secs(parts[1].trim()).ok_or_else(|| {
54        RateError::InvalidSpec(spec.to_string(), format!("unknown unit: {}", parts[1]))
55    })?;
56    Ok(RateLimit { count, window_secs })
57}
58
59/// Parse rate claims from JWT AtiNamespace.rate HashMap.
60/// Format: {"tool:github__*": "10/hour", "tool:*": "100/hour"}
61pub fn parse_rate_config(rate_map: &HashMap<String, String>) -> Result<RateConfig, RateError> {
62    let mut limits = HashMap::new();
63    for (pattern, spec) in rate_map {
64        limits.insert(pattern.clone(), parse_rate_spec(spec)?);
65    }
66    Ok(RateConfig { limits })
67}
68
69/// Check if a tool call is within rate limits and record it.
70/// Returns Ok(()) if allowed, Err(RateError::Exceeded) if rate limited.
71pub fn check_and_record(tool_name: &str, config: &RateConfig) -> Result<(), RateError> {
72    let now = now_secs();
73    let mut state = load_state()?;
74
75    for (pattern, limit) in &config.limits {
76        // Prepend "tool:" to the tool name for pattern matching against rate patterns
77        let tool_scope = format!("tool:{}", tool_name);
78        if matches_wildcard(&tool_scope, pattern) {
79            let calls = state.calls.entry(pattern.clone()).or_default();
80
81            // Prune expired entries
82            let cutoff = now.saturating_sub(limit.window_secs);
83            calls.retain(|&ts| ts > cutoff);
84
85            // Check if over limit
86            if calls.len() as u64 >= limit.count {
87                let count = calls.len() as u64;
88                let limit_count = limit.count;
89                let window_str = format_window(limit.window_secs);
90                let pattern_clone = pattern.clone();
91                let _ = calls;
92                save_state(&state)?;
93                return Err(RateError::Exceeded {
94                    pattern: pattern_clone,
95                    count,
96                    limit: limit_count,
97                    window: window_str,
98                });
99            }
100
101            // Record this call
102            calls.push(now);
103        }
104    }
105
106    save_state(&state)?;
107    Ok(())
108}
109
110fn format_window(secs: u64) -> String {
111    match secs {
112        1 => "second".into(),
113        60 => "minute".into(),
114        3600 => "hour".into(),
115        86400 => "day".into(),
116        _ => format!("{secs}s"),
117    }
118}
119
120fn rate_state_path() -> PathBuf {
121    dirs::ati_dir().join("rate-state.json")
122}
123
124fn load_state() -> Result<RateState, RateError> {
125    let path = rate_state_path();
126    if !path.exists() {
127        return Ok(RateState::default());
128    }
129    let content = std::fs::read_to_string(&path)?;
130    match serde_json::from_str(&content) {
131        Ok(state) => Ok(state),
132        Err(_) => {
133            // Corrupted state file -- reset
134            let _ = std::fs::remove_file(&path);
135            Ok(RateState::default())
136        }
137    }
138}
139
140/// Save state atomically: write to a temp file, then rename into place.
141/// This prevents corruption from concurrent `ati run` invocations.
142fn save_state(state: &RateState) -> Result<(), RateError> {
143    let path = rate_state_path();
144    if let Some(parent) = path.parent() {
145        std::fs::create_dir_all(parent)?;
146    }
147    let content =
148        serde_json::to_string(state).map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
149    let tmp_path = path.with_extension("json.tmp");
150    std::fs::write(&tmp_path, content)?;
151    std::fs::rename(&tmp_path, &path)?;
152    Ok(())
153}
154
155fn now_secs() -> u64 {
156    std::time::SystemTime::now()
157        .duration_since(std::time::UNIX_EPOCH)
158        .unwrap_or_default()
159        .as_secs()
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn test_parse_rate_spec_hour() {
168        let rl = parse_rate_spec("10/hour").unwrap();
169        assert_eq!(rl.count, 10);
170        assert_eq!(rl.window_secs, 3600);
171    }
172
173    #[test]
174    fn test_parse_rate_spec_minute() {
175        let rl = parse_rate_spec("5/minute").unwrap();
176        assert_eq!(rl.count, 5);
177        assert_eq!(rl.window_secs, 60);
178    }
179
180    #[test]
181    fn test_parse_rate_spec_second() {
182        let rl = parse_rate_spec("1/second").unwrap();
183        assert_eq!(rl.count, 1);
184        assert_eq!(rl.window_secs, 1);
185    }
186
187    #[test]
188    fn test_parse_rate_spec_day() {
189        let rl = parse_rate_spec("100/day").unwrap();
190        assert_eq!(rl.count, 100);
191        assert_eq!(rl.window_secs, 86400);
192    }
193
194    #[test]
195    fn test_parse_rate_spec_short_units() {
196        assert_eq!(parse_rate_spec("1/s").unwrap().window_secs, 1);
197        assert_eq!(parse_rate_spec("1/m").unwrap().window_secs, 60);
198        assert_eq!(parse_rate_spec("1/h").unwrap().window_secs, 3600);
199        assert_eq!(parse_rate_spec("1/d").unwrap().window_secs, 86400);
200        assert_eq!(parse_rate_spec("1/sec").unwrap().window_secs, 1);
201        assert_eq!(parse_rate_spec("1/min").unwrap().window_secs, 60);
202        assert_eq!(parse_rate_spec("1/hr").unwrap().window_secs, 3600);
203    }
204
205    #[test]
206    fn test_parse_rate_spec_invalid() {
207        assert!(parse_rate_spec("abc/hour").is_err());
208        assert!(parse_rate_spec("10").is_err());
209        assert!(parse_rate_spec("10/week").is_err());
210        assert!(parse_rate_spec("").is_err());
211        assert!(parse_rate_spec("10/hour/extra").is_err());
212    }
213
214    #[test]
215    fn test_parse_rate_config() {
216        let mut map = HashMap::new();
217        map.insert("tool:github__*".to_string(), "10/hour".to_string());
218        map.insert("tool:*".to_string(), "100/hour".to_string());
219
220        let config = parse_rate_config(&map).unwrap();
221        assert_eq!(config.limits.len(), 2);
222        assert_eq!(config.limits["tool:github__*"].count, 10);
223        assert_eq!(config.limits["tool:*"].count, 100);
224    }
225
226    // Stateful tests (check_and_record, persistence) are in tests/rate_test.rs
227    // to avoid env var races with parallel unit tests.
228
229    #[test]
230    fn test_format_window() {
231        assert_eq!(format_window(1), "second");
232        assert_eq!(format_window(60), "minute");
233        assert_eq!(format_window(3600), "hour");
234        assert_eq!(format_window(86400), "day");
235        assert_eq!(format_window(7200), "7200s");
236    }
237}