Skip to main content

llm_stack/
usage.rs

1//! Token usage and cost tracking.
2//!
3//! Every response carries a [`Usage`] record counting input and output
4//! tokens, with optional fields for reasoning and cache tokens when the
5//! provider reports them.
6//!
7//! [`Cost`] tracks monetary cost in **microdollars** (1 USD = 1,000,000
8//! microdollars). Integer arithmetic avoids floating-point rounding
9//! issues when aggregating costs across many requests. Use
10//! [`total_usd`](Cost::total_usd) for display purposes.
11//!
12//! # Invariant
13//!
14//! `Cost` enforces `total == input + output` at construction time.
15//! The fields are private — use [`Cost::new`] to build one, and the
16//! accessor methods to read values. Deserialization recomputes the
17//! total from `input` and `output`, ignoring any `total` in the JSON.
18
19use std::fmt;
20use std::ops::{Add, AddAssign};
21
22use serde::{Deserialize, Serialize};
23
24/// Token counts for a single request/response pair.
25#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
26pub struct Usage {
27    /// Tokens consumed by the prompt (messages + system + tool defs).
28    pub input_tokens: u64,
29    /// Tokens produced by the model's response.
30    pub output_tokens: u64,
31    /// Tokens used for chain-of-thought reasoning, if applicable.
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub reasoning_tokens: Option<u64>,
34    /// Tokens served from the provider's prompt cache (reducing cost).
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub cache_read_tokens: Option<u64>,
37    /// Tokens written into the provider's prompt cache for future reuse.
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub cache_write_tokens: Option<u64>,
40}
41
42impl Usage {
43    /// Total tokens across all categories.
44    ///
45    /// Sums `input_tokens`, `output_tokens`, and all optional fields
46    /// (`reasoning_tokens`, `cache_read_tokens`, `cache_write_tokens`),
47    /// treating `None` as zero. Uses saturating arithmetic.
48    pub fn total_tokens(&self) -> u64 {
49        self.input_tokens
50            .saturating_add(self.output_tokens)
51            .saturating_add(self.reasoning_tokens.unwrap_or(0))
52            .saturating_add(self.cache_read_tokens.unwrap_or(0))
53            .saturating_add(self.cache_write_tokens.unwrap_or(0))
54    }
55}
56
57/// Helper: adds two `Option<u64>` fields, treating `None` as zero.
58fn add_optional(a: Option<u64>, b: Option<u64>) -> Option<u64> {
59    match (a, b) {
60        (Some(x), Some(y)) => Some(x.saturating_add(y)),
61        (Some(x), None) | (None, Some(x)) => Some(x),
62        (None, None) => None,
63    }
64}
65
66impl Add for Usage {
67    type Output = Self;
68
69    /// Adds two `Usage` records field-by-field.
70    ///
71    /// Mandatory fields use saturating addition. Optional fields are
72    /// summed when both are `Some`, preserved when one is `Some`, and
73    /// remain `None` when both are `None`.
74    fn add(self, rhs: Self) -> Self {
75        Self {
76            input_tokens: self.input_tokens.saturating_add(rhs.input_tokens),
77            output_tokens: self.output_tokens.saturating_add(rhs.output_tokens),
78            reasoning_tokens: add_optional(self.reasoning_tokens, rhs.reasoning_tokens),
79            cache_read_tokens: add_optional(self.cache_read_tokens, rhs.cache_read_tokens),
80            cache_write_tokens: add_optional(self.cache_write_tokens, rhs.cache_write_tokens),
81        }
82    }
83}
84
85impl AddAssign for Usage {
86    fn add_assign(&mut self, rhs: Self) {
87        *self += &rhs;
88    }
89}
90
91impl AddAssign<&Usage> for Usage {
92    /// Adds another `Usage` to this one in-place without cloning.
93    ///
94    /// This is more efficient than `AddAssign<Usage>` when you have a reference.
95    fn add_assign(&mut self, rhs: &Self) {
96        self.input_tokens = self.input_tokens.saturating_add(rhs.input_tokens);
97        self.output_tokens = self.output_tokens.saturating_add(rhs.output_tokens);
98        self.reasoning_tokens = add_optional(self.reasoning_tokens, rhs.reasoning_tokens);
99        self.cache_read_tokens = add_optional(self.cache_read_tokens, rhs.cache_read_tokens);
100        self.cache_write_tokens = add_optional(self.cache_write_tokens, rhs.cache_write_tokens);
101    }
102}
103
104/// Monetary cost in microdollars (1 USD = 1,000,000 microdollars).
105///
106/// Uses integer arithmetic to avoid floating-point accumulation errors.
107/// The invariant `total == input + output` is enforced by the
108/// constructor and maintained through deserialization.
109///
110/// # Examples
111///
112/// ```rust
113/// use llm_stack::usage::Cost;
114///
115/// let cost = Cost::new(300_000, 150_000).expect("no overflow");
116/// assert_eq!(cost.total_microdollars(), 450_000);
117/// assert!((cost.total_usd() - 0.45).abs() < f64::EPSILON);
118/// ```
119#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
120pub struct Cost {
121    input: u64,
122    output: u64,
123    total: u64,
124}
125
126impl Default for Cost {
127    /// Returns a zero cost.
128    fn default() -> Self {
129        Self {
130            input: 0,
131            output: 0,
132            total: 0,
133        }
134    }
135}
136
137/// Intermediate type for safe deserialization — recomputes total.
138#[derive(Deserialize)]
139struct CostRaw {
140    input: u64,
141    output: u64,
142}
143
144impl<'de> Deserialize<'de> for Cost {
145    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
146    where
147        D: serde::Deserializer<'de>,
148    {
149        let raw = CostRaw::deserialize(deserializer)?;
150        let total = raw
151            .input
152            .checked_add(raw.output)
153            .ok_or_else(|| serde::de::Error::custom("cost overflow: input + output exceeds u64"))?;
154        Ok(Self {
155            input: raw.input,
156            output: raw.output,
157            total,
158        })
159    }
160}
161
162impl Cost {
163    /// Creates a new `Cost`, returning `None` if `input + output`
164    /// would overflow `u64`.
165    pub fn new(input: u64, output: u64) -> Option<Self> {
166        let total = input.checked_add(output)?;
167        Some(Self {
168            input,
169            output,
170            total,
171        })
172    }
173
174    /// Cost of the input (prompt) in microdollars.
175    pub fn input_microdollars(&self) -> u64 {
176        self.input
177    }
178
179    /// Cost of the output (completion) in microdollars.
180    pub fn output_microdollars(&self) -> u64 {
181        self.output
182    }
183
184    /// Total cost (`input + output`) in microdollars.
185    pub fn total_microdollars(&self) -> u64 {
186        self.total
187    }
188
189    /// Returns the sum of two costs, or `None` on overflow.
190    pub fn checked_add(&self, rhs: &Self) -> Option<Self> {
191        let input = self.input.checked_add(rhs.input)?;
192        let output = self.output.checked_add(rhs.output)?;
193        Self::new(input, output)
194    }
195
196    /// Total cost in US dollars, for display purposes.
197    ///
198    /// Uses floating-point division — prefer
199    /// [`total_microdollars`](Self::total_microdollars) for arithmetic.
200    #[allow(clippy::cast_precision_loss)] // microdollar u64 fits f64 mantissa in practice
201    pub fn total_usd(&self) -> f64 {
202        self.total as f64 / 1_000_000.0
203    }
204}
205
206impl fmt::Display for Cost {
207    /// Formats the cost as a USD string, e.g. `$1.50`.
208    #[allow(clippy::cast_precision_loss)]
209    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210        write!(f, "${:.2}", self.total as f64 / 1_000_000.0)
211    }
212}
213
214impl Add for Cost {
215    type Output = Self;
216
217    /// Adds two costs using saturating arithmetic.
218    ///
219    /// Use [`checked_add`](Self::checked_add) when overflow must be detected.
220    fn add(self, rhs: Self) -> Self {
221        let input = self.input.saturating_add(rhs.input);
222        let output = self.output.saturating_add(rhs.output);
223        Self {
224            input,
225            output,
226            total: input.saturating_add(output),
227        }
228    }
229}
230
231impl AddAssign for Cost {
232    fn add_assign(&mut self, rhs: Self) {
233        self.input = self.input.saturating_add(rhs.input);
234        self.output = self.output.saturating_add(rhs.output);
235        self.total = self.input.saturating_add(self.output);
236    }
237}
238
239// ── UsageTracker ────────────────────────────────────────────────────
240
241/// Tracks cumulative token usage across multiple LLM calls.
242///
243/// `UsageTracker` accumulates [`Usage`] records from each request and
244/// provides context-awareness features for detecting when the conversation
245/// is approaching the model's context limit.
246///
247/// # Example
248///
249/// ```rust
250/// use llm_stack::usage::{Usage, UsageTracker};
251///
252/// let mut tracker = UsageTracker::with_context_limit(128_000);
253///
254/// // Record usage from each LLM call
255/// tracker.record(Usage {
256///     input_tokens: 1000,
257///     output_tokens: 500,
258///     ..Default::default()
259/// });
260///
261/// assert_eq!(tracker.total().input_tokens, 1000);
262/// assert!(!tracker.is_near_limit(0.8)); // Not near 80% yet
263/// ```
264///
265/// # Use Cases
266///
267/// - **Billing/cost tracking**: Aggregate costs across a session
268/// - **Budget alerts**: Warn when approaching token limits
269/// - **Compaction triggers**: Signal when context window is nearly full
270/// - **Token debugging**: Analyze per-call consumption patterns
271#[derive(Debug, Clone)]
272pub struct UsageTracker {
273    /// Accumulated usage across all calls.
274    total: Usage,
275    /// Usage from each individual call, in order.
276    by_call: Vec<Usage>,
277    /// Optional context window limit for utilization calculations.
278    context_limit: Option<u64>,
279}
280
281impl Default for UsageTracker {
282    fn default() -> Self {
283        Self::new()
284    }
285}
286
287impl UsageTracker {
288    /// Creates a new tracker with no context limit.
289    pub fn new() -> Self {
290        Self {
291            total: Usage::default(),
292            by_call: Vec::new(),
293            context_limit: None,
294        }
295    }
296
297    /// Creates a tracker with a known context window limit.
298    ///
299    /// The limit is used for [`context_utilization`](Self::context_utilization)
300    /// and [`is_near_limit`](Self::is_near_limit) calculations.
301    pub fn with_context_limit(limit: u64) -> Self {
302        Self {
303            total: Usage::default(),
304            by_call: Vec::new(),
305            context_limit: Some(limit),
306        }
307    }
308
309    /// Records a usage sample from an LLM call.
310    ///
311    /// The usage is added to the running total and stored for per-call
312    /// analysis.
313    pub fn record(&mut self, usage: Usage) {
314        self.total += &usage;
315        self.by_call.push(usage);
316    }
317
318    /// Returns the accumulated usage across all recorded calls.
319    pub fn total(&self) -> &Usage {
320        &self.total
321    }
322
323    /// Returns the usage from each individual call, in order.
324    pub fn calls(&self) -> &[Usage] {
325        &self.by_call
326    }
327
328    /// Returns the number of calls recorded.
329    pub fn call_count(&self) -> usize {
330        self.by_call.len()
331    }
332
333    /// Returns the context limit, if set.
334    pub fn context_limit(&self) -> Option<u64> {
335        self.context_limit
336    }
337
338    /// Sets or updates the context limit.
339    ///
340    /// Useful when the model is determined after tracker creation.
341    pub fn set_context_limit(&mut self, limit: u64) {
342        self.context_limit = Some(limit);
343    }
344
345    /// Returns the context utilization as a ratio (0.0 to 1.0+).
346    ///
347    /// Utilization is calculated as `total_input_tokens / context_limit`.
348    /// Returns `None` if no context limit is set.
349    ///
350    /// # Note
351    ///
352    /// The value can exceed 1.0 if the total exceeds the limit (which
353    /// shouldn't happen in practice but is not enforced).
354    #[allow(clippy::cast_precision_loss)] // u64 token counts fit f64 mantissa
355    pub fn context_utilization(&self) -> Option<f64> {
356        self.context_limit.map(|limit| {
357            if limit == 0 {
358                return 0.0;
359            }
360            self.total.input_tokens as f64 / limit as f64
361        })
362    }
363
364    /// Checks if the context utilization is at or above the given threshold.
365    ///
366    /// Returns `false` if no context limit is set.
367    ///
368    /// # Example
369    ///
370    /// ```rust
371    /// use llm_stack::usage::{Usage, UsageTracker};
372    ///
373    /// let mut tracker = UsageTracker::with_context_limit(100_000);
374    /// tracker.record(Usage {
375    ///     input_tokens: 85_000,
376    ///     output_tokens: 1000,
377    ///     ..Default::default()
378    /// });
379    ///
380    /// assert!(tracker.is_near_limit(0.8));   // 85% >= 80%
381    /// assert!(!tracker.is_near_limit(0.9));  // 85% < 90%
382    /// ```
383    pub fn is_near_limit(&self, threshold: f64) -> bool {
384        self.context_utilization()
385            .is_some_and(|util| util >= threshold)
386    }
387
388    /// Computes the cost of all recorded usage given a pricing table.
389    ///
390    /// Uses the pricing rates (per-million tokens) to calculate cost.
391    /// Returns `None` if the cost would overflow.
392    pub fn cost(&self, pricing: &ModelPricing) -> Option<Cost> {
393        pricing.compute_cost(&self.total)
394    }
395
396    /// Resets the tracker, clearing all recorded usage.
397    pub fn reset(&mut self) {
398        self.total = Usage::default();
399        self.by_call.clear();
400    }
401}
402
403/// Pricing information for a specific model.
404///
405/// All prices are in **microdollars per million tokens**. For example,
406/// a price of $3.00 per million input tokens would be `3_000_000`.
407///
408/// # Example
409///
410/// ```rust
411/// use llm_stack::usage::{ModelPricing, Usage};
412///
413/// // Claude 3.5 Sonnet pricing (as of early 2024)
414/// let pricing = ModelPricing {
415///     input_per_million: 3_000_000,   // $3.00 / MTok
416///     output_per_million: 15_000_000, // $15.00 / MTok
417///     cache_read_per_million: Some(300_000), // $0.30 / MTok
418/// };
419///
420/// let usage = Usage {
421///     input_tokens: 1_000_000,
422///     output_tokens: 100_000,
423///     ..Default::default()
424/// };
425///
426/// let cost = pricing.compute_cost(&usage).unwrap();
427/// assert_eq!(cost.total_microdollars(), 4_500_000); // $3 input + $1.50 output
428/// ```
429#[derive(Debug, Clone, PartialEq, Eq)]
430pub struct ModelPricing {
431    /// Cost per million input tokens in microdollars.
432    pub input_per_million: u64,
433    /// Cost per million output tokens in microdollars.
434    pub output_per_million: u64,
435    /// Cost per million cache-read tokens in microdollars (if applicable).
436    pub cache_read_per_million: Option<u64>,
437}
438
439impl ModelPricing {
440    /// Computes the cost for the given usage.
441    ///
442    /// Returns `None` if the calculation would overflow.
443    pub fn compute_cost(&self, usage: &Usage) -> Option<Cost> {
444        // Cost = (tokens * price_per_million) / 1_000_000
445        // To avoid precision loss, we use u128 for intermediate calculations
446        let input_cost = compute_token_cost(usage.input_tokens, self.input_per_million)?;
447        let output_cost = compute_token_cost(usage.output_tokens, self.output_per_million)?;
448
449        // If cache read tokens are present and pricing is set, include them
450        let cache_cost = match (usage.cache_read_tokens, self.cache_read_per_million) {
451            (Some(tokens), Some(rate)) => compute_token_cost(tokens, rate)?,
452            _ => 0,
453        };
454
455        // Combine costs (cache reads reduce effective input cost conceptually,
456        // but for billing they're additive at the cache rate)
457        let total_input = input_cost.checked_add(cache_cost)?;
458        Cost::new(total_input, output_cost)
459    }
460}
461
462/// Compute cost for a token count at a given rate.
463///
464/// Returns microdollars, or `None` on overflow.
465fn compute_token_cost(tokens: u64, per_million: u64) -> Option<u64> {
466    // (tokens * per_million) / 1_000_000
467    // Use u128 to avoid overflow in multiplication
468    let product = u128::from(tokens) * u128::from(per_million);
469    let cost = product / 1_000_000;
470    u64::try_from(cost).ok()
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476
477    #[test]
478    fn test_total_tokens_all_fields() {
479        let u = Usage {
480            input_tokens: 100,
481            output_tokens: 50,
482            reasoning_tokens: Some(30),
483            cache_read_tokens: Some(20),
484            cache_write_tokens: Some(10),
485        };
486        assert_eq!(u.total_tokens(), 210);
487    }
488
489    #[test]
490    fn test_total_tokens_none_fields() {
491        let u = Usage {
492            input_tokens: 100,
493            output_tokens: 50,
494            reasoning_tokens: None,
495            cache_read_tokens: None,
496            cache_write_tokens: None,
497        };
498        assert_eq!(u.total_tokens(), 150);
499    }
500
501    #[test]
502    fn test_total_tokens_default() {
503        assert_eq!(Usage::default().total_tokens(), 0);
504    }
505
506    #[test]
507    fn test_usage_clone_eq() {
508        let u = Usage {
509            input_tokens: 100,
510            output_tokens: 50,
511            reasoning_tokens: Some(10),
512            cache_read_tokens: None,
513            cache_write_tokens: None,
514        };
515        assert_eq!(u, u.clone());
516    }
517
518    #[test]
519    fn test_usage_debug_format() {
520        let u = Usage::default();
521        let debug = format!("{u:?}");
522        assert!(debug.contains("input_tokens"));
523        assert!(debug.contains("output_tokens"));
524    }
525
526    #[test]
527    fn test_usage_optional_fields_none() {
528        let u = Usage::default();
529        assert_eq!(u.reasoning_tokens, None);
530        assert_eq!(u.cache_read_tokens, None);
531        assert_eq!(u.cache_write_tokens, None);
532    }
533
534    #[test]
535    fn test_usage_optional_fields_some() {
536        let u = Usage {
537            input_tokens: 0,
538            output_tokens: 0,
539            reasoning_tokens: Some(500),
540            cache_read_tokens: Some(200),
541            cache_write_tokens: Some(100),
542        };
543        assert_eq!(u.reasoning_tokens, Some(500));
544        assert_eq!(u.cache_read_tokens, Some(200));
545        assert_eq!(u.cache_write_tokens, Some(100));
546    }
547
548    #[test]
549    fn test_usage_serde_roundtrip() {
550        let u = Usage {
551            input_tokens: 100,
552            output_tokens: 50,
553            reasoning_tokens: Some(10),
554            cache_read_tokens: None,
555            cache_write_tokens: None,
556        };
557        let json = serde_json::to_string(&u).unwrap();
558        let back: Usage = serde_json::from_str(&json).unwrap();
559        assert_eq!(u, back);
560    }
561
562    #[test]
563    fn test_usage_skip_serializing_none_fields() {
564        let u = Usage {
565            input_tokens: 100,
566            output_tokens: 50,
567            reasoning_tokens: None,
568            cache_read_tokens: None,
569            cache_write_tokens: None,
570        };
571        let json = serde_json::to_string(&u).unwrap();
572        // None fields should not appear in JSON
573        assert!(!json.contains("reasoning_tokens"));
574        assert!(!json.contains("cache_read_tokens"));
575        assert!(!json.contains("cache_write_tokens"));
576        // Required fields should be present
577        assert!(json.contains("input_tokens"));
578        assert!(json.contains("output_tokens"));
579    }
580
581    #[test]
582    fn test_usage_serializes_some_fields() {
583        let u = Usage {
584            input_tokens: 100,
585            output_tokens: 50,
586            reasoning_tokens: Some(10),
587            cache_read_tokens: Some(20),
588            cache_write_tokens: Some(5),
589        };
590        let json = serde_json::to_string(&u).unwrap();
591        // All fields should appear when Some
592        assert!(json.contains("reasoning_tokens"));
593        assert!(json.contains("cache_read_tokens"));
594        assert!(json.contains("cache_write_tokens"));
595    }
596
597    #[test]
598    fn test_cost_new_enforces_invariant() {
599        let c = Cost::new(1_000_000, 500_000).unwrap();
600        assert_eq!(c.input_microdollars(), 1_000_000);
601        assert_eq!(c.output_microdollars(), 500_000);
602        assert_eq!(c.total_microdollars(), 1_500_000);
603    }
604
605    #[test]
606    fn test_cost_new_overflow_returns_none() {
607        assert!(Cost::new(u64::MAX, 1).is_none());
608    }
609
610    #[test]
611    fn test_cost_total_usd_exact() {
612        let c = Cost::new(1_000_000, 500_000).unwrap();
613        assert!((c.total_usd() - 1.5).abs() < f64::EPSILON);
614    }
615
616    #[test]
617    fn test_cost_total_usd_zero() {
618        let c = Cost::new(0, 0).unwrap();
619        assert!((c.total_usd()).abs() < f64::EPSILON);
620    }
621
622    #[test]
623    fn test_cost_total_usd_sub_cent() {
624        let c = Cost::new(300, 200).unwrap();
625        assert!((c.total_usd() - 0.0005).abs() < f64::EPSILON);
626    }
627
628    #[test]
629    fn test_cost_clone_eq() {
630        let c = Cost::new(42, 58).unwrap();
631        assert_eq!(c, c.clone());
632    }
633
634    #[test]
635    fn test_cost_serde_roundtrip() {
636        let c = Cost::new(1_000_000, 500_000).unwrap();
637        let json = serde_json::to_string(&c).unwrap();
638        let back: Cost = serde_json::from_str(&json).unwrap();
639        assert_eq!(c, back);
640    }
641
642    #[test]
643    fn test_cost_deserialization_recomputes_total() {
644        // Even if JSON has a wrong total, deserialization recomputes it
645        let json = r#"{"input":100,"output":200,"total":999}"#;
646        let c: Cost = serde_json::from_str(json).unwrap();
647        assert_eq!(c.total_microdollars(), 300);
648    }
649
650    #[test]
651    fn test_cost_deserialization_without_total() {
652        let json = r#"{"input":100,"output":200}"#;
653        let c: Cost = serde_json::from_str(json).unwrap();
654        assert_eq!(c.total_microdollars(), 300);
655    }
656
657    #[test]
658    fn test_cost_deserialization_overflow_fails() {
659        let json = format!(r#"{{"input":{},"output":1}}"#, u64::MAX);
660        let result: Result<Cost, _> = serde_json::from_str(&json);
661        assert!(result.is_err());
662    }
663
664    #[test]
665    fn test_cost_default_is_zero() {
666        let c = Cost::default();
667        assert_eq!(c.input_microdollars(), 0);
668        assert_eq!(c.output_microdollars(), 0);
669        assert_eq!(c.total_microdollars(), 0);
670    }
671
672    // --- Cost Display ---
673
674    #[test]
675    fn test_cost_display() {
676        let c = Cost::new(1_000_000, 500_000).unwrap();
677        assert_eq!(c.to_string(), "$1.50");
678    }
679
680    #[test]
681    fn test_cost_display_zero() {
682        assert_eq!(Cost::default().to_string(), "$0.00");
683    }
684
685    #[test]
686    fn test_cost_display_sub_cent() {
687        let c = Cost::new(500, 0).unwrap();
688        assert_eq!(c.to_string(), "$0.00");
689    }
690
691    // --- Usage Add/AddAssign ---
692
693    #[test]
694    fn test_usage_add_basic() {
695        let a = Usage {
696            input_tokens: 100,
697            output_tokens: 50,
698            reasoning_tokens: Some(10),
699            cache_read_tokens: None,
700            cache_write_tokens: Some(20),
701        };
702        let b = Usage {
703            input_tokens: 200,
704            output_tokens: 30,
705            reasoning_tokens: Some(5),
706            cache_read_tokens: Some(50),
707            cache_write_tokens: None,
708        };
709        let sum = a + b;
710        assert_eq!(sum.input_tokens, 300);
711        assert_eq!(sum.output_tokens, 80);
712        assert_eq!(sum.reasoning_tokens, Some(15));
713        assert_eq!(sum.cache_read_tokens, Some(50));
714        assert_eq!(sum.cache_write_tokens, Some(20));
715    }
716
717    #[test]
718    fn test_usage_add_both_none() {
719        let a = Usage::default();
720        let b = Usage::default();
721        let sum = a + b;
722        assert_eq!(sum.reasoning_tokens, None);
723        assert_eq!(sum.cache_read_tokens, None);
724        assert_eq!(sum.cache_write_tokens, None);
725    }
726
727    #[test]
728    fn test_usage_add_assign() {
729        let mut a = Usage {
730            input_tokens: 100,
731            output_tokens: 50,
732            ..Default::default()
733        };
734        a += Usage {
735            input_tokens: 200,
736            output_tokens: 30,
737            ..Default::default()
738        };
739        assert_eq!(a.input_tokens, 300);
740        assert_eq!(a.output_tokens, 80);
741    }
742
743    #[test]
744    fn test_usage_add_saturates() {
745        let a = Usage {
746            input_tokens: u64::MAX,
747            output_tokens: 0,
748            ..Default::default()
749        };
750        let b = Usage {
751            input_tokens: 1,
752            output_tokens: 0,
753            ..Default::default()
754        };
755        let sum = a + b;
756        assert_eq!(sum.input_tokens, u64::MAX);
757    }
758
759    // --- Cost Add/AddAssign/checked_add ---
760
761    #[test]
762    fn test_cost_add_basic() {
763        let a = Cost::new(100, 200).unwrap();
764        let b = Cost::new(300, 400).unwrap();
765        let sum = a + b;
766        assert_eq!(sum.input_microdollars(), 400);
767        assert_eq!(sum.output_microdollars(), 600);
768        assert_eq!(sum.total_microdollars(), 1000);
769    }
770
771    #[test]
772    fn test_cost_add_assign() {
773        let mut c = Cost::new(100, 200).unwrap();
774        c += Cost::new(50, 50).unwrap();
775        assert_eq!(c.input_microdollars(), 150);
776        assert_eq!(c.output_microdollars(), 250);
777        assert_eq!(c.total_microdollars(), 400);
778    }
779
780    #[test]
781    fn test_cost_checked_add() {
782        let a = Cost::new(100, 200).unwrap();
783        let b = Cost::new(300, 400).unwrap();
784        let sum = a.checked_add(&b).unwrap();
785        assert_eq!(sum.total_microdollars(), 1000);
786    }
787
788    #[test]
789    fn test_cost_checked_add_overflow() {
790        let a = Cost::new(u64::MAX - 1, 0).unwrap();
791        let b = Cost::new(2, 0).unwrap();
792        assert!(a.checked_add(&b).is_none());
793    }
794
795    #[test]
796    fn test_cost_add_saturates() {
797        let a = Cost::new(u64::MAX - 1, 0).unwrap();
798        let b = Cost::new(2, 0).unwrap();
799        let sum = a + b;
800        assert_eq!(sum.input_microdollars(), u64::MAX);
801    }
802
803    // --- UsageTracker ---
804
805    #[test]
806    fn test_usage_tracker_new() {
807        let tracker = UsageTracker::new();
808        assert_eq!(tracker.total().input_tokens, 0);
809        assert_eq!(tracker.total().output_tokens, 0);
810        assert!(tracker.calls().is_empty());
811        assert_eq!(tracker.context_limit(), None);
812    }
813
814    #[test]
815    fn test_usage_tracker_default() {
816        let tracker = UsageTracker::default();
817        assert_eq!(tracker.call_count(), 0);
818        assert_eq!(tracker.context_limit(), None);
819    }
820
821    #[test]
822    fn test_usage_tracker_with_context_limit() {
823        let tracker = UsageTracker::with_context_limit(128_000);
824        assert_eq!(tracker.context_limit(), Some(128_000));
825    }
826
827    #[test]
828    fn test_usage_tracker_record() {
829        let mut tracker = UsageTracker::new();
830        tracker.record(Usage {
831            input_tokens: 100,
832            output_tokens: 50,
833            ..Default::default()
834        });
835        tracker.record(Usage {
836            input_tokens: 200,
837            output_tokens: 100,
838            ..Default::default()
839        });
840
841        assert_eq!(tracker.total().input_tokens, 300);
842        assert_eq!(tracker.total().output_tokens, 150);
843        assert_eq!(tracker.call_count(), 2);
844        assert_eq!(tracker.calls()[0].input_tokens, 100);
845        assert_eq!(tracker.calls()[1].input_tokens, 200);
846    }
847
848    #[test]
849    fn test_usage_tracker_context_utilization() {
850        let mut tracker = UsageTracker::with_context_limit(100_000);
851        tracker.record(Usage {
852            input_tokens: 50_000,
853            output_tokens: 1000,
854            ..Default::default()
855        });
856
857        let util = tracker.context_utilization().unwrap();
858        assert!((util - 0.5).abs() < f64::EPSILON);
859    }
860
861    #[test]
862    fn test_usage_tracker_context_utilization_no_limit() {
863        let tracker = UsageTracker::new();
864        assert!(tracker.context_utilization().is_none());
865    }
866
867    #[test]
868    fn test_usage_tracker_context_utilization_zero_limit() {
869        let tracker = UsageTracker::with_context_limit(0);
870        assert!((tracker.context_utilization().unwrap()).abs() < f64::EPSILON);
871    }
872
873    #[test]
874    fn test_usage_tracker_is_near_limit() {
875        let mut tracker = UsageTracker::with_context_limit(100_000);
876        tracker.record(Usage {
877            input_tokens: 85_000,
878            output_tokens: 1000,
879            ..Default::default()
880        });
881
882        assert!(tracker.is_near_limit(0.8)); // 85% >= 80%
883        assert!(tracker.is_near_limit(0.85)); // 85% >= 85%
884        assert!(!tracker.is_near_limit(0.9)); // 85% < 90%
885    }
886
887    #[test]
888    fn test_usage_tracker_is_near_limit_no_limit() {
889        let tracker = UsageTracker::new();
890        assert!(!tracker.is_near_limit(0.8));
891    }
892
893    #[test]
894    fn test_usage_tracker_set_context_limit() {
895        let mut tracker = UsageTracker::new();
896        assert_eq!(tracker.context_limit(), None);
897
898        tracker.set_context_limit(200_000);
899        assert_eq!(tracker.context_limit(), Some(200_000));
900    }
901
902    #[test]
903    fn test_usage_tracker_reset() {
904        let mut tracker = UsageTracker::with_context_limit(100_000);
905        tracker.record(Usage {
906            input_tokens: 1000,
907            output_tokens: 500,
908            ..Default::default()
909        });
910        assert_eq!(tracker.call_count(), 1);
911        assert_eq!(tracker.total().input_tokens, 1000);
912
913        tracker.reset();
914        assert_eq!(tracker.call_count(), 0);
915        assert_eq!(tracker.total().input_tokens, 0);
916        // Context limit should be preserved
917        assert_eq!(tracker.context_limit(), Some(100_000));
918    }
919
920    #[test]
921    fn test_usage_tracker_clone() {
922        let mut tracker = UsageTracker::with_context_limit(50_000);
923        tracker.record(Usage {
924            input_tokens: 100,
925            output_tokens: 50,
926            ..Default::default()
927        });
928
929        let cloned = tracker.clone();
930        assert_eq!(cloned.total().input_tokens, 100);
931        assert_eq!(cloned.call_count(), 1);
932        assert_eq!(cloned.context_limit(), Some(50_000));
933    }
934
935    // --- ModelPricing ---
936
937    #[test]
938    fn test_model_pricing_compute_cost() {
939        let pricing = ModelPricing {
940            input_per_million: 3_000_000,   // $3 per MTok
941            output_per_million: 15_000_000, // $15 per MTok
942            cache_read_per_million: None,
943        };
944
945        let usage = Usage {
946            input_tokens: 1_000_000, // 1 MTok input
947            output_tokens: 100_000,  // 0.1 MTok output
948            ..Default::default()
949        };
950
951        let cost = pricing.compute_cost(&usage).unwrap();
952        assert_eq!(cost.input_microdollars(), 3_000_000); // $3
953        assert_eq!(cost.output_microdollars(), 1_500_000); // $1.50
954        assert_eq!(cost.total_microdollars(), 4_500_000); // $4.50
955    }
956
957    #[test]
958    fn test_model_pricing_with_cache_tokens() {
959        let pricing = ModelPricing {
960            input_per_million: 3_000_000,
961            output_per_million: 15_000_000,
962            cache_read_per_million: Some(300_000), // $0.30 per MTok
963        };
964
965        let usage = Usage {
966            input_tokens: 500_000,
967            output_tokens: 100_000,
968            cache_read_tokens: Some(500_000), // 0.5 MTok from cache
969            ..Default::default()
970        };
971
972        let cost = pricing.compute_cost(&usage).unwrap();
973        // Input: 500k * $3/MTok = $1.50 = 1_500_000
974        // Cache: 500k * $0.30/MTok = $0.15 = 150_000
975        // Total input side: $1.65 = 1_650_000
976        // Output: 100k * $15/MTok = $1.50 = 1_500_000
977        assert_eq!(cost.input_microdollars(), 1_650_000);
978        assert_eq!(cost.output_microdollars(), 1_500_000);
979    }
980
981    #[test]
982    fn test_model_pricing_zero_tokens() {
983        let pricing = ModelPricing {
984            input_per_million: 3_000_000,
985            output_per_million: 15_000_000,
986            cache_read_per_million: None,
987        };
988
989        let usage = Usage::default();
990        let cost = pricing.compute_cost(&usage).unwrap();
991        assert_eq!(cost.total_microdollars(), 0);
992    }
993
994    #[test]
995    fn test_model_pricing_cache_without_pricing() {
996        // Cache tokens present but no cache pricing — should ignore
997        let pricing = ModelPricing {
998            input_per_million: 3_000_000,
999            output_per_million: 15_000_000,
1000            cache_read_per_million: None,
1001        };
1002
1003        let usage = Usage {
1004            input_tokens: 1_000_000,
1005            output_tokens: 100_000,
1006            cache_read_tokens: Some(500_000),
1007            ..Default::default()
1008        };
1009
1010        let cost = pricing.compute_cost(&usage).unwrap();
1011        // Cache tokens ignored since no pricing
1012        assert_eq!(cost.input_microdollars(), 3_000_000);
1013    }
1014
1015    #[test]
1016    fn test_usage_tracker_cost() {
1017        let mut tracker = UsageTracker::new();
1018        tracker.record(Usage {
1019            input_tokens: 1_000_000,
1020            output_tokens: 100_000,
1021            ..Default::default()
1022        });
1023
1024        let pricing = ModelPricing {
1025            input_per_million: 3_000_000,
1026            output_per_million: 15_000_000,
1027            cache_read_per_million: None,
1028        };
1029
1030        let cost = tracker.cost(&pricing).unwrap();
1031        assert_eq!(cost.total_microdollars(), 4_500_000);
1032    }
1033
1034    #[test]
1035    fn test_model_pricing_clone_eq() {
1036        let p1 = ModelPricing {
1037            input_per_million: 100,
1038            output_per_million: 200,
1039            cache_read_per_million: Some(50),
1040        };
1041        let p2 = p1.clone();
1042        assert_eq!(p1, p2);
1043    }
1044
1045    #[test]
1046    fn test_compute_token_cost_large_values() {
1047        // Test with large but reasonable values
1048        let cost = compute_token_cost(10_000_000_000, 3_000_000);
1049        // 10B tokens * $3/MTok = $30,000 = 30_000_000_000 microdollars
1050        assert_eq!(cost, Some(30_000_000_000));
1051    }
1052}