Skip to main content

llm_stack/
usage.rs

1//! Token usage and cost tracking.
2//!
3//! Every response carries a [`Usage`] record counting input and output
4//! tokens, with optional fields for reasoning and cache tokens when the
5//! provider reports them.
6//!
7//! [`Cost`] tracks monetary cost in **microdollars** (1 USD = 1,000,000
8//! microdollars). Integer arithmetic avoids floating-point rounding
9//! issues when aggregating costs across many requests. Use
10//! [`total_usd`](Cost::total_usd) for display purposes.
11//!
12//! # Invariant
13//!
14//! `Cost` enforces `total == input + output` at construction time.
15//! The fields are private — use [`Cost::new`] to build one, and the
16//! accessor methods to read values. Deserialization recomputes the
17//! total from `input` and `output`, ignoring any `total` in the JSON.
18
19use std::fmt;
20use std::ops::{Add, AddAssign};
21
22use serde::{Deserialize, Serialize};
23
24/// Token counts for a single request/response pair.
25#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
26pub struct Usage {
27    /// Tokens consumed by the prompt (messages + system + tool defs).
28    pub input_tokens: u64,
29    /// Tokens produced by the model's response.
30    pub output_tokens: u64,
31    /// Tokens used for chain-of-thought reasoning, if applicable.
32    pub reasoning_tokens: Option<u64>,
33    /// Tokens served from the provider's prompt cache (reducing cost).
34    pub cache_read_tokens: Option<u64>,
35    /// Tokens written into the provider's prompt cache for future reuse.
36    pub cache_write_tokens: Option<u64>,
37}
38
39/// Helper: adds two `Option<u64>` fields, treating `None` as zero.
40fn add_optional(a: Option<u64>, b: Option<u64>) -> Option<u64> {
41    match (a, b) {
42        (Some(x), Some(y)) => Some(x.saturating_add(y)),
43        (Some(x), None) | (None, Some(x)) => Some(x),
44        (None, None) => None,
45    }
46}
47
48impl Add for Usage {
49    type Output = Self;
50
51    /// Adds two `Usage` records field-by-field.
52    ///
53    /// Mandatory fields use saturating addition. Optional fields are
54    /// summed when both are `Some`, preserved when one is `Some`, and
55    /// remain `None` when both are `None`.
56    fn add(self, rhs: Self) -> Self {
57        Self {
58            input_tokens: self.input_tokens.saturating_add(rhs.input_tokens),
59            output_tokens: self.output_tokens.saturating_add(rhs.output_tokens),
60            reasoning_tokens: add_optional(self.reasoning_tokens, rhs.reasoning_tokens),
61            cache_read_tokens: add_optional(self.cache_read_tokens, rhs.cache_read_tokens),
62            cache_write_tokens: add_optional(self.cache_write_tokens, rhs.cache_write_tokens),
63        }
64    }
65}
66
67impl AddAssign for Usage {
68    fn add_assign(&mut self, rhs: Self) {
69        *self += &rhs;
70    }
71}
72
73impl AddAssign<&Usage> for Usage {
74    /// Adds another `Usage` to this one in-place without cloning.
75    ///
76    /// This is more efficient than `AddAssign<Usage>` when you have a reference.
77    fn add_assign(&mut self, rhs: &Self) {
78        self.input_tokens = self.input_tokens.saturating_add(rhs.input_tokens);
79        self.output_tokens = self.output_tokens.saturating_add(rhs.output_tokens);
80        self.reasoning_tokens = add_optional(self.reasoning_tokens, rhs.reasoning_tokens);
81        self.cache_read_tokens = add_optional(self.cache_read_tokens, rhs.cache_read_tokens);
82        self.cache_write_tokens = add_optional(self.cache_write_tokens, rhs.cache_write_tokens);
83    }
84}
85
86/// Monetary cost in microdollars (1 USD = 1,000,000 microdollars).
87///
88/// Uses integer arithmetic to avoid floating-point accumulation errors.
89/// The invariant `total == input + output` is enforced by the
90/// constructor and maintained through deserialization.
91///
92/// # Examples
93///
94/// ```rust
95/// use llm_stack::Cost;
96///
97/// let cost = Cost::new(300_000, 150_000).expect("no overflow");
98/// assert_eq!(cost.total_microdollars(), 450_000);
99/// assert!((cost.total_usd() - 0.45).abs() < f64::EPSILON);
100/// ```
101#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
102pub struct Cost {
103    input: u64,
104    output: u64,
105    total: u64,
106}
107
108impl Default for Cost {
109    /// Returns a zero cost.
110    fn default() -> Self {
111        Self {
112            input: 0,
113            output: 0,
114            total: 0,
115        }
116    }
117}
118
119/// Intermediate type for safe deserialization — recomputes total.
120#[derive(Deserialize)]
121struct CostRaw {
122    input: u64,
123    output: u64,
124}
125
126impl<'de> Deserialize<'de> for Cost {
127    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
128    where
129        D: serde::Deserializer<'de>,
130    {
131        let raw = CostRaw::deserialize(deserializer)?;
132        let total = raw
133            .input
134            .checked_add(raw.output)
135            .ok_or_else(|| serde::de::Error::custom("cost overflow: input + output exceeds u64"))?;
136        Ok(Self {
137            input: raw.input,
138            output: raw.output,
139            total,
140        })
141    }
142}
143
144impl Cost {
145    /// Creates a new `Cost`, returning `None` if `input + output`
146    /// would overflow `u64`.
147    pub fn new(input: u64, output: u64) -> Option<Self> {
148        let total = input.checked_add(output)?;
149        Some(Self {
150            input,
151            output,
152            total,
153        })
154    }
155
156    /// Cost of the input (prompt) in microdollars.
157    pub fn input_microdollars(&self) -> u64 {
158        self.input
159    }
160
161    /// Cost of the output (completion) in microdollars.
162    pub fn output_microdollars(&self) -> u64 {
163        self.output
164    }
165
166    /// Total cost (`input + output`) in microdollars.
167    pub fn total_microdollars(&self) -> u64 {
168        self.total
169    }
170
171    /// Returns the sum of two costs, or `None` on overflow.
172    pub fn checked_add(&self, rhs: &Self) -> Option<Self> {
173        let input = self.input.checked_add(rhs.input)?;
174        let output = self.output.checked_add(rhs.output)?;
175        Self::new(input, output)
176    }
177
178    /// Total cost in US dollars, for display purposes.
179    ///
180    /// Uses floating-point division — prefer
181    /// [`total_microdollars`](Self::total_microdollars) for arithmetic.
182    #[allow(clippy::cast_precision_loss)] // microdollar u64 fits f64 mantissa in practice
183    pub fn total_usd(&self) -> f64 {
184        self.total as f64 / 1_000_000.0
185    }
186}
187
188impl fmt::Display for Cost {
189    /// Formats the cost as a USD string, e.g. `$1.50`.
190    #[allow(clippy::cast_precision_loss)]
191    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
192        write!(f, "${:.2}", self.total as f64 / 1_000_000.0)
193    }
194}
195
196impl Add for Cost {
197    type Output = Self;
198
199    /// Adds two costs using saturating arithmetic.
200    ///
201    /// Use [`checked_add`](Self::checked_add) when overflow must be detected.
202    fn add(self, rhs: Self) -> Self {
203        let input = self.input.saturating_add(rhs.input);
204        let output = self.output.saturating_add(rhs.output);
205        Self {
206            input,
207            output,
208            total: input.saturating_add(output),
209        }
210    }
211}
212
213impl AddAssign for Cost {
214    fn add_assign(&mut self, rhs: Self) {
215        self.input = self.input.saturating_add(rhs.input);
216        self.output = self.output.saturating_add(rhs.output);
217        self.total = self.input.saturating_add(self.output);
218    }
219}
220
221// ── UsageTracker ────────────────────────────────────────────────────
222
223/// Tracks cumulative token usage across multiple LLM calls.
224///
225/// `UsageTracker` accumulates [`Usage`] records from each request and
226/// provides context-awareness features for detecting when the conversation
227/// is approaching the model's context limit.
228///
229/// # Example
230///
231/// ```rust
232/// use llm_stack::usage::{Usage, UsageTracker};
233///
234/// let mut tracker = UsageTracker::with_context_limit(128_000);
235///
236/// // Record usage from each LLM call
237/// tracker.record(Usage {
238///     input_tokens: 1000,
239///     output_tokens: 500,
240///     ..Default::default()
241/// });
242///
243/// assert_eq!(tracker.total().input_tokens, 1000);
244/// assert!(!tracker.is_near_limit(0.8)); // Not near 80% yet
245/// ```
246///
247/// # Use Cases
248///
249/// - **Billing/cost tracking**: Aggregate costs across a session
250/// - **Budget alerts**: Warn when approaching token limits
251/// - **Compaction triggers**: Signal when context window is nearly full
252/// - **Token debugging**: Analyze per-call consumption patterns
253#[derive(Debug, Clone)]
254pub struct UsageTracker {
255    /// Accumulated usage across all calls.
256    total: Usage,
257    /// Usage from each individual call, in order.
258    by_call: Vec<Usage>,
259    /// Optional context window limit for utilization calculations.
260    context_limit: Option<u64>,
261}
262
263impl Default for UsageTracker {
264    fn default() -> Self {
265        Self::new()
266    }
267}
268
269impl UsageTracker {
270    /// Creates a new tracker with no context limit.
271    pub fn new() -> Self {
272        Self {
273            total: Usage::default(),
274            by_call: Vec::new(),
275            context_limit: None,
276        }
277    }
278
279    /// Creates a tracker with a known context window limit.
280    ///
281    /// The limit is used for [`context_utilization`](Self::context_utilization)
282    /// and [`is_near_limit`](Self::is_near_limit) calculations.
283    pub fn with_context_limit(limit: u64) -> Self {
284        Self {
285            total: Usage::default(),
286            by_call: Vec::new(),
287            context_limit: Some(limit),
288        }
289    }
290
291    /// Records a usage sample from an LLM call.
292    ///
293    /// The usage is added to the running total and stored for per-call
294    /// analysis.
295    pub fn record(&mut self, usage: Usage) {
296        self.total += &usage;
297        self.by_call.push(usage);
298    }
299
300    /// Returns the accumulated usage across all recorded calls.
301    pub fn total(&self) -> &Usage {
302        &self.total
303    }
304
305    /// Returns the usage from each individual call, in order.
306    pub fn calls(&self) -> &[Usage] {
307        &self.by_call
308    }
309
310    /// Returns the number of calls recorded.
311    pub fn call_count(&self) -> usize {
312        self.by_call.len()
313    }
314
315    /// Returns the context limit, if set.
316    pub fn context_limit(&self) -> Option<u64> {
317        self.context_limit
318    }
319
320    /// Sets or updates the context limit.
321    ///
322    /// Useful when the model is determined after tracker creation.
323    pub fn set_context_limit(&mut self, limit: u64) {
324        self.context_limit = Some(limit);
325    }
326
327    /// Returns the context utilization as a ratio (0.0 to 1.0+).
328    ///
329    /// Utilization is calculated as `total_input_tokens / context_limit`.
330    /// Returns `None` if no context limit is set.
331    ///
332    /// # Note
333    ///
334    /// The value can exceed 1.0 if the total exceeds the limit (which
335    /// shouldn't happen in practice but is not enforced).
336    #[allow(clippy::cast_precision_loss)] // u64 token counts fit f64 mantissa
337    pub fn context_utilization(&self) -> Option<f64> {
338        self.context_limit.map(|limit| {
339            if limit == 0 {
340                return 0.0;
341            }
342            self.total.input_tokens as f64 / limit as f64
343        })
344    }
345
346    /// Checks if the context utilization is at or above the given threshold.
347    ///
348    /// Returns `false` if no context limit is set.
349    ///
350    /// # Example
351    ///
352    /// ```rust
353    /// use llm_stack::usage::{Usage, UsageTracker};
354    ///
355    /// let mut tracker = UsageTracker::with_context_limit(100_000);
356    /// tracker.record(Usage {
357    ///     input_tokens: 85_000,
358    ///     output_tokens: 1000,
359    ///     ..Default::default()
360    /// });
361    ///
362    /// assert!(tracker.is_near_limit(0.8));   // 85% >= 80%
363    /// assert!(!tracker.is_near_limit(0.9));  // 85% < 90%
364    /// ```
365    pub fn is_near_limit(&self, threshold: f64) -> bool {
366        self.context_utilization()
367            .is_some_and(|util| util >= threshold)
368    }
369
370    /// Computes the cost of all recorded usage given a pricing table.
371    ///
372    /// Uses the pricing rates (per-million tokens) to calculate cost.
373    /// Returns `None` if the cost would overflow.
374    pub fn cost(&self, pricing: &ModelPricing) -> Option<Cost> {
375        pricing.compute_cost(&self.total)
376    }
377
378    /// Resets the tracker, clearing all recorded usage.
379    pub fn reset(&mut self) {
380        self.total = Usage::default();
381        self.by_call.clear();
382    }
383}
384
385/// Pricing information for a specific model.
386///
387/// All prices are in **microdollars per million tokens**. For example,
388/// a price of $3.00 per million input tokens would be `3_000_000`.
389///
390/// # Example
391///
392/// ```rust
393/// use llm_stack::usage::{ModelPricing, Usage};
394///
395/// // Claude 3.5 Sonnet pricing (as of early 2024)
396/// let pricing = ModelPricing {
397///     input_per_million: 3_000_000,   // $3.00 / MTok
398///     output_per_million: 15_000_000, // $15.00 / MTok
399///     cache_read_per_million: Some(300_000), // $0.30 / MTok
400/// };
401///
402/// let usage = Usage {
403///     input_tokens: 1_000_000,
404///     output_tokens: 100_000,
405///     ..Default::default()
406/// };
407///
408/// let cost = pricing.compute_cost(&usage).unwrap();
409/// assert_eq!(cost.total_microdollars(), 4_500_000); // $3 input + $1.50 output
410/// ```
411#[derive(Debug, Clone, PartialEq, Eq)]
412pub struct ModelPricing {
413    /// Cost per million input tokens in microdollars.
414    pub input_per_million: u64,
415    /// Cost per million output tokens in microdollars.
416    pub output_per_million: u64,
417    /// Cost per million cache-read tokens in microdollars (if applicable).
418    pub cache_read_per_million: Option<u64>,
419}
420
421impl ModelPricing {
422    /// Computes the cost for the given usage.
423    ///
424    /// Returns `None` if the calculation would overflow.
425    pub fn compute_cost(&self, usage: &Usage) -> Option<Cost> {
426        // Cost = (tokens * price_per_million) / 1_000_000
427        // To avoid precision loss, we use u128 for intermediate calculations
428        let input_cost = compute_token_cost(usage.input_tokens, self.input_per_million)?;
429        let output_cost = compute_token_cost(usage.output_tokens, self.output_per_million)?;
430
431        // If cache read tokens are present and pricing is set, include them
432        let cache_cost = match (usage.cache_read_tokens, self.cache_read_per_million) {
433            (Some(tokens), Some(rate)) => compute_token_cost(tokens, rate)?,
434            _ => 0,
435        };
436
437        // Combine costs (cache reads reduce effective input cost conceptually,
438        // but for billing they're additive at the cache rate)
439        let total_input = input_cost.checked_add(cache_cost)?;
440        Cost::new(total_input, output_cost)
441    }
442}
443
444/// Compute cost for a token count at a given rate.
445///
446/// Returns microdollars, or `None` on overflow.
447fn compute_token_cost(tokens: u64, per_million: u64) -> Option<u64> {
448    // (tokens * per_million) / 1_000_000
449    // Use u128 to avoid overflow in multiplication
450    let product = u128::from(tokens) * u128::from(per_million);
451    let cost = product / 1_000_000;
452    u64::try_from(cost).ok()
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    #[test]
460    fn test_usage_clone_eq() {
461        let u = Usage {
462            input_tokens: 100,
463            output_tokens: 50,
464            reasoning_tokens: Some(10),
465            cache_read_tokens: None,
466            cache_write_tokens: None,
467        };
468        assert_eq!(u, u.clone());
469    }
470
471    #[test]
472    fn test_usage_debug_format() {
473        let u = Usage::default();
474        let debug = format!("{u:?}");
475        assert!(debug.contains("input_tokens"));
476        assert!(debug.contains("output_tokens"));
477    }
478
479    #[test]
480    fn test_usage_optional_fields_none() {
481        let u = Usage::default();
482        assert_eq!(u.reasoning_tokens, None);
483        assert_eq!(u.cache_read_tokens, None);
484        assert_eq!(u.cache_write_tokens, None);
485    }
486
487    #[test]
488    fn test_usage_optional_fields_some() {
489        let u = Usage {
490            input_tokens: 0,
491            output_tokens: 0,
492            reasoning_tokens: Some(500),
493            cache_read_tokens: Some(200),
494            cache_write_tokens: Some(100),
495        };
496        assert_eq!(u.reasoning_tokens, Some(500));
497        assert_eq!(u.cache_read_tokens, Some(200));
498        assert_eq!(u.cache_write_tokens, Some(100));
499    }
500
501    #[test]
502    fn test_usage_serde_roundtrip() {
503        let u = Usage {
504            input_tokens: 100,
505            output_tokens: 50,
506            reasoning_tokens: Some(10),
507            cache_read_tokens: None,
508            cache_write_tokens: None,
509        };
510        let json = serde_json::to_string(&u).unwrap();
511        let back: Usage = serde_json::from_str(&json).unwrap();
512        assert_eq!(u, back);
513    }
514
515    #[test]
516    fn test_cost_new_enforces_invariant() {
517        let c = Cost::new(1_000_000, 500_000).unwrap();
518        assert_eq!(c.input_microdollars(), 1_000_000);
519        assert_eq!(c.output_microdollars(), 500_000);
520        assert_eq!(c.total_microdollars(), 1_500_000);
521    }
522
523    #[test]
524    fn test_cost_new_overflow_returns_none() {
525        assert!(Cost::new(u64::MAX, 1).is_none());
526    }
527
528    #[test]
529    fn test_cost_total_usd_exact() {
530        let c = Cost::new(1_000_000, 500_000).unwrap();
531        assert!((c.total_usd() - 1.5).abs() < f64::EPSILON);
532    }
533
534    #[test]
535    fn test_cost_total_usd_zero() {
536        let c = Cost::new(0, 0).unwrap();
537        assert!((c.total_usd()).abs() < f64::EPSILON);
538    }
539
540    #[test]
541    fn test_cost_total_usd_sub_cent() {
542        let c = Cost::new(300, 200).unwrap();
543        assert!((c.total_usd() - 0.0005).abs() < f64::EPSILON);
544    }
545
546    #[test]
547    fn test_cost_clone_eq() {
548        let c = Cost::new(42, 58).unwrap();
549        assert_eq!(c, c.clone());
550    }
551
552    #[test]
553    fn test_cost_serde_roundtrip() {
554        let c = Cost::new(1_000_000, 500_000).unwrap();
555        let json = serde_json::to_string(&c).unwrap();
556        let back: Cost = serde_json::from_str(&json).unwrap();
557        assert_eq!(c, back);
558    }
559
560    #[test]
561    fn test_cost_deserialization_recomputes_total() {
562        // Even if JSON has a wrong total, deserialization recomputes it
563        let json = r#"{"input":100,"output":200,"total":999}"#;
564        let c: Cost = serde_json::from_str(json).unwrap();
565        assert_eq!(c.total_microdollars(), 300);
566    }
567
568    #[test]
569    fn test_cost_deserialization_without_total() {
570        let json = r#"{"input":100,"output":200}"#;
571        let c: Cost = serde_json::from_str(json).unwrap();
572        assert_eq!(c.total_microdollars(), 300);
573    }
574
575    #[test]
576    fn test_cost_deserialization_overflow_fails() {
577        let json = format!(r#"{{"input":{},"output":1}}"#, u64::MAX);
578        let result: Result<Cost, _> = serde_json::from_str(&json);
579        assert!(result.is_err());
580    }
581
582    #[test]
583    fn test_cost_default_is_zero() {
584        let c = Cost::default();
585        assert_eq!(c.input_microdollars(), 0);
586        assert_eq!(c.output_microdollars(), 0);
587        assert_eq!(c.total_microdollars(), 0);
588    }
589
590    // --- Cost Display ---
591
592    #[test]
593    fn test_cost_display() {
594        let c = Cost::new(1_000_000, 500_000).unwrap();
595        assert_eq!(c.to_string(), "$1.50");
596    }
597
598    #[test]
599    fn test_cost_display_zero() {
600        assert_eq!(Cost::default().to_string(), "$0.00");
601    }
602
603    #[test]
604    fn test_cost_display_sub_cent() {
605        let c = Cost::new(500, 0).unwrap();
606        assert_eq!(c.to_string(), "$0.00");
607    }
608
609    // --- Usage Add/AddAssign ---
610
611    #[test]
612    fn test_usage_add_basic() {
613        let a = Usage {
614            input_tokens: 100,
615            output_tokens: 50,
616            reasoning_tokens: Some(10),
617            cache_read_tokens: None,
618            cache_write_tokens: Some(20),
619        };
620        let b = Usage {
621            input_tokens: 200,
622            output_tokens: 30,
623            reasoning_tokens: Some(5),
624            cache_read_tokens: Some(50),
625            cache_write_tokens: None,
626        };
627        let sum = a + b;
628        assert_eq!(sum.input_tokens, 300);
629        assert_eq!(sum.output_tokens, 80);
630        assert_eq!(sum.reasoning_tokens, Some(15));
631        assert_eq!(sum.cache_read_tokens, Some(50));
632        assert_eq!(sum.cache_write_tokens, Some(20));
633    }
634
635    #[test]
636    fn test_usage_add_both_none() {
637        let a = Usage::default();
638        let b = Usage::default();
639        let sum = a + b;
640        assert_eq!(sum.reasoning_tokens, None);
641        assert_eq!(sum.cache_read_tokens, None);
642        assert_eq!(sum.cache_write_tokens, None);
643    }
644
645    #[test]
646    fn test_usage_add_assign() {
647        let mut a = Usage {
648            input_tokens: 100,
649            output_tokens: 50,
650            ..Default::default()
651        };
652        a += Usage {
653            input_tokens: 200,
654            output_tokens: 30,
655            ..Default::default()
656        };
657        assert_eq!(a.input_tokens, 300);
658        assert_eq!(a.output_tokens, 80);
659    }
660
661    #[test]
662    fn test_usage_add_saturates() {
663        let a = Usage {
664            input_tokens: u64::MAX,
665            output_tokens: 0,
666            ..Default::default()
667        };
668        let b = Usage {
669            input_tokens: 1,
670            output_tokens: 0,
671            ..Default::default()
672        };
673        let sum = a + b;
674        assert_eq!(sum.input_tokens, u64::MAX);
675    }
676
677    // --- Cost Add/AddAssign/checked_add ---
678
679    #[test]
680    fn test_cost_add_basic() {
681        let a = Cost::new(100, 200).unwrap();
682        let b = Cost::new(300, 400).unwrap();
683        let sum = a + b;
684        assert_eq!(sum.input_microdollars(), 400);
685        assert_eq!(sum.output_microdollars(), 600);
686        assert_eq!(sum.total_microdollars(), 1000);
687    }
688
689    #[test]
690    fn test_cost_add_assign() {
691        let mut c = Cost::new(100, 200).unwrap();
692        c += Cost::new(50, 50).unwrap();
693        assert_eq!(c.input_microdollars(), 150);
694        assert_eq!(c.output_microdollars(), 250);
695        assert_eq!(c.total_microdollars(), 400);
696    }
697
698    #[test]
699    fn test_cost_checked_add() {
700        let a = Cost::new(100, 200).unwrap();
701        let b = Cost::new(300, 400).unwrap();
702        let sum = a.checked_add(&b).unwrap();
703        assert_eq!(sum.total_microdollars(), 1000);
704    }
705
706    #[test]
707    fn test_cost_checked_add_overflow() {
708        let a = Cost::new(u64::MAX - 1, 0).unwrap();
709        let b = Cost::new(2, 0).unwrap();
710        assert!(a.checked_add(&b).is_none());
711    }
712
713    #[test]
714    fn test_cost_add_saturates() {
715        let a = Cost::new(u64::MAX - 1, 0).unwrap();
716        let b = Cost::new(2, 0).unwrap();
717        let sum = a + b;
718        assert_eq!(sum.input_microdollars(), u64::MAX);
719    }
720
721    // --- UsageTracker ---
722
723    #[test]
724    fn test_usage_tracker_new() {
725        let tracker = UsageTracker::new();
726        assert_eq!(tracker.total().input_tokens, 0);
727        assert_eq!(tracker.total().output_tokens, 0);
728        assert!(tracker.calls().is_empty());
729        assert_eq!(tracker.context_limit(), None);
730    }
731
732    #[test]
733    fn test_usage_tracker_default() {
734        let tracker = UsageTracker::default();
735        assert_eq!(tracker.call_count(), 0);
736        assert_eq!(tracker.context_limit(), None);
737    }
738
739    #[test]
740    fn test_usage_tracker_with_context_limit() {
741        let tracker = UsageTracker::with_context_limit(128_000);
742        assert_eq!(tracker.context_limit(), Some(128_000));
743    }
744
745    #[test]
746    fn test_usage_tracker_record() {
747        let mut tracker = UsageTracker::new();
748        tracker.record(Usage {
749            input_tokens: 100,
750            output_tokens: 50,
751            ..Default::default()
752        });
753        tracker.record(Usage {
754            input_tokens: 200,
755            output_tokens: 100,
756            ..Default::default()
757        });
758
759        assert_eq!(tracker.total().input_tokens, 300);
760        assert_eq!(tracker.total().output_tokens, 150);
761        assert_eq!(tracker.call_count(), 2);
762        assert_eq!(tracker.calls()[0].input_tokens, 100);
763        assert_eq!(tracker.calls()[1].input_tokens, 200);
764    }
765
766    #[test]
767    fn test_usage_tracker_context_utilization() {
768        let mut tracker = UsageTracker::with_context_limit(100_000);
769        tracker.record(Usage {
770            input_tokens: 50_000,
771            output_tokens: 1000,
772            ..Default::default()
773        });
774
775        let util = tracker.context_utilization().unwrap();
776        assert!((util - 0.5).abs() < f64::EPSILON);
777    }
778
779    #[test]
780    fn test_usage_tracker_context_utilization_no_limit() {
781        let tracker = UsageTracker::new();
782        assert!(tracker.context_utilization().is_none());
783    }
784
785    #[test]
786    fn test_usage_tracker_context_utilization_zero_limit() {
787        let tracker = UsageTracker::with_context_limit(0);
788        assert!((tracker.context_utilization().unwrap()).abs() < f64::EPSILON);
789    }
790
791    #[test]
792    fn test_usage_tracker_is_near_limit() {
793        let mut tracker = UsageTracker::with_context_limit(100_000);
794        tracker.record(Usage {
795            input_tokens: 85_000,
796            output_tokens: 1000,
797            ..Default::default()
798        });
799
800        assert!(tracker.is_near_limit(0.8)); // 85% >= 80%
801        assert!(tracker.is_near_limit(0.85)); // 85% >= 85%
802        assert!(!tracker.is_near_limit(0.9)); // 85% < 90%
803    }
804
805    #[test]
806    fn test_usage_tracker_is_near_limit_no_limit() {
807        let tracker = UsageTracker::new();
808        assert!(!tracker.is_near_limit(0.8));
809    }
810
811    #[test]
812    fn test_usage_tracker_set_context_limit() {
813        let mut tracker = UsageTracker::new();
814        assert_eq!(tracker.context_limit(), None);
815
816        tracker.set_context_limit(200_000);
817        assert_eq!(tracker.context_limit(), Some(200_000));
818    }
819
820    #[test]
821    fn test_usage_tracker_reset() {
822        let mut tracker = UsageTracker::with_context_limit(100_000);
823        tracker.record(Usage {
824            input_tokens: 1000,
825            output_tokens: 500,
826            ..Default::default()
827        });
828        assert_eq!(tracker.call_count(), 1);
829        assert_eq!(tracker.total().input_tokens, 1000);
830
831        tracker.reset();
832        assert_eq!(tracker.call_count(), 0);
833        assert_eq!(tracker.total().input_tokens, 0);
834        // Context limit should be preserved
835        assert_eq!(tracker.context_limit(), Some(100_000));
836    }
837
838    #[test]
839    fn test_usage_tracker_clone() {
840        let mut tracker = UsageTracker::with_context_limit(50_000);
841        tracker.record(Usage {
842            input_tokens: 100,
843            output_tokens: 50,
844            ..Default::default()
845        });
846
847        let cloned = tracker.clone();
848        assert_eq!(cloned.total().input_tokens, 100);
849        assert_eq!(cloned.call_count(), 1);
850        assert_eq!(cloned.context_limit(), Some(50_000));
851    }
852
853    // --- ModelPricing ---
854
855    #[test]
856    fn test_model_pricing_compute_cost() {
857        let pricing = ModelPricing {
858            input_per_million: 3_000_000,   // $3 per MTok
859            output_per_million: 15_000_000, // $15 per MTok
860            cache_read_per_million: None,
861        };
862
863        let usage = Usage {
864            input_tokens: 1_000_000, // 1 MTok input
865            output_tokens: 100_000,  // 0.1 MTok output
866            ..Default::default()
867        };
868
869        let cost = pricing.compute_cost(&usage).unwrap();
870        assert_eq!(cost.input_microdollars(), 3_000_000); // $3
871        assert_eq!(cost.output_microdollars(), 1_500_000); // $1.50
872        assert_eq!(cost.total_microdollars(), 4_500_000); // $4.50
873    }
874
875    #[test]
876    fn test_model_pricing_with_cache_tokens() {
877        let pricing = ModelPricing {
878            input_per_million: 3_000_000,
879            output_per_million: 15_000_000,
880            cache_read_per_million: Some(300_000), // $0.30 per MTok
881        };
882
883        let usage = Usage {
884            input_tokens: 500_000,
885            output_tokens: 100_000,
886            cache_read_tokens: Some(500_000), // 0.5 MTok from cache
887            ..Default::default()
888        };
889
890        let cost = pricing.compute_cost(&usage).unwrap();
891        // Input: 500k * $3/MTok = $1.50 = 1_500_000
892        // Cache: 500k * $0.30/MTok = $0.15 = 150_000
893        // Total input side: $1.65 = 1_650_000
894        // Output: 100k * $15/MTok = $1.50 = 1_500_000
895        assert_eq!(cost.input_microdollars(), 1_650_000);
896        assert_eq!(cost.output_microdollars(), 1_500_000);
897    }
898
899    #[test]
900    fn test_model_pricing_zero_tokens() {
901        let pricing = ModelPricing {
902            input_per_million: 3_000_000,
903            output_per_million: 15_000_000,
904            cache_read_per_million: None,
905        };
906
907        let usage = Usage::default();
908        let cost = pricing.compute_cost(&usage).unwrap();
909        assert_eq!(cost.total_microdollars(), 0);
910    }
911
912    #[test]
913    fn test_model_pricing_cache_without_pricing() {
914        // Cache tokens present but no cache pricing — should ignore
915        let pricing = ModelPricing {
916            input_per_million: 3_000_000,
917            output_per_million: 15_000_000,
918            cache_read_per_million: None,
919        };
920
921        let usage = Usage {
922            input_tokens: 1_000_000,
923            output_tokens: 100_000,
924            cache_read_tokens: Some(500_000),
925            ..Default::default()
926        };
927
928        let cost = pricing.compute_cost(&usage).unwrap();
929        // Cache tokens ignored since no pricing
930        assert_eq!(cost.input_microdollars(), 3_000_000);
931    }
932
933    #[test]
934    fn test_usage_tracker_cost() {
935        let mut tracker = UsageTracker::new();
936        tracker.record(Usage {
937            input_tokens: 1_000_000,
938            output_tokens: 100_000,
939            ..Default::default()
940        });
941
942        let pricing = ModelPricing {
943            input_per_million: 3_000_000,
944            output_per_million: 15_000_000,
945            cache_read_per_million: None,
946        };
947
948        let cost = tracker.cost(&pricing).unwrap();
949        assert_eq!(cost.total_microdollars(), 4_500_000);
950    }
951
952    #[test]
953    fn test_model_pricing_clone_eq() {
954        let p1 = ModelPricing {
955            input_per_million: 100,
956            output_per_million: 200,
957            cache_read_per_million: Some(50),
958        };
959        let p2 = p1.clone();
960        assert_eq!(p1, p2);
961    }
962
963    #[test]
964    fn test_compute_token_cost_large_values() {
965        // Test with large but reasonable values
966        let cost = compute_token_cost(10_000_000_000, 3_000_000);
967        // 10B tokens * $3/MTok = $30,000 = 30_000_000_000 microdollars
968        assert_eq!(cost, Some(30_000_000_000));
969    }
970}