agent_tools_interface/core/
rate.rs1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5use crate::core::dirs;
6use crate::core::scope::matches_wildcard;
7
8#[derive(Debug, Clone)]
9pub struct RateConfig {
10 pub limits: HashMap<String, RateLimit>,
12}
13
14#[derive(Debug, Clone)]
15pub struct RateLimit {
16 pub count: u64,
17 pub window_secs: u64,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, Default)]
21pub struct RateState {
22 pub calls: HashMap<String, Vec<u64>>,
24}
25
26#[derive(Debug, thiserror::Error)]
27pub enum RateError {
28 #[error("Rate limit exceeded for '{pattern}': {count}/{window} (limit: {limit}/{window})")]
29 Exceeded {
30 pattern: String,
31 count: u64,
32 limit: u64,
33 window: String,
34 },
35 #[error("Invalid rate spec '{0}': {1}")]
36 InvalidSpec(String, String),
37 #[error("Rate state I/O error: {0}")]
38 Io(#[from] std::io::Error),
39}
40
41pub fn parse_rate_spec(spec: &str) -> Result<RateLimit, RateError> {
43 let parts: Vec<&str> = spec.split('/').collect();
44 if parts.len() != 2 {
45 return Err(RateError::InvalidSpec(
46 spec.to_string(),
47 "expected format: count/unit (e.g. 10/hour)".into(),
48 ));
49 }
50 let count: u64 = parts[0]
51 .parse()
52 .map_err(|_| RateError::InvalidSpec(spec.to_string(), "invalid count".into()))?;
53 let window_secs = dirs::unit_to_secs(parts[1].trim()).ok_or_else(|| {
54 RateError::InvalidSpec(spec.to_string(), format!("unknown unit: {}", parts[1]))
55 })?;
56 Ok(RateLimit { count, window_secs })
57}
58
59pub fn parse_rate_config(rate_map: &HashMap<String, String>) -> Result<RateConfig, RateError> {
62 let mut limits = HashMap::new();
63 for (pattern, spec) in rate_map {
64 limits.insert(pattern.clone(), parse_rate_spec(spec)?);
65 }
66 Ok(RateConfig { limits })
67}
68
69pub fn check_and_record(tool_name: &str, config: &RateConfig) -> Result<(), RateError> {
72 let now = now_secs();
73 let mut state = load_state()?;
74
75 for (pattern, limit) in &config.limits {
76 let tool_scope = format!("tool:{}", tool_name);
78 if matches_wildcard(&tool_scope, pattern) {
79 let calls = state.calls.entry(pattern.clone()).or_default();
80
81 let cutoff = now.saturating_sub(limit.window_secs);
83 calls.retain(|&ts| ts > cutoff);
84
85 if calls.len() as u64 >= limit.count {
87 let count = calls.len() as u64;
88 let limit_count = limit.count;
89 let window_str = format_window(limit.window_secs);
90 let pattern_clone = pattern.clone();
91 let _ = calls;
92 save_state(&state)?;
93 return Err(RateError::Exceeded {
94 pattern: pattern_clone,
95 count,
96 limit: limit_count,
97 window: window_str,
98 });
99 }
100
101 calls.push(now);
103 }
104 }
105
106 save_state(&state)?;
107 Ok(())
108}
109
110fn format_window(secs: u64) -> String {
111 match secs {
112 1 => "second".into(),
113 60 => "minute".into(),
114 3600 => "hour".into(),
115 86400 => "day".into(),
116 _ => format!("{secs}s"),
117 }
118}
119
120fn rate_state_path() -> PathBuf {
121 dirs::ati_dir().join("rate-state.json")
122}
123
124fn load_state() -> Result<RateState, RateError> {
125 let path = rate_state_path();
126 if !path.exists() {
127 return Ok(RateState::default());
128 }
129 let content = std::fs::read_to_string(&path)?;
130 match serde_json::from_str(&content) {
131 Ok(state) => Ok(state),
132 Err(_) => {
133 let _ = std::fs::remove_file(&path);
135 Ok(RateState::default())
136 }
137 }
138}
139
140fn save_state(state: &RateState) -> Result<(), RateError> {
143 let path = rate_state_path();
144 if let Some(parent) = path.parent() {
145 std::fs::create_dir_all(parent)?;
146 }
147 let content =
148 serde_json::to_string(state).map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
149 let tmp_path = path.with_extension("json.tmp");
150 std::fs::write(&tmp_path, content)?;
151 std::fs::rename(&tmp_path, &path)?;
152 Ok(())
153}
154
155fn now_secs() -> u64 {
156 std::time::SystemTime::now()
157 .duration_since(std::time::UNIX_EPOCH)
158 .unwrap_or_default()
159 .as_secs()
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn test_parse_rate_spec_hour() {
168 let rl = parse_rate_spec("10/hour").unwrap();
169 assert_eq!(rl.count, 10);
170 assert_eq!(rl.window_secs, 3600);
171 }
172
173 #[test]
174 fn test_parse_rate_spec_minute() {
175 let rl = parse_rate_spec("5/minute").unwrap();
176 assert_eq!(rl.count, 5);
177 assert_eq!(rl.window_secs, 60);
178 }
179
180 #[test]
181 fn test_parse_rate_spec_second() {
182 let rl = parse_rate_spec("1/second").unwrap();
183 assert_eq!(rl.count, 1);
184 assert_eq!(rl.window_secs, 1);
185 }
186
187 #[test]
188 fn test_parse_rate_spec_day() {
189 let rl = parse_rate_spec("100/day").unwrap();
190 assert_eq!(rl.count, 100);
191 assert_eq!(rl.window_secs, 86400);
192 }
193
194 #[test]
195 fn test_parse_rate_spec_short_units() {
196 assert_eq!(parse_rate_spec("1/s").unwrap().window_secs, 1);
197 assert_eq!(parse_rate_spec("1/m").unwrap().window_secs, 60);
198 assert_eq!(parse_rate_spec("1/h").unwrap().window_secs, 3600);
199 assert_eq!(parse_rate_spec("1/d").unwrap().window_secs, 86400);
200 assert_eq!(parse_rate_spec("1/sec").unwrap().window_secs, 1);
201 assert_eq!(parse_rate_spec("1/min").unwrap().window_secs, 60);
202 assert_eq!(parse_rate_spec("1/hr").unwrap().window_secs, 3600);
203 }
204
205 #[test]
206 fn test_parse_rate_spec_invalid() {
207 assert!(parse_rate_spec("abc/hour").is_err());
208 assert!(parse_rate_spec("10").is_err());
209 assert!(parse_rate_spec("10/week").is_err());
210 assert!(parse_rate_spec("").is_err());
211 assert!(parse_rate_spec("10/hour/extra").is_err());
212 }
213
214 #[test]
215 fn test_parse_rate_config() {
216 let mut map = HashMap::new();
217 map.insert("tool:github__*".to_string(), "10/hour".to_string());
218 map.insert("tool:*".to_string(), "100/hour".to_string());
219
220 let config = parse_rate_config(&map).unwrap();
221 assert_eq!(config.limits.len(), 2);
222 assert_eq!(config.limits["tool:github__*"].count, 10);
223 assert_eq!(config.limits["tool:*"].count, 100);
224 }
225
226 #[test]
230 fn test_format_window() {
231 assert_eq!(format_window(1), "second");
232 assert_eq!(format_window(60), "minute");
233 assert_eq!(format_window(3600), "hour");
234 assert_eq!(format_window(86400), "day");
235 assert_eq!(format_window(7200), "7200s");
236 }
237}