llmkit/
tenant.rs

1//! Tenant context and multi-tenancy support for LLM requests.
2//!
3//! This module provides tenant isolation, rate limiting, and cost controls
4//! for multi-tenant LLM applications.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use llmkit::{TenantConfig, TenantProvider, RateLimitConfig, CostLimitConfig};
10//!
11//! let config = TenantConfig::new("acme-corp")
12//!     .with_allowed_models(vec!["gpt-4o", "claude-sonnet-4-20250514"])
13//!     .with_rate_limit(RateLimitConfig {
14//!         requests_per_minute: Some(60),
15//!         tokens_per_minute: Some(100_000),
16//!         ..Default::default()
17//!     })
18//!     .with_cost_limit(CostLimitConfig {
19//!         daily_limit_usd: Some(100.0),
20//!         monthly_limit_usd: Some(1000.0),
21//!         ..Default::default()
22//!     });
23//!
24//! let provider = TenantProvider::new(inner_provider, config);
25//! ```
26
27use std::collections::{HashMap, HashSet};
28use std::pin::Pin;
29use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
30use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
31
32use async_trait::async_trait;
33use futures::Stream;
34use parking_lot::RwLock;
35use serde::{Deserialize, Serialize};
36
37use crate::error::{Error, Result};
38use crate::provider::Provider;
39use crate::types::{
40    BatchJob, BatchRequest, BatchResult, CompletionRequest, CompletionResponse, StreamChunk,
41    TokenCountRequest, TokenCountResult,
42};
43
44/// Unique identifier for a tenant.
45#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
46pub struct TenantId(String);
47
48impl TenantId {
49    /// Create a new tenant ID.
50    pub fn new(id: impl Into<String>) -> Self {
51        Self(id.into())
52    }
53
54    /// Get the ID as a string slice.
55    pub fn as_str(&self) -> &str {
56        &self.0
57    }
58}
59
60impl std::fmt::Display for TenantId {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        write!(f, "{}", self.0)
63    }
64}
65
66impl From<String> for TenantId {
67    fn from(s: String) -> Self {
68        Self(s)
69    }
70}
71
72impl From<&str> for TenantId {
73    fn from(s: &str) -> Self {
74        Self(s.to_string())
75    }
76}
77
78/// Rate limiting configuration.
79#[derive(Debug, Clone, Default, Serialize, Deserialize)]
80pub struct RateLimitConfig {
81    /// Maximum requests per minute
82    pub requests_per_minute: Option<u32>,
83    /// Maximum requests per hour
84    pub requests_per_hour: Option<u32>,
85    /// Maximum requests per day
86    pub requests_per_day: Option<u32>,
87    /// Maximum tokens per minute (input + output)
88    pub tokens_per_minute: Option<u64>,
89    /// Maximum tokens per hour
90    pub tokens_per_hour: Option<u64>,
91    /// Maximum tokens per day
92    pub tokens_per_day: Option<u64>,
93    /// Maximum concurrent requests
94    pub max_concurrent: Option<u32>,
95}
96
97impl RateLimitConfig {
98    /// Create a basic rate limit config.
99    pub fn basic(requests_per_minute: u32, tokens_per_minute: u64) -> Self {
100        Self {
101            requests_per_minute: Some(requests_per_minute),
102            tokens_per_minute: Some(tokens_per_minute),
103            ..Default::default()
104        }
105    }
106
107    /// Set the max concurrent requests.
108    pub fn with_max_concurrent(mut self, max: u32) -> Self {
109        self.max_concurrent = Some(max);
110        self
111    }
112}
113
114/// Cost limiting configuration.
115#[derive(Debug, Clone, Default, Serialize, Deserialize)]
116pub struct CostLimitConfig {
117    /// Maximum cost per request in USD
118    pub per_request_limit_usd: Option<f64>,
119    /// Maximum daily cost in USD
120    pub daily_limit_usd: Option<f64>,
121    /// Maximum weekly cost in USD
122    pub weekly_limit_usd: Option<f64>,
123    /// Maximum monthly cost in USD
124    pub monthly_limit_usd: Option<f64>,
125    /// Alert threshold as percentage of limit (0.0 to 1.0)
126    pub alert_threshold: Option<f64>,
127}
128
129impl CostLimitConfig {
130    /// Create a basic cost limit config.
131    pub fn basic(daily_limit: f64, monthly_limit: f64) -> Self {
132        Self {
133            daily_limit_usd: Some(daily_limit),
134            monthly_limit_usd: Some(monthly_limit),
135            ..Default::default()
136        }
137    }
138
139    /// Set alert threshold.
140    pub fn with_alert_threshold(mut self, threshold: f64) -> Self {
141        self.alert_threshold = Some(threshold.clamp(0.0, 1.0));
142        self
143    }
144}
145
146/// Tenant configuration.
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct TenantConfig {
149    /// Unique tenant identifier
150    pub id: TenantId,
151    /// Display name
152    pub name: Option<String>,
153    /// Allowed models (empty = all allowed)
154    pub allowed_models: HashSet<String>,
155    /// Blocked models
156    pub blocked_models: HashSet<String>,
157    /// Rate limit configuration
158    pub rate_limit: Option<RateLimitConfig>,
159    /// Cost limit configuration
160    pub cost_limit: Option<CostLimitConfig>,
161    /// Whether the tenant is active
162    pub active: bool,
163    /// Custom metadata
164    pub metadata: HashMap<String, String>,
165}
166
167impl TenantConfig {
168    /// Create a new tenant config.
169    pub fn new(id: impl Into<TenantId>) -> Self {
170        Self {
171            id: id.into(),
172            name: None,
173            allowed_models: HashSet::new(),
174            blocked_models: HashSet::new(),
175            rate_limit: None,
176            cost_limit: None,
177            active: true,
178            metadata: HashMap::new(),
179        }
180    }
181
182    /// Set the display name.
183    pub fn with_name(mut self, name: impl Into<String>) -> Self {
184        self.name = Some(name.into());
185        self
186    }
187
188    /// Set allowed models.
189    pub fn with_allowed_models<I, S>(mut self, models: I) -> Self
190    where
191        I: IntoIterator<Item = S>,
192        S: Into<String>,
193    {
194        self.allowed_models = models.into_iter().map(Into::into).collect();
195        self
196    }
197
198    /// Add an allowed model.
199    pub fn allow_model(mut self, model: impl Into<String>) -> Self {
200        self.allowed_models.insert(model.into());
201        self
202    }
203
204    /// Block a model.
205    pub fn block_model(mut self, model: impl Into<String>) -> Self {
206        self.blocked_models.insert(model.into());
207        self
208    }
209
210    /// Set rate limit configuration.
211    pub fn with_rate_limit(mut self, config: RateLimitConfig) -> Self {
212        self.rate_limit = Some(config);
213        self
214    }
215
216    /// Set cost limit configuration.
217    pub fn with_cost_limit(mut self, config: CostLimitConfig) -> Self {
218        self.cost_limit = Some(config);
219        self
220    }
221
222    /// Add metadata.
223    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
224        self.metadata.insert(key.into(), value.into());
225        self
226    }
227
228    /// Check if a model is allowed for this tenant.
229    pub fn is_model_allowed(&self, model: &str) -> bool {
230        // If blocked, not allowed
231        if self.blocked_models.contains(model) {
232            return false;
233        }
234
235        // If allowed list is empty, all non-blocked models are allowed
236        if self.allowed_models.is_empty() {
237            return true;
238        }
239
240        // Otherwise, must be in allowed list
241        self.allowed_models.contains(model)
242    }
243}
244
245/// Rate limiter state.
246#[derive(Debug)]
247struct RateLimiterState {
248    /// Request counts per window
249    requests_minute: AtomicU32,
250    requests_hour: AtomicU32,
251    requests_day: AtomicU32,
252    /// Token counts per window
253    tokens_minute: AtomicU64,
254    tokens_hour: AtomicU64,
255    tokens_day: AtomicU64,
256    /// Current concurrent requests
257    concurrent: AtomicU32,
258    /// Window start times
259    minute_start: RwLock<Instant>,
260    hour_start: RwLock<Instant>,
261    day_start: RwLock<Instant>,
262}
263
264impl Default for RateLimiterState {
265    fn default() -> Self {
266        let now = Instant::now();
267        Self {
268            requests_minute: AtomicU32::new(0),
269            requests_hour: AtomicU32::new(0),
270            requests_day: AtomicU32::new(0),
271            tokens_minute: AtomicU64::new(0),
272            tokens_hour: AtomicU64::new(0),
273            tokens_day: AtomicU64::new(0),
274            concurrent: AtomicU32::new(0),
275            minute_start: RwLock::new(now),
276            hour_start: RwLock::new(now),
277            day_start: RwLock::new(now),
278        }
279    }
280}
281
282impl RateLimiterState {
283    fn reset_if_needed(&self) {
284        let now = Instant::now();
285
286        // Check minute window
287        {
288            let mut minute_start = self.minute_start.write();
289            if now.duration_since(*minute_start) >= Duration::from_secs(60) {
290                *minute_start = now;
291                self.requests_minute.store(0, Ordering::Relaxed);
292                self.tokens_minute.store(0, Ordering::Relaxed);
293            }
294        }
295
296        // Check hour window
297        {
298            let mut hour_start = self.hour_start.write();
299            if now.duration_since(*hour_start) >= Duration::from_secs(3600) {
300                *hour_start = now;
301                self.requests_hour.store(0, Ordering::Relaxed);
302                self.tokens_hour.store(0, Ordering::Relaxed);
303            }
304        }
305
306        // Check day window
307        {
308            let mut day_start = self.day_start.write();
309            if now.duration_since(*day_start) >= Duration::from_secs(86400) {
310                *day_start = now;
311                self.requests_day.store(0, Ordering::Relaxed);
312                self.tokens_day.store(0, Ordering::Relaxed);
313            }
314        }
315    }
316}
317
318/// Cost tracker state.
319#[derive(Debug, Default)]
320struct CostTrackerState {
321    /// Cost in microdollars
322    daily_cost: AtomicU64,
323    weekly_cost: AtomicU64,
324    monthly_cost: AtomicU64,
325    /// Window start times (Unix timestamp ms)
326    day_start_ms: AtomicU64,
327    week_start_ms: AtomicU64,
328    month_start_ms: AtomicU64,
329}
330
331impl CostTrackerState {
332    fn new() -> Self {
333        let now_ms = SystemTime::now()
334            .duration_since(UNIX_EPOCH)
335            .unwrap_or_default()
336            .as_millis() as u64;
337
338        Self {
339            daily_cost: AtomicU64::new(0),
340            weekly_cost: AtomicU64::new(0),
341            monthly_cost: AtomicU64::new(0),
342            day_start_ms: AtomicU64::new(now_ms),
343            week_start_ms: AtomicU64::new(now_ms),
344            month_start_ms: AtomicU64::new(now_ms),
345        }
346    }
347
348    fn reset_if_needed(&self) {
349        let now_ms = SystemTime::now()
350            .duration_since(UNIX_EPOCH)
351            .unwrap_or_default()
352            .as_millis() as u64;
353
354        let day_ms = 86400 * 1000;
355        let week_ms = 7 * day_ms;
356        let month_ms = 30 * day_ms;
357
358        let day_start = self.day_start_ms.load(Ordering::Relaxed);
359        if now_ms - day_start >= day_ms {
360            self.day_start_ms.store(now_ms, Ordering::Relaxed);
361            self.daily_cost.store(0, Ordering::Relaxed);
362        }
363
364        let week_start = self.week_start_ms.load(Ordering::Relaxed);
365        if now_ms - week_start >= week_ms {
366            self.week_start_ms.store(now_ms, Ordering::Relaxed);
367            self.weekly_cost.store(0, Ordering::Relaxed);
368        }
369
370        let month_start = self.month_start_ms.load(Ordering::Relaxed);
371        if now_ms - month_start >= month_ms {
372            self.month_start_ms.store(now_ms, Ordering::Relaxed);
373            self.monthly_cost.store(0, Ordering::Relaxed);
374        }
375    }
376
377    fn add_cost(&self, cost_usd: f64) {
378        let microdollars = (cost_usd * 1_000_000.0) as u64;
379        self.daily_cost.fetch_add(microdollars, Ordering::Relaxed);
380        self.weekly_cost.fetch_add(microdollars, Ordering::Relaxed);
381        self.monthly_cost.fetch_add(microdollars, Ordering::Relaxed);
382    }
383
384    fn daily_cost_usd(&self) -> f64 {
385        self.daily_cost.load(Ordering::Relaxed) as f64 / 1_000_000.0
386    }
387
388    fn weekly_cost_usd(&self) -> f64 {
389        self.weekly_cost.load(Ordering::Relaxed) as f64 / 1_000_000.0
390    }
391
392    fn monthly_cost_usd(&self) -> f64 {
393        self.monthly_cost.load(Ordering::Relaxed) as f64 / 1_000_000.0
394    }
395}
396
397/// Error when rate limit is exceeded.
398#[derive(Debug, Clone)]
399pub struct RateLimitExceeded {
400    /// Type of limit exceeded
401    pub limit_type: RateLimitType,
402    /// Current value
403    pub current: u64,
404    /// Maximum allowed
405    pub limit: u64,
406    /// Time until reset (if known)
407    pub retry_after: Option<Duration>,
408}
409
410/// Type of rate limit that was exceeded.
411#[derive(Debug, Clone, Copy, PartialEq, Eq)]
412pub enum RateLimitType {
413    RequestsPerMinute,
414    RequestsPerHour,
415    RequestsPerDay,
416    TokensPerMinute,
417    TokensPerHour,
418    TokensPerDay,
419    Concurrent,
420}
421
422impl std::fmt::Display for RateLimitType {
423    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424        match self {
425            Self::RequestsPerMinute => write!(f, "requests per minute"),
426            Self::RequestsPerHour => write!(f, "requests per hour"),
427            Self::RequestsPerDay => write!(f, "requests per day"),
428            Self::TokensPerMinute => write!(f, "tokens per minute"),
429            Self::TokensPerHour => write!(f, "tokens per hour"),
430            Self::TokensPerDay => write!(f, "tokens per day"),
431            Self::Concurrent => write!(f, "concurrent requests"),
432        }
433    }
434}
435
436/// Error when cost limit is exceeded.
437#[derive(Debug, Clone)]
438pub struct CostLimitExceeded {
439    /// Type of limit exceeded
440    pub limit_type: CostLimitType,
441    /// Current cost in USD
442    pub current_usd: f64,
443    /// Maximum allowed in USD
444    pub limit_usd: f64,
445}
446
447/// Type of cost limit that was exceeded.
448#[derive(Debug, Clone, Copy, PartialEq, Eq)]
449pub enum CostLimitType {
450    Daily,
451    Weekly,
452    Monthly,
453}
454
455impl std::fmt::Display for CostLimitType {
456    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457        match self {
458            Self::Daily => write!(f, "daily"),
459            Self::Weekly => write!(f, "weekly"),
460            Self::Monthly => write!(f, "monthly"),
461        }
462    }
463}
464
465/// Tenant error types.
466#[derive(Debug)]
467pub enum TenantError {
468    /// Tenant is not active
469    Inactive,
470    /// Model is not allowed
471    ModelNotAllowed(String),
472    /// Rate limit exceeded
473    RateLimitExceeded(RateLimitExceeded),
474    /// Cost limit exceeded
475    CostLimitExceeded(CostLimitExceeded),
476}
477
478impl std::fmt::Display for TenantError {
479    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
480        match self {
481            Self::Inactive => write!(f, "Tenant is inactive"),
482            Self::ModelNotAllowed(model) => write!(f, "Model '{}' is not allowed", model),
483            Self::RateLimitExceeded(info) => {
484                write!(
485                    f,
486                    "Rate limit exceeded: {} ({}/{})",
487                    info.limit_type, info.current, info.limit
488                )
489            }
490            Self::CostLimitExceeded(info) => {
491                write!(
492                    f,
493                    "Cost limit exceeded: {} (${:.2}/${:.2})",
494                    info.limit_type, info.current_usd, info.limit_usd
495                )
496            }
497        }
498    }
499}
500
501impl std::error::Error for TenantError {}
502
503/// Provider wrapper that enforces tenant restrictions.
504pub struct TenantProvider<P: Provider> {
505    inner: P,
506    config: TenantConfig,
507    rate_state: RateLimiterState,
508    cost_state: CostTrackerState,
509}
510
511impl<P: Provider> TenantProvider<P> {
512    /// Create a new tenant provider.
513    pub fn new(inner: P, config: TenantConfig) -> Self {
514        Self {
515            inner,
516            config,
517            rate_state: RateLimiterState::default(),
518            cost_state: CostTrackerState::new(),
519        }
520    }
521
522    /// Get the tenant ID.
523    pub fn tenant_id(&self) -> &TenantId {
524        &self.config.id
525    }
526
527    /// Get the tenant config.
528    pub fn config(&self) -> &TenantConfig {
529        &self.config
530    }
531
532    /// Check if a request is allowed.
533    fn check_request(&self, model: &str) -> std::result::Result<(), TenantError> {
534        // Check if tenant is active
535        if !self.config.active {
536            return Err(TenantError::Inactive);
537        }
538
539        // Check if model is allowed
540        if !self.config.is_model_allowed(model) {
541            return Err(TenantError::ModelNotAllowed(model.to_string()));
542        }
543
544        // Reset rate limiter windows if needed
545        self.rate_state.reset_if_needed();
546
547        // Check rate limits
548        if let Some(ref limits) = self.config.rate_limit {
549            // Check concurrent
550            if let Some(max_concurrent) = limits.max_concurrent {
551                let current = self.rate_state.concurrent.load(Ordering::Relaxed);
552                if current >= max_concurrent {
553                    return Err(TenantError::RateLimitExceeded(RateLimitExceeded {
554                        limit_type: RateLimitType::Concurrent,
555                        current: current as u64,
556                        limit: max_concurrent as u64,
557                        retry_after: None,
558                    }));
559                }
560            }
561
562            // Check requests per minute
563            if let Some(rpm) = limits.requests_per_minute {
564                let current = self.rate_state.requests_minute.load(Ordering::Relaxed);
565                if current >= rpm {
566                    return Err(TenantError::RateLimitExceeded(RateLimitExceeded {
567                        limit_type: RateLimitType::RequestsPerMinute,
568                        current: current as u64,
569                        limit: rpm as u64,
570                        retry_after: Some(Duration::from_secs(60)),
571                    }));
572                }
573            }
574
575            // Check requests per hour
576            if let Some(rph) = limits.requests_per_hour {
577                let current = self.rate_state.requests_hour.load(Ordering::Relaxed);
578                if current >= rph {
579                    return Err(TenantError::RateLimitExceeded(RateLimitExceeded {
580                        limit_type: RateLimitType::RequestsPerHour,
581                        current: current as u64,
582                        limit: rph as u64,
583                        retry_after: Some(Duration::from_secs(3600)),
584                    }));
585                }
586            }
587
588            // Check requests per day
589            if let Some(rpd) = limits.requests_per_day {
590                let current = self.rate_state.requests_day.load(Ordering::Relaxed);
591                if current >= rpd {
592                    return Err(TenantError::RateLimitExceeded(RateLimitExceeded {
593                        limit_type: RateLimitType::RequestsPerDay,
594                        current: current as u64,
595                        limit: rpd as u64,
596                        retry_after: Some(Duration::from_secs(86400)),
597                    }));
598                }
599            }
600        }
601
602        // Reset cost windows if needed
603        self.cost_state.reset_if_needed();
604
605        // Check cost limits
606        if let Some(ref limits) = self.config.cost_limit {
607            if let Some(daily) = limits.daily_limit_usd {
608                let current = self.cost_state.daily_cost_usd();
609                if current >= daily {
610                    return Err(TenantError::CostLimitExceeded(CostLimitExceeded {
611                        limit_type: CostLimitType::Daily,
612                        current_usd: current,
613                        limit_usd: daily,
614                    }));
615                }
616            }
617
618            if let Some(weekly) = limits.weekly_limit_usd {
619                let current = self.cost_state.weekly_cost_usd();
620                if current >= weekly {
621                    return Err(TenantError::CostLimitExceeded(CostLimitExceeded {
622                        limit_type: CostLimitType::Weekly,
623                        current_usd: current,
624                        limit_usd: weekly,
625                    }));
626                }
627            }
628
629            if let Some(monthly) = limits.monthly_limit_usd {
630                let current = self.cost_state.monthly_cost_usd();
631                if current >= monthly {
632                    return Err(TenantError::CostLimitExceeded(CostLimitExceeded {
633                        limit_type: CostLimitType::Monthly,
634                        current_usd: current,
635                        limit_usd: monthly,
636                    }));
637                }
638            }
639        }
640
641        Ok(())
642    }
643
644    /// Record a completed request.
645    fn record_request(&self, tokens: u64, cost_usd: f64) {
646        // Increment request counters
647        self.rate_state
648            .requests_minute
649            .fetch_add(1, Ordering::Relaxed);
650        self.rate_state
651            .requests_hour
652            .fetch_add(1, Ordering::Relaxed);
653        self.rate_state.requests_day.fetch_add(1, Ordering::Relaxed);
654
655        // Increment token counters
656        self.rate_state
657            .tokens_minute
658            .fetch_add(tokens, Ordering::Relaxed);
659        self.rate_state
660            .tokens_hour
661            .fetch_add(tokens, Ordering::Relaxed);
662        self.rate_state
663            .tokens_day
664            .fetch_add(tokens, Ordering::Relaxed);
665
666        // Record cost
667        self.cost_state.add_cost(cost_usd);
668    }
669
670    /// Start a request (increment concurrent counter).
671    fn start_request(&self) {
672        self.rate_state.concurrent.fetch_add(1, Ordering::Relaxed);
673    }
674
675    /// End a request (decrement concurrent counter).
676    fn end_request(&self) {
677        self.rate_state.concurrent.fetch_sub(1, Ordering::Relaxed);
678    }
679
680    /// Get current usage stats.
681    pub fn usage_stats(&self) -> TenantUsageStats {
682        self.rate_state.reset_if_needed();
683        self.cost_state.reset_if_needed();
684
685        TenantUsageStats {
686            requests_minute: self.rate_state.requests_minute.load(Ordering::Relaxed),
687            requests_hour: self.rate_state.requests_hour.load(Ordering::Relaxed),
688            requests_day: self.rate_state.requests_day.load(Ordering::Relaxed),
689            tokens_minute: self.rate_state.tokens_minute.load(Ordering::Relaxed),
690            tokens_hour: self.rate_state.tokens_hour.load(Ordering::Relaxed),
691            tokens_day: self.rate_state.tokens_day.load(Ordering::Relaxed),
692            concurrent: self.rate_state.concurrent.load(Ordering::Relaxed),
693            daily_cost_usd: self.cost_state.daily_cost_usd(),
694            weekly_cost_usd: self.cost_state.weekly_cost_usd(),
695            monthly_cost_usd: self.cost_state.monthly_cost_usd(),
696        }
697    }
698}
699
700/// Current usage statistics for a tenant.
701#[derive(Debug, Clone, Serialize, Deserialize)]
702pub struct TenantUsageStats {
703    /// Requests in current minute
704    pub requests_minute: u32,
705    /// Requests in current hour
706    pub requests_hour: u32,
707    /// Requests in current day
708    pub requests_day: u32,
709    /// Tokens in current minute
710    pub tokens_minute: u64,
711    /// Tokens in current hour
712    pub tokens_hour: u64,
713    /// Tokens in current day
714    pub tokens_day: u64,
715    /// Current concurrent requests
716    pub concurrent: u32,
717    /// Cost today in USD
718    pub daily_cost_usd: f64,
719    /// Cost this week in USD
720    pub weekly_cost_usd: f64,
721    /// Cost this month in USD
722    pub monthly_cost_usd: f64,
723}
724
725#[async_trait]
726impl<P: Provider> Provider for TenantProvider<P> {
727    fn name(&self) -> &str {
728        self.inner.name()
729    }
730
731    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
732        // Check if request is allowed
733        self.check_request(&request.model)
734            .map_err(|e| Error::other(e.to_string()))?;
735
736        self.start_request();
737        let result = self.inner.complete(request).await;
738        self.end_request();
739
740        if let Ok(ref response) = result {
741            let tokens = (response.usage.input_tokens + response.usage.output_tokens) as u64;
742            // Estimate cost - in production, use actual pricing
743            let cost_usd = tokens as f64 * 0.000001; // Placeholder
744            self.record_request(tokens, cost_usd);
745        }
746
747        result
748    }
749
750    async fn complete_stream(
751        &self,
752        request: CompletionRequest,
753    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>> {
754        // Check if request is allowed
755        self.check_request(&request.model)
756            .map_err(|e| Error::other(e.to_string()))?;
757
758        self.start_request();
759        // Note: For proper tracking, we'd need to wrap the stream
760        // to track completion and tokens
761        self.inner.complete_stream(request).await
762    }
763
764    fn supports_tools(&self) -> bool {
765        self.inner.supports_tools()
766    }
767
768    fn supports_vision(&self) -> bool {
769        self.inner.supports_vision()
770    }
771
772    fn supports_streaming(&self) -> bool {
773        self.inner.supports_streaming()
774    }
775
776    fn supports_token_counting(&self) -> bool {
777        self.inner.supports_token_counting()
778    }
779
780    async fn count_tokens(&self, request: TokenCountRequest) -> Result<TokenCountResult> {
781        self.inner.count_tokens(request).await
782    }
783
784    fn supports_batch(&self) -> bool {
785        self.inner.supports_batch()
786    }
787
788    async fn create_batch(&self, requests: Vec<BatchRequest>) -> Result<BatchJob> {
789        self.inner.create_batch(requests).await
790    }
791
792    async fn get_batch(&self, batch_id: &str) -> Result<BatchJob> {
793        self.inner.get_batch(batch_id).await
794    }
795
796    async fn get_batch_results(&self, batch_id: &str) -> Result<Vec<BatchResult>> {
797        self.inner.get_batch_results(batch_id).await
798    }
799
800    async fn cancel_batch(&self, batch_id: &str) -> Result<BatchJob> {
801        self.inner.cancel_batch(batch_id).await
802    }
803
804    async fn list_batches(&self, limit: Option<u32>) -> Result<Vec<BatchJob>> {
805        self.inner.list_batches(limit).await
806    }
807}
808
809/// Manager for multiple tenants.
810pub struct TenantManager {
811    tenants: RwLock<HashMap<TenantId, TenantConfig>>,
812}
813
814impl Default for TenantManager {
815    fn default() -> Self {
816        Self::new()
817    }
818}
819
820impl TenantManager {
821    /// Create a new tenant manager.
822    pub fn new() -> Self {
823        Self {
824            tenants: RwLock::new(HashMap::new()),
825        }
826    }
827
828    /// Register a tenant.
829    pub fn register(&self, config: TenantConfig) {
830        self.tenants.write().insert(config.id.clone(), config);
831    }
832
833    /// Get a tenant config.
834    pub fn get(&self, id: &TenantId) -> Option<TenantConfig> {
835        self.tenants.read().get(id).cloned()
836    }
837
838    /// Remove a tenant.
839    pub fn remove(&self, id: &TenantId) -> Option<TenantConfig> {
840        self.tenants.write().remove(id)
841    }
842
843    /// List all tenant IDs.
844    pub fn list(&self) -> Vec<TenantId> {
845        self.tenants.read().keys().cloned().collect()
846    }
847
848    /// Check if a tenant exists.
849    pub fn exists(&self, id: &TenantId) -> bool {
850        self.tenants.read().contains_key(id)
851    }
852
853    /// Update a tenant config.
854    pub fn update(&self, config: TenantConfig) -> bool {
855        let mut tenants = self.tenants.write();
856        if tenants.contains_key(&config.id) {
857            tenants.insert(config.id.clone(), config);
858            true
859        } else {
860            false
861        }
862    }
863}
864
865#[cfg(test)]
866mod tests {
867    use super::*;
868
869    #[test]
870    fn test_tenant_id() {
871        let id = TenantId::new("test-tenant");
872        assert_eq!(id.as_str(), "test-tenant");
873        assert_eq!(id.to_string(), "test-tenant");
874    }
875
876    #[test]
877    fn test_tenant_config_allowed_models() {
878        let config = TenantConfig::new("test")
879            .with_allowed_models(vec!["gpt-4o", "claude-sonnet-4-20250514"]);
880
881        assert!(config.is_model_allowed("gpt-4o"));
882        assert!(config.is_model_allowed("claude-sonnet-4-20250514"));
883        assert!(!config.is_model_allowed("gpt-3.5-turbo"));
884    }
885
886    #[test]
887    fn test_tenant_config_blocked_models() {
888        let config = TenantConfig::new("test").block_model("gpt-3.5-turbo");
889
890        assert!(config.is_model_allowed("gpt-4o"));
891        assert!(!config.is_model_allowed("gpt-3.5-turbo"));
892    }
893
894    #[test]
895    fn test_rate_limit_config() {
896        let config = RateLimitConfig::basic(60, 100_000).with_max_concurrent(10);
897
898        assert_eq!(config.requests_per_minute, Some(60));
899        assert_eq!(config.tokens_per_minute, Some(100_000));
900        assert_eq!(config.max_concurrent, Some(10));
901    }
902
903    #[test]
904    fn test_cost_limit_config() {
905        let config = CostLimitConfig::basic(100.0, 1000.0).with_alert_threshold(0.8);
906
907        assert_eq!(config.daily_limit_usd, Some(100.0));
908        assert_eq!(config.monthly_limit_usd, Some(1000.0));
909        assert_eq!(config.alert_threshold, Some(0.8));
910    }
911
912    #[test]
913    fn test_tenant_manager() {
914        let manager = TenantManager::new();
915
916        let config = TenantConfig::new("acme");
917        manager.register(config);
918
919        assert!(manager.exists(&TenantId::new("acme")));
920        assert!(!manager.exists(&TenantId::new("other")));
921
922        let ids = manager.list();
923        assert_eq!(ids.len(), 1);
924
925        let removed = manager.remove(&TenantId::new("acme"));
926        assert!(removed.is_some());
927        assert!(!manager.exists(&TenantId::new("acme")));
928    }
929}