Skip to main content

grapsus_proxy/inference/
manager.rs

1//! Inference rate limit manager
2//!
3//! Manages token-based rate limiters, budgets, and cost calculators per route,
4//! integrating with the request flow for inference endpoints.
5
6use dashmap::DashMap;
7use http::HeaderMap;
8use std::sync::Arc;
9use tracing::{debug, info, trace};
10
11use grapsus_common::budget::{BudgetAlert, BudgetCheckResult, CostResult};
12use grapsus_config::{InferenceConfig, TokenEstimation};
13
14use super::budget::TokenBudgetTracker;
15use super::cost::CostCalculator;
16use super::providers::create_provider;
17use super::rate_limit::{TokenRateLimitResult, TokenRateLimiter};
18use super::tokens::{TokenCounter, TokenEstimate, TokenSource};
19
20/// Per-route inference state with rate limiter, budget, and cost tracking.
21struct RouteInferenceState {
22    /// Token rate limiter (per-minute)
23    rate_limiter: Option<TokenRateLimiter>,
24    /// Token budget tracker (per-period cumulative)
25    budget_tracker: Option<TokenBudgetTracker>,
26    /// Cost calculator
27    cost_calculator: Option<CostCalculator>,
28    /// Token counter (for estimation and actual counting)
29    token_counter: TokenCounter,
30    /// Route ID for logging
31    route_id: String,
32}
33
34/// Manager for inference rate limiting, budgets, and cost tracking.
35///
36/// Each route with inference configuration gets its own TokenRateLimiter,
37/// TokenBudgetTracker, and CostCalculator based on the route's configuration.
38pub struct InferenceRateLimitManager {
39    /// Per-route inference state (keyed by route ID)
40    routes: DashMap<String, Arc<RouteInferenceState>>,
41}
42
43impl InferenceRateLimitManager {
44    /// Create a new inference rate limit manager
45    pub fn new() -> Self {
46        Self {
47            routes: DashMap::new(),
48        }
49    }
50
51    /// Register a route with inference configuration.
52    ///
53    /// Creates a TokenRateLimiter, TokenBudgetTracker, and CostCalculator
54    /// as configured for the route.
55    pub fn register_route(&self, route_id: &str, config: &InferenceConfig) {
56        let provider = create_provider(&config.provider);
57
58        // Determine estimation method (from rate_limit or default)
59        let estimation_method = config
60            .rate_limit
61            .as_ref()
62            .map(|rl| rl.estimation_method)
63            .unwrap_or(TokenEstimation::Chars);
64
65        let token_counter = TokenCounter::new(provider, estimation_method);
66
67        // Create rate limiter if configured
68        let rate_limiter = config.rate_limit.as_ref().map(|rl| {
69            info!(
70                route_id = route_id,
71                tokens_per_minute = rl.tokens_per_minute,
72                requests_per_minute = ?rl.requests_per_minute,
73                burst_tokens = rl.burst_tokens,
74                "Registered inference rate limiter"
75            );
76            TokenRateLimiter::new(rl.clone())
77        });
78
79        // Create budget tracker if configured
80        let budget_tracker = config.budget.as_ref().map(|budget| {
81            info!(
82                route_id = route_id,
83                period = ?budget.period,
84                limit = budget.limit,
85                enforce = budget.enforce,
86                "Registered token budget tracker"
87            );
88            TokenBudgetTracker::new(budget.clone(), route_id)
89        });
90
91        // Create cost calculator if configured
92        let cost_calculator = config.cost_attribution.as_ref().map(|cost| {
93            info!(
94                route_id = route_id,
95                enabled = cost.enabled,
96                pricing_rules = cost.pricing.len(),
97                "Registered cost calculator"
98            );
99            CostCalculator::new(cost.clone(), route_id)
100        });
101
102        // Only register if at least one feature is enabled
103        if rate_limiter.is_some() || budget_tracker.is_some() || cost_calculator.is_some() {
104            let state = RouteInferenceState {
105                rate_limiter,
106                budget_tracker,
107                cost_calculator,
108                token_counter,
109                route_id: route_id.to_string(),
110            };
111
112            self.routes.insert(route_id.to_string(), Arc::new(state));
113
114            info!(
115                route_id = route_id,
116                provider = ?config.provider,
117                has_rate_limit = config.rate_limit.is_some(),
118                has_budget = config.budget.is_some(),
119                has_cost = config.cost_attribution.is_some(),
120                "Registered inference route"
121            );
122        }
123    }
124
125    /// Check if a route has inference configuration registered.
126    pub fn has_route(&self, route_id: &str) -> bool {
127        self.routes.contains_key(route_id)
128    }
129
130    /// Check if a route has budget tracking enabled.
131    pub fn has_budget(&self, route_id: &str) -> bool {
132        self.routes
133            .get(route_id)
134            .map(|s| s.budget_tracker.is_some())
135            .unwrap_or(false)
136    }
137
138    /// Check if a route has cost attribution enabled.
139    pub fn has_cost_attribution(&self, route_id: &str) -> bool {
140        self.routes
141            .get(route_id)
142            .map(|s| {
143                s.cost_calculator
144                    .as_ref()
145                    .map(|c| c.is_enabled())
146                    .unwrap_or(false)
147            })
148            .unwrap_or(false)
149    }
150
151    /// Check rate limit for a request.
152    ///
153    /// Returns the rate limit result and the estimated token count.
154    pub fn check(
155        &self,
156        route_id: &str,
157        key: &str,
158        headers: &HeaderMap,
159        body: &[u8],
160    ) -> Option<InferenceCheckResult> {
161        let state = self.routes.get(route_id)?;
162
163        // Estimate tokens for the request
164        let estimate = state.token_counter.estimate_request(headers, body);
165
166        trace!(
167            route_id = route_id,
168            key = key,
169            estimated_tokens = estimate.tokens,
170            model = ?estimate.model,
171            "Checking inference rate limit"
172        );
173
174        // Check rate limit if configured
175        let rate_limit_result = if let Some(ref rate_limiter) = state.rate_limiter {
176            rate_limiter.check(key, estimate.tokens)
177        } else {
178            TokenRateLimitResult::Allowed
179        };
180
181        Some(InferenceCheckResult {
182            result: rate_limit_result,
183            estimated_tokens: estimate.tokens,
184            model: estimate.model,
185        })
186    }
187
188    /// Check budget for a request.
189    ///
190    /// Returns the budget check result, or None if no budget is configured.
191    pub fn check_budget(
192        &self,
193        route_id: &str,
194        tenant: &str,
195        estimated_tokens: u64,
196    ) -> Option<BudgetCheckResult> {
197        let state = self.routes.get(route_id)?;
198        let budget_tracker = state.budget_tracker.as_ref()?;
199
200        Some(budget_tracker.check(tenant, estimated_tokens))
201    }
202
203    /// Record budget usage after a request completes.
204    ///
205    /// Returns any budget alerts that were triggered.
206    pub fn record_budget(
207        &self,
208        route_id: &str,
209        tenant: &str,
210        actual_tokens: u64,
211    ) -> Vec<BudgetAlert> {
212        if let Some(state) = self.routes.get(route_id) {
213            if let Some(ref budget_tracker) = state.budget_tracker {
214                return budget_tracker.record(tenant, actual_tokens);
215            }
216        }
217        Vec::new()
218    }
219
220    /// Get budget status for a tenant.
221    pub fn budget_status(
222        &self,
223        route_id: &str,
224        tenant: &str,
225    ) -> Option<grapsus_common::budget::TenantBudgetStatus> {
226        let state = self.routes.get(route_id)?;
227        let budget_tracker = state.budget_tracker.as_ref()?;
228        Some(budget_tracker.status(tenant))
229    }
230
231    /// Calculate cost for a request.
232    ///
233    /// Returns the cost result, or None if cost attribution is not configured.
234    pub fn calculate_cost(
235        &self,
236        route_id: &str,
237        model: &str,
238        input_tokens: u64,
239        output_tokens: u64,
240    ) -> Option<CostResult> {
241        let state = self.routes.get(route_id)?;
242        let cost_calculator = state.cost_calculator.as_ref()?;
243
244        if !cost_calculator.is_enabled() {
245            return None;
246        }
247
248        Some(cost_calculator.calculate(model, input_tokens, output_tokens))
249    }
250
251    /// Record actual token usage from response.
252    ///
253    /// This adjusts the rate limiter based on actual vs estimated usage.
254    pub fn record_actual(
255        &self,
256        route_id: &str,
257        key: &str,
258        headers: &HeaderMap,
259        body: &[u8],
260        estimated_tokens: u64,
261    ) -> Option<TokenEstimate> {
262        let state = self.routes.get(route_id)?;
263
264        // Get actual token count from response
265        let actual = state.token_counter.tokens_from_response(headers, body);
266
267        // Only record if we got actual tokens
268        if actual.tokens > 0 && actual.source != TokenSource::Estimated {
269            // Update rate limiter if configured
270            if let Some(ref rate_limiter) = state.rate_limiter {
271                rate_limiter.record_actual(key, actual.tokens, estimated_tokens);
272            }
273
274            debug!(
275                route_id = route_id,
276                key = key,
277                actual_tokens = actual.tokens,
278                estimated_tokens = estimated_tokens,
279                source = ?actual.source,
280                "Recorded actual token usage"
281            );
282        }
283
284        Some(actual)
285    }
286
287    /// Get the number of registered routes
288    pub fn route_count(&self) -> usize {
289        self.routes.len()
290    }
291
292    /// Get stats for a route.
293    pub fn route_stats(&self, route_id: &str) -> Option<InferenceRouteStats> {
294        let state = self.routes.get(route_id)?;
295
296        // Get rate limit stats if available
297        let (active_keys, tokens_per_minute, requests_per_minute) =
298            if let Some(ref rate_limiter) = state.rate_limiter {
299                let stats = rate_limiter.stats();
300                (
301                    stats.active_keys,
302                    stats.tokens_per_minute,
303                    stats.requests_per_minute,
304                )
305            } else {
306                (0, 0, None)
307            };
308
309        Some(InferenceRouteStats {
310            route_id: route_id.to_string(),
311            active_keys,
312            tokens_per_minute,
313            requests_per_minute,
314            has_budget: state.budget_tracker.is_some(),
315            has_cost_attribution: state
316                .cost_calculator
317                .as_ref()
318                .map(|c| c.is_enabled())
319                .unwrap_or(false),
320        })
321    }
322
323    /// Clean up idle rate limiters (called periodically)
324    pub fn cleanup(&self) {
325        // Currently, cleanup is handled internally by the rate limiters
326        // This is a hook for future cleanup logic
327        trace!("Inference rate limit cleanup");
328    }
329}
330
331impl Default for InferenceRateLimitManager {
332    fn default() -> Self {
333        Self::new()
334    }
335}
336
337/// Result of an inference rate limit check
338#[derive(Debug)]
339pub struct InferenceCheckResult {
340    /// Rate limit decision
341    pub result: TokenRateLimitResult,
342    /// Estimated tokens for this request
343    pub estimated_tokens: u64,
344    /// Model name if detected
345    pub model: Option<String>,
346}
347
348impl InferenceCheckResult {
349    /// Returns true if the request is allowed
350    pub fn is_allowed(&self) -> bool {
351        self.result.is_allowed()
352    }
353
354    /// Get retry-after value in milliseconds (0 if allowed)
355    pub fn retry_after_ms(&self) -> u64 {
356        self.result.retry_after_ms()
357    }
358}
359
360/// Stats for a route's inference configuration.
361#[derive(Debug, Clone)]
362pub struct InferenceRouteStats {
363    /// Route ID
364    pub route_id: String,
365    /// Number of active rate limit keys
366    pub active_keys: usize,
367    /// Configured tokens per minute (0 if no rate limiting)
368    pub tokens_per_minute: u64,
369    /// Configured requests per minute (if any)
370    pub requests_per_minute: Option<u64>,
371    /// Whether budget tracking is enabled
372    pub has_budget: bool,
373    /// Whether cost attribution is enabled
374    pub has_cost_attribution: bool,
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use grapsus_config::{InferenceProvider, TokenRateLimit};
381
382    fn test_inference_config() -> InferenceConfig {
383        InferenceConfig {
384            provider: InferenceProvider::OpenAi,
385            model_header: None,
386            rate_limit: Some(TokenRateLimit {
387                tokens_per_minute: 10000,
388                requests_per_minute: Some(100),
389                burst_tokens: 2000,
390                estimation_method: TokenEstimation::Chars,
391            }),
392            budget: None,
393            cost_attribution: None,
394            routing: None,
395            model_routing: None,
396            guardrails: None,
397        }
398    }
399
400    #[test]
401    fn test_register_route() {
402        let manager = InferenceRateLimitManager::new();
403        manager.register_route("test-route", &test_inference_config());
404
405        assert!(manager.has_route("test-route"));
406        assert!(!manager.has_route("other-route"));
407    }
408
409    #[test]
410    fn test_check_rate_limit() {
411        let manager = InferenceRateLimitManager::new();
412        manager.register_route("test-route", &test_inference_config());
413
414        let headers = HeaderMap::new();
415        let body = br#"{"messages": [{"content": "Hello world"}]}"#;
416
417        let result = manager.check("test-route", "client-1", &headers, body);
418        assert!(result.is_some());
419
420        let check = result.unwrap();
421        assert!(check.is_allowed());
422        assert!(check.estimated_tokens > 0);
423    }
424
425    #[test]
426    fn test_no_rate_limit_config() {
427        let manager = InferenceRateLimitManager::new();
428
429        // Config without any features should not register
430        let config = InferenceConfig {
431            provider: InferenceProvider::OpenAi,
432            model_header: None,
433            rate_limit: None,
434            budget: None,
435            cost_attribution: None,
436            routing: None,
437            model_routing: None,
438            guardrails: None,
439        };
440        manager.register_route("no-limit-route", &config);
441
442        assert!(!manager.has_route("no-limit-route"));
443    }
444
445    #[test]
446    fn test_budget_only_config() {
447        use grapsus_common::budget::{BudgetPeriod, TokenBudgetConfig};
448
449        let manager = InferenceRateLimitManager::new();
450
451        let config = InferenceConfig {
452            provider: InferenceProvider::OpenAi,
453            model_header: None,
454            rate_limit: None,
455            budget: Some(TokenBudgetConfig {
456                period: BudgetPeriod::Daily,
457                limit: 100000,
458                alert_thresholds: vec![0.80, 0.90],
459                enforce: true,
460                rollover: false,
461                burst_allowance: None,
462            }),
463            cost_attribution: None,
464            routing: None,
465            model_routing: None,
466            guardrails: None,
467        };
468        manager.register_route("budget-route", &config);
469
470        assert!(manager.has_route("budget-route"));
471        assert!(manager.has_budget("budget-route"));
472        assert!(!manager.has_cost_attribution("budget-route"));
473    }
474}