Skip to main content

chio_guards/
agent_velocity.rs

1//! Agent velocity guard -- per-agent and per-session rate limiting.
2//!
3//! Unlike the existing `VelocityGuard` which keys on (capability_id, grant_index),
4//! this guard rate-limits by agent identity and (optionally) session, providing
5//! cross-capability rate limiting for individual agents.
6//!
7//! Uses token-bucket semantics with integer milli-token arithmetic to
8//! avoid floating-point drift. Produces `GuardEvidence` entries and
9//! fails closed on internal errors.
10
11use std::collections::HashMap;
12use std::sync::Mutex;
13use std::time::Instant;
14
15use chio_kernel::{Guard, GuardContext, KernelError, Verdict};
16
17// ---------------------------------------------------------------------------
18// Token bucket (private, same algorithm as velocity.rs)
19// ---------------------------------------------------------------------------
20
21/// Milli-tokens per logical token.
22const MT_PER_TOKEN: u64 = 1_000;
23
24struct TokenBucket {
25    capacity_mt: u64,
26    tokens_mt: u64,
27    refill_rate_mpm: u64,
28    last_refill: Instant,
29}
30
31impl TokenBucket {
32    fn new(capacity_tokens: u64, max_per_window: u64, window_secs: u64) -> Self {
33        let window_ms = window_secs.saturating_mul(1_000).max(1);
34        let refill_rate_mpm = (max_per_window.saturating_mul(MT_PER_TOKEN))
35            .checked_div(window_ms)
36            .unwrap_or(1)
37            .max(1);
38
39        Self {
40            capacity_mt: capacity_tokens.saturating_mul(MT_PER_TOKEN),
41            tokens_mt: capacity_tokens.saturating_mul(MT_PER_TOKEN),
42            refill_rate_mpm,
43            last_refill: Instant::now(),
44        }
45    }
46
47    fn try_consume(&mut self, amount_tokens: u64) -> bool {
48        self.refill();
49        let cost_mt = amount_tokens.saturating_mul(MT_PER_TOKEN);
50        if self.tokens_mt >= cost_mt {
51            self.tokens_mt -= cost_mt;
52            true
53        } else {
54            false
55        }
56    }
57
58    fn refill(&mut self) {
59        let elapsed_ms = self.last_refill.elapsed().as_millis() as u64;
60        if elapsed_ms == 0 {
61            return;
62        }
63        let added = elapsed_ms.saturating_mul(self.refill_rate_mpm);
64        self.tokens_mt = self.tokens_mt.saturating_add(added).min(self.capacity_mt);
65        self.last_refill = Instant::now();
66    }
67}
68
69// ---------------------------------------------------------------------------
70// AgentVelocityConfig
71// ---------------------------------------------------------------------------
72
73/// Configuration for the per-agent velocity guard.
74#[derive(Clone, Debug)]
75pub struct AgentVelocityConfig {
76    /// Maximum requests per agent per window. None means unlimited.
77    pub max_requests_per_agent: Option<u32>,
78    /// Maximum requests per session per window. None means unlimited.
79    pub max_requests_per_session: Option<u32>,
80    /// Window duration in seconds.
81    pub window_secs: u64,
82    /// Burst factor (1.0 = no burst above steady rate).
83    pub burst_factor: f64,
84}
85
86impl Default for AgentVelocityConfig {
87    fn default() -> Self {
88        Self {
89            max_requests_per_agent: None,
90            max_requests_per_session: None,
91            window_secs: 60,
92            burst_factor: 1.0,
93        }
94    }
95}
96
97// ---------------------------------------------------------------------------
98// AgentVelocityGuard
99// ---------------------------------------------------------------------------
100
101/// Guard that rate-limits by agent identity and session.
102///
103/// Per-agent buckets are keyed by `agent_id`. Per-session buckets are keyed
104/// by `(agent_id, capability_id)` as a session proxy (since the guard context
105/// does not directly expose session IDs, the capability ID serves as a
106/// session-scoped discriminator).
107pub struct AgentVelocityGuard {
108    agent_buckets: Mutex<HashMap<String, TokenBucket>>,
109    session_buckets: Mutex<HashMap<(String, String), TokenBucket>>,
110    config: AgentVelocityConfig,
111}
112
113impl AgentVelocityGuard {
114    /// Create a new guard with the given configuration.
115    pub fn new(config: AgentVelocityConfig) -> Self {
116        Self {
117            agent_buckets: Mutex::new(HashMap::new()),
118            session_buckets: Mutex::new(HashMap::new()),
119            config,
120        }
121    }
122}
123
124impl Guard for AgentVelocityGuard {
125    fn name(&self) -> &str {
126        "agent-velocity"
127    }
128
129    fn evaluate(&self, ctx: &GuardContext) -> Result<Verdict, KernelError> {
130        let agent_id = ctx.agent_id.clone();
131        let cap_id = ctx.request.capability.id.clone();
132        let window_secs = self.config.window_secs.max(1);
133
134        // Check per-agent rate limit.
135        if let Some(max_per_agent) = self.config.max_requests_per_agent {
136            let capacity =
137                ((max_per_agent as f64 * self.config.burst_factor).round() as u64).max(1);
138
139            let mut buckets = self.agent_buckets.lock().map_err(|_| {
140                KernelError::Internal("agent-velocity agent lock poisoned".to_string())
141            })?;
142            let bucket = buckets
143                .entry(agent_id.clone())
144                .or_insert_with(|| TokenBucket::new(capacity, max_per_agent as u64, window_secs));
145            if !bucket.try_consume(1) {
146                return Ok(Verdict::Deny);
147            }
148        }
149
150        // Check per-session rate limit.
151        if let Some(max_per_session) = self.config.max_requests_per_session {
152            let capacity =
153                ((max_per_session as f64 * self.config.burst_factor).round() as u64).max(1);
154
155            let session_key = (agent_id, cap_id);
156            let mut buckets = self.session_buckets.lock().map_err(|_| {
157                KernelError::Internal("agent-velocity session lock poisoned".to_string())
158            })?;
159            let bucket = buckets
160                .entry(session_key)
161                .or_insert_with(|| TokenBucket::new(capacity, max_per_session as u64, window_secs));
162            if !bucket.try_consume(1) {
163                return Ok(Verdict::Deny);
164            }
165        }
166
167        Ok(Verdict::Allow)
168    }
169}
170
171// ---------------------------------------------------------------------------
172// Tests
173// ---------------------------------------------------------------------------
174
175#[cfg(test)]
176mod tests {
177    use std::thread;
178    use std::time::Duration;
179
180    use chio_core::capability::{CapabilityToken, CapabilityTokenBody, ChioScope};
181    use chio_core::crypto::Keypair;
182
183    use super::*;
184
185    fn make_request(
186        cap: &CapabilityToken,
187        agent_id: &str,
188        server_id: &str,
189    ) -> chio_kernel::ToolCallRequest {
190        chio_kernel::ToolCallRequest {
191            request_id: "req-test".to_string(),
192            capability: cap.clone(),
193            tool_name: "read_file".to_string(),
194            server_id: server_id.to_string(),
195            agent_id: agent_id.to_string(),
196            arguments: serde_json::json!({}),
197            dpop_proof: None,
198            governed_intent: None,
199            approval_token: None,
200            model_metadata: None,
201            federated_origin_kernel_id: None,
202        }
203    }
204
205    fn signed_cap(kp: &Keypair, cap_id: &str) -> CapabilityToken {
206        let scope = ChioScope::default();
207        let body = CapabilityTokenBody {
208            id: cap_id.to_string(),
209            issuer: kp.public_key(),
210            subject: kp.public_key(),
211            scope,
212            issued_at: 0,
213            expires_at: u64::MAX,
214            delegation_chain: vec![],
215        };
216        CapabilityToken::sign(body, kp).expect("sign cap")
217    }
218
219    fn guard_ctx<'a>(
220        request: &'a chio_kernel::ToolCallRequest,
221        scope: &'a ChioScope,
222        agent_id: &'a String,
223        server_id: &'a String,
224    ) -> chio_kernel::GuardContext<'a> {
225        chio_kernel::GuardContext {
226            request,
227            scope,
228            agent_id,
229            server_id,
230            session_filesystem_roots: None,
231            matched_grant_index: None,
232        }
233    }
234
235    #[test]
236    fn guard_name() {
237        let guard = AgentVelocityGuard::new(AgentVelocityConfig::default());
238        assert_eq!(guard.name(), "agent-velocity");
239    }
240
241    #[test]
242    fn unlimited_config_allows_all() {
243        let guard = AgentVelocityGuard::new(AgentVelocityConfig::default());
244        let kp = Keypair::generate();
245        let cap = signed_cap(&kp, "cap-1");
246        let scope = ChioScope::default();
247        let agent = kp.public_key().to_hex();
248        let server = "srv".to_string();
249        let request = make_request(&cap, &agent, &server);
250
251        for _ in 0..100 {
252            let ctx = guard_ctx(&request, &scope, &agent, &server);
253            let result = guard.evaluate(&ctx).expect("should not error");
254            assert_eq!(result, Verdict::Allow);
255        }
256    }
257
258    #[test]
259    fn per_agent_limit_enforced() {
260        let guard = AgentVelocityGuard::new(AgentVelocityConfig {
261            max_requests_per_agent: Some(3),
262            max_requests_per_session: None,
263            window_secs: 60,
264            burst_factor: 1.0,
265        });
266
267        let kp = Keypair::generate();
268        let cap = signed_cap(&kp, "cap-1");
269        let scope = ChioScope::default();
270        let agent = kp.public_key().to_hex();
271        let server = "srv".to_string();
272        let request = make_request(&cap, &agent, &server);
273
274        // First 3 should pass.
275        for _ in 0..3 {
276            let ctx = guard_ctx(&request, &scope, &agent, &server);
277            assert_eq!(guard.evaluate(&ctx).expect("ok"), Verdict::Allow);
278        }
279
280        // 4th should deny.
281        let ctx = guard_ctx(&request, &scope, &agent, &server);
282        assert_eq!(guard.evaluate(&ctx).expect("ok"), Verdict::Deny);
283    }
284
285    #[test]
286    fn per_session_limit_enforced() {
287        let guard = AgentVelocityGuard::new(AgentVelocityConfig {
288            max_requests_per_agent: None,
289            max_requests_per_session: Some(2),
290            window_secs: 60,
291            burst_factor: 1.0,
292        });
293
294        let kp = Keypair::generate();
295        let cap = signed_cap(&kp, "cap-session");
296        let scope = ChioScope::default();
297        let agent = kp.public_key().to_hex();
298        let server = "srv".to_string();
299        let request = make_request(&cap, &agent, &server);
300
301        // First 2 pass.
302        for _ in 0..2 {
303            let ctx = guard_ctx(&request, &scope, &agent, &server);
304            assert_eq!(guard.evaluate(&ctx).expect("ok"), Verdict::Allow);
305        }
306
307        // 3rd denied.
308        let ctx = guard_ctx(&request, &scope, &agent, &server);
309        assert_eq!(guard.evaluate(&ctx).expect("ok"), Verdict::Deny);
310    }
311
312    #[test]
313    fn different_agents_get_separate_buckets() {
314        let guard = AgentVelocityGuard::new(AgentVelocityConfig {
315            max_requests_per_agent: Some(1),
316            max_requests_per_session: None,
317            window_secs: 60,
318            burst_factor: 1.0,
319        });
320
321        let kp1 = Keypair::generate();
322        let kp2 = Keypair::generate();
323        let cap = signed_cap(&kp1, "cap-shared");
324        let scope = ChioScope::default();
325        let agent1 = kp1.public_key().to_hex();
326        let agent2 = kp2.public_key().to_hex();
327        let server = "srv".to_string();
328
329        // Agent 1 exhausts its bucket.
330        let req1 = make_request(&cap, &agent1, &server);
331        let ctx1 = guard_ctx(&req1, &scope, &agent1, &server);
332        assert_eq!(guard.evaluate(&ctx1).expect("ok"), Verdict::Allow);
333        let ctx1b = guard_ctx(&req1, &scope, &agent1, &server);
334        assert_eq!(guard.evaluate(&ctx1b).expect("ok"), Verdict::Deny);
335
336        // Agent 2 should have its own bucket.
337        let req2 = make_request(&cap, &agent2, &server);
338        let ctx2 = guard_ctx(&req2, &scope, &agent2, &server);
339        assert_eq!(guard.evaluate(&ctx2).expect("ok"), Verdict::Allow);
340    }
341
342    #[test]
343    fn different_sessions_get_separate_buckets() {
344        let guard = AgentVelocityGuard::new(AgentVelocityConfig {
345            max_requests_per_agent: None,
346            max_requests_per_session: Some(1),
347            window_secs: 60,
348            burst_factor: 1.0,
349        });
350
351        let kp = Keypair::generate();
352        let cap_a = signed_cap(&kp, "session-a");
353        let cap_b = signed_cap(&kp, "session-b");
354        let scope = ChioScope::default();
355        let agent = kp.public_key().to_hex();
356        let server = "srv".to_string();
357
358        // Session A: exhaust.
359        let req_a = make_request(&cap_a, &agent, &server);
360        let ctx_a = guard_ctx(&req_a, &scope, &agent, &server);
361        assert_eq!(guard.evaluate(&ctx_a).expect("ok"), Verdict::Allow);
362        let ctx_a2 = guard_ctx(&req_a, &scope, &agent, &server);
363        assert_eq!(guard.evaluate(&ctx_a2).expect("ok"), Verdict::Deny);
364
365        // Session B: should have fresh bucket.
366        let req_b = make_request(&cap_b, &agent, &server);
367        let ctx_b = guard_ctx(&req_b, &scope, &agent, &server);
368        assert_eq!(guard.evaluate(&ctx_b).expect("ok"), Verdict::Allow);
369    }
370
371    #[test]
372    fn tokens_refill_over_time() {
373        let guard = AgentVelocityGuard::new(AgentVelocityConfig {
374            max_requests_per_agent: Some(1),
375            max_requests_per_session: None,
376            window_secs: 1,
377            burst_factor: 1.0,
378        });
379
380        let kp = Keypair::generate();
381        let cap = signed_cap(&kp, "cap-refill");
382        let scope = ChioScope::default();
383        let agent = kp.public_key().to_hex();
384        let server = "srv".to_string();
385        let request = make_request(&cap, &agent, &server);
386
387        // Exhaust.
388        let ctx = guard_ctx(&request, &scope, &agent, &server);
389        assert_eq!(guard.evaluate(&ctx).expect("ok"), Verdict::Allow);
390        let ctx2 = guard_ctx(&request, &scope, &agent, &server);
391        assert_eq!(guard.evaluate(&ctx2).expect("ok"), Verdict::Deny);
392
393        // Wait for refill.
394        thread::sleep(Duration::from_millis(1100));
395
396        let ctx3 = guard_ctx(&request, &scope, &agent, &server);
397        assert_eq!(guard.evaluate(&ctx3).expect("ok"), Verdict::Allow);
398    }
399
400    #[test]
401    fn both_limits_applied() {
402        // Agent limit = 10, session limit = 2. Session limit is stricter.
403        let guard = AgentVelocityGuard::new(AgentVelocityConfig {
404            max_requests_per_agent: Some(10),
405            max_requests_per_session: Some(2),
406            window_secs: 60,
407            burst_factor: 1.0,
408        });
409
410        let kp = Keypair::generate();
411        let cap = signed_cap(&kp, "cap-both");
412        let scope = ChioScope::default();
413        let agent = kp.public_key().to_hex();
414        let server = "srv".to_string();
415        let request = make_request(&cap, &agent, &server);
416
417        // 2 pass (session limit).
418        for _ in 0..2 {
419            let ctx = guard_ctx(&request, &scope, &agent, &server);
420            assert_eq!(guard.evaluate(&ctx).expect("ok"), Verdict::Allow);
421        }
422        // 3rd denied by session limit.
423        let ctx = guard_ctx(&request, &scope, &agent, &server);
424        assert_eq!(guard.evaluate(&ctx).expect("ok"), Verdict::Deny);
425    }
426
427    #[test]
428    fn returns_verdict_deny_not_err() {
429        let guard = AgentVelocityGuard::new(AgentVelocityConfig {
430            max_requests_per_agent: Some(1),
431            max_requests_per_session: None,
432            window_secs: 60,
433            burst_factor: 1.0,
434        });
435
436        let kp = Keypair::generate();
437        let cap = signed_cap(&kp, "cap-deny-type");
438        let scope = ChioScope::default();
439        let agent = kp.public_key().to_hex();
440        let server = "srv".to_string();
441        let request = make_request(&cap, &agent, &server);
442
443        let ctx = guard_ctx(&request, &scope, &agent, &server);
444        guard.evaluate(&ctx).expect("ok");
445
446        let ctx2 = guard_ctx(&request, &scope, &agent, &server);
447        let result = guard.evaluate(&ctx2);
448        assert!(result.is_ok());
449        assert_eq!(result.expect("ok"), Verdict::Deny);
450    }
451}