1use crate::{
4 EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmError, LlmProvider, LlmRequest,
5 LlmResponse, Result, Usage,
6};
7use async_trait::async_trait;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10
11#[derive(Debug, Clone, Copy)]
14pub struct ModelPricing {
15 pub input_per_1k: f64,
17 pub output_per_1k: f64,
19}
20
21impl ModelPricing {
22 pub const fn new(input_per_1k: f64, output_per_1k: f64) -> Self {
24 Self {
25 input_per_1k,
26 output_per_1k,
27 }
28 }
29
30 pub const GPT4: Self = Self::new(3.0, 6.0); pub const GPT4_TURBO: Self = Self::new(1.0, 3.0); pub const GPT4O: Self = Self::new(0.5, 1.5); pub const GPT4O_MINI: Self = Self::new(0.015, 0.06); pub const O1_PREVIEW: Self = Self::new(1.5, 6.0); pub const O1_MINI: Self = Self::new(0.3, 1.2); pub const GPT35_TURBO: Self = Self::new(0.05, 0.15); pub const CLAUDE3_OPUS: Self = Self::new(1.5, 7.5); pub const CLAUDE35_SONNET: Self = Self::new(0.3, 1.5); pub const CLAUDE3_SONNET: Self = Self::new(0.3, 1.5); pub const CLAUDE35_HAIKU: Self = Self::new(0.08, 0.4); pub const CLAUDE3_HAIKU: Self = Self::new(0.025, 0.125); pub const GEMINI_PRO: Self = Self::new(0.0125, 0.0375); pub const GEMINI_FLASH: Self = Self::new(0.00375, 0.01125); pub const MISTRAL_LARGE: Self = Self::new(0.2, 0.6); pub const MISTRAL_SMALL: Self = Self::new(0.1, 0.3); pub const COHERE_COMMAND_R: Self = Self::new(0.05, 0.15); pub const COHERE_COMMAND_R_PLUS: Self = Self::new(0.3, 1.5); pub const FREE: Self = Self::new(0.0, 0.0);
86
87 pub const ADA_EMBEDDING: Self = Self::new(0.01, 0.0); pub const TEXT_EMBEDDING_3_SMALL: Self = Self::new(0.002, 0.0); pub const TEXT_EMBEDDING_3_LARGE: Self = Self::new(0.013, 0.0); pub fn calculate_cost(&self, prompt_tokens: u32, completion_tokens: u32) -> f64 {
98 let input_cost = (prompt_tokens as f64 / 1000.0) * self.input_per_1k;
99 let output_cost = (completion_tokens as f64 / 1000.0) * self.output_per_1k;
100 input_cost + output_cost
101 }
102
103 pub const fn gpt4() -> Self {
105 Self::GPT4
106 }
107
108 pub const fn ada_embedding() -> Self {
110 Self::ADA_EMBEDDING
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct UsageStats {
117 pub total_prompt_tokens: u64,
119 pub total_completion_tokens: u64,
121 pub total_tokens: u64,
123 pub request_count: u64,
125 pub estimated_cost_cents: f64,
127}
128
129impl UsageStats {
130 pub fn estimated_cost_usd(&self) -> f64 {
132 self.estimated_cost_cents / 100.0
133 }
134
135 pub fn avg_tokens_per_request(&self) -> f64 {
137 if self.request_count == 0 {
138 0.0
139 } else {
140 self.total_tokens as f64 / self.request_count as f64
141 }
142 }
143}
144
145#[derive(Debug)]
147pub struct UsageTracker {
148 prompt_tokens: AtomicU64,
149 completion_tokens: AtomicU64,
150 request_count: AtomicU64,
151 pricing: Option<ModelPricing>,
152}
153
154impl Clone for UsageTracker {
155 fn clone(&self) -> Self {
156 Self {
157 prompt_tokens: AtomicU64::new(self.prompt_tokens.load(Ordering::Relaxed)),
158 completion_tokens: AtomicU64::new(self.completion_tokens.load(Ordering::Relaxed)),
159 request_count: AtomicU64::new(self.request_count.load(Ordering::Relaxed)),
160 pricing: self.pricing,
161 }
162 }
163}
164
165impl Default for UsageTracker {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171impl UsageTracker {
172 pub fn new() -> Self {
174 Self {
175 prompt_tokens: AtomicU64::new(0),
176 completion_tokens: AtomicU64::new(0),
177 request_count: AtomicU64::new(0),
178 pricing: None,
179 }
180 }
181
182 pub fn with_pricing(pricing: ModelPricing) -> Self {
184 Self {
185 prompt_tokens: AtomicU64::new(0),
186 completion_tokens: AtomicU64::new(0),
187 request_count: AtomicU64::new(0),
188 pricing: Some(pricing),
189 }
190 }
191
192 pub fn record(&self, usage: &Usage) {
194 self.prompt_tokens
195 .fetch_add(usage.prompt_tokens as u64, Ordering::Relaxed);
196 self.completion_tokens
197 .fetch_add(usage.completion_tokens as u64, Ordering::Relaxed);
198 self.request_count.fetch_add(1, Ordering::Relaxed);
199 }
200
201 pub fn stats(&self) -> UsageStats {
203 let prompt = self.prompt_tokens.load(Ordering::Relaxed);
204 let completion = self.completion_tokens.load(Ordering::Relaxed);
205 let count = self.request_count.load(Ordering::Relaxed);
206
207 let cost = self
208 .pricing
209 .map(|p| p.calculate_cost(prompt as u32, completion as u32))
210 .unwrap_or(0.0);
211
212 UsageStats {
213 total_prompt_tokens: prompt,
214 total_completion_tokens: completion,
215 total_tokens: prompt + completion,
216 request_count: count,
217 estimated_cost_cents: cost,
218 }
219 }
220
221 pub fn reset(&self) {
223 self.prompt_tokens.store(0, Ordering::Relaxed);
224 self.completion_tokens.store(0, Ordering::Relaxed);
225 self.request_count.store(0, Ordering::Relaxed);
226 }
227}
228
229pub struct TrackedProvider<P> {
231 inner: P,
232 tracker: Arc<UsageTracker>,
233}
234
235impl<P> TrackedProvider<P> {
236 pub fn new(provider: P) -> Self {
238 Self {
239 inner: provider,
240 tracker: Arc::new(UsageTracker::new()),
241 }
242 }
243
244 pub fn with_pricing(provider: P, pricing: ModelPricing) -> Self {
246 Self {
247 inner: provider,
248 tracker: Arc::new(UsageTracker::with_pricing(pricing)),
249 }
250 }
251
252 pub fn with_tracker(provider: P, tracker: Arc<UsageTracker>) -> Self {
254 Self {
255 inner: provider,
256 tracker,
257 }
258 }
259
260 pub fn inner(&self) -> &P {
262 &self.inner
263 }
264
265 pub fn inner_mut(&mut self) -> &mut P {
267 &mut self.inner
268 }
269
270 pub fn tracker(&self) -> &Arc<UsageTracker> {
272 &self.tracker
273 }
274
275 pub fn stats(&self) -> UsageStats {
277 self.tracker.stats()
278 }
279
280 pub fn reset(&self) {
282 self.tracker.reset();
283 }
284}
285
286#[async_trait]
287impl<P: LlmProvider> LlmProvider for TrackedProvider<P> {
288 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
289 let response = self.inner.complete(request).await?;
290
291 if let Some(usage) = &response.usage {
293 self.tracker.record(usage);
294 }
295
296 Ok(response)
297 }
298}
299
300#[async_trait]
301impl<P: EmbeddingProvider> EmbeddingProvider for TrackedProvider<P> {
302 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
303 let response = self.inner.embed(request).await?;
304
305 if let Some(usage) = &response.usage {
307 self.tracker.record(&Usage {
308 prompt_tokens: usage.prompt_tokens,
309 completion_tokens: 0,
310 total_tokens: usage.total_tokens,
311 });
312 }
313
314 Ok(response)
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_model_pricing_calculation() {
324 let pricing = ModelPricing::GPT4_TURBO;
325
326 let cost = pricing.calculate_cost(1000, 500);
328
329 assert!((cost - 2.5).abs() < 0.001);
331 }
332
333 #[test]
334 fn test_usage_tracker() {
335 let tracker = UsageTracker::with_pricing(ModelPricing::GPT35_TURBO);
336
337 tracker.record(&Usage {
338 prompt_tokens: 100,
339 completion_tokens: 50,
340 total_tokens: 150,
341 });
342
343 let stats = tracker.stats();
344 assert_eq!(stats.total_prompt_tokens, 100);
345 assert_eq!(stats.total_completion_tokens, 50);
346 assert_eq!(stats.total_tokens, 150);
347 assert_eq!(stats.request_count, 1);
348
349 tracker.record(&Usage {
351 prompt_tokens: 200,
352 completion_tokens: 100,
353 total_tokens: 300,
354 });
355
356 let stats = tracker.stats();
357 assert_eq!(stats.total_prompt_tokens, 300);
358 assert_eq!(stats.total_completion_tokens, 150);
359 assert_eq!(stats.total_tokens, 450);
360 assert_eq!(stats.request_count, 2);
361 assert_eq!(stats.avg_tokens_per_request(), 225.0);
362 }
363
364 #[test]
365 fn test_usage_tracker_reset() {
366 let tracker = UsageTracker::new();
367
368 tracker.record(&Usage {
369 prompt_tokens: 100,
370 completion_tokens: 50,
371 total_tokens: 150,
372 });
373
374 assert_eq!(tracker.stats().total_tokens, 150);
375
376 tracker.reset();
377
378 assert_eq!(tracker.stats().total_tokens, 0);
379 assert_eq!(tracker.stats().request_count, 0);
380 }
381
382 #[test]
383 fn test_free_pricing() {
384 let pricing = ModelPricing::FREE;
385 let cost = pricing.calculate_cost(10000, 5000);
386 assert_eq!(cost, 0.0);
387 }
388
389 #[test]
390 fn test_usage_stats_usd() {
391 let tracker = UsageTracker::with_pricing(ModelPricing::new(100.0, 100.0)); tracker.record(&Usage {
394 prompt_tokens: 1000,
395 completion_tokens: 1000,
396 total_tokens: 2000,
397 });
398
399 let stats = tracker.stats();
400 assert_eq!(stats.estimated_cost_cents, 200.0);
402 assert_eq!(stats.estimated_cost_usd(), 2.0);
403 }
404
405 #[tokio::test]
406 async fn test_budget_provider_under_limit() {
407 struct MockProvider;
408 #[async_trait]
409 impl LlmProvider for MockProvider {
410 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
411 Ok(LlmResponse {
412 content: format!("Response to: {}", request.prompt),
413 model: "mock".to_string(),
414 usage: Some(Usage {
415 prompt_tokens: 100,
416 completion_tokens: 50,
417 total_tokens: 150,
418 }),
419 tool_calls: Vec::new(),
420 })
421 }
422 }
423
424 let budget = BudgetLimit::new(100.0); let provider = BudgetProvider::new(MockProvider, budget, ModelPricing::new(10.0, 10.0));
426
427 let request = LlmRequest {
428 prompt: "test".to_string(),
429 system_prompt: None,
430 temperature: None,
431 max_tokens: None,
432 tools: Vec::new(),
433 images: Vec::new(),
434 };
435
436 let result = provider.complete(request).await;
437 assert!(result.is_ok());
438 }
439
440 #[tokio::test]
441 async fn test_budget_provider_exceeds_limit() {
442 struct MockProvider;
443 #[async_trait]
444 impl LlmProvider for MockProvider {
445 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
446 Ok(LlmResponse {
447 content: format!("Response to: {}", request.prompt),
448 model: "mock".to_string(),
449 usage: Some(Usage {
450 prompt_tokens: 10000,
451 completion_tokens: 5000,
452 total_tokens: 15000,
453 }),
454 tool_calls: Vec::new(),
455 })
456 }
457 }
458
459 let budget = BudgetLimit::new(0.5); let provider = BudgetProvider::new(MockProvider, budget, ModelPricing::new(100.0, 100.0));
461
462 let request = LlmRequest {
463 prompt: "test".to_string(),
464 system_prompt: None,
465 temperature: None,
466 max_tokens: None,
467 tools: Vec::new(),
468 images: Vec::new(),
469 };
470
471 let result = provider.complete(request.clone()).await;
473 assert!(result.is_ok());
474
475 let result = provider.complete(request).await;
477 assert!(result.is_err());
478 assert!(matches!(result.unwrap_err(), LlmError::ApiError(_)));
479 }
480
481 #[test]
482 fn test_budget_limit_remaining() {
483 let budget = BudgetLimit::new(100.0);
484 assert_eq!(budget.remaining_cents(), 100.0);
485
486 budget.consume(50.0);
487 assert_eq!(budget.remaining_cents(), 50.0);
488
489 budget.reset();
490 assert_eq!(budget.remaining_cents(), 100.0);
491 }
492}
493
494#[derive(Debug, Clone)]
498pub struct BudgetLimit {
499 max_budget_cents: f64,
501 consumed_cents: Arc<AtomicU64>,
503}
504
505impl PartialEq for BudgetLimit {
506 fn eq(&self, other: &Self) -> bool {
507 self.max_budget_cents == other.max_budget_cents
508 }
509}
510
511impl BudgetLimit {
512 pub fn new(max_budget_cents: f64) -> Self {
514 Self {
515 max_budget_cents,
516 consumed_cents: Arc::new(AtomicU64::new(0)),
517 }
518 }
519
520 pub fn cents(max_budget_cents: u64) -> Self {
522 Self::new(max_budget_cents as f64)
523 }
524
525 pub fn from_usd(max_budget_usd: f64) -> Self {
527 Self::new(max_budget_usd * 100.0)
528 }
529
530 pub fn dollars(max_budget_usd: u64) -> Self {
532 Self::from_usd(max_budget_usd as f64)
533 }
534
535 pub fn as_cents(&self) -> u64 {
537 self.max_budget_cents as u64
538 }
539
540 pub fn as_usd(&self) -> f64 {
542 self.max_budget_cents / 100.0
543 }
544
545 pub fn is_exceeded(&self) -> bool {
547 let consumed = self.consumed_cents();
548 consumed >= self.max_budget_cents
549 }
550
551 pub fn remaining_cents(&self) -> f64 {
553 let consumed = self.consumed_cents();
554 (self.max_budget_cents - consumed).max(0.0)
555 }
556
557 pub fn remaining_usd(&self) -> f64 {
559 self.remaining_cents() / 100.0
560 }
561
562 pub fn consumed_cents(&self) -> f64 {
564 let bits = self.consumed_cents.load(Ordering::Relaxed);
566 f64::from_bits(bits)
567 }
568
569 fn consume(&self, cents: f64) {
571 let current = self.consumed_cents();
573 let new_value = current + cents;
574 self.consumed_cents
575 .store(new_value.to_bits(), Ordering::Relaxed);
576 }
577
578 pub fn reset(&self) {
580 self.consumed_cents
581 .store(0_f64.to_bits(), Ordering::Relaxed);
582 }
583}
584
585pub struct BudgetProvider<P> {
587 inner: P,
588 budget: BudgetLimit,
589 pricing: ModelPricing,
590}
591
592impl<P> BudgetProvider<P> {
593 pub fn new(provider: P, budget: BudgetLimit, pricing: ModelPricing) -> Self {
595 Self {
596 inner: provider,
597 budget,
598 pricing,
599 }
600 }
601
602 pub fn inner(&self) -> &P {
604 &self.inner
605 }
606
607 pub fn remaining_budget_cents(&self) -> f64 {
609 self.budget.remaining_cents()
610 }
611
612 pub fn remaining_budget_usd(&self) -> f64 {
614 self.budget.remaining_usd()
615 }
616
617 pub fn is_budget_exceeded(&self) -> bool {
619 self.budget.is_exceeded()
620 }
621
622 pub fn reset_budget(&self) {
624 self.budget.reset();
625 }
626}
627
628#[async_trait]
629impl<P: LlmProvider> LlmProvider for BudgetProvider<P> {
630 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
631 if self.budget.is_exceeded() {
633 return Err(LlmError::ApiError(format!(
634 "Budget exceeded: ${:.4} spent of ${:.4} limit",
635 self.budget.consumed_cents() / 100.0,
636 self.budget.max_budget_cents / 100.0
637 )));
638 }
639
640 let response = self.inner.complete(request).await?;
641
642 if let Some(usage) = &response.usage {
644 let cost = self
645 .pricing
646 .calculate_cost(usage.prompt_tokens, usage.completion_tokens);
647 self.budget.consume(cost);
648
649 tracing::info!(
651 cost_cents = cost,
652 remaining_cents = self.budget.remaining_cents(),
653 "Request cost tracked against budget"
654 );
655
656 let remaining_pct =
658 self.budget.remaining_cents() / self.budget.max_budget_cents * 100.0;
659 if remaining_pct < 10.0 && remaining_pct > 0.0 {
660 tracing::warn!(
661 remaining_pct = format!("{:.1}%", remaining_pct),
662 "Budget running low"
663 );
664 }
665 }
666
667 Ok(response)
668 }
669}
670
671#[async_trait]
672impl<P: EmbeddingProvider> EmbeddingProvider for BudgetProvider<P> {
673 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
674 if self.budget.is_exceeded() {
676 return Err(LlmError::ApiError(format!(
677 "Budget exceeded: ${:.4} spent of ${:.4} limit",
678 self.budget.consumed_cents() / 100.0,
679 self.budget.max_budget_cents / 100.0
680 )));
681 }
682
683 let response = self.inner.embed(request).await?;
684
685 if let Some(usage) = &response.usage {
687 let cost = self.pricing.calculate_cost(usage.prompt_tokens, 0);
688 self.budget.consume(cost);
689
690 tracing::info!(
691 cost_cents = cost,
692 remaining_cents = self.budget.remaining_cents(),
693 "Embedding cost tracked against budget"
694 );
695 }
696
697 Ok(response)
698 }
699}