use std::collections::{HashMap, HashSet};
use std::pin::Pin;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use futures::Stream;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
use crate::provider::Provider;
use crate::types::{
BatchJob, BatchRequest, BatchResult, CompletionRequest, CompletionResponse, StreamChunk,
TokenCountRequest, TokenCountResult,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TenantId(String);
impl TenantId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for TenantId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<String> for TenantId {
fn from(s: String) -> Self {
Self(s)
}
}
impl From<&str> for TenantId {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub requests_per_minute: Option<u32>,
pub requests_per_hour: Option<u32>,
pub requests_per_day: Option<u32>,
pub tokens_per_minute: Option<u64>,
pub tokens_per_hour: Option<u64>,
pub tokens_per_day: Option<u64>,
pub max_concurrent: Option<u32>,
}
impl RateLimitConfig {
pub fn basic(requests_per_minute: u32, tokens_per_minute: u64) -> Self {
Self {
requests_per_minute: Some(requests_per_minute),
tokens_per_minute: Some(tokens_per_minute),
..Default::default()
}
}
pub fn with_max_concurrent(mut self, max: u32) -> Self {
self.max_concurrent = Some(max);
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CostLimitConfig {
pub per_request_limit_usd: Option<f64>,
pub daily_limit_usd: Option<f64>,
pub weekly_limit_usd: Option<f64>,
pub monthly_limit_usd: Option<f64>,
pub alert_threshold: Option<f64>,
}
impl CostLimitConfig {
pub fn basic(daily_limit: f64, monthly_limit: f64) -> Self {
Self {
daily_limit_usd: Some(daily_limit),
monthly_limit_usd: Some(monthly_limit),
..Default::default()
}
}
pub fn with_alert_threshold(mut self, threshold: f64) -> Self {
self.alert_threshold = Some(threshold.clamp(0.0, 1.0));
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TenantConfig {
pub id: TenantId,
pub name: Option<String>,
pub allowed_models: HashSet<String>,
pub blocked_models: HashSet<String>,
pub rate_limit: Option<RateLimitConfig>,
pub cost_limit: Option<CostLimitConfig>,
pub active: bool,
pub metadata: HashMap<String, String>,
}
impl TenantConfig {
pub fn new(id: impl Into<TenantId>) -> Self {
Self {
id: id.into(),
name: None,
allowed_models: HashSet::new(),
blocked_models: HashSet::new(),
rate_limit: None,
cost_limit: None,
active: true,
metadata: HashMap::new(),
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_allowed_models<I, S>(mut self, models: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.allowed_models = models.into_iter().map(Into::into).collect();
self
}
pub fn allow_model(mut self, model: impl Into<String>) -> Self {
self.allowed_models.insert(model.into());
self
}
pub fn block_model(mut self, model: impl Into<String>) -> Self {
self.blocked_models.insert(model.into());
self
}
pub fn with_rate_limit(mut self, config: RateLimitConfig) -> Self {
self.rate_limit = Some(config);
self
}
pub fn with_cost_limit(mut self, config: CostLimitConfig) -> Self {
self.cost_limit = Some(config);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn is_model_allowed(&self, model: &str) -> bool {
if self.blocked_models.contains(model) {
return false;
}
if self.allowed_models.is_empty() {
return true;
}
self.allowed_models.contains(model)
}
}
#[derive(Debug)]
struct RateLimiterState {
requests_minute: AtomicU32,
requests_hour: AtomicU32,
requests_day: AtomicU32,
tokens_minute: AtomicU64,
tokens_hour: AtomicU64,
tokens_day: AtomicU64,
concurrent: AtomicU32,
minute_start: RwLock<Instant>,
hour_start: RwLock<Instant>,
day_start: RwLock<Instant>,
}
impl Default for RateLimiterState {
fn default() -> Self {
let now = Instant::now();
Self {
requests_minute: AtomicU32::new(0),
requests_hour: AtomicU32::new(0),
requests_day: AtomicU32::new(0),
tokens_minute: AtomicU64::new(0),
tokens_hour: AtomicU64::new(0),
tokens_day: AtomicU64::new(0),
concurrent: AtomicU32::new(0),
minute_start: RwLock::new(now),
hour_start: RwLock::new(now),
day_start: RwLock::new(now),
}
}
}
impl RateLimiterState {
fn reset_if_needed(&self) {
let now = Instant::now();
{
let mut minute_start = self.minute_start.write();
if now.duration_since(*minute_start) >= Duration::from_secs(60) {
*minute_start = now;
self.requests_minute.store(0, Ordering::Relaxed);
self.tokens_minute.store(0, Ordering::Relaxed);
}
}
{
let mut hour_start = self.hour_start.write();
if now.duration_since(*hour_start) >= Duration::from_secs(3600) {
*hour_start = now;
self.requests_hour.store(0, Ordering::Relaxed);
self.tokens_hour.store(0, Ordering::Relaxed);
}
}
{
let mut day_start = self.day_start.write();
if now.duration_since(*day_start) >= Duration::from_secs(86400) {
*day_start = now;
self.requests_day.store(0, Ordering::Relaxed);
self.tokens_day.store(0, Ordering::Relaxed);
}
}
}
}
#[derive(Debug, Default)]
struct CostTrackerState {
daily_cost: AtomicU64,
weekly_cost: AtomicU64,
monthly_cost: AtomicU64,
day_start_ms: AtomicU64,
week_start_ms: AtomicU64,
month_start_ms: AtomicU64,
}
impl CostTrackerState {
fn new() -> Self {
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
Self {
daily_cost: AtomicU64::new(0),
weekly_cost: AtomicU64::new(0),
monthly_cost: AtomicU64::new(0),
day_start_ms: AtomicU64::new(now_ms),
week_start_ms: AtomicU64::new(now_ms),
month_start_ms: AtomicU64::new(now_ms),
}
}
fn reset_if_needed(&self) {
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let day_ms = 86400 * 1000;
let week_ms = 7 * day_ms;
let month_ms = 30 * day_ms;
let day_start = self.day_start_ms.load(Ordering::Relaxed);
if now_ms - day_start >= day_ms {
self.day_start_ms.store(now_ms, Ordering::Relaxed);
self.daily_cost.store(0, Ordering::Relaxed);
}
let week_start = self.week_start_ms.load(Ordering::Relaxed);
if now_ms - week_start >= week_ms {
self.week_start_ms.store(now_ms, Ordering::Relaxed);
self.weekly_cost.store(0, Ordering::Relaxed);
}
let month_start = self.month_start_ms.load(Ordering::Relaxed);
if now_ms - month_start >= month_ms {
self.month_start_ms.store(now_ms, Ordering::Relaxed);
self.monthly_cost.store(0, Ordering::Relaxed);
}
}
fn add_cost(&self, cost_usd: f64) {
let microdollars = (cost_usd * 1_000_000.0) as u64;
self.daily_cost.fetch_add(microdollars, Ordering::Relaxed);
self.weekly_cost.fetch_add(microdollars, Ordering::Relaxed);
self.monthly_cost.fetch_add(microdollars, Ordering::Relaxed);
}
fn daily_cost_usd(&self) -> f64 {
self.daily_cost.load(Ordering::Relaxed) as f64 / 1_000_000.0
}
fn weekly_cost_usd(&self) -> f64 {
self.weekly_cost.load(Ordering::Relaxed) as f64 / 1_000_000.0
}
fn monthly_cost_usd(&self) -> f64 {
self.monthly_cost.load(Ordering::Relaxed) as f64 / 1_000_000.0
}
}
#[derive(Debug, Clone)]
pub struct RateLimitExceeded {
pub limit_type: RateLimitType,
pub current: u64,
pub limit: u64,
pub retry_after: Option<Duration>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateLimitType {
RequestsPerMinute,
RequestsPerHour,
RequestsPerDay,
TokensPerMinute,
TokensPerHour,
TokensPerDay,
Concurrent,
}
impl std::fmt::Display for RateLimitType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::RequestsPerMinute => write!(f, "requests per minute"),
Self::RequestsPerHour => write!(f, "requests per hour"),
Self::RequestsPerDay => write!(f, "requests per day"),
Self::TokensPerMinute => write!(f, "tokens per minute"),
Self::TokensPerHour => write!(f, "tokens per hour"),
Self::TokensPerDay => write!(f, "tokens per day"),
Self::Concurrent => write!(f, "concurrent requests"),
}
}
}
#[derive(Debug, Clone)]
pub struct CostLimitExceeded {
pub limit_type: CostLimitType,
pub current_usd: f64,
pub limit_usd: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CostLimitType {
Daily,
Weekly,
Monthly,
}
impl std::fmt::Display for CostLimitType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Daily => write!(f, "daily"),
Self::Weekly => write!(f, "weekly"),
Self::Monthly => write!(f, "monthly"),
}
}
}
#[derive(Debug)]
pub enum TenantError {
Inactive,
ModelNotAllowed(String),
RateLimitExceeded(RateLimitExceeded),
CostLimitExceeded(CostLimitExceeded),
}
impl std::fmt::Display for TenantError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Inactive => write!(f, "Tenant is inactive"),
Self::ModelNotAllowed(model) => write!(f, "Model '{}' is not allowed", model),
Self::RateLimitExceeded(info) => {
write!(
f,
"Rate limit exceeded: {} ({}/{})",
info.limit_type, info.current, info.limit
)
}
Self::CostLimitExceeded(info) => {
write!(
f,
"Cost limit exceeded: {} (${:.2}/${:.2})",
info.limit_type, info.current_usd, info.limit_usd
)
}
}
}
}
impl std::error::Error for TenantError {}
pub struct TenantProvider<P: Provider> {
inner: P,
config: TenantConfig,
rate_state: RateLimiterState,
cost_state: CostTrackerState,
}
impl<P: Provider> TenantProvider<P> {
pub fn new(inner: P, config: TenantConfig) -> Self {
Self {
inner,
config,
rate_state: RateLimiterState::default(),
cost_state: CostTrackerState::new(),
}
}
pub fn tenant_id(&self) -> &TenantId {
&self.config.id
}
pub fn config(&self) -> &TenantConfig {
&self.config
}
fn check_request(&self, model: &str) -> std::result::Result<(), TenantError> {
if !self.config.active {
return Err(TenantError::Inactive);
}
if !self.config.is_model_allowed(model) {
return Err(TenantError::ModelNotAllowed(model.to_string()));
}
self.rate_state.reset_if_needed();
if let Some(ref limits) = self.config.rate_limit {
if let Some(max_concurrent) = limits.max_concurrent {
let current = self.rate_state.concurrent.load(Ordering::Relaxed);
if current >= max_concurrent {
return Err(TenantError::RateLimitExceeded(RateLimitExceeded {
limit_type: RateLimitType::Concurrent,
current: current as u64,
limit: max_concurrent as u64,
retry_after: None,
}));
}
}
if let Some(rpm) = limits.requests_per_minute {
let current = self.rate_state.requests_minute.load(Ordering::Relaxed);
if current >= rpm {
return Err(TenantError::RateLimitExceeded(RateLimitExceeded {
limit_type: RateLimitType::RequestsPerMinute,
current: current as u64,
limit: rpm as u64,
retry_after: Some(Duration::from_secs(60)),
}));
}
}
if let Some(rph) = limits.requests_per_hour {
let current = self.rate_state.requests_hour.load(Ordering::Relaxed);
if current >= rph {
return Err(TenantError::RateLimitExceeded(RateLimitExceeded {
limit_type: RateLimitType::RequestsPerHour,
current: current as u64,
limit: rph as u64,
retry_after: Some(Duration::from_secs(3600)),
}));
}
}
if let Some(rpd) = limits.requests_per_day {
let current = self.rate_state.requests_day.load(Ordering::Relaxed);
if current >= rpd {
return Err(TenantError::RateLimitExceeded(RateLimitExceeded {
limit_type: RateLimitType::RequestsPerDay,
current: current as u64,
limit: rpd as u64,
retry_after: Some(Duration::from_secs(86400)),
}));
}
}
}
self.cost_state.reset_if_needed();
if let Some(ref limits) = self.config.cost_limit {
if let Some(daily) = limits.daily_limit_usd {
let current = self.cost_state.daily_cost_usd();
if current >= daily {
return Err(TenantError::CostLimitExceeded(CostLimitExceeded {
limit_type: CostLimitType::Daily,
current_usd: current,
limit_usd: daily,
}));
}
}
if let Some(weekly) = limits.weekly_limit_usd {
let current = self.cost_state.weekly_cost_usd();
if current >= weekly {
return Err(TenantError::CostLimitExceeded(CostLimitExceeded {
limit_type: CostLimitType::Weekly,
current_usd: current,
limit_usd: weekly,
}));
}
}
if let Some(monthly) = limits.monthly_limit_usd {
let current = self.cost_state.monthly_cost_usd();
if current >= monthly {
return Err(TenantError::CostLimitExceeded(CostLimitExceeded {
limit_type: CostLimitType::Monthly,
current_usd: current,
limit_usd: monthly,
}));
}
}
}
Ok(())
}
fn record_request(&self, tokens: u64, cost_usd: f64) {
self.rate_state
.requests_minute
.fetch_add(1, Ordering::Relaxed);
self.rate_state
.requests_hour
.fetch_add(1, Ordering::Relaxed);
self.rate_state.requests_day.fetch_add(1, Ordering::Relaxed);
self.rate_state
.tokens_minute
.fetch_add(tokens, Ordering::Relaxed);
self.rate_state
.tokens_hour
.fetch_add(tokens, Ordering::Relaxed);
self.rate_state
.tokens_day
.fetch_add(tokens, Ordering::Relaxed);
self.cost_state.add_cost(cost_usd);
}
fn start_request(&self) {
self.rate_state.concurrent.fetch_add(1, Ordering::Relaxed);
}
fn end_request(&self) {
self.rate_state.concurrent.fetch_sub(1, Ordering::Relaxed);
}
pub fn usage_stats(&self) -> TenantUsageStats {
self.rate_state.reset_if_needed();
self.cost_state.reset_if_needed();
TenantUsageStats {
requests_minute: self.rate_state.requests_minute.load(Ordering::Relaxed),
requests_hour: self.rate_state.requests_hour.load(Ordering::Relaxed),
requests_day: self.rate_state.requests_day.load(Ordering::Relaxed),
tokens_minute: self.rate_state.tokens_minute.load(Ordering::Relaxed),
tokens_hour: self.rate_state.tokens_hour.load(Ordering::Relaxed),
tokens_day: self.rate_state.tokens_day.load(Ordering::Relaxed),
concurrent: self.rate_state.concurrent.load(Ordering::Relaxed),
daily_cost_usd: self.cost_state.daily_cost_usd(),
weekly_cost_usd: self.cost_state.weekly_cost_usd(),
monthly_cost_usd: self.cost_state.monthly_cost_usd(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TenantUsageStats {
pub requests_minute: u32,
pub requests_hour: u32,
pub requests_day: u32,
pub tokens_minute: u64,
pub tokens_hour: u64,
pub tokens_day: u64,
pub concurrent: u32,
pub daily_cost_usd: f64,
pub weekly_cost_usd: f64,
pub monthly_cost_usd: f64,
}
#[async_trait]
impl<P: Provider> Provider for TenantProvider<P> {
fn name(&self) -> &str {
self.inner.name()
}
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
self.check_request(&request.model)
.map_err(|e| Error::other(e.to_string()))?;
self.start_request();
let result = self.inner.complete(request).await;
self.end_request();
if let Ok(ref response) = result {
let tokens = (response.usage.input_tokens + response.usage.output_tokens) as u64;
let cost_usd = tokens as f64 * 0.000001; self.record_request(tokens, cost_usd);
}
result
}
async fn complete_stream(
&self,
request: CompletionRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>> {
self.check_request(&request.model)
.map_err(|e| Error::other(e.to_string()))?;
self.start_request();
self.inner.complete_stream(request).await
}
fn supports_tools(&self) -> bool {
self.inner.supports_tools()
}
fn supports_vision(&self) -> bool {
self.inner.supports_vision()
}
fn supports_streaming(&self) -> bool {
self.inner.supports_streaming()
}
fn supports_token_counting(&self) -> bool {
self.inner.supports_token_counting()
}
async fn count_tokens(&self, request: TokenCountRequest) -> Result<TokenCountResult> {
self.inner.count_tokens(request).await
}
fn supports_batch(&self) -> bool {
self.inner.supports_batch()
}
async fn create_batch(&self, requests: Vec<BatchRequest>) -> Result<BatchJob> {
self.inner.create_batch(requests).await
}
async fn get_batch(&self, batch_id: &str) -> Result<BatchJob> {
self.inner.get_batch(batch_id).await
}
async fn get_batch_results(&self, batch_id: &str) -> Result<Vec<BatchResult>> {
self.inner.get_batch_results(batch_id).await
}
async fn cancel_batch(&self, batch_id: &str) -> Result<BatchJob> {
self.inner.cancel_batch(batch_id).await
}
async fn list_batches(&self, limit: Option<u32>) -> Result<Vec<BatchJob>> {
self.inner.list_batches(limit).await
}
}
pub struct TenantManager {
tenants: RwLock<HashMap<TenantId, TenantConfig>>,
}
impl Default for TenantManager {
fn default() -> Self {
Self::new()
}
}
impl TenantManager {
pub fn new() -> Self {
Self {
tenants: RwLock::new(HashMap::new()),
}
}
pub fn register(&self, config: TenantConfig) {
self.tenants.write().insert(config.id.clone(), config);
}
pub fn get(&self, id: &TenantId) -> Option<TenantConfig> {
self.tenants.read().get(id).cloned()
}
pub fn remove(&self, id: &TenantId) -> Option<TenantConfig> {
self.tenants.write().remove(id)
}
pub fn list(&self) -> Vec<TenantId> {
self.tenants.read().keys().cloned().collect()
}
pub fn exists(&self, id: &TenantId) -> bool {
self.tenants.read().contains_key(id)
}
pub fn update(&self, config: TenantConfig) -> bool {
let mut tenants = self.tenants.write();
if tenants.contains_key(&config.id) {
tenants.insert(config.id.clone(), config);
true
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tenant_id() {
let id = TenantId::new("test-tenant");
assert_eq!(id.as_str(), "test-tenant");
assert_eq!(id.to_string(), "test-tenant");
}
#[test]
fn test_tenant_config_allowed_models() {
let config = TenantConfig::new("test")
.with_allowed_models(vec!["gpt-4o", "claude-sonnet-4-20250514"]);
assert!(config.is_model_allowed("gpt-4o"));
assert!(config.is_model_allowed("claude-sonnet-4-20250514"));
assert!(!config.is_model_allowed("gpt-3.5-turbo"));
}
#[test]
fn test_tenant_config_blocked_models() {
let config = TenantConfig::new("test").block_model("gpt-3.5-turbo");
assert!(config.is_model_allowed("gpt-4o"));
assert!(!config.is_model_allowed("gpt-3.5-turbo"));
}
#[test]
fn test_rate_limit_config() {
let config = RateLimitConfig::basic(60, 100_000).with_max_concurrent(10);
assert_eq!(config.requests_per_minute, Some(60));
assert_eq!(config.tokens_per_minute, Some(100_000));
assert_eq!(config.max_concurrent, Some(10));
}
#[test]
fn test_cost_limit_config() {
let config = CostLimitConfig::basic(100.0, 1000.0).with_alert_threshold(0.8);
assert_eq!(config.daily_limit_usd, Some(100.0));
assert_eq!(config.monthly_limit_usd, Some(1000.0));
assert_eq!(config.alert_threshold, Some(0.8));
}
#[test]
fn test_tenant_manager() {
let manager = TenantManager::new();
let config = TenantConfig::new("acme");
manager.register(config);
assert!(manager.exists(&TenantId::new("acme")));
assert!(!manager.exists(&TenantId::new("other")));
let ids = manager.list();
assert_eq!(ids.len(), 1);
let removed = manager.remove(&TenantId::new("acme"));
assert!(removed.is_some());
assert!(!manager.exists(&TenantId::new("acme")));
}
}