sentinel_proxy/inference/
manager.rs1use dashmap::DashMap;
7use http::HeaderMap;
8use std::sync::Arc;
9use tracing::{debug, info, trace};
10
11use sentinel_common::budget::{BudgetAlert, BudgetCheckResult, CostResult};
12use sentinel_config::{InferenceConfig, TokenEstimation};
13
14use super::budget::TokenBudgetTracker;
15use super::cost::CostCalculator;
16use super::providers::create_provider;
17use super::rate_limit::{TokenRateLimitResult, TokenRateLimiter};
18use super::tokens::{TokenCounter, TokenEstimate, TokenSource};
19
20struct RouteInferenceState {
22 rate_limiter: Option<TokenRateLimiter>,
24 budget_tracker: Option<TokenBudgetTracker>,
26 cost_calculator: Option<CostCalculator>,
28 token_counter: TokenCounter,
30 route_id: String,
32}
33
34pub struct InferenceRateLimitManager {
39 routes: DashMap<String, Arc<RouteInferenceState>>,
41}
42
43impl InferenceRateLimitManager {
44 pub fn new() -> Self {
46 Self {
47 routes: DashMap::new(),
48 }
49 }
50
51 pub fn register_route(&self, route_id: &str, config: &InferenceConfig) {
56 let provider = create_provider(&config.provider);
57
58 let estimation_method = config
60 .rate_limit
61 .as_ref()
62 .map(|rl| rl.estimation_method)
63 .unwrap_or(TokenEstimation::Chars);
64
65 let token_counter = TokenCounter::new(provider, estimation_method);
66
67 let rate_limiter = config.rate_limit.as_ref().map(|rl| {
69 info!(
70 route_id = route_id,
71 tokens_per_minute = rl.tokens_per_minute,
72 requests_per_minute = ?rl.requests_per_minute,
73 burst_tokens = rl.burst_tokens,
74 "Registered inference rate limiter"
75 );
76 TokenRateLimiter::new(rl.clone())
77 });
78
79 let budget_tracker = config.budget.as_ref().map(|budget| {
81 info!(
82 route_id = route_id,
83 period = ?budget.period,
84 limit = budget.limit,
85 enforce = budget.enforce,
86 "Registered token budget tracker"
87 );
88 TokenBudgetTracker::new(budget.clone(), route_id)
89 });
90
91 let cost_calculator = config.cost_attribution.as_ref().map(|cost| {
93 info!(
94 route_id = route_id,
95 enabled = cost.enabled,
96 pricing_rules = cost.pricing.len(),
97 "Registered cost calculator"
98 );
99 CostCalculator::new(cost.clone(), route_id)
100 });
101
102 if rate_limiter.is_some() || budget_tracker.is_some() || cost_calculator.is_some() {
104 let state = RouteInferenceState {
105 rate_limiter,
106 budget_tracker,
107 cost_calculator,
108 token_counter,
109 route_id: route_id.to_string(),
110 };
111
112 self.routes.insert(route_id.to_string(), Arc::new(state));
113
114 info!(
115 route_id = route_id,
116 provider = ?config.provider,
117 has_rate_limit = config.rate_limit.is_some(),
118 has_budget = config.budget.is_some(),
119 has_cost = config.cost_attribution.is_some(),
120 "Registered inference route"
121 );
122 }
123 }
124
125 pub fn has_route(&self, route_id: &str) -> bool {
127 self.routes.contains_key(route_id)
128 }
129
130 pub fn has_budget(&self, route_id: &str) -> bool {
132 self.routes
133 .get(route_id)
134 .map(|s| s.budget_tracker.is_some())
135 .unwrap_or(false)
136 }
137
138 pub fn has_cost_attribution(&self, route_id: &str) -> bool {
140 self.routes
141 .get(route_id)
142 .map(|s| s.cost_calculator.as_ref().map(|c| c.is_enabled()).unwrap_or(false))
143 .unwrap_or(false)
144 }
145
146 pub fn check(
150 &self,
151 route_id: &str,
152 key: &str,
153 headers: &HeaderMap,
154 body: &[u8],
155 ) -> Option<InferenceCheckResult> {
156 let state = self.routes.get(route_id)?;
157
158 let estimate = state.token_counter.estimate_request(headers, body);
160
161 trace!(
162 route_id = route_id,
163 key = key,
164 estimated_tokens = estimate.tokens,
165 model = ?estimate.model,
166 "Checking inference rate limit"
167 );
168
169 let rate_limit_result = if let Some(ref rate_limiter) = state.rate_limiter {
171 rate_limiter.check(key, estimate.tokens)
172 } else {
173 TokenRateLimitResult::Allowed
174 };
175
176 Some(InferenceCheckResult {
177 result: rate_limit_result,
178 estimated_tokens: estimate.tokens,
179 model: estimate.model,
180 })
181 }
182
183 pub fn check_budget(
187 &self,
188 route_id: &str,
189 tenant: &str,
190 estimated_tokens: u64,
191 ) -> Option<BudgetCheckResult> {
192 let state = self.routes.get(route_id)?;
193 let budget_tracker = state.budget_tracker.as_ref()?;
194
195 Some(budget_tracker.check(tenant, estimated_tokens))
196 }
197
198 pub fn record_budget(
202 &self,
203 route_id: &str,
204 tenant: &str,
205 actual_tokens: u64,
206 ) -> Vec<BudgetAlert> {
207 if let Some(state) = self.routes.get(route_id) {
208 if let Some(ref budget_tracker) = state.budget_tracker {
209 return budget_tracker.record(tenant, actual_tokens);
210 }
211 }
212 Vec::new()
213 }
214
215 pub fn budget_status(
217 &self,
218 route_id: &str,
219 tenant: &str,
220 ) -> Option<sentinel_common::budget::TenantBudgetStatus> {
221 let state = self.routes.get(route_id)?;
222 let budget_tracker = state.budget_tracker.as_ref()?;
223 Some(budget_tracker.status(tenant))
224 }
225
226 pub fn calculate_cost(
230 &self,
231 route_id: &str,
232 model: &str,
233 input_tokens: u64,
234 output_tokens: u64,
235 ) -> Option<CostResult> {
236 let state = self.routes.get(route_id)?;
237 let cost_calculator = state.cost_calculator.as_ref()?;
238
239 if !cost_calculator.is_enabled() {
240 return None;
241 }
242
243 Some(cost_calculator.calculate(model, input_tokens, output_tokens))
244 }
245
246 pub fn record_actual(
250 &self,
251 route_id: &str,
252 key: &str,
253 headers: &HeaderMap,
254 body: &[u8],
255 estimated_tokens: u64,
256 ) -> Option<TokenEstimate> {
257 let state = self.routes.get(route_id)?;
258
259 let actual = state.token_counter.tokens_from_response(headers, body);
261
262 if actual.tokens > 0 && actual.source != TokenSource::Estimated {
264 if let Some(ref rate_limiter) = state.rate_limiter {
266 rate_limiter.record_actual(key, actual.tokens, estimated_tokens);
267 }
268
269 debug!(
270 route_id = route_id,
271 key = key,
272 actual_tokens = actual.tokens,
273 estimated_tokens = estimated_tokens,
274 source = ?actual.source,
275 "Recorded actual token usage"
276 );
277 }
278
279 Some(actual)
280 }
281
282 pub fn route_count(&self) -> usize {
284 self.routes.len()
285 }
286
287 pub fn route_stats(&self, route_id: &str) -> Option<InferenceRouteStats> {
289 let state = self.routes.get(route_id)?;
290
291 let (active_keys, tokens_per_minute, requests_per_minute) =
293 if let Some(ref rate_limiter) = state.rate_limiter {
294 let stats = rate_limiter.stats();
295 (stats.active_keys, stats.tokens_per_minute, stats.requests_per_minute)
296 } else {
297 (0, 0, None)
298 };
299
300 Some(InferenceRouteStats {
301 route_id: route_id.to_string(),
302 active_keys,
303 tokens_per_minute,
304 requests_per_minute,
305 has_budget: state.budget_tracker.is_some(),
306 has_cost_attribution: state.cost_calculator.as_ref().map(|c| c.is_enabled()).unwrap_or(false),
307 })
308 }
309
310 pub fn cleanup(&self) {
312 trace!("Inference rate limit cleanup");
315 }
316}
317
318impl Default for InferenceRateLimitManager {
319 fn default() -> Self {
320 Self::new()
321 }
322}
323
324#[derive(Debug)]
326pub struct InferenceCheckResult {
327 pub result: TokenRateLimitResult,
329 pub estimated_tokens: u64,
331 pub model: Option<String>,
333}
334
335impl InferenceCheckResult {
336 pub fn is_allowed(&self) -> bool {
338 self.result.is_allowed()
339 }
340
341 pub fn retry_after_ms(&self) -> u64 {
343 self.result.retry_after_ms()
344 }
345}
346
347#[derive(Debug, Clone)]
349pub struct InferenceRouteStats {
350 pub route_id: String,
352 pub active_keys: usize,
354 pub tokens_per_minute: u64,
356 pub requests_per_minute: Option<u64>,
358 pub has_budget: bool,
360 pub has_cost_attribution: bool,
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use sentinel_config::{InferenceProvider, TokenRateLimit};
368
369 fn test_inference_config() -> InferenceConfig {
370 InferenceConfig {
371 provider: InferenceProvider::OpenAi,
372 model_header: None,
373 rate_limit: Some(TokenRateLimit {
374 tokens_per_minute: 10000,
375 requests_per_minute: Some(100),
376 burst_tokens: 2000,
377 estimation_method: TokenEstimation::Chars,
378 }),
379 budget: None,
380 cost_attribution: None,
381 routing: None,
382 model_routing: None,
383 guardrails: None,
384 }
385 }
386
387 #[test]
388 fn test_register_route() {
389 let manager = InferenceRateLimitManager::new();
390 manager.register_route("test-route", &test_inference_config());
391
392 assert!(manager.has_route("test-route"));
393 assert!(!manager.has_route("other-route"));
394 }
395
396 #[test]
397 fn test_check_rate_limit() {
398 let manager = InferenceRateLimitManager::new();
399 manager.register_route("test-route", &test_inference_config());
400
401 let headers = HeaderMap::new();
402 let body = br#"{"messages": [{"content": "Hello world"}]}"#;
403
404 let result = manager.check("test-route", "client-1", &headers, body);
405 assert!(result.is_some());
406
407 let check = result.unwrap();
408 assert!(check.is_allowed());
409 assert!(check.estimated_tokens > 0);
410 }
411
412 #[test]
413 fn test_no_rate_limit_config() {
414 let manager = InferenceRateLimitManager::new();
415
416 let config = InferenceConfig {
418 provider: InferenceProvider::OpenAi,
419 model_header: None,
420 rate_limit: None,
421 budget: None,
422 cost_attribution: None,
423 routing: None,
424 model_routing: None,
425 guardrails: None,
426 };
427 manager.register_route("no-limit-route", &config);
428
429 assert!(!manager.has_route("no-limit-route"));
430 }
431
432 #[test]
433 fn test_budget_only_config() {
434 use sentinel_common::budget::{BudgetPeriod, TokenBudgetConfig};
435
436 let manager = InferenceRateLimitManager::new();
437
438 let config = InferenceConfig {
439 provider: InferenceProvider::OpenAi,
440 model_header: None,
441 rate_limit: None,
442 budget: Some(TokenBudgetConfig {
443 period: BudgetPeriod::Daily,
444 limit: 100000,
445 alert_thresholds: vec![0.80, 0.90],
446 enforce: true,
447 rollover: false,
448 burst_allowance: None,
449 }),
450 cost_attribution: None,
451 routing: None,
452 model_routing: None,
453 guardrails: None,
454 };
455 manager.register_route("budget-route", &config);
456
457 assert!(manager.has_route("budget-route"));
458 assert!(manager.has_budget("budget-route"));
459 assert!(!manager.has_cost_attribution("budget-route"));
460 }
461}