use dashmap::DashMap;
use http::HeaderMap;
use std::sync::Arc;
use tracing::{debug, info, trace};
use zentinel_common::budget::{BudgetAlert, BudgetCheckResult, CostResult};
use zentinel_config::{InferenceConfig, TokenEstimation};
use super::budget::TokenBudgetTracker;
use super::cost::CostCalculator;
use super::providers::create_provider;
use super::rate_limit::{TokenRateLimitResult, TokenRateLimiter};
use super::tokens::{TokenCounter, TokenEstimate, TokenSource};
struct RouteInferenceState {
rate_limiter: Option<TokenRateLimiter>,
budget_tracker: Option<TokenBudgetTracker>,
cost_calculator: Option<CostCalculator>,
token_counter: TokenCounter,
route_id: String,
}
pub struct InferenceRateLimitManager {
routes: DashMap<String, Arc<RouteInferenceState>>,
}
impl InferenceRateLimitManager {
pub fn new() -> Self {
Self {
routes: DashMap::new(),
}
}
pub fn register_route(&self, route_id: &str, config: &InferenceConfig) {
let provider = create_provider(&config.provider);
let estimation_method = config
.rate_limit
.as_ref()
.map(|rl| rl.estimation_method)
.unwrap_or(TokenEstimation::Chars);
let token_counter = TokenCounter::new(provider, estimation_method);
let rate_limiter = config.rate_limit.as_ref().map(|rl| {
info!(
route_id = route_id,
tokens_per_minute = rl.tokens_per_minute,
requests_per_minute = ?rl.requests_per_minute,
burst_tokens = rl.burst_tokens,
"Registered inference rate limiter"
);
TokenRateLimiter::new(rl.clone())
});
let budget_tracker = config.budget.as_ref().map(|budget| {
info!(
route_id = route_id,
period = ?budget.period,
limit = budget.limit,
enforce = budget.enforce,
"Registered token budget tracker"
);
TokenBudgetTracker::new(budget.clone(), route_id)
});
let cost_calculator = config.cost_attribution.as_ref().map(|cost| {
info!(
route_id = route_id,
enabled = cost.enabled,
pricing_rules = cost.pricing.len(),
"Registered cost calculator"
);
CostCalculator::new(cost.clone(), route_id)
});
if rate_limiter.is_some() || budget_tracker.is_some() || cost_calculator.is_some() {
let state = RouteInferenceState {
rate_limiter,
budget_tracker,
cost_calculator,
token_counter,
route_id: route_id.to_string(),
};
self.routes.insert(route_id.to_string(), Arc::new(state));
info!(
route_id = route_id,
provider = ?config.provider,
has_rate_limit = config.rate_limit.is_some(),
has_budget = config.budget.is_some(),
has_cost = config.cost_attribution.is_some(),
"Registered inference route"
);
}
}
pub fn has_route(&self, route_id: &str) -> bool {
self.routes.contains_key(route_id)
}
pub fn has_budget(&self, route_id: &str) -> bool {
self.routes
.get(route_id)
.map(|s| s.budget_tracker.is_some())
.unwrap_or(false)
}
pub fn has_cost_attribution(&self, route_id: &str) -> bool {
self.routes
.get(route_id)
.map(|s| {
s.cost_calculator
.as_ref()
.map(|c| c.is_enabled())
.unwrap_or(false)
})
.unwrap_or(false)
}
pub fn check(
&self,
route_id: &str,
key: &str,
headers: &HeaderMap,
body: &[u8],
) -> Option<InferenceCheckResult> {
let state = self.routes.get(route_id)?;
let estimate = state.token_counter.estimate_request(headers, body);
trace!(
route_id = route_id,
key = key,
estimated_tokens = estimate.tokens,
model = ?estimate.model,
"Checking inference rate limit"
);
let rate_limit_result = if let Some(ref rate_limiter) = state.rate_limiter {
rate_limiter.check(key, estimate.tokens)
} else {
TokenRateLimitResult::Allowed
};
Some(InferenceCheckResult {
result: rate_limit_result,
estimated_tokens: estimate.tokens,
model: estimate.model,
})
}
pub fn check_budget(
&self,
route_id: &str,
tenant: &str,
estimated_tokens: u64,
) -> Option<BudgetCheckResult> {
let state = self.routes.get(route_id)?;
let budget_tracker = state.budget_tracker.as_ref()?;
Some(budget_tracker.check(tenant, estimated_tokens))
}
pub fn record_budget(
&self,
route_id: &str,
tenant: &str,
actual_tokens: u64,
) -> Vec<BudgetAlert> {
if let Some(state) = self.routes.get(route_id) {
if let Some(ref budget_tracker) = state.budget_tracker {
return budget_tracker.record(tenant, actual_tokens);
}
}
Vec::new()
}
pub fn budget_status(
&self,
route_id: &str,
tenant: &str,
) -> Option<zentinel_common::budget::TenantBudgetStatus> {
let state = self.routes.get(route_id)?;
let budget_tracker = state.budget_tracker.as_ref()?;
Some(budget_tracker.status(tenant))
}
pub fn calculate_cost(
&self,
route_id: &str,
model: &str,
input_tokens: u64,
output_tokens: u64,
) -> Option<CostResult> {
let state = self.routes.get(route_id)?;
let cost_calculator = state.cost_calculator.as_ref()?;
if !cost_calculator.is_enabled() {
return None;
}
Some(cost_calculator.calculate(model, input_tokens, output_tokens))
}
pub fn record_actual(
&self,
route_id: &str,
key: &str,
headers: &HeaderMap,
body: &[u8],
estimated_tokens: u64,
) -> Option<TokenEstimate> {
let state = self.routes.get(route_id)?;
let actual = state.token_counter.tokens_from_response(headers, body);
if actual.tokens > 0 && actual.source != TokenSource::Estimated {
if let Some(ref rate_limiter) = state.rate_limiter {
rate_limiter.record_actual(key, actual.tokens, estimated_tokens);
}
debug!(
route_id = route_id,
key = key,
actual_tokens = actual.tokens,
estimated_tokens = estimated_tokens,
source = ?actual.source,
"Recorded actual token usage"
);
}
Some(actual)
}
pub fn route_count(&self) -> usize {
self.routes.len()
}
pub fn route_stats(&self, route_id: &str) -> Option<InferenceRouteStats> {
let state = self.routes.get(route_id)?;
let (active_keys, tokens_per_minute, requests_per_minute) =
if let Some(ref rate_limiter) = state.rate_limiter {
let stats = rate_limiter.stats();
(
stats.active_keys,
stats.tokens_per_minute,
stats.requests_per_minute,
)
} else {
(0, 0, None)
};
Some(InferenceRouteStats {
route_id: route_id.to_string(),
active_keys,
tokens_per_minute,
requests_per_minute,
has_budget: state.budget_tracker.is_some(),
has_cost_attribution: state
.cost_calculator
.as_ref()
.map(|c| c.is_enabled())
.unwrap_or(false),
})
}
pub fn cleanup(&self) {
trace!("Inference rate limit cleanup");
}
}
impl Default for InferenceRateLimitManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct InferenceCheckResult {
pub result: TokenRateLimitResult,
pub estimated_tokens: u64,
pub model: Option<String>,
}
impl InferenceCheckResult {
pub fn is_allowed(&self) -> bool {
self.result.is_allowed()
}
pub fn retry_after_ms(&self) -> u64 {
self.result.retry_after_ms()
}
}
#[derive(Debug, Clone)]
pub struct InferenceRouteStats {
pub route_id: String,
pub active_keys: usize,
pub tokens_per_minute: u64,
pub requests_per_minute: Option<u64>,
pub has_budget: bool,
pub has_cost_attribution: bool,
}
#[cfg(test)]
mod tests {
use super::*;
use zentinel_config::{InferenceProvider, TokenRateLimit};
fn test_inference_config() -> InferenceConfig {
InferenceConfig {
provider: InferenceProvider::OpenAi,
model_header: None,
rate_limit: Some(TokenRateLimit {
tokens_per_minute: 10000,
requests_per_minute: Some(100),
burst_tokens: 2000,
estimation_method: TokenEstimation::Chars,
}),
budget: None,
cost_attribution: None,
routing: None,
model_routing: None,
guardrails: None,
}
}
#[test]
fn test_register_route() {
let manager = InferenceRateLimitManager::new();
manager.register_route("test-route", &test_inference_config());
assert!(manager.has_route("test-route"));
assert!(!manager.has_route("other-route"));
}
#[test]
fn test_check_rate_limit() {
let manager = InferenceRateLimitManager::new();
manager.register_route("test-route", &test_inference_config());
let headers = HeaderMap::new();
let body = br#"{"messages": [{"content": "Hello world"}]}"#;
let result = manager.check("test-route", "client-1", &headers, body);
assert!(result.is_some());
let check = result.unwrap();
assert!(check.is_allowed());
assert!(check.estimated_tokens > 0);
}
#[test]
fn test_no_rate_limit_config() {
let manager = InferenceRateLimitManager::new();
let config = InferenceConfig {
provider: InferenceProvider::OpenAi,
model_header: None,
rate_limit: None,
budget: None,
cost_attribution: None,
routing: None,
model_routing: None,
guardrails: None,
};
manager.register_route("no-limit-route", &config);
assert!(!manager.has_route("no-limit-route"));
}
#[test]
fn test_budget_only_config() {
use zentinel_common::budget::{BudgetPeriod, TokenBudgetConfig};
let manager = InferenceRateLimitManager::new();
let config = InferenceConfig {
provider: InferenceProvider::OpenAi,
model_header: None,
rate_limit: None,
budget: Some(TokenBudgetConfig {
period: BudgetPeriod::Daily,
limit: 100000,
alert_thresholds: vec![0.80, 0.90],
enforce: true,
rollover: false,
burst_allowance: None,
}),
cost_attribution: None,
routing: None,
model_routing: None,
guardrails: None,
};
manager.register_route("budget-route", &config);
assert!(manager.has_route("budget-route"));
assert!(manager.has_budget("budget-route"));
assert!(!manager.has_cost_attribution("budget-route"));
}
}