1use std::fmt;
20use std::ops::{Add, AddAssign};
21
22use serde::{Deserialize, Serialize};
23
24#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
26pub struct Usage {
27 pub input_tokens: u64,
29 pub output_tokens: u64,
31 pub reasoning_tokens: Option<u64>,
33 pub cache_read_tokens: Option<u64>,
35 pub cache_write_tokens: Option<u64>,
37}
38
39fn add_optional(a: Option<u64>, b: Option<u64>) -> Option<u64> {
41 match (a, b) {
42 (Some(x), Some(y)) => Some(x.saturating_add(y)),
43 (Some(x), None) | (None, Some(x)) => Some(x),
44 (None, None) => None,
45 }
46}
47
48impl Add for Usage {
49 type Output = Self;
50
51 fn add(self, rhs: Self) -> Self {
57 Self {
58 input_tokens: self.input_tokens.saturating_add(rhs.input_tokens),
59 output_tokens: self.output_tokens.saturating_add(rhs.output_tokens),
60 reasoning_tokens: add_optional(self.reasoning_tokens, rhs.reasoning_tokens),
61 cache_read_tokens: add_optional(self.cache_read_tokens, rhs.cache_read_tokens),
62 cache_write_tokens: add_optional(self.cache_write_tokens, rhs.cache_write_tokens),
63 }
64 }
65}
66
67impl AddAssign for Usage {
68 fn add_assign(&mut self, rhs: Self) {
69 *self += &rhs;
70 }
71}
72
73impl AddAssign<&Usage> for Usage {
74 fn add_assign(&mut self, rhs: &Self) {
78 self.input_tokens = self.input_tokens.saturating_add(rhs.input_tokens);
79 self.output_tokens = self.output_tokens.saturating_add(rhs.output_tokens);
80 self.reasoning_tokens = add_optional(self.reasoning_tokens, rhs.reasoning_tokens);
81 self.cache_read_tokens = add_optional(self.cache_read_tokens, rhs.cache_read_tokens);
82 self.cache_write_tokens = add_optional(self.cache_write_tokens, rhs.cache_write_tokens);
83 }
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
102pub struct Cost {
103 input: u64,
104 output: u64,
105 total: u64,
106}
107
108impl Default for Cost {
109 fn default() -> Self {
111 Self {
112 input: 0,
113 output: 0,
114 total: 0,
115 }
116 }
117}
118
119#[derive(Deserialize)]
121struct CostRaw {
122 input: u64,
123 output: u64,
124}
125
126impl<'de> Deserialize<'de> for Cost {
127 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
128 where
129 D: serde::Deserializer<'de>,
130 {
131 let raw = CostRaw::deserialize(deserializer)?;
132 let total = raw
133 .input
134 .checked_add(raw.output)
135 .ok_or_else(|| serde::de::Error::custom("cost overflow: input + output exceeds u64"))?;
136 Ok(Self {
137 input: raw.input,
138 output: raw.output,
139 total,
140 })
141 }
142}
143
144impl Cost {
145 pub fn new(input: u64, output: u64) -> Option<Self> {
148 let total = input.checked_add(output)?;
149 Some(Self {
150 input,
151 output,
152 total,
153 })
154 }
155
156 pub fn input_microdollars(&self) -> u64 {
158 self.input
159 }
160
161 pub fn output_microdollars(&self) -> u64 {
163 self.output
164 }
165
166 pub fn total_microdollars(&self) -> u64 {
168 self.total
169 }
170
171 pub fn checked_add(&self, rhs: &Self) -> Option<Self> {
173 let input = self.input.checked_add(rhs.input)?;
174 let output = self.output.checked_add(rhs.output)?;
175 Self::new(input, output)
176 }
177
178 #[allow(clippy::cast_precision_loss)] pub fn total_usd(&self) -> f64 {
184 self.total as f64 / 1_000_000.0
185 }
186}
187
188impl fmt::Display for Cost {
189 #[allow(clippy::cast_precision_loss)]
191 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
192 write!(f, "${:.2}", self.total as f64 / 1_000_000.0)
193 }
194}
195
196impl Add for Cost {
197 type Output = Self;
198
199 fn add(self, rhs: Self) -> Self {
203 let input = self.input.saturating_add(rhs.input);
204 let output = self.output.saturating_add(rhs.output);
205 Self {
206 input,
207 output,
208 total: input.saturating_add(output),
209 }
210 }
211}
212
213impl AddAssign for Cost {
214 fn add_assign(&mut self, rhs: Self) {
215 self.input = self.input.saturating_add(rhs.input);
216 self.output = self.output.saturating_add(rhs.output);
217 self.total = self.input.saturating_add(self.output);
218 }
219}
220
221#[derive(Debug, Clone)]
254pub struct UsageTracker {
255 total: Usage,
257 by_call: Vec<Usage>,
259 context_limit: Option<u64>,
261}
262
263impl Default for UsageTracker {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269impl UsageTracker {
270 pub fn new() -> Self {
272 Self {
273 total: Usage::default(),
274 by_call: Vec::new(),
275 context_limit: None,
276 }
277 }
278
279 pub fn with_context_limit(limit: u64) -> Self {
284 Self {
285 total: Usage::default(),
286 by_call: Vec::new(),
287 context_limit: Some(limit),
288 }
289 }
290
291 pub fn record(&mut self, usage: Usage) {
296 self.total += &usage;
297 self.by_call.push(usage);
298 }
299
300 pub fn total(&self) -> &Usage {
302 &self.total
303 }
304
305 pub fn calls(&self) -> &[Usage] {
307 &self.by_call
308 }
309
310 pub fn call_count(&self) -> usize {
312 self.by_call.len()
313 }
314
315 pub fn context_limit(&self) -> Option<u64> {
317 self.context_limit
318 }
319
320 pub fn set_context_limit(&mut self, limit: u64) {
324 self.context_limit = Some(limit);
325 }
326
327 #[allow(clippy::cast_precision_loss)] pub fn context_utilization(&self) -> Option<f64> {
338 self.context_limit.map(|limit| {
339 if limit == 0 {
340 return 0.0;
341 }
342 self.total.input_tokens as f64 / limit as f64
343 })
344 }
345
346 pub fn is_near_limit(&self, threshold: f64) -> bool {
366 self.context_utilization()
367 .is_some_and(|util| util >= threshold)
368 }
369
370 pub fn cost(&self, pricing: &ModelPricing) -> Option<Cost> {
375 pricing.compute_cost(&self.total)
376 }
377
378 pub fn reset(&mut self) {
380 self.total = Usage::default();
381 self.by_call.clear();
382 }
383}
384
385#[derive(Debug, Clone, PartialEq, Eq)]
412pub struct ModelPricing {
413 pub input_per_million: u64,
415 pub output_per_million: u64,
417 pub cache_read_per_million: Option<u64>,
419}
420
421impl ModelPricing {
422 pub fn compute_cost(&self, usage: &Usage) -> Option<Cost> {
426 let input_cost = compute_token_cost(usage.input_tokens, self.input_per_million)?;
429 let output_cost = compute_token_cost(usage.output_tokens, self.output_per_million)?;
430
431 let cache_cost = match (usage.cache_read_tokens, self.cache_read_per_million) {
433 (Some(tokens), Some(rate)) => compute_token_cost(tokens, rate)?,
434 _ => 0,
435 };
436
437 let total_input = input_cost.checked_add(cache_cost)?;
440 Cost::new(total_input, output_cost)
441 }
442}
443
444fn compute_token_cost(tokens: u64, per_million: u64) -> Option<u64> {
448 let product = u128::from(tokens) * u128::from(per_million);
451 let cost = product / 1_000_000;
452 u64::try_from(cost).ok()
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458
459 #[test]
460 fn test_usage_clone_eq() {
461 let u = Usage {
462 input_tokens: 100,
463 output_tokens: 50,
464 reasoning_tokens: Some(10),
465 cache_read_tokens: None,
466 cache_write_tokens: None,
467 };
468 assert_eq!(u, u.clone());
469 }
470
471 #[test]
472 fn test_usage_debug_format() {
473 let u = Usage::default();
474 let debug = format!("{u:?}");
475 assert!(debug.contains("input_tokens"));
476 assert!(debug.contains("output_tokens"));
477 }
478
479 #[test]
480 fn test_usage_optional_fields_none() {
481 let u = Usage::default();
482 assert_eq!(u.reasoning_tokens, None);
483 assert_eq!(u.cache_read_tokens, None);
484 assert_eq!(u.cache_write_tokens, None);
485 }
486
487 #[test]
488 fn test_usage_optional_fields_some() {
489 let u = Usage {
490 input_tokens: 0,
491 output_tokens: 0,
492 reasoning_tokens: Some(500),
493 cache_read_tokens: Some(200),
494 cache_write_tokens: Some(100),
495 };
496 assert_eq!(u.reasoning_tokens, Some(500));
497 assert_eq!(u.cache_read_tokens, Some(200));
498 assert_eq!(u.cache_write_tokens, Some(100));
499 }
500
501 #[test]
502 fn test_usage_serde_roundtrip() {
503 let u = Usage {
504 input_tokens: 100,
505 output_tokens: 50,
506 reasoning_tokens: Some(10),
507 cache_read_tokens: None,
508 cache_write_tokens: None,
509 };
510 let json = serde_json::to_string(&u).unwrap();
511 let back: Usage = serde_json::from_str(&json).unwrap();
512 assert_eq!(u, back);
513 }
514
515 #[test]
516 fn test_cost_new_enforces_invariant() {
517 let c = Cost::new(1_000_000, 500_000).unwrap();
518 assert_eq!(c.input_microdollars(), 1_000_000);
519 assert_eq!(c.output_microdollars(), 500_000);
520 assert_eq!(c.total_microdollars(), 1_500_000);
521 }
522
523 #[test]
524 fn test_cost_new_overflow_returns_none() {
525 assert!(Cost::new(u64::MAX, 1).is_none());
526 }
527
528 #[test]
529 fn test_cost_total_usd_exact() {
530 let c = Cost::new(1_000_000, 500_000).unwrap();
531 assert!((c.total_usd() - 1.5).abs() < f64::EPSILON);
532 }
533
534 #[test]
535 fn test_cost_total_usd_zero() {
536 let c = Cost::new(0, 0).unwrap();
537 assert!((c.total_usd()).abs() < f64::EPSILON);
538 }
539
540 #[test]
541 fn test_cost_total_usd_sub_cent() {
542 let c = Cost::new(300, 200).unwrap();
543 assert!((c.total_usd() - 0.0005).abs() < f64::EPSILON);
544 }
545
546 #[test]
547 fn test_cost_clone_eq() {
548 let c = Cost::new(42, 58).unwrap();
549 assert_eq!(c, c.clone());
550 }
551
552 #[test]
553 fn test_cost_serde_roundtrip() {
554 let c = Cost::new(1_000_000, 500_000).unwrap();
555 let json = serde_json::to_string(&c).unwrap();
556 let back: Cost = serde_json::from_str(&json).unwrap();
557 assert_eq!(c, back);
558 }
559
560 #[test]
561 fn test_cost_deserialization_recomputes_total() {
562 let json = r#"{"input":100,"output":200,"total":999}"#;
564 let c: Cost = serde_json::from_str(json).unwrap();
565 assert_eq!(c.total_microdollars(), 300);
566 }
567
568 #[test]
569 fn test_cost_deserialization_without_total() {
570 let json = r#"{"input":100,"output":200}"#;
571 let c: Cost = serde_json::from_str(json).unwrap();
572 assert_eq!(c.total_microdollars(), 300);
573 }
574
575 #[test]
576 fn test_cost_deserialization_overflow_fails() {
577 let json = format!(r#"{{"input":{},"output":1}}"#, u64::MAX);
578 let result: Result<Cost, _> = serde_json::from_str(&json);
579 assert!(result.is_err());
580 }
581
582 #[test]
583 fn test_cost_default_is_zero() {
584 let c = Cost::default();
585 assert_eq!(c.input_microdollars(), 0);
586 assert_eq!(c.output_microdollars(), 0);
587 assert_eq!(c.total_microdollars(), 0);
588 }
589
590 #[test]
593 fn test_cost_display() {
594 let c = Cost::new(1_000_000, 500_000).unwrap();
595 assert_eq!(c.to_string(), "$1.50");
596 }
597
598 #[test]
599 fn test_cost_display_zero() {
600 assert_eq!(Cost::default().to_string(), "$0.00");
601 }
602
603 #[test]
604 fn test_cost_display_sub_cent() {
605 let c = Cost::new(500, 0).unwrap();
606 assert_eq!(c.to_string(), "$0.00");
607 }
608
609 #[test]
612 fn test_usage_add_basic() {
613 let a = Usage {
614 input_tokens: 100,
615 output_tokens: 50,
616 reasoning_tokens: Some(10),
617 cache_read_tokens: None,
618 cache_write_tokens: Some(20),
619 };
620 let b = Usage {
621 input_tokens: 200,
622 output_tokens: 30,
623 reasoning_tokens: Some(5),
624 cache_read_tokens: Some(50),
625 cache_write_tokens: None,
626 };
627 let sum = a + b;
628 assert_eq!(sum.input_tokens, 300);
629 assert_eq!(sum.output_tokens, 80);
630 assert_eq!(sum.reasoning_tokens, Some(15));
631 assert_eq!(sum.cache_read_tokens, Some(50));
632 assert_eq!(sum.cache_write_tokens, Some(20));
633 }
634
635 #[test]
636 fn test_usage_add_both_none() {
637 let a = Usage::default();
638 let b = Usage::default();
639 let sum = a + b;
640 assert_eq!(sum.reasoning_tokens, None);
641 assert_eq!(sum.cache_read_tokens, None);
642 assert_eq!(sum.cache_write_tokens, None);
643 }
644
645 #[test]
646 fn test_usage_add_assign() {
647 let mut a = Usage {
648 input_tokens: 100,
649 output_tokens: 50,
650 ..Default::default()
651 };
652 a += Usage {
653 input_tokens: 200,
654 output_tokens: 30,
655 ..Default::default()
656 };
657 assert_eq!(a.input_tokens, 300);
658 assert_eq!(a.output_tokens, 80);
659 }
660
661 #[test]
662 fn test_usage_add_saturates() {
663 let a = Usage {
664 input_tokens: u64::MAX,
665 output_tokens: 0,
666 ..Default::default()
667 };
668 let b = Usage {
669 input_tokens: 1,
670 output_tokens: 0,
671 ..Default::default()
672 };
673 let sum = a + b;
674 assert_eq!(sum.input_tokens, u64::MAX);
675 }
676
677 #[test]
680 fn test_cost_add_basic() {
681 let a = Cost::new(100, 200).unwrap();
682 let b = Cost::new(300, 400).unwrap();
683 let sum = a + b;
684 assert_eq!(sum.input_microdollars(), 400);
685 assert_eq!(sum.output_microdollars(), 600);
686 assert_eq!(sum.total_microdollars(), 1000);
687 }
688
689 #[test]
690 fn test_cost_add_assign() {
691 let mut c = Cost::new(100, 200).unwrap();
692 c += Cost::new(50, 50).unwrap();
693 assert_eq!(c.input_microdollars(), 150);
694 assert_eq!(c.output_microdollars(), 250);
695 assert_eq!(c.total_microdollars(), 400);
696 }
697
698 #[test]
699 fn test_cost_checked_add() {
700 let a = Cost::new(100, 200).unwrap();
701 let b = Cost::new(300, 400).unwrap();
702 let sum = a.checked_add(&b).unwrap();
703 assert_eq!(sum.total_microdollars(), 1000);
704 }
705
706 #[test]
707 fn test_cost_checked_add_overflow() {
708 let a = Cost::new(u64::MAX - 1, 0).unwrap();
709 let b = Cost::new(2, 0).unwrap();
710 assert!(a.checked_add(&b).is_none());
711 }
712
713 #[test]
714 fn test_cost_add_saturates() {
715 let a = Cost::new(u64::MAX - 1, 0).unwrap();
716 let b = Cost::new(2, 0).unwrap();
717 let sum = a + b;
718 assert_eq!(sum.input_microdollars(), u64::MAX);
719 }
720
721 #[test]
724 fn test_usage_tracker_new() {
725 let tracker = UsageTracker::new();
726 assert_eq!(tracker.total().input_tokens, 0);
727 assert_eq!(tracker.total().output_tokens, 0);
728 assert!(tracker.calls().is_empty());
729 assert_eq!(tracker.context_limit(), None);
730 }
731
732 #[test]
733 fn test_usage_tracker_default() {
734 let tracker = UsageTracker::default();
735 assert_eq!(tracker.call_count(), 0);
736 assert_eq!(tracker.context_limit(), None);
737 }
738
739 #[test]
740 fn test_usage_tracker_with_context_limit() {
741 let tracker = UsageTracker::with_context_limit(128_000);
742 assert_eq!(tracker.context_limit(), Some(128_000));
743 }
744
745 #[test]
746 fn test_usage_tracker_record() {
747 let mut tracker = UsageTracker::new();
748 tracker.record(Usage {
749 input_tokens: 100,
750 output_tokens: 50,
751 ..Default::default()
752 });
753 tracker.record(Usage {
754 input_tokens: 200,
755 output_tokens: 100,
756 ..Default::default()
757 });
758
759 assert_eq!(tracker.total().input_tokens, 300);
760 assert_eq!(tracker.total().output_tokens, 150);
761 assert_eq!(tracker.call_count(), 2);
762 assert_eq!(tracker.calls()[0].input_tokens, 100);
763 assert_eq!(tracker.calls()[1].input_tokens, 200);
764 }
765
766 #[test]
767 fn test_usage_tracker_context_utilization() {
768 let mut tracker = UsageTracker::with_context_limit(100_000);
769 tracker.record(Usage {
770 input_tokens: 50_000,
771 output_tokens: 1000,
772 ..Default::default()
773 });
774
775 let util = tracker.context_utilization().unwrap();
776 assert!((util - 0.5).abs() < f64::EPSILON);
777 }
778
779 #[test]
780 fn test_usage_tracker_context_utilization_no_limit() {
781 let tracker = UsageTracker::new();
782 assert!(tracker.context_utilization().is_none());
783 }
784
785 #[test]
786 fn test_usage_tracker_context_utilization_zero_limit() {
787 let tracker = UsageTracker::with_context_limit(0);
788 assert!((tracker.context_utilization().unwrap()).abs() < f64::EPSILON);
789 }
790
791 #[test]
792 fn test_usage_tracker_is_near_limit() {
793 let mut tracker = UsageTracker::with_context_limit(100_000);
794 tracker.record(Usage {
795 input_tokens: 85_000,
796 output_tokens: 1000,
797 ..Default::default()
798 });
799
800 assert!(tracker.is_near_limit(0.8)); assert!(tracker.is_near_limit(0.85)); assert!(!tracker.is_near_limit(0.9)); }
804
805 #[test]
806 fn test_usage_tracker_is_near_limit_no_limit() {
807 let tracker = UsageTracker::new();
808 assert!(!tracker.is_near_limit(0.8));
809 }
810
811 #[test]
812 fn test_usage_tracker_set_context_limit() {
813 let mut tracker = UsageTracker::new();
814 assert_eq!(tracker.context_limit(), None);
815
816 tracker.set_context_limit(200_000);
817 assert_eq!(tracker.context_limit(), Some(200_000));
818 }
819
820 #[test]
821 fn test_usage_tracker_reset() {
822 let mut tracker = UsageTracker::with_context_limit(100_000);
823 tracker.record(Usage {
824 input_tokens: 1000,
825 output_tokens: 500,
826 ..Default::default()
827 });
828 assert_eq!(tracker.call_count(), 1);
829 assert_eq!(tracker.total().input_tokens, 1000);
830
831 tracker.reset();
832 assert_eq!(tracker.call_count(), 0);
833 assert_eq!(tracker.total().input_tokens, 0);
834 assert_eq!(tracker.context_limit(), Some(100_000));
836 }
837
838 #[test]
839 fn test_usage_tracker_clone() {
840 let mut tracker = UsageTracker::with_context_limit(50_000);
841 tracker.record(Usage {
842 input_tokens: 100,
843 output_tokens: 50,
844 ..Default::default()
845 });
846
847 let cloned = tracker.clone();
848 assert_eq!(cloned.total().input_tokens, 100);
849 assert_eq!(cloned.call_count(), 1);
850 assert_eq!(cloned.context_limit(), Some(50_000));
851 }
852
853 #[test]
856 fn test_model_pricing_compute_cost() {
857 let pricing = ModelPricing {
858 input_per_million: 3_000_000, output_per_million: 15_000_000, cache_read_per_million: None,
861 };
862
863 let usage = Usage {
864 input_tokens: 1_000_000, output_tokens: 100_000, ..Default::default()
867 };
868
869 let cost = pricing.compute_cost(&usage).unwrap();
870 assert_eq!(cost.input_microdollars(), 3_000_000); assert_eq!(cost.output_microdollars(), 1_500_000); assert_eq!(cost.total_microdollars(), 4_500_000); }
874
875 #[test]
876 fn test_model_pricing_with_cache_tokens() {
877 let pricing = ModelPricing {
878 input_per_million: 3_000_000,
879 output_per_million: 15_000_000,
880 cache_read_per_million: Some(300_000), };
882
883 let usage = Usage {
884 input_tokens: 500_000,
885 output_tokens: 100_000,
886 cache_read_tokens: Some(500_000), ..Default::default()
888 };
889
890 let cost = pricing.compute_cost(&usage).unwrap();
891 assert_eq!(cost.input_microdollars(), 1_650_000);
896 assert_eq!(cost.output_microdollars(), 1_500_000);
897 }
898
899 #[test]
900 fn test_model_pricing_zero_tokens() {
901 let pricing = ModelPricing {
902 input_per_million: 3_000_000,
903 output_per_million: 15_000_000,
904 cache_read_per_million: None,
905 };
906
907 let usage = Usage::default();
908 let cost = pricing.compute_cost(&usage).unwrap();
909 assert_eq!(cost.total_microdollars(), 0);
910 }
911
912 #[test]
913 fn test_model_pricing_cache_without_pricing() {
914 let pricing = ModelPricing {
916 input_per_million: 3_000_000,
917 output_per_million: 15_000_000,
918 cache_read_per_million: None,
919 };
920
921 let usage = Usage {
922 input_tokens: 1_000_000,
923 output_tokens: 100_000,
924 cache_read_tokens: Some(500_000),
925 ..Default::default()
926 };
927
928 let cost = pricing.compute_cost(&usage).unwrap();
929 assert_eq!(cost.input_microdollars(), 3_000_000);
931 }
932
933 #[test]
934 fn test_usage_tracker_cost() {
935 let mut tracker = UsageTracker::new();
936 tracker.record(Usage {
937 input_tokens: 1_000_000,
938 output_tokens: 100_000,
939 ..Default::default()
940 });
941
942 let pricing = ModelPricing {
943 input_per_million: 3_000_000,
944 output_per_million: 15_000_000,
945 cache_read_per_million: None,
946 };
947
948 let cost = tracker.cost(&pricing).unwrap();
949 assert_eq!(cost.total_microdollars(), 4_500_000);
950 }
951
952 #[test]
953 fn test_model_pricing_clone_eq() {
954 let p1 = ModelPricing {
955 input_per_million: 100,
956 output_per_million: 200,
957 cache_read_per_million: Some(50),
958 };
959 let p2 = p1.clone();
960 assert_eq!(p1, p2);
961 }
962
963 #[test]
964 fn test_compute_token_cost_large_values() {
965 let cost = compute_token_cost(10_000_000_000, 3_000_000);
967 assert_eq!(cost, Some(30_000_000_000));
969 }
970}