1use std::collections::HashMap;
12use std::sync::Mutex;
13use std::time::Instant;
14
15use chio_kernel::{Guard, GuardContext, KernelError, Verdict};
16
17const MT_PER_TOKEN: u64 = 1_000;
23
24struct TokenBucket {
25 capacity_mt: u64,
26 tokens_mt: u64,
27 refill_rate_mpm: u64,
28 last_refill: Instant,
29}
30
31impl TokenBucket {
32 fn new(capacity_tokens: u64, max_per_window: u64, window_secs: u64) -> Self {
33 let window_ms = window_secs.saturating_mul(1_000).max(1);
34 let refill_rate_mpm = (max_per_window.saturating_mul(MT_PER_TOKEN))
35 .checked_div(window_ms)
36 .unwrap_or(1)
37 .max(1);
38
39 Self {
40 capacity_mt: capacity_tokens.saturating_mul(MT_PER_TOKEN),
41 tokens_mt: capacity_tokens.saturating_mul(MT_PER_TOKEN),
42 refill_rate_mpm,
43 last_refill: Instant::now(),
44 }
45 }
46
47 fn try_consume(&mut self, amount_tokens: u64) -> bool {
48 self.refill();
49 let cost_mt = amount_tokens.saturating_mul(MT_PER_TOKEN);
50 if self.tokens_mt >= cost_mt {
51 self.tokens_mt -= cost_mt;
52 true
53 } else {
54 false
55 }
56 }
57
58 fn refill(&mut self) {
59 let elapsed_ms = self.last_refill.elapsed().as_millis() as u64;
60 if elapsed_ms == 0 {
61 return;
62 }
63 let added = elapsed_ms.saturating_mul(self.refill_rate_mpm);
64 self.tokens_mt = self.tokens_mt.saturating_add(added).min(self.capacity_mt);
65 self.last_refill = Instant::now();
66 }
67}
68
69#[derive(Clone, Debug)]
75pub struct AgentVelocityConfig {
76 pub max_requests_per_agent: Option<u32>,
78 pub max_requests_per_session: Option<u32>,
80 pub window_secs: u64,
82 pub burst_factor: f64,
84}
85
86impl Default for AgentVelocityConfig {
87 fn default() -> Self {
88 Self {
89 max_requests_per_agent: None,
90 max_requests_per_session: None,
91 window_secs: 60,
92 burst_factor: 1.0,
93 }
94 }
95}
96
97pub struct AgentVelocityGuard {
108 agent_buckets: Mutex<HashMap<String, TokenBucket>>,
109 session_buckets: Mutex<HashMap<(String, String), TokenBucket>>,
110 config: AgentVelocityConfig,
111}
112
113impl AgentVelocityGuard {
114 pub fn new(config: AgentVelocityConfig) -> Self {
116 Self {
117 agent_buckets: Mutex::new(HashMap::new()),
118 session_buckets: Mutex::new(HashMap::new()),
119 config,
120 }
121 }
122}
123
124impl Guard for AgentVelocityGuard {
125 fn name(&self) -> &str {
126 "agent-velocity"
127 }
128
129 fn evaluate(&self, ctx: &GuardContext) -> Result<Verdict, KernelError> {
130 let agent_id = ctx.agent_id.clone();
131 let cap_id = ctx.request.capability.id.clone();
132 let window_secs = self.config.window_secs.max(1);
133
134 if let Some(max_per_agent) = self.config.max_requests_per_agent {
136 let capacity =
137 ((max_per_agent as f64 * self.config.burst_factor).round() as u64).max(1);
138
139 let mut buckets = self.agent_buckets.lock().map_err(|_| {
140 KernelError::Internal("agent-velocity agent lock poisoned".to_string())
141 })?;
142 let bucket = buckets
143 .entry(agent_id.clone())
144 .or_insert_with(|| TokenBucket::new(capacity, max_per_agent as u64, window_secs));
145 if !bucket.try_consume(1) {
146 return Ok(Verdict::Deny);
147 }
148 }
149
150 if let Some(max_per_session) = self.config.max_requests_per_session {
152 let capacity =
153 ((max_per_session as f64 * self.config.burst_factor).round() as u64).max(1);
154
155 let session_key = (agent_id, cap_id);
156 let mut buckets = self.session_buckets.lock().map_err(|_| {
157 KernelError::Internal("agent-velocity session lock poisoned".to_string())
158 })?;
159 let bucket = buckets
160 .entry(session_key)
161 .or_insert_with(|| TokenBucket::new(capacity, max_per_session as u64, window_secs));
162 if !bucket.try_consume(1) {
163 return Ok(Verdict::Deny);
164 }
165 }
166
167 Ok(Verdict::Allow)
168 }
169}
170
171#[cfg(test)]
176mod tests {
177 use std::thread;
178 use std::time::Duration;
179
180 use chio_core::capability::{CapabilityToken, CapabilityTokenBody, ChioScope};
181 use chio_core::crypto::Keypair;
182
183 use super::*;
184
185 fn make_request(
186 cap: &CapabilityToken,
187 agent_id: &str,
188 server_id: &str,
189 ) -> chio_kernel::ToolCallRequest {
190 chio_kernel::ToolCallRequest {
191 request_id: "req-test".to_string(),
192 capability: cap.clone(),
193 tool_name: "read_file".to_string(),
194 server_id: server_id.to_string(),
195 agent_id: agent_id.to_string(),
196 arguments: serde_json::json!({}),
197 dpop_proof: None,
198 governed_intent: None,
199 approval_token: None,
200 model_metadata: None,
201 federated_origin_kernel_id: None,
202 }
203 }
204
205 fn signed_cap(kp: &Keypair, cap_id: &str) -> CapabilityToken {
206 let scope = ChioScope::default();
207 let body = CapabilityTokenBody {
208 id: cap_id.to_string(),
209 issuer: kp.public_key(),
210 subject: kp.public_key(),
211 scope,
212 issued_at: 0,
213 expires_at: u64::MAX,
214 delegation_chain: vec![],
215 };
216 CapabilityToken::sign(body, kp).expect("sign cap")
217 }
218
219 fn guard_ctx<'a>(
220 request: &'a chio_kernel::ToolCallRequest,
221 scope: &'a ChioScope,
222 agent_id: &'a String,
223 server_id: &'a String,
224 ) -> chio_kernel::GuardContext<'a> {
225 chio_kernel::GuardContext {
226 request,
227 scope,
228 agent_id,
229 server_id,
230 session_filesystem_roots: None,
231 matched_grant_index: None,
232 }
233 }
234
235 #[test]
236 fn guard_name() {
237 let guard = AgentVelocityGuard::new(AgentVelocityConfig::default());
238 assert_eq!(guard.name(), "agent-velocity");
239 }
240
241 #[test]
242 fn unlimited_config_allows_all() {
243 let guard = AgentVelocityGuard::new(AgentVelocityConfig::default());
244 let kp = Keypair::generate();
245 let cap = signed_cap(&kp, "cap-1");
246 let scope = ChioScope::default();
247 let agent = kp.public_key().to_hex();
248 let server = "srv".to_string();
249 let request = make_request(&cap, &agent, &server);
250
251 for _ in 0..100 {
252 let ctx = guard_ctx(&request, &scope, &agent, &server);
253 let result = guard.evaluate(&ctx).expect("should not error");
254 assert_eq!(result, Verdict::Allow);
255 }
256 }
257
258 #[test]
259 fn per_agent_limit_enforced() {
260 let guard = AgentVelocityGuard::new(AgentVelocityConfig {
261 max_requests_per_agent: Some(3),
262 max_requests_per_session: None,
263 window_secs: 60,
264 burst_factor: 1.0,
265 });
266
267 let kp = Keypair::generate();
268 let cap = signed_cap(&kp, "cap-1");
269 let scope = ChioScope::default();
270 let agent = kp.public_key().to_hex();
271 let server = "srv".to_string();
272 let request = make_request(&cap, &agent, &server);
273
274 for _ in 0..3 {
276 let ctx = guard_ctx(&request, &scope, &agent, &server);
277 assert_eq!(guard.evaluate(&ctx).expect("ok"), Verdict::Allow);
278 }
279
280 let ctx = guard_ctx(&request, &scope, &agent, &server);
282 assert_eq!(guard.evaluate(&ctx).expect("ok"), Verdict::Deny);
283 }
284
285 #[test]
286 fn per_session_limit_enforced() {
287 let guard = AgentVelocityGuard::new(AgentVelocityConfig {
288 max_requests_per_agent: None,
289 max_requests_per_session: Some(2),
290 window_secs: 60,
291 burst_factor: 1.0,
292 });
293
294 let kp = Keypair::generate();
295 let cap = signed_cap(&kp, "cap-session");
296 let scope = ChioScope::default();
297 let agent = kp.public_key().to_hex();
298 let server = "srv".to_string();
299 let request = make_request(&cap, &agent, &server);
300
301 for _ in 0..2 {
303 let ctx = guard_ctx(&request, &scope, &agent, &server);
304 assert_eq!(guard.evaluate(&ctx).expect("ok"), Verdict::Allow);
305 }
306
307 let ctx = guard_ctx(&request, &scope, &agent, &server);
309 assert_eq!(guard.evaluate(&ctx).expect("ok"), Verdict::Deny);
310 }
311
312 #[test]
313 fn different_agents_get_separate_buckets() {
314 let guard = AgentVelocityGuard::new(AgentVelocityConfig {
315 max_requests_per_agent: Some(1),
316 max_requests_per_session: None,
317 window_secs: 60,
318 burst_factor: 1.0,
319 });
320
321 let kp1 = Keypair::generate();
322 let kp2 = Keypair::generate();
323 let cap = signed_cap(&kp1, "cap-shared");
324 let scope = ChioScope::default();
325 let agent1 = kp1.public_key().to_hex();
326 let agent2 = kp2.public_key().to_hex();
327 let server = "srv".to_string();
328
329 let req1 = make_request(&cap, &agent1, &server);
331 let ctx1 = guard_ctx(&req1, &scope, &agent1, &server);
332 assert_eq!(guard.evaluate(&ctx1).expect("ok"), Verdict::Allow);
333 let ctx1b = guard_ctx(&req1, &scope, &agent1, &server);
334 assert_eq!(guard.evaluate(&ctx1b).expect("ok"), Verdict::Deny);
335
336 let req2 = make_request(&cap, &agent2, &server);
338 let ctx2 = guard_ctx(&req2, &scope, &agent2, &server);
339 assert_eq!(guard.evaluate(&ctx2).expect("ok"), Verdict::Allow);
340 }
341
342 #[test]
343 fn different_sessions_get_separate_buckets() {
344 let guard = AgentVelocityGuard::new(AgentVelocityConfig {
345 max_requests_per_agent: None,
346 max_requests_per_session: Some(1),
347 window_secs: 60,
348 burst_factor: 1.0,
349 });
350
351 let kp = Keypair::generate();
352 let cap_a = signed_cap(&kp, "session-a");
353 let cap_b = signed_cap(&kp, "session-b");
354 let scope = ChioScope::default();
355 let agent = kp.public_key().to_hex();
356 let server = "srv".to_string();
357
358 let req_a = make_request(&cap_a, &agent, &server);
360 let ctx_a = guard_ctx(&req_a, &scope, &agent, &server);
361 assert_eq!(guard.evaluate(&ctx_a).expect("ok"), Verdict::Allow);
362 let ctx_a2 = guard_ctx(&req_a, &scope, &agent, &server);
363 assert_eq!(guard.evaluate(&ctx_a2).expect("ok"), Verdict::Deny);
364
365 let req_b = make_request(&cap_b, &agent, &server);
367 let ctx_b = guard_ctx(&req_b, &scope, &agent, &server);
368 assert_eq!(guard.evaluate(&ctx_b).expect("ok"), Verdict::Allow);
369 }
370
371 #[test]
372 fn tokens_refill_over_time() {
373 let guard = AgentVelocityGuard::new(AgentVelocityConfig {
374 max_requests_per_agent: Some(1),
375 max_requests_per_session: None,
376 window_secs: 1,
377 burst_factor: 1.0,
378 });
379
380 let kp = Keypair::generate();
381 let cap = signed_cap(&kp, "cap-refill");
382 let scope = ChioScope::default();
383 let agent = kp.public_key().to_hex();
384 let server = "srv".to_string();
385 let request = make_request(&cap, &agent, &server);
386
387 let ctx = guard_ctx(&request, &scope, &agent, &server);
389 assert_eq!(guard.evaluate(&ctx).expect("ok"), Verdict::Allow);
390 let ctx2 = guard_ctx(&request, &scope, &agent, &server);
391 assert_eq!(guard.evaluate(&ctx2).expect("ok"), Verdict::Deny);
392
393 thread::sleep(Duration::from_millis(1100));
395
396 let ctx3 = guard_ctx(&request, &scope, &agent, &server);
397 assert_eq!(guard.evaluate(&ctx3).expect("ok"), Verdict::Allow);
398 }
399
400 #[test]
401 fn both_limits_applied() {
402 let guard = AgentVelocityGuard::new(AgentVelocityConfig {
404 max_requests_per_agent: Some(10),
405 max_requests_per_session: Some(2),
406 window_secs: 60,
407 burst_factor: 1.0,
408 });
409
410 let kp = Keypair::generate();
411 let cap = signed_cap(&kp, "cap-both");
412 let scope = ChioScope::default();
413 let agent = kp.public_key().to_hex();
414 let server = "srv".to_string();
415 let request = make_request(&cap, &agent, &server);
416
417 for _ in 0..2 {
419 let ctx = guard_ctx(&request, &scope, &agent, &server);
420 assert_eq!(guard.evaluate(&ctx).expect("ok"), Verdict::Allow);
421 }
422 let ctx = guard_ctx(&request, &scope, &agent, &server);
424 assert_eq!(guard.evaluate(&ctx).expect("ok"), Verdict::Deny);
425 }
426
427 #[test]
428 fn returns_verdict_deny_not_err() {
429 let guard = AgentVelocityGuard::new(AgentVelocityConfig {
430 max_requests_per_agent: Some(1),
431 max_requests_per_session: None,
432 window_secs: 60,
433 burst_factor: 1.0,
434 });
435
436 let kp = Keypair::generate();
437 let cap = signed_cap(&kp, "cap-deny-type");
438 let scope = ChioScope::default();
439 let agent = kp.public_key().to_hex();
440 let server = "srv".to_string();
441 let request = make_request(&cap, &agent, &server);
442
443 let ctx = guard_ctx(&request, &scope, &agent, &server);
444 guard.evaluate(&ctx).expect("ok");
445
446 let ctx2 = guard_ctx(&request, &scope, &agent, &server);
447 let result = guard.evaluate(&ctx2);
448 assert!(result.is_ok());
449 assert_eq!(result.expect("ok"), Verdict::Deny);
450 }
451}