1use dashmap::DashMap;
7use http::HeaderMap;
8use std::sync::Arc;
9use tracing::{debug, info, trace};
10
11use grapsus_common::budget::{BudgetAlert, BudgetCheckResult, CostResult};
12use grapsus_config::{InferenceConfig, TokenEstimation};
13
14use super::budget::TokenBudgetTracker;
15use super::cost::CostCalculator;
16use super::providers::create_provider;
17use super::rate_limit::{TokenRateLimitResult, TokenRateLimiter};
18use super::tokens::{TokenCounter, TokenEstimate, TokenSource};
19
20struct RouteInferenceState {
22 rate_limiter: Option<TokenRateLimiter>,
24 budget_tracker: Option<TokenBudgetTracker>,
26 cost_calculator: Option<CostCalculator>,
28 token_counter: TokenCounter,
30 route_id: String,
32}
33
34pub struct InferenceRateLimitManager {
39 routes: DashMap<String, Arc<RouteInferenceState>>,
41}
42
43impl InferenceRateLimitManager {
44 pub fn new() -> Self {
46 Self {
47 routes: DashMap::new(),
48 }
49 }
50
51 pub fn register_route(&self, route_id: &str, config: &InferenceConfig) {
56 let provider = create_provider(&config.provider);
57
58 let estimation_method = config
60 .rate_limit
61 .as_ref()
62 .map(|rl| rl.estimation_method)
63 .unwrap_or(TokenEstimation::Chars);
64
65 let token_counter = TokenCounter::new(provider, estimation_method);
66
67 let rate_limiter = config.rate_limit.as_ref().map(|rl| {
69 info!(
70 route_id = route_id,
71 tokens_per_minute = rl.tokens_per_minute,
72 requests_per_minute = ?rl.requests_per_minute,
73 burst_tokens = rl.burst_tokens,
74 "Registered inference rate limiter"
75 );
76 TokenRateLimiter::new(rl.clone())
77 });
78
79 let budget_tracker = config.budget.as_ref().map(|budget| {
81 info!(
82 route_id = route_id,
83 period = ?budget.period,
84 limit = budget.limit,
85 enforce = budget.enforce,
86 "Registered token budget tracker"
87 );
88 TokenBudgetTracker::new(budget.clone(), route_id)
89 });
90
91 let cost_calculator = config.cost_attribution.as_ref().map(|cost| {
93 info!(
94 route_id = route_id,
95 enabled = cost.enabled,
96 pricing_rules = cost.pricing.len(),
97 "Registered cost calculator"
98 );
99 CostCalculator::new(cost.clone(), route_id)
100 });
101
102 if rate_limiter.is_some() || budget_tracker.is_some() || cost_calculator.is_some() {
104 let state = RouteInferenceState {
105 rate_limiter,
106 budget_tracker,
107 cost_calculator,
108 token_counter,
109 route_id: route_id.to_string(),
110 };
111
112 self.routes.insert(route_id.to_string(), Arc::new(state));
113
114 info!(
115 route_id = route_id,
116 provider = ?config.provider,
117 has_rate_limit = config.rate_limit.is_some(),
118 has_budget = config.budget.is_some(),
119 has_cost = config.cost_attribution.is_some(),
120 "Registered inference route"
121 );
122 }
123 }
124
125 pub fn has_route(&self, route_id: &str) -> bool {
127 self.routes.contains_key(route_id)
128 }
129
130 pub fn has_budget(&self, route_id: &str) -> bool {
132 self.routes
133 .get(route_id)
134 .map(|s| s.budget_tracker.is_some())
135 .unwrap_or(false)
136 }
137
138 pub fn has_cost_attribution(&self, route_id: &str) -> bool {
140 self.routes
141 .get(route_id)
142 .map(|s| {
143 s.cost_calculator
144 .as_ref()
145 .map(|c| c.is_enabled())
146 .unwrap_or(false)
147 })
148 .unwrap_or(false)
149 }
150
151 pub fn check(
155 &self,
156 route_id: &str,
157 key: &str,
158 headers: &HeaderMap,
159 body: &[u8],
160 ) -> Option<InferenceCheckResult> {
161 let state = self.routes.get(route_id)?;
162
163 let estimate = state.token_counter.estimate_request(headers, body);
165
166 trace!(
167 route_id = route_id,
168 key = key,
169 estimated_tokens = estimate.tokens,
170 model = ?estimate.model,
171 "Checking inference rate limit"
172 );
173
174 let rate_limit_result = if let Some(ref rate_limiter) = state.rate_limiter {
176 rate_limiter.check(key, estimate.tokens)
177 } else {
178 TokenRateLimitResult::Allowed
179 };
180
181 Some(InferenceCheckResult {
182 result: rate_limit_result,
183 estimated_tokens: estimate.tokens,
184 model: estimate.model,
185 })
186 }
187
188 pub fn check_budget(
192 &self,
193 route_id: &str,
194 tenant: &str,
195 estimated_tokens: u64,
196 ) -> Option<BudgetCheckResult> {
197 let state = self.routes.get(route_id)?;
198 let budget_tracker = state.budget_tracker.as_ref()?;
199
200 Some(budget_tracker.check(tenant, estimated_tokens))
201 }
202
203 pub fn record_budget(
207 &self,
208 route_id: &str,
209 tenant: &str,
210 actual_tokens: u64,
211 ) -> Vec<BudgetAlert> {
212 if let Some(state) = self.routes.get(route_id) {
213 if let Some(ref budget_tracker) = state.budget_tracker {
214 return budget_tracker.record(tenant, actual_tokens);
215 }
216 }
217 Vec::new()
218 }
219
220 pub fn budget_status(
222 &self,
223 route_id: &str,
224 tenant: &str,
225 ) -> Option<grapsus_common::budget::TenantBudgetStatus> {
226 let state = self.routes.get(route_id)?;
227 let budget_tracker = state.budget_tracker.as_ref()?;
228 Some(budget_tracker.status(tenant))
229 }
230
231 pub fn calculate_cost(
235 &self,
236 route_id: &str,
237 model: &str,
238 input_tokens: u64,
239 output_tokens: u64,
240 ) -> Option<CostResult> {
241 let state = self.routes.get(route_id)?;
242 let cost_calculator = state.cost_calculator.as_ref()?;
243
244 if !cost_calculator.is_enabled() {
245 return None;
246 }
247
248 Some(cost_calculator.calculate(model, input_tokens, output_tokens))
249 }
250
251 pub fn record_actual(
255 &self,
256 route_id: &str,
257 key: &str,
258 headers: &HeaderMap,
259 body: &[u8],
260 estimated_tokens: u64,
261 ) -> Option<TokenEstimate> {
262 let state = self.routes.get(route_id)?;
263
264 let actual = state.token_counter.tokens_from_response(headers, body);
266
267 if actual.tokens > 0 && actual.source != TokenSource::Estimated {
269 if let Some(ref rate_limiter) = state.rate_limiter {
271 rate_limiter.record_actual(key, actual.tokens, estimated_tokens);
272 }
273
274 debug!(
275 route_id = route_id,
276 key = key,
277 actual_tokens = actual.tokens,
278 estimated_tokens = estimated_tokens,
279 source = ?actual.source,
280 "Recorded actual token usage"
281 );
282 }
283
284 Some(actual)
285 }
286
287 pub fn route_count(&self) -> usize {
289 self.routes.len()
290 }
291
292 pub fn route_stats(&self, route_id: &str) -> Option<InferenceRouteStats> {
294 let state = self.routes.get(route_id)?;
295
296 let (active_keys, tokens_per_minute, requests_per_minute) =
298 if let Some(ref rate_limiter) = state.rate_limiter {
299 let stats = rate_limiter.stats();
300 (
301 stats.active_keys,
302 stats.tokens_per_minute,
303 stats.requests_per_minute,
304 )
305 } else {
306 (0, 0, None)
307 };
308
309 Some(InferenceRouteStats {
310 route_id: route_id.to_string(),
311 active_keys,
312 tokens_per_minute,
313 requests_per_minute,
314 has_budget: state.budget_tracker.is_some(),
315 has_cost_attribution: state
316 .cost_calculator
317 .as_ref()
318 .map(|c| c.is_enabled())
319 .unwrap_or(false),
320 })
321 }
322
323 pub fn cleanup(&self) {
325 trace!("Inference rate limit cleanup");
328 }
329}
330
331impl Default for InferenceRateLimitManager {
332 fn default() -> Self {
333 Self::new()
334 }
335}
336
337#[derive(Debug)]
339pub struct InferenceCheckResult {
340 pub result: TokenRateLimitResult,
342 pub estimated_tokens: u64,
344 pub model: Option<String>,
346}
347
348impl InferenceCheckResult {
349 pub fn is_allowed(&self) -> bool {
351 self.result.is_allowed()
352 }
353
354 pub fn retry_after_ms(&self) -> u64 {
356 self.result.retry_after_ms()
357 }
358}
359
360#[derive(Debug, Clone)]
362pub struct InferenceRouteStats {
363 pub route_id: String,
365 pub active_keys: usize,
367 pub tokens_per_minute: u64,
369 pub requests_per_minute: Option<u64>,
371 pub has_budget: bool,
373 pub has_cost_attribution: bool,
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use grapsus_config::{InferenceProvider, TokenRateLimit};
381
382 fn test_inference_config() -> InferenceConfig {
383 InferenceConfig {
384 provider: InferenceProvider::OpenAi,
385 model_header: None,
386 rate_limit: Some(TokenRateLimit {
387 tokens_per_minute: 10000,
388 requests_per_minute: Some(100),
389 burst_tokens: 2000,
390 estimation_method: TokenEstimation::Chars,
391 }),
392 budget: None,
393 cost_attribution: None,
394 routing: None,
395 model_routing: None,
396 guardrails: None,
397 }
398 }
399
400 #[test]
401 fn test_register_route() {
402 let manager = InferenceRateLimitManager::new();
403 manager.register_route("test-route", &test_inference_config());
404
405 assert!(manager.has_route("test-route"));
406 assert!(!manager.has_route("other-route"));
407 }
408
409 #[test]
410 fn test_check_rate_limit() {
411 let manager = InferenceRateLimitManager::new();
412 manager.register_route("test-route", &test_inference_config());
413
414 let headers = HeaderMap::new();
415 let body = br#"{"messages": [{"content": "Hello world"}]}"#;
416
417 let result = manager.check("test-route", "client-1", &headers, body);
418 assert!(result.is_some());
419
420 let check = result.unwrap();
421 assert!(check.is_allowed());
422 assert!(check.estimated_tokens > 0);
423 }
424
425 #[test]
426 fn test_no_rate_limit_config() {
427 let manager = InferenceRateLimitManager::new();
428
429 let config = InferenceConfig {
431 provider: InferenceProvider::OpenAi,
432 model_header: None,
433 rate_limit: None,
434 budget: None,
435 cost_attribution: None,
436 routing: None,
437 model_routing: None,
438 guardrails: None,
439 };
440 manager.register_route("no-limit-route", &config);
441
442 assert!(!manager.has_route("no-limit-route"));
443 }
444
445 #[test]
446 fn test_budget_only_config() {
447 use grapsus_common::budget::{BudgetPeriod, TokenBudgetConfig};
448
449 let manager = InferenceRateLimitManager::new();
450
451 let config = InferenceConfig {
452 provider: InferenceProvider::OpenAi,
453 model_header: None,
454 rate_limit: None,
455 budget: Some(TokenBudgetConfig {
456 period: BudgetPeriod::Daily,
457 limit: 100000,
458 alert_thresholds: vec![0.80, 0.90],
459 enforce: true,
460 rollover: false,
461 burst_allowance: None,
462 }),
463 cost_attribution: None,
464 routing: None,
465 model_routing: None,
466 guardrails: None,
467 };
468 manager.register_route("budget-route", &config);
469
470 assert!(manager.has_route("budget-route"));
471 assert!(manager.has_budget("budget-route"));
472 assert!(!manager.has_cost_attribution("budget-route"));
473 }
474}