oxify_connect_llm/
usage.rs

1//! Usage monitoring and cost tracking for LLM providers
2
3use crate::{
4    EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmError, LlmProvider, LlmRequest,
5    LlmResponse, Result, Usage,
6};
7use async_trait::async_trait;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10
11/// Cost per 1K tokens for different providers/models (in USD cents)
12/// These are approximate values and should be updated as pricing changes
13#[derive(Debug, Clone, Copy)]
14pub struct ModelPricing {
15    /// Cost per 1K input/prompt tokens (in cents)
16    pub input_per_1k: f64,
17    /// Cost per 1K output/completion tokens (in cents)
18    pub output_per_1k: f64,
19}
20
21impl ModelPricing {
22    /// Create custom pricing
23    pub const fn new(input_per_1k: f64, output_per_1k: f64) -> Self {
24        Self {
25            input_per_1k,
26            output_per_1k,
27        }
28    }
29
30    /// GPT-4 (original) pricing
31    pub const GPT4: Self = Self::new(3.0, 6.0); // $0.03/$0.06 per 1K
32
33    /// GPT-4 Turbo pricing (as of 2026)
34    pub const GPT4_TURBO: Self = Self::new(1.0, 3.0); // $0.01/$0.03 per 1K
35
36    /// GPT-4o pricing (as of 2026)
37    pub const GPT4O: Self = Self::new(0.5, 1.5); // $0.005/$0.015 per 1K
38
39    /// GPT-4o-mini pricing
40    pub const GPT4O_MINI: Self = Self::new(0.015, 0.06); // $0.00015/$0.0006 per 1K
41
42    /// o1-preview pricing (reasoning model)
43    pub const O1_PREVIEW: Self = Self::new(1.5, 6.0); // $0.015/$0.06 per 1K
44
45    /// o1-mini pricing (reasoning model)
46    pub const O1_MINI: Self = Self::new(0.3, 1.2); // $0.003/$0.012 per 1K
47
48    /// GPT-3.5 Turbo pricing
49    pub const GPT35_TURBO: Self = Self::new(0.05, 0.15); // $0.0005/$0.0015 per 1K
50
51    /// Claude 3 Opus pricing
52    pub const CLAUDE3_OPUS: Self = Self::new(1.5, 7.5); // $0.015/$0.075 per 1K
53
54    /// Claude 3.5 Sonnet pricing (newer model)
55    pub const CLAUDE35_SONNET: Self = Self::new(0.3, 1.5); // $0.003/$0.015 per 1K
56
57    /// Claude 3 Sonnet pricing (original)
58    pub const CLAUDE3_SONNET: Self = Self::new(0.3, 1.5); // $0.003/$0.015 per 1K
59
60    /// Claude 3.5 Haiku pricing (newer model)
61    pub const CLAUDE35_HAIKU: Self = Self::new(0.08, 0.4); // $0.0008/$0.004 per 1K
62
63    /// Claude 3 Haiku pricing (original)
64    pub const CLAUDE3_HAIKU: Self = Self::new(0.025, 0.125); // $0.00025/$0.00125 per 1K
65
66    /// Gemini Pro pricing
67    pub const GEMINI_PRO: Self = Self::new(0.0125, 0.0375); // Free tier, then $0.000125/$0.000375
68
69    /// Gemini Flash pricing (faster, cheaper)
70    pub const GEMINI_FLASH: Self = Self::new(0.00375, 0.01125); // $0.0000375/$0.0001125 per 1K
71
72    /// Mistral Large pricing
73    pub const MISTRAL_LARGE: Self = Self::new(0.2, 0.6); // $0.002/$0.006 per 1K
74
75    /// Mistral Small pricing
76    pub const MISTRAL_SMALL: Self = Self::new(0.1, 0.3); // $0.001/$0.003 per 1K
77
78    /// Cohere Command-R pricing
79    pub const COHERE_COMMAND_R: Self = Self::new(0.05, 0.15); // $0.0005/$0.0015 per 1K
80
81    /// Cohere Command-R+ pricing (larger model)
82    pub const COHERE_COMMAND_R_PLUS: Self = Self::new(0.3, 1.5); // $0.003/$0.015 per 1K
83
84    /// Free/local models (Ollama, etc.)
85    pub const FREE: Self = Self::new(0.0, 0.0);
86
87    /// OpenAI Ada embedding pricing (text-embedding-ada-002)
88    pub const ADA_EMBEDDING: Self = Self::new(0.01, 0.0); // $0.0001 per 1K tokens
89
90    /// OpenAI text-embedding-3-small pricing
91    pub const TEXT_EMBEDDING_3_SMALL: Self = Self::new(0.002, 0.0); // $0.00002 per 1K tokens
92
93    /// OpenAI text-embedding-3-large pricing
94    pub const TEXT_EMBEDDING_3_LARGE: Self = Self::new(0.013, 0.0); // $0.00013 per 1K tokens
95
96    /// Calculate cost in cents for given token counts
97    pub fn calculate_cost(&self, prompt_tokens: u32, completion_tokens: u32) -> f64 {
98        let input_cost = (prompt_tokens as f64 / 1000.0) * self.input_per_1k;
99        let output_cost = (completion_tokens as f64 / 1000.0) * self.output_per_1k;
100        input_cost + output_cost
101    }
102
103    /// Helper function to get GPT-4 pricing
104    pub const fn gpt4() -> Self {
105        Self::GPT4
106    }
107
108    /// Helper function to get Ada embedding pricing
109    pub const fn ada_embedding() -> Self {
110        Self::ADA_EMBEDDING
111    }
112}
113
114/// Accumulated usage statistics
115#[derive(Debug, Clone)]
116pub struct UsageStats {
117    /// Total prompt/input tokens
118    pub total_prompt_tokens: u64,
119    /// Total completion/output tokens
120    pub total_completion_tokens: u64,
121    /// Total tokens (prompt + completion)
122    pub total_tokens: u64,
123    /// Number of requests
124    pub request_count: u64,
125    /// Estimated cost in cents (if pricing is set)
126    pub estimated_cost_cents: f64,
127}
128
129impl UsageStats {
130    /// Get estimated cost in dollars
131    pub fn estimated_cost_usd(&self) -> f64 {
132        self.estimated_cost_cents / 100.0
133    }
134
135    /// Get average tokens per request
136    pub fn avg_tokens_per_request(&self) -> f64 {
137        if self.request_count == 0 {
138            0.0
139        } else {
140            self.total_tokens as f64 / self.request_count as f64
141        }
142    }
143}
144
145/// Thread-safe usage tracker
146#[derive(Debug)]
147pub struct UsageTracker {
148    prompt_tokens: AtomicU64,
149    completion_tokens: AtomicU64,
150    request_count: AtomicU64,
151    pricing: Option<ModelPricing>,
152}
153
154impl Clone for UsageTracker {
155    fn clone(&self) -> Self {
156        Self {
157            prompt_tokens: AtomicU64::new(self.prompt_tokens.load(Ordering::Relaxed)),
158            completion_tokens: AtomicU64::new(self.completion_tokens.load(Ordering::Relaxed)),
159            request_count: AtomicU64::new(self.request_count.load(Ordering::Relaxed)),
160            pricing: self.pricing,
161        }
162    }
163}
164
165impl Default for UsageTracker {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171impl UsageTracker {
172    /// Create a new usage tracker without pricing
173    pub fn new() -> Self {
174        Self {
175            prompt_tokens: AtomicU64::new(0),
176            completion_tokens: AtomicU64::new(0),
177            request_count: AtomicU64::new(0),
178            pricing: None,
179        }
180    }
181
182    /// Create a new usage tracker with pricing
183    pub fn with_pricing(pricing: ModelPricing) -> Self {
184        Self {
185            prompt_tokens: AtomicU64::new(0),
186            completion_tokens: AtomicU64::new(0),
187            request_count: AtomicU64::new(0),
188            pricing: Some(pricing),
189        }
190    }
191
192    /// Record usage from an LLM response
193    pub fn record(&self, usage: &Usage) {
194        self.prompt_tokens
195            .fetch_add(usage.prompt_tokens as u64, Ordering::Relaxed);
196        self.completion_tokens
197            .fetch_add(usage.completion_tokens as u64, Ordering::Relaxed);
198        self.request_count.fetch_add(1, Ordering::Relaxed);
199    }
200
201    /// Get current usage statistics
202    pub fn stats(&self) -> UsageStats {
203        let prompt = self.prompt_tokens.load(Ordering::Relaxed);
204        let completion = self.completion_tokens.load(Ordering::Relaxed);
205        let count = self.request_count.load(Ordering::Relaxed);
206
207        let cost = self
208            .pricing
209            .map(|p| p.calculate_cost(prompt as u32, completion as u32))
210            .unwrap_or(0.0);
211
212        UsageStats {
213            total_prompt_tokens: prompt,
214            total_completion_tokens: completion,
215            total_tokens: prompt + completion,
216            request_count: count,
217            estimated_cost_cents: cost,
218        }
219    }
220
221    /// Reset all counters
222    pub fn reset(&self) {
223        self.prompt_tokens.store(0, Ordering::Relaxed);
224        self.completion_tokens.store(0, Ordering::Relaxed);
225        self.request_count.store(0, Ordering::Relaxed);
226    }
227}
228
229/// A wrapper that tracks usage for any LLM provider
230pub struct TrackedProvider<P> {
231    inner: P,
232    tracker: Arc<UsageTracker>,
233}
234
235impl<P> TrackedProvider<P> {
236    /// Create a new TrackedProvider without pricing
237    pub fn new(provider: P) -> Self {
238        Self {
239            inner: provider,
240            tracker: Arc::new(UsageTracker::new()),
241        }
242    }
243
244    /// Create a new TrackedProvider with pricing
245    pub fn with_pricing(provider: P, pricing: ModelPricing) -> Self {
246        Self {
247            inner: provider,
248            tracker: Arc::new(UsageTracker::with_pricing(pricing)),
249        }
250    }
251
252    /// Create a new TrackedProvider with a shared tracker
253    pub fn with_tracker(provider: P, tracker: Arc<UsageTracker>) -> Self {
254        Self {
255            inner: provider,
256            tracker,
257        }
258    }
259
260    /// Get a reference to the inner provider
261    pub fn inner(&self) -> &P {
262        &self.inner
263    }
264
265    /// Get a mutable reference to the inner provider
266    pub fn inner_mut(&mut self) -> &mut P {
267        &mut self.inner
268    }
269
270    /// Get a reference to the usage tracker
271    pub fn tracker(&self) -> &Arc<UsageTracker> {
272        &self.tracker
273    }
274
275    /// Get current usage statistics
276    pub fn stats(&self) -> UsageStats {
277        self.tracker.stats()
278    }
279
280    /// Reset usage counters
281    pub fn reset(&self) {
282        self.tracker.reset();
283    }
284}
285
286#[async_trait]
287impl<P: LlmProvider> LlmProvider for TrackedProvider<P> {
288    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
289        let response = self.inner.complete(request).await?;
290
291        // Track usage if available
292        if let Some(usage) = &response.usage {
293            self.tracker.record(usage);
294        }
295
296        Ok(response)
297    }
298}
299
300#[async_trait]
301impl<P: EmbeddingProvider> EmbeddingProvider for TrackedProvider<P> {
302    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
303        let response = self.inner.embed(request).await?;
304
305        // Track usage if available
306        if let Some(usage) = &response.usage {
307            self.tracker.record(&Usage {
308                prompt_tokens: usage.prompt_tokens,
309                completion_tokens: 0,
310                total_tokens: usage.total_tokens,
311            });
312        }
313
314        Ok(response)
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_model_pricing_calculation() {
324        let pricing = ModelPricing::GPT4_TURBO;
325
326        // 1000 input + 500 output tokens
327        let cost = pricing.calculate_cost(1000, 500);
328
329        // 1.0 cents for input + 1.5 cents for output = 2.5 cents
330        assert!((cost - 2.5).abs() < 0.001);
331    }
332
333    #[test]
334    fn test_usage_tracker() {
335        let tracker = UsageTracker::with_pricing(ModelPricing::GPT35_TURBO);
336
337        tracker.record(&Usage {
338            prompt_tokens: 100,
339            completion_tokens: 50,
340            total_tokens: 150,
341        });
342
343        let stats = tracker.stats();
344        assert_eq!(stats.total_prompt_tokens, 100);
345        assert_eq!(stats.total_completion_tokens, 50);
346        assert_eq!(stats.total_tokens, 150);
347        assert_eq!(stats.request_count, 1);
348
349        // Record another request
350        tracker.record(&Usage {
351            prompt_tokens: 200,
352            completion_tokens: 100,
353            total_tokens: 300,
354        });
355
356        let stats = tracker.stats();
357        assert_eq!(stats.total_prompt_tokens, 300);
358        assert_eq!(stats.total_completion_tokens, 150);
359        assert_eq!(stats.total_tokens, 450);
360        assert_eq!(stats.request_count, 2);
361        assert_eq!(stats.avg_tokens_per_request(), 225.0);
362    }
363
364    #[test]
365    fn test_usage_tracker_reset() {
366        let tracker = UsageTracker::new();
367
368        tracker.record(&Usage {
369            prompt_tokens: 100,
370            completion_tokens: 50,
371            total_tokens: 150,
372        });
373
374        assert_eq!(tracker.stats().total_tokens, 150);
375
376        tracker.reset();
377
378        assert_eq!(tracker.stats().total_tokens, 0);
379        assert_eq!(tracker.stats().request_count, 0);
380    }
381
382    #[test]
383    fn test_free_pricing() {
384        let pricing = ModelPricing::FREE;
385        let cost = pricing.calculate_cost(10000, 5000);
386        assert_eq!(cost, 0.0);
387    }
388
389    #[test]
390    fn test_usage_stats_usd() {
391        let tracker = UsageTracker::with_pricing(ModelPricing::new(100.0, 100.0)); // 1 dollar per 1K
392
393        tracker.record(&Usage {
394            prompt_tokens: 1000,
395            completion_tokens: 1000,
396            total_tokens: 2000,
397        });
398
399        let stats = tracker.stats();
400        // 100 cents input + 100 cents output = 200 cents = $2.00
401        assert_eq!(stats.estimated_cost_cents, 200.0);
402        assert_eq!(stats.estimated_cost_usd(), 2.0);
403    }
404
405    #[tokio::test]
406    async fn test_budget_provider_under_limit() {
407        struct MockProvider;
408        #[async_trait]
409        impl LlmProvider for MockProvider {
410            async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
411                Ok(LlmResponse {
412                    content: format!("Response to: {}", request.prompt),
413                    model: "mock".to_string(),
414                    usage: Some(Usage {
415                        prompt_tokens: 100,
416                        completion_tokens: 50,
417                        total_tokens: 150,
418                    }),
419                    tool_calls: Vec::new(),
420                })
421            }
422        }
423
424        let budget = BudgetLimit::new(100.0); // $1.00 budget
425        let provider = BudgetProvider::new(MockProvider, budget, ModelPricing::new(10.0, 10.0));
426
427        let request = LlmRequest {
428            prompt: "test".to_string(),
429            system_prompt: None,
430            temperature: None,
431            max_tokens: None,
432            tools: Vec::new(),
433            images: Vec::new(),
434        };
435
436        let result = provider.complete(request).await;
437        assert!(result.is_ok());
438    }
439
440    #[tokio::test]
441    async fn test_budget_provider_exceeds_limit() {
442        struct MockProvider;
443        #[async_trait]
444        impl LlmProvider for MockProvider {
445            async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
446                Ok(LlmResponse {
447                    content: format!("Response to: {}", request.prompt),
448                    model: "mock".to_string(),
449                    usage: Some(Usage {
450                        prompt_tokens: 10000,
451                        completion_tokens: 5000,
452                        total_tokens: 15000,
453                    }),
454                    tool_calls: Vec::new(),
455                })
456            }
457        }
458
459        let budget = BudgetLimit::new(0.5); // $0.005 budget (very small)
460        let provider = BudgetProvider::new(MockProvider, budget, ModelPricing::new(100.0, 100.0));
461
462        let request = LlmRequest {
463            prompt: "test".to_string(),
464            system_prompt: None,
465            temperature: None,
466            max_tokens: None,
467            tools: Vec::new(),
468            images: Vec::new(),
469        };
470
471        // First request should succeed
472        let result = provider.complete(request.clone()).await;
473        assert!(result.is_ok());
474
475        // Second request should fail due to budget exceeded
476        let result = provider.complete(request).await;
477        assert!(result.is_err());
478        assert!(matches!(result.unwrap_err(), LlmError::ApiError(_)));
479    }
480
481    #[test]
482    fn test_budget_limit_remaining() {
483        let budget = BudgetLimit::new(100.0);
484        assert_eq!(budget.remaining_cents(), 100.0);
485
486        budget.consume(50.0);
487        assert_eq!(budget.remaining_cents(), 50.0);
488
489        budget.reset();
490        assert_eq!(budget.remaining_cents(), 100.0);
491    }
492}
493
494// ===== Budget Limits =====
495
496/// Budget limit configuration
497#[derive(Debug, Clone)]
498pub struct BudgetLimit {
499    /// Maximum budget in cents
500    max_budget_cents: f64,
501    /// Consumed budget tracker (atomic for thread safety)
502    consumed_cents: Arc<AtomicU64>,
503}
504
505impl PartialEq for BudgetLimit {
506    fn eq(&self, other: &Self) -> bool {
507        self.max_budget_cents == other.max_budget_cents
508    }
509}
510
511impl BudgetLimit {
512    /// Create a new budget limit in cents
513    pub fn new(max_budget_cents: f64) -> Self {
514        Self {
515            max_budget_cents,
516            consumed_cents: Arc::new(AtomicU64::new(0)),
517        }
518    }
519
520    /// Create a new budget limit in cents
521    pub fn cents(max_budget_cents: u64) -> Self {
522        Self::new(max_budget_cents as f64)
523    }
524
525    /// Create a new budget limit in dollars
526    pub fn from_usd(max_budget_usd: f64) -> Self {
527        Self::new(max_budget_usd * 100.0)
528    }
529
530    /// Create a new budget limit in dollars (alias for from_usd)
531    pub fn dollars(max_budget_usd: u64) -> Self {
532        Self::from_usd(max_budget_usd as f64)
533    }
534
535    /// Get the maximum budget in cents
536    pub fn as_cents(&self) -> u64 {
537        self.max_budget_cents as u64
538    }
539
540    /// Get the maximum budget in dollars
541    pub fn as_usd(&self) -> f64 {
542        self.max_budget_cents / 100.0
543    }
544
545    /// Check if budget has been exceeded
546    pub fn is_exceeded(&self) -> bool {
547        let consumed = self.consumed_cents();
548        consumed >= self.max_budget_cents
549    }
550
551    /// Get remaining budget in cents
552    pub fn remaining_cents(&self) -> f64 {
553        let consumed = self.consumed_cents();
554        (self.max_budget_cents - consumed).max(0.0)
555    }
556
557    /// Get remaining budget in dollars
558    pub fn remaining_usd(&self) -> f64 {
559        self.remaining_cents() / 100.0
560    }
561
562    /// Get consumed budget in cents
563    pub fn consumed_cents(&self) -> f64 {
564        // Stored as u64 with 2 decimal places (cents as integer)
565        let bits = self.consumed_cents.load(Ordering::Relaxed);
566        f64::from_bits(bits)
567    }
568
569    /// Consume budget (internal use)
570    fn consume(&self, cents: f64) {
571        // Use atomic operations to update consumed budget
572        let current = self.consumed_cents();
573        let new_value = current + cents;
574        self.consumed_cents
575            .store(new_value.to_bits(), Ordering::Relaxed);
576    }
577
578    /// Reset consumed budget
579    pub fn reset(&self) {
580        self.consumed_cents
581            .store(0_f64.to_bits(), Ordering::Relaxed);
582    }
583}
584
585/// A provider that enforces budget limits
586pub struct BudgetProvider<P> {
587    inner: P,
588    budget: BudgetLimit,
589    pricing: ModelPricing,
590}
591
592impl<P> BudgetProvider<P> {
593    /// Create a new budget-limited provider
594    pub fn new(provider: P, budget: BudgetLimit, pricing: ModelPricing) -> Self {
595        Self {
596            inner: provider,
597            budget,
598            pricing,
599        }
600    }
601
602    /// Get a reference to the inner provider
603    pub fn inner(&self) -> &P {
604        &self.inner
605    }
606
607    /// Get remaining budget in cents
608    pub fn remaining_budget_cents(&self) -> f64 {
609        self.budget.remaining_cents()
610    }
611
612    /// Get remaining budget in dollars
613    pub fn remaining_budget_usd(&self) -> f64 {
614        self.budget.remaining_usd()
615    }
616
617    /// Check if budget is exceeded
618    pub fn is_budget_exceeded(&self) -> bool {
619        self.budget.is_exceeded()
620    }
621
622    /// Reset budget (starts from zero again)
623    pub fn reset_budget(&self) {
624        self.budget.reset();
625    }
626}
627
628#[async_trait]
629impl<P: LlmProvider> LlmProvider for BudgetProvider<P> {
630    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
631        // Check budget before making request
632        if self.budget.is_exceeded() {
633            return Err(LlmError::ApiError(format!(
634                "Budget exceeded: ${:.4} spent of ${:.4} limit",
635                self.budget.consumed_cents() / 100.0,
636                self.budget.max_budget_cents / 100.0
637            )));
638        }
639
640        let response = self.inner.complete(request).await?;
641
642        // Track cost after successful response
643        if let Some(usage) = &response.usage {
644            let cost = self
645                .pricing
646                .calculate_cost(usage.prompt_tokens, usage.completion_tokens);
647            self.budget.consume(cost);
648
649            // Log budget status
650            tracing::info!(
651                cost_cents = cost,
652                remaining_cents = self.budget.remaining_cents(),
653                "Request cost tracked against budget"
654            );
655
656            // Warn if budget is getting low
657            let remaining_pct =
658                self.budget.remaining_cents() / self.budget.max_budget_cents * 100.0;
659            if remaining_pct < 10.0 && remaining_pct > 0.0 {
660                tracing::warn!(
661                    remaining_pct = format!("{:.1}%", remaining_pct),
662                    "Budget running low"
663                );
664            }
665        }
666
667        Ok(response)
668    }
669}
670
671#[async_trait]
672impl<P: EmbeddingProvider> EmbeddingProvider for BudgetProvider<P> {
673    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
674        // Check budget before making request
675        if self.budget.is_exceeded() {
676            return Err(LlmError::ApiError(format!(
677                "Budget exceeded: ${:.4} spent of ${:.4} limit",
678                self.budget.consumed_cents() / 100.0,
679                self.budget.max_budget_cents / 100.0
680            )));
681        }
682
683        let response = self.inner.embed(request).await?;
684
685        // Track cost after successful response
686        if let Some(usage) = &response.usage {
687            let cost = self.pricing.calculate_cost(usage.prompt_tokens, 0);
688            self.budget.consume(cost);
689
690            tracing::info!(
691                cost_cents = cost,
692                remaining_cents = self.budget.remaining_cents(),
693                "Embedding cost tracked against budget"
694            );
695        }
696
697        Ok(response)
698    }
699}