sentinel_proxy/inference/
manager.rs

1//! Inference rate limit manager
2//!
3//! Manages token-based rate limiters, budgets, and cost calculators per route,
4//! integrating with the request flow for inference endpoints.
5
6use dashmap::DashMap;
7use http::HeaderMap;
8use std::sync::Arc;
9use tracing::{debug, info, trace};
10
11use sentinel_common::budget::{BudgetAlert, BudgetCheckResult, CostResult};
12use sentinel_config::{InferenceConfig, TokenEstimation};
13
14use super::budget::TokenBudgetTracker;
15use super::cost::CostCalculator;
16use super::providers::create_provider;
17use super::rate_limit::{TokenRateLimitResult, TokenRateLimiter};
18use super::tokens::{TokenCounter, TokenEstimate, TokenSource};
19
20/// Per-route inference state with rate limiter, budget, and cost tracking.
21struct RouteInferenceState {
22    /// Token rate limiter (per-minute)
23    rate_limiter: Option<TokenRateLimiter>,
24    /// Token budget tracker (per-period cumulative)
25    budget_tracker: Option<TokenBudgetTracker>,
26    /// Cost calculator
27    cost_calculator: Option<CostCalculator>,
28    /// Token counter (for estimation and actual counting)
29    token_counter: TokenCounter,
30    /// Route ID for logging
31    route_id: String,
32}
33
34/// Manager for inference rate limiting, budgets, and cost tracking.
35///
36/// Each route with inference configuration gets its own TokenRateLimiter,
37/// TokenBudgetTracker, and CostCalculator based on the route's configuration.
38pub struct InferenceRateLimitManager {
39    /// Per-route inference state (keyed by route ID)
40    routes: DashMap<String, Arc<RouteInferenceState>>,
41}
42
43impl InferenceRateLimitManager {
44    /// Create a new inference rate limit manager
45    pub fn new() -> Self {
46        Self {
47            routes: DashMap::new(),
48        }
49    }
50
51    /// Register a route with inference configuration.
52    ///
53    /// Creates a TokenRateLimiter, TokenBudgetTracker, and CostCalculator
54    /// as configured for the route.
55    pub fn register_route(&self, route_id: &str, config: &InferenceConfig) {
56        let provider = create_provider(&config.provider);
57
58        // Determine estimation method (from rate_limit or default)
59        let estimation_method = config
60            .rate_limit
61            .as_ref()
62            .map(|rl| rl.estimation_method)
63            .unwrap_or(TokenEstimation::Chars);
64
65        let token_counter = TokenCounter::new(provider, estimation_method);
66
67        // Create rate limiter if configured
68        let rate_limiter = config.rate_limit.as_ref().map(|rl| {
69            info!(
70                route_id = route_id,
71                tokens_per_minute = rl.tokens_per_minute,
72                requests_per_minute = ?rl.requests_per_minute,
73                burst_tokens = rl.burst_tokens,
74                "Registered inference rate limiter"
75            );
76            TokenRateLimiter::new(rl.clone())
77        });
78
79        // Create budget tracker if configured
80        let budget_tracker = config.budget.as_ref().map(|budget| {
81            info!(
82                route_id = route_id,
83                period = ?budget.period,
84                limit = budget.limit,
85                enforce = budget.enforce,
86                "Registered token budget tracker"
87            );
88            TokenBudgetTracker::new(budget.clone(), route_id)
89        });
90
91        // Create cost calculator if configured
92        let cost_calculator = config.cost_attribution.as_ref().map(|cost| {
93            info!(
94                route_id = route_id,
95                enabled = cost.enabled,
96                pricing_rules = cost.pricing.len(),
97                "Registered cost calculator"
98            );
99            CostCalculator::new(cost.clone(), route_id)
100        });
101
102        // Only register if at least one feature is enabled
103        if rate_limiter.is_some() || budget_tracker.is_some() || cost_calculator.is_some() {
104            let state = RouteInferenceState {
105                rate_limiter,
106                budget_tracker,
107                cost_calculator,
108                token_counter,
109                route_id: route_id.to_string(),
110            };
111
112            self.routes.insert(route_id.to_string(), Arc::new(state));
113
114            info!(
115                route_id = route_id,
116                provider = ?config.provider,
117                has_rate_limit = config.rate_limit.is_some(),
118                has_budget = config.budget.is_some(),
119                has_cost = config.cost_attribution.is_some(),
120                "Registered inference route"
121            );
122        }
123    }
124
125    /// Check if a route has inference configuration registered.
126    pub fn has_route(&self, route_id: &str) -> bool {
127        self.routes.contains_key(route_id)
128    }
129
130    /// Check if a route has budget tracking enabled.
131    pub fn has_budget(&self, route_id: &str) -> bool {
132        self.routes
133            .get(route_id)
134            .map(|s| s.budget_tracker.is_some())
135            .unwrap_or(false)
136    }
137
138    /// Check if a route has cost attribution enabled.
139    pub fn has_cost_attribution(&self, route_id: &str) -> bool {
140        self.routes
141            .get(route_id)
142            .map(|s| s.cost_calculator.as_ref().map(|c| c.is_enabled()).unwrap_or(false))
143            .unwrap_or(false)
144    }
145
146    /// Check rate limit for a request.
147    ///
148    /// Returns the rate limit result and the estimated token count.
149    pub fn check(
150        &self,
151        route_id: &str,
152        key: &str,
153        headers: &HeaderMap,
154        body: &[u8],
155    ) -> Option<InferenceCheckResult> {
156        let state = self.routes.get(route_id)?;
157
158        // Estimate tokens for the request
159        let estimate = state.token_counter.estimate_request(headers, body);
160
161        trace!(
162            route_id = route_id,
163            key = key,
164            estimated_tokens = estimate.tokens,
165            model = ?estimate.model,
166            "Checking inference rate limit"
167        );
168
169        // Check rate limit if configured
170        let rate_limit_result = if let Some(ref rate_limiter) = state.rate_limiter {
171            rate_limiter.check(key, estimate.tokens)
172        } else {
173            TokenRateLimitResult::Allowed
174        };
175
176        Some(InferenceCheckResult {
177            result: rate_limit_result,
178            estimated_tokens: estimate.tokens,
179            model: estimate.model,
180        })
181    }
182
183    /// Check budget for a request.
184    ///
185    /// Returns the budget check result, or None if no budget is configured.
186    pub fn check_budget(
187        &self,
188        route_id: &str,
189        tenant: &str,
190        estimated_tokens: u64,
191    ) -> Option<BudgetCheckResult> {
192        let state = self.routes.get(route_id)?;
193        let budget_tracker = state.budget_tracker.as_ref()?;
194
195        Some(budget_tracker.check(tenant, estimated_tokens))
196    }
197
198    /// Record budget usage after a request completes.
199    ///
200    /// Returns any budget alerts that were triggered.
201    pub fn record_budget(
202        &self,
203        route_id: &str,
204        tenant: &str,
205        actual_tokens: u64,
206    ) -> Vec<BudgetAlert> {
207        if let Some(state) = self.routes.get(route_id) {
208            if let Some(ref budget_tracker) = state.budget_tracker {
209                return budget_tracker.record(tenant, actual_tokens);
210            }
211        }
212        Vec::new()
213    }
214
215    /// Get budget status for a tenant.
216    pub fn budget_status(
217        &self,
218        route_id: &str,
219        tenant: &str,
220    ) -> Option<sentinel_common::budget::TenantBudgetStatus> {
221        let state = self.routes.get(route_id)?;
222        let budget_tracker = state.budget_tracker.as_ref()?;
223        Some(budget_tracker.status(tenant))
224    }
225
226    /// Calculate cost for a request.
227    ///
228    /// Returns the cost result, or None if cost attribution is not configured.
229    pub fn calculate_cost(
230        &self,
231        route_id: &str,
232        model: &str,
233        input_tokens: u64,
234        output_tokens: u64,
235    ) -> Option<CostResult> {
236        let state = self.routes.get(route_id)?;
237        let cost_calculator = state.cost_calculator.as_ref()?;
238
239        if !cost_calculator.is_enabled() {
240            return None;
241        }
242
243        Some(cost_calculator.calculate(model, input_tokens, output_tokens))
244    }
245
246    /// Record actual token usage from response.
247    ///
248    /// This adjusts the rate limiter based on actual vs estimated usage.
249    pub fn record_actual(
250        &self,
251        route_id: &str,
252        key: &str,
253        headers: &HeaderMap,
254        body: &[u8],
255        estimated_tokens: u64,
256    ) -> Option<TokenEstimate> {
257        let state = self.routes.get(route_id)?;
258
259        // Get actual token count from response
260        let actual = state.token_counter.tokens_from_response(headers, body);
261
262        // Only record if we got actual tokens
263        if actual.tokens > 0 && actual.source != TokenSource::Estimated {
264            // Update rate limiter if configured
265            if let Some(ref rate_limiter) = state.rate_limiter {
266                rate_limiter.record_actual(key, actual.tokens, estimated_tokens);
267            }
268
269            debug!(
270                route_id = route_id,
271                key = key,
272                actual_tokens = actual.tokens,
273                estimated_tokens = estimated_tokens,
274                source = ?actual.source,
275                "Recorded actual token usage"
276            );
277        }
278
279        Some(actual)
280    }
281
282    /// Get the number of registered routes
283    pub fn route_count(&self) -> usize {
284        self.routes.len()
285    }
286
287    /// Get stats for a route.
288    pub fn route_stats(&self, route_id: &str) -> Option<InferenceRouteStats> {
289        let state = self.routes.get(route_id)?;
290
291        // Get rate limit stats if available
292        let (active_keys, tokens_per_minute, requests_per_minute) =
293            if let Some(ref rate_limiter) = state.rate_limiter {
294                let stats = rate_limiter.stats();
295                (stats.active_keys, stats.tokens_per_minute, stats.requests_per_minute)
296            } else {
297                (0, 0, None)
298            };
299
300        Some(InferenceRouteStats {
301            route_id: route_id.to_string(),
302            active_keys,
303            tokens_per_minute,
304            requests_per_minute,
305            has_budget: state.budget_tracker.is_some(),
306            has_cost_attribution: state.cost_calculator.as_ref().map(|c| c.is_enabled()).unwrap_or(false),
307        })
308    }
309
310    /// Clean up idle rate limiters (called periodically)
311    pub fn cleanup(&self) {
312        // Currently, cleanup is handled internally by the rate limiters
313        // This is a hook for future cleanup logic
314        trace!("Inference rate limit cleanup");
315    }
316}
317
318impl Default for InferenceRateLimitManager {
319    fn default() -> Self {
320        Self::new()
321    }
322}
323
324/// Result of an inference rate limit check
325#[derive(Debug)]
326pub struct InferenceCheckResult {
327    /// Rate limit decision
328    pub result: TokenRateLimitResult,
329    /// Estimated tokens for this request
330    pub estimated_tokens: u64,
331    /// Model name if detected
332    pub model: Option<String>,
333}
334
335impl InferenceCheckResult {
336    /// Returns true if the request is allowed
337    pub fn is_allowed(&self) -> bool {
338        self.result.is_allowed()
339    }
340
341    /// Get retry-after value in milliseconds (0 if allowed)
342    pub fn retry_after_ms(&self) -> u64 {
343        self.result.retry_after_ms()
344    }
345}
346
347/// Stats for a route's inference configuration.
348#[derive(Debug, Clone)]
349pub struct InferenceRouteStats {
350    /// Route ID
351    pub route_id: String,
352    /// Number of active rate limit keys
353    pub active_keys: usize,
354    /// Configured tokens per minute (0 if no rate limiting)
355    pub tokens_per_minute: u64,
356    /// Configured requests per minute (if any)
357    pub requests_per_minute: Option<u64>,
358    /// Whether budget tracking is enabled
359    pub has_budget: bool,
360    /// Whether cost attribution is enabled
361    pub has_cost_attribution: bool,
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use sentinel_config::{InferenceProvider, TokenRateLimit};
368
369    fn test_inference_config() -> InferenceConfig {
370        InferenceConfig {
371            provider: InferenceProvider::OpenAi,
372            model_header: None,
373            rate_limit: Some(TokenRateLimit {
374                tokens_per_minute: 10000,
375                requests_per_minute: Some(100),
376                burst_tokens: 2000,
377                estimation_method: TokenEstimation::Chars,
378            }),
379            budget: None,
380            cost_attribution: None,
381            routing: None,
382            model_routing: None,
383            guardrails: None,
384        }
385    }
386
387    #[test]
388    fn test_register_route() {
389        let manager = InferenceRateLimitManager::new();
390        manager.register_route("test-route", &test_inference_config());
391
392        assert!(manager.has_route("test-route"));
393        assert!(!manager.has_route("other-route"));
394    }
395
396    #[test]
397    fn test_check_rate_limit() {
398        let manager = InferenceRateLimitManager::new();
399        manager.register_route("test-route", &test_inference_config());
400
401        let headers = HeaderMap::new();
402        let body = br#"{"messages": [{"content": "Hello world"}]}"#;
403
404        let result = manager.check("test-route", "client-1", &headers, body);
405        assert!(result.is_some());
406
407        let check = result.unwrap();
408        assert!(check.is_allowed());
409        assert!(check.estimated_tokens > 0);
410    }
411
412    #[test]
413    fn test_no_rate_limit_config() {
414        let manager = InferenceRateLimitManager::new();
415
416        // Config without any features should not register
417        let config = InferenceConfig {
418            provider: InferenceProvider::OpenAi,
419            model_header: None,
420            rate_limit: None,
421            budget: None,
422            cost_attribution: None,
423            routing: None,
424            model_routing: None,
425            guardrails: None,
426        };
427        manager.register_route("no-limit-route", &config);
428
429        assert!(!manager.has_route("no-limit-route"));
430    }
431
432    #[test]
433    fn test_budget_only_config() {
434        use sentinel_common::budget::{BudgetPeriod, TokenBudgetConfig};
435
436        let manager = InferenceRateLimitManager::new();
437
438        let config = InferenceConfig {
439            provider: InferenceProvider::OpenAi,
440            model_header: None,
441            rate_limit: None,
442            budget: Some(TokenBudgetConfig {
443                period: BudgetPeriod::Daily,
444                limit: 100000,
445                alert_thresholds: vec![0.80, 0.90],
446                enforce: true,
447                rollover: false,
448                burst_allowance: None,
449            }),
450            cost_attribution: None,
451            routing: None,
452            model_routing: None,
453            guardrails: None,
454        };
455        manager.register_route("budget-route", &config);
456
457        assert!(manager.has_route("budget-route"));
458        assert!(manager.has_budget("budget-route"));
459        assert!(!manager.has_cost_attribution("budget-route"));
460    }
461}