Skip to main content

llm_stack/
usage.rs

1//! Token usage and cost tracking.
2//!
3//! Every response carries a [`Usage`] record counting input and output
4//! tokens, with optional fields for reasoning and cache tokens when the
5//! provider reports them.
6//!
7//! [`Cost`] tracks monetary cost in **microdollars** (1 USD = 1,000,000
8//! microdollars). Integer arithmetic avoids floating-point rounding
9//! issues when aggregating costs across many requests. Use
10//! [`total_usd`](Cost::total_usd) for display purposes.
11//!
12//! # Invariant
13//!
14//! `Cost` enforces `total == input + output` at construction time.
15//! The fields are private — use [`Cost::new`] to build one, and the
16//! accessor methods to read values. Deserialization recomputes the
17//! total from `input` and `output`, ignoring any `total` in the JSON.
18
19use std::fmt;
20use std::ops::{Add, AddAssign};
21
22use serde::{Deserialize, Serialize};
23
24/// Token counts for a single request/response pair.
25#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
26pub struct Usage {
27    /// Tokens consumed by the prompt (messages + system + tool defs).
28    pub input_tokens: u64,
29    /// Tokens produced by the model's response.
30    pub output_tokens: u64,
31    /// Tokens used for chain-of-thought reasoning, if applicable.
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub reasoning_tokens: Option<u64>,
34    /// Tokens served from the provider's prompt cache (reducing cost).
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub cache_read_tokens: Option<u64>,
37    /// Tokens written into the provider's prompt cache for future reuse.
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub cache_write_tokens: Option<u64>,
40}
41
42/// Helper: adds two `Option<u64>` fields, treating `None` as zero.
43fn add_optional(a: Option<u64>, b: Option<u64>) -> Option<u64> {
44    match (a, b) {
45        (Some(x), Some(y)) => Some(x.saturating_add(y)),
46        (Some(x), None) | (None, Some(x)) => Some(x),
47        (None, None) => None,
48    }
49}
50
51impl Add for Usage {
52    type Output = Self;
53
54    /// Adds two `Usage` records field-by-field.
55    ///
56    /// Mandatory fields use saturating addition. Optional fields are
57    /// summed when both are `Some`, preserved when one is `Some`, and
58    /// remain `None` when both are `None`.
59    fn add(self, rhs: Self) -> Self {
60        Self {
61            input_tokens: self.input_tokens.saturating_add(rhs.input_tokens),
62            output_tokens: self.output_tokens.saturating_add(rhs.output_tokens),
63            reasoning_tokens: add_optional(self.reasoning_tokens, rhs.reasoning_tokens),
64            cache_read_tokens: add_optional(self.cache_read_tokens, rhs.cache_read_tokens),
65            cache_write_tokens: add_optional(self.cache_write_tokens, rhs.cache_write_tokens),
66        }
67    }
68}
69
70impl AddAssign for Usage {
71    fn add_assign(&mut self, rhs: Self) {
72        *self += &rhs;
73    }
74}
75
76impl AddAssign<&Usage> for Usage {
77    /// Adds another `Usage` to this one in-place without cloning.
78    ///
79    /// This is more efficient than `AddAssign<Usage>` when you have a reference.
80    fn add_assign(&mut self, rhs: &Self) {
81        self.input_tokens = self.input_tokens.saturating_add(rhs.input_tokens);
82        self.output_tokens = self.output_tokens.saturating_add(rhs.output_tokens);
83        self.reasoning_tokens = add_optional(self.reasoning_tokens, rhs.reasoning_tokens);
84        self.cache_read_tokens = add_optional(self.cache_read_tokens, rhs.cache_read_tokens);
85        self.cache_write_tokens = add_optional(self.cache_write_tokens, rhs.cache_write_tokens);
86    }
87}
88
89/// Monetary cost in microdollars (1 USD = 1,000,000 microdollars).
90///
91/// Uses integer arithmetic to avoid floating-point accumulation errors.
92/// The invariant `total == input + output` is enforced by the
93/// constructor and maintained through deserialization.
94///
95/// # Examples
96///
97/// ```rust
98/// use llm_stack::Cost;
99///
100/// let cost = Cost::new(300_000, 150_000).expect("no overflow");
101/// assert_eq!(cost.total_microdollars(), 450_000);
102/// assert!((cost.total_usd() - 0.45).abs() < f64::EPSILON);
103/// ```
104#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
105pub struct Cost {
106    input: u64,
107    output: u64,
108    total: u64,
109}
110
111impl Default for Cost {
112    /// Returns a zero cost.
113    fn default() -> Self {
114        Self {
115            input: 0,
116            output: 0,
117            total: 0,
118        }
119    }
120}
121
122/// Intermediate type for safe deserialization — recomputes total.
123#[derive(Deserialize)]
124struct CostRaw {
125    input: u64,
126    output: u64,
127}
128
129impl<'de> Deserialize<'de> for Cost {
130    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
131    where
132        D: serde::Deserializer<'de>,
133    {
134        let raw = CostRaw::deserialize(deserializer)?;
135        let total = raw
136            .input
137            .checked_add(raw.output)
138            .ok_or_else(|| serde::de::Error::custom("cost overflow: input + output exceeds u64"))?;
139        Ok(Self {
140            input: raw.input,
141            output: raw.output,
142            total,
143        })
144    }
145}
146
147impl Cost {
148    /// Creates a new `Cost`, returning `None` if `input + output`
149    /// would overflow `u64`.
150    pub fn new(input: u64, output: u64) -> Option<Self> {
151        let total = input.checked_add(output)?;
152        Some(Self {
153            input,
154            output,
155            total,
156        })
157    }
158
159    /// Cost of the input (prompt) in microdollars.
160    pub fn input_microdollars(&self) -> u64 {
161        self.input
162    }
163
164    /// Cost of the output (completion) in microdollars.
165    pub fn output_microdollars(&self) -> u64 {
166        self.output
167    }
168
169    /// Total cost (`input + output`) in microdollars.
170    pub fn total_microdollars(&self) -> u64 {
171        self.total
172    }
173
174    /// Returns the sum of two costs, or `None` on overflow.
175    pub fn checked_add(&self, rhs: &Self) -> Option<Self> {
176        let input = self.input.checked_add(rhs.input)?;
177        let output = self.output.checked_add(rhs.output)?;
178        Self::new(input, output)
179    }
180
181    /// Total cost in US dollars, for display purposes.
182    ///
183    /// Uses floating-point division — prefer
184    /// [`total_microdollars`](Self::total_microdollars) for arithmetic.
185    #[allow(clippy::cast_precision_loss)] // microdollar u64 fits f64 mantissa in practice
186    pub fn total_usd(&self) -> f64 {
187        self.total as f64 / 1_000_000.0
188    }
189}
190
191impl fmt::Display for Cost {
192    /// Formats the cost as a USD string, e.g. `$1.50`.
193    #[allow(clippy::cast_precision_loss)]
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        write!(f, "${:.2}", self.total as f64 / 1_000_000.0)
196    }
197}
198
199impl Add for Cost {
200    type Output = Self;
201
202    /// Adds two costs using saturating arithmetic.
203    ///
204    /// Use [`checked_add`](Self::checked_add) when overflow must be detected.
205    fn add(self, rhs: Self) -> Self {
206        let input = self.input.saturating_add(rhs.input);
207        let output = self.output.saturating_add(rhs.output);
208        Self {
209            input,
210            output,
211            total: input.saturating_add(output),
212        }
213    }
214}
215
216impl AddAssign for Cost {
217    fn add_assign(&mut self, rhs: Self) {
218        self.input = self.input.saturating_add(rhs.input);
219        self.output = self.output.saturating_add(rhs.output);
220        self.total = self.input.saturating_add(self.output);
221    }
222}
223
224// ── UsageTracker ────────────────────────────────────────────────────
225
226/// Tracks cumulative token usage across multiple LLM calls.
227///
228/// `UsageTracker` accumulates [`Usage`] records from each request and
229/// provides context-awareness features for detecting when the conversation
230/// is approaching the model's context limit.
231///
232/// # Example
233///
234/// ```rust
235/// use llm_stack::usage::{Usage, UsageTracker};
236///
237/// let mut tracker = UsageTracker::with_context_limit(128_000);
238///
239/// // Record usage from each LLM call
240/// tracker.record(Usage {
241///     input_tokens: 1000,
242///     output_tokens: 500,
243///     ..Default::default()
244/// });
245///
246/// assert_eq!(tracker.total().input_tokens, 1000);
247/// assert!(!tracker.is_near_limit(0.8)); // Not near 80% yet
248/// ```
249///
250/// # Use Cases
251///
252/// - **Billing/cost tracking**: Aggregate costs across a session
253/// - **Budget alerts**: Warn when approaching token limits
254/// - **Compaction triggers**: Signal when context window is nearly full
255/// - **Token debugging**: Analyze per-call consumption patterns
256#[derive(Debug, Clone)]
257pub struct UsageTracker {
258    /// Accumulated usage across all calls.
259    total: Usage,
260    /// Usage from each individual call, in order.
261    by_call: Vec<Usage>,
262    /// Optional context window limit for utilization calculations.
263    context_limit: Option<u64>,
264}
265
266impl Default for UsageTracker {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272impl UsageTracker {
273    /// Creates a new tracker with no context limit.
274    pub fn new() -> Self {
275        Self {
276            total: Usage::default(),
277            by_call: Vec::new(),
278            context_limit: None,
279        }
280    }
281
282    /// Creates a tracker with a known context window limit.
283    ///
284    /// The limit is used for [`context_utilization`](Self::context_utilization)
285    /// and [`is_near_limit`](Self::is_near_limit) calculations.
286    pub fn with_context_limit(limit: u64) -> Self {
287        Self {
288            total: Usage::default(),
289            by_call: Vec::new(),
290            context_limit: Some(limit),
291        }
292    }
293
294    /// Records a usage sample from an LLM call.
295    ///
296    /// The usage is added to the running total and stored for per-call
297    /// analysis.
298    pub fn record(&mut self, usage: Usage) {
299        self.total += &usage;
300        self.by_call.push(usage);
301    }
302
303    /// Returns the accumulated usage across all recorded calls.
304    pub fn total(&self) -> &Usage {
305        &self.total
306    }
307
308    /// Returns the usage from each individual call, in order.
309    pub fn calls(&self) -> &[Usage] {
310        &self.by_call
311    }
312
313    /// Returns the number of calls recorded.
314    pub fn call_count(&self) -> usize {
315        self.by_call.len()
316    }
317
318    /// Returns the context limit, if set.
319    pub fn context_limit(&self) -> Option<u64> {
320        self.context_limit
321    }
322
323    /// Sets or updates the context limit.
324    ///
325    /// Useful when the model is determined after tracker creation.
326    pub fn set_context_limit(&mut self, limit: u64) {
327        self.context_limit = Some(limit);
328    }
329
330    /// Returns the context utilization as a ratio (0.0 to 1.0+).
331    ///
332    /// Utilization is calculated as `total_input_tokens / context_limit`.
333    /// Returns `None` if no context limit is set.
334    ///
335    /// # Note
336    ///
337    /// The value can exceed 1.0 if the total exceeds the limit (which
338    /// shouldn't happen in practice but is not enforced).
339    #[allow(clippy::cast_precision_loss)] // u64 token counts fit f64 mantissa
340    pub fn context_utilization(&self) -> Option<f64> {
341        self.context_limit.map(|limit| {
342            if limit == 0 {
343                return 0.0;
344            }
345            self.total.input_tokens as f64 / limit as f64
346        })
347    }
348
349    /// Checks if the context utilization is at or above the given threshold.
350    ///
351    /// Returns `false` if no context limit is set.
352    ///
353    /// # Example
354    ///
355    /// ```rust
356    /// use llm_stack::usage::{Usage, UsageTracker};
357    ///
358    /// let mut tracker = UsageTracker::with_context_limit(100_000);
359    /// tracker.record(Usage {
360    ///     input_tokens: 85_000,
361    ///     output_tokens: 1000,
362    ///     ..Default::default()
363    /// });
364    ///
365    /// assert!(tracker.is_near_limit(0.8));   // 85% >= 80%
366    /// assert!(!tracker.is_near_limit(0.9));  // 85% < 90%
367    /// ```
368    pub fn is_near_limit(&self, threshold: f64) -> bool {
369        self.context_utilization()
370            .is_some_and(|util| util >= threshold)
371    }
372
373    /// Computes the cost of all recorded usage given a pricing table.
374    ///
375    /// Uses the pricing rates (per-million tokens) to calculate cost.
376    /// Returns `None` if the cost would overflow.
377    pub fn cost(&self, pricing: &ModelPricing) -> Option<Cost> {
378        pricing.compute_cost(&self.total)
379    }
380
381    /// Resets the tracker, clearing all recorded usage.
382    pub fn reset(&mut self) {
383        self.total = Usage::default();
384        self.by_call.clear();
385    }
386}
387
388/// Pricing information for a specific model.
389///
390/// All prices are in **microdollars per million tokens**. For example,
391/// a price of $3.00 per million input tokens would be `3_000_000`.
392///
393/// # Example
394///
395/// ```rust
396/// use llm_stack::usage::{ModelPricing, Usage};
397///
398/// // Claude 3.5 Sonnet pricing (as of early 2024)
399/// let pricing = ModelPricing {
400///     input_per_million: 3_000_000,   // $3.00 / MTok
401///     output_per_million: 15_000_000, // $15.00 / MTok
402///     cache_read_per_million: Some(300_000), // $0.30 / MTok
403/// };
404///
405/// let usage = Usage {
406///     input_tokens: 1_000_000,
407///     output_tokens: 100_000,
408///     ..Default::default()
409/// };
410///
411/// let cost = pricing.compute_cost(&usage).unwrap();
412/// assert_eq!(cost.total_microdollars(), 4_500_000); // $3 input + $1.50 output
413/// ```
414#[derive(Debug, Clone, PartialEq, Eq)]
415pub struct ModelPricing {
416    /// Cost per million input tokens in microdollars.
417    pub input_per_million: u64,
418    /// Cost per million output tokens in microdollars.
419    pub output_per_million: u64,
420    /// Cost per million cache-read tokens in microdollars (if applicable).
421    pub cache_read_per_million: Option<u64>,
422}
423
424impl ModelPricing {
425    /// Computes the cost for the given usage.
426    ///
427    /// Returns `None` if the calculation would overflow.
428    pub fn compute_cost(&self, usage: &Usage) -> Option<Cost> {
429        // Cost = (tokens * price_per_million) / 1_000_000
430        // To avoid precision loss, we use u128 for intermediate calculations
431        let input_cost = compute_token_cost(usage.input_tokens, self.input_per_million)?;
432        let output_cost = compute_token_cost(usage.output_tokens, self.output_per_million)?;
433
434        // If cache read tokens are present and pricing is set, include them
435        let cache_cost = match (usage.cache_read_tokens, self.cache_read_per_million) {
436            (Some(tokens), Some(rate)) => compute_token_cost(tokens, rate)?,
437            _ => 0,
438        };
439
440        // Combine costs (cache reads reduce effective input cost conceptually,
441        // but for billing they're additive at the cache rate)
442        let total_input = input_cost.checked_add(cache_cost)?;
443        Cost::new(total_input, output_cost)
444    }
445}
446
447/// Compute cost for a token count at a given rate.
448///
449/// Returns microdollars, or `None` on overflow.
450fn compute_token_cost(tokens: u64, per_million: u64) -> Option<u64> {
451    // (tokens * per_million) / 1_000_000
452    // Use u128 to avoid overflow in multiplication
453    let product = u128::from(tokens) * u128::from(per_million);
454    let cost = product / 1_000_000;
455    u64::try_from(cost).ok()
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn test_usage_clone_eq() {
464        let u = Usage {
465            input_tokens: 100,
466            output_tokens: 50,
467            reasoning_tokens: Some(10),
468            cache_read_tokens: None,
469            cache_write_tokens: None,
470        };
471        assert_eq!(u, u.clone());
472    }
473
474    #[test]
475    fn test_usage_debug_format() {
476        let u = Usage::default();
477        let debug = format!("{u:?}");
478        assert!(debug.contains("input_tokens"));
479        assert!(debug.contains("output_tokens"));
480    }
481
482    #[test]
483    fn test_usage_optional_fields_none() {
484        let u = Usage::default();
485        assert_eq!(u.reasoning_tokens, None);
486        assert_eq!(u.cache_read_tokens, None);
487        assert_eq!(u.cache_write_tokens, None);
488    }
489
490    #[test]
491    fn test_usage_optional_fields_some() {
492        let u = Usage {
493            input_tokens: 0,
494            output_tokens: 0,
495            reasoning_tokens: Some(500),
496            cache_read_tokens: Some(200),
497            cache_write_tokens: Some(100),
498        };
499        assert_eq!(u.reasoning_tokens, Some(500));
500        assert_eq!(u.cache_read_tokens, Some(200));
501        assert_eq!(u.cache_write_tokens, Some(100));
502    }
503
504    #[test]
505    fn test_usage_serde_roundtrip() {
506        let u = Usage {
507            input_tokens: 100,
508            output_tokens: 50,
509            reasoning_tokens: Some(10),
510            cache_read_tokens: None,
511            cache_write_tokens: None,
512        };
513        let json = serde_json::to_string(&u).unwrap();
514        let back: Usage = serde_json::from_str(&json).unwrap();
515        assert_eq!(u, back);
516    }
517
518    #[test]
519    fn test_usage_skip_serializing_none_fields() {
520        let u = Usage {
521            input_tokens: 100,
522            output_tokens: 50,
523            reasoning_tokens: None,
524            cache_read_tokens: None,
525            cache_write_tokens: None,
526        };
527        let json = serde_json::to_string(&u).unwrap();
528        // None fields should not appear in JSON
529        assert!(!json.contains("reasoning_tokens"));
530        assert!(!json.contains("cache_read_tokens"));
531        assert!(!json.contains("cache_write_tokens"));
532        // Required fields should be present
533        assert!(json.contains("input_tokens"));
534        assert!(json.contains("output_tokens"));
535    }
536
537    #[test]
538    fn test_usage_serializes_some_fields() {
539        let u = Usage {
540            input_tokens: 100,
541            output_tokens: 50,
542            reasoning_tokens: Some(10),
543            cache_read_tokens: Some(20),
544            cache_write_tokens: Some(5),
545        };
546        let json = serde_json::to_string(&u).unwrap();
547        // All fields should appear when Some
548        assert!(json.contains("reasoning_tokens"));
549        assert!(json.contains("cache_read_tokens"));
550        assert!(json.contains("cache_write_tokens"));
551    }
552
553    #[test]
554    fn test_cost_new_enforces_invariant() {
555        let c = Cost::new(1_000_000, 500_000).unwrap();
556        assert_eq!(c.input_microdollars(), 1_000_000);
557        assert_eq!(c.output_microdollars(), 500_000);
558        assert_eq!(c.total_microdollars(), 1_500_000);
559    }
560
561    #[test]
562    fn test_cost_new_overflow_returns_none() {
563        assert!(Cost::new(u64::MAX, 1).is_none());
564    }
565
566    #[test]
567    fn test_cost_total_usd_exact() {
568        let c = Cost::new(1_000_000, 500_000).unwrap();
569        assert!((c.total_usd() - 1.5).abs() < f64::EPSILON);
570    }
571
572    #[test]
573    fn test_cost_total_usd_zero() {
574        let c = Cost::new(0, 0).unwrap();
575        assert!((c.total_usd()).abs() < f64::EPSILON);
576    }
577
578    #[test]
579    fn test_cost_total_usd_sub_cent() {
580        let c = Cost::new(300, 200).unwrap();
581        assert!((c.total_usd() - 0.0005).abs() < f64::EPSILON);
582    }
583
584    #[test]
585    fn test_cost_clone_eq() {
586        let c = Cost::new(42, 58).unwrap();
587        assert_eq!(c, c.clone());
588    }
589
590    #[test]
591    fn test_cost_serde_roundtrip() {
592        let c = Cost::new(1_000_000, 500_000).unwrap();
593        let json = serde_json::to_string(&c).unwrap();
594        let back: Cost = serde_json::from_str(&json).unwrap();
595        assert_eq!(c, back);
596    }
597
598    #[test]
599    fn test_cost_deserialization_recomputes_total() {
600        // Even if JSON has a wrong total, deserialization recomputes it
601        let json = r#"{"input":100,"output":200,"total":999}"#;
602        let c: Cost = serde_json::from_str(json).unwrap();
603        assert_eq!(c.total_microdollars(), 300);
604    }
605
606    #[test]
607    fn test_cost_deserialization_without_total() {
608        let json = r#"{"input":100,"output":200}"#;
609        let c: Cost = serde_json::from_str(json).unwrap();
610        assert_eq!(c.total_microdollars(), 300);
611    }
612
613    #[test]
614    fn test_cost_deserialization_overflow_fails() {
615        let json = format!(r#"{{"input":{},"output":1}}"#, u64::MAX);
616        let result: Result<Cost, _> = serde_json::from_str(&json);
617        assert!(result.is_err());
618    }
619
620    #[test]
621    fn test_cost_default_is_zero() {
622        let c = Cost::default();
623        assert_eq!(c.input_microdollars(), 0);
624        assert_eq!(c.output_microdollars(), 0);
625        assert_eq!(c.total_microdollars(), 0);
626    }
627
628    // --- Cost Display ---
629
630    #[test]
631    fn test_cost_display() {
632        let c = Cost::new(1_000_000, 500_000).unwrap();
633        assert_eq!(c.to_string(), "$1.50");
634    }
635
636    #[test]
637    fn test_cost_display_zero() {
638        assert_eq!(Cost::default().to_string(), "$0.00");
639    }
640
641    #[test]
642    fn test_cost_display_sub_cent() {
643        let c = Cost::new(500, 0).unwrap();
644        assert_eq!(c.to_string(), "$0.00");
645    }
646
647    // --- Usage Add/AddAssign ---
648
649    #[test]
650    fn test_usage_add_basic() {
651        let a = Usage {
652            input_tokens: 100,
653            output_tokens: 50,
654            reasoning_tokens: Some(10),
655            cache_read_tokens: None,
656            cache_write_tokens: Some(20),
657        };
658        let b = Usage {
659            input_tokens: 200,
660            output_tokens: 30,
661            reasoning_tokens: Some(5),
662            cache_read_tokens: Some(50),
663            cache_write_tokens: None,
664        };
665        let sum = a + b;
666        assert_eq!(sum.input_tokens, 300);
667        assert_eq!(sum.output_tokens, 80);
668        assert_eq!(sum.reasoning_tokens, Some(15));
669        assert_eq!(sum.cache_read_tokens, Some(50));
670        assert_eq!(sum.cache_write_tokens, Some(20));
671    }
672
673    #[test]
674    fn test_usage_add_both_none() {
675        let a = Usage::default();
676        let b = Usage::default();
677        let sum = a + b;
678        assert_eq!(sum.reasoning_tokens, None);
679        assert_eq!(sum.cache_read_tokens, None);
680        assert_eq!(sum.cache_write_tokens, None);
681    }
682
683    #[test]
684    fn test_usage_add_assign() {
685        let mut a = Usage {
686            input_tokens: 100,
687            output_tokens: 50,
688            ..Default::default()
689        };
690        a += Usage {
691            input_tokens: 200,
692            output_tokens: 30,
693            ..Default::default()
694        };
695        assert_eq!(a.input_tokens, 300);
696        assert_eq!(a.output_tokens, 80);
697    }
698
699    #[test]
700    fn test_usage_add_saturates() {
701        let a = Usage {
702            input_tokens: u64::MAX,
703            output_tokens: 0,
704            ..Default::default()
705        };
706        let b = Usage {
707            input_tokens: 1,
708            output_tokens: 0,
709            ..Default::default()
710        };
711        let sum = a + b;
712        assert_eq!(sum.input_tokens, u64::MAX);
713    }
714
715    // --- Cost Add/AddAssign/checked_add ---
716
717    #[test]
718    fn test_cost_add_basic() {
719        let a = Cost::new(100, 200).unwrap();
720        let b = Cost::new(300, 400).unwrap();
721        let sum = a + b;
722        assert_eq!(sum.input_microdollars(), 400);
723        assert_eq!(sum.output_microdollars(), 600);
724        assert_eq!(sum.total_microdollars(), 1000);
725    }
726
727    #[test]
728    fn test_cost_add_assign() {
729        let mut c = Cost::new(100, 200).unwrap();
730        c += Cost::new(50, 50).unwrap();
731        assert_eq!(c.input_microdollars(), 150);
732        assert_eq!(c.output_microdollars(), 250);
733        assert_eq!(c.total_microdollars(), 400);
734    }
735
736    #[test]
737    fn test_cost_checked_add() {
738        let a = Cost::new(100, 200).unwrap();
739        let b = Cost::new(300, 400).unwrap();
740        let sum = a.checked_add(&b).unwrap();
741        assert_eq!(sum.total_microdollars(), 1000);
742    }
743
744    #[test]
745    fn test_cost_checked_add_overflow() {
746        let a = Cost::new(u64::MAX - 1, 0).unwrap();
747        let b = Cost::new(2, 0).unwrap();
748        assert!(a.checked_add(&b).is_none());
749    }
750
751    #[test]
752    fn test_cost_add_saturates() {
753        let a = Cost::new(u64::MAX - 1, 0).unwrap();
754        let b = Cost::new(2, 0).unwrap();
755        let sum = a + b;
756        assert_eq!(sum.input_microdollars(), u64::MAX);
757    }
758
759    // --- UsageTracker ---
760
761    #[test]
762    fn test_usage_tracker_new() {
763        let tracker = UsageTracker::new();
764        assert_eq!(tracker.total().input_tokens, 0);
765        assert_eq!(tracker.total().output_tokens, 0);
766        assert!(tracker.calls().is_empty());
767        assert_eq!(tracker.context_limit(), None);
768    }
769
770    #[test]
771    fn test_usage_tracker_default() {
772        let tracker = UsageTracker::default();
773        assert_eq!(tracker.call_count(), 0);
774        assert_eq!(tracker.context_limit(), None);
775    }
776
777    #[test]
778    fn test_usage_tracker_with_context_limit() {
779        let tracker = UsageTracker::with_context_limit(128_000);
780        assert_eq!(tracker.context_limit(), Some(128_000));
781    }
782
783    #[test]
784    fn test_usage_tracker_record() {
785        let mut tracker = UsageTracker::new();
786        tracker.record(Usage {
787            input_tokens: 100,
788            output_tokens: 50,
789            ..Default::default()
790        });
791        tracker.record(Usage {
792            input_tokens: 200,
793            output_tokens: 100,
794            ..Default::default()
795        });
796
797        assert_eq!(tracker.total().input_tokens, 300);
798        assert_eq!(tracker.total().output_tokens, 150);
799        assert_eq!(tracker.call_count(), 2);
800        assert_eq!(tracker.calls()[0].input_tokens, 100);
801        assert_eq!(tracker.calls()[1].input_tokens, 200);
802    }
803
804    #[test]
805    fn test_usage_tracker_context_utilization() {
806        let mut tracker = UsageTracker::with_context_limit(100_000);
807        tracker.record(Usage {
808            input_tokens: 50_000,
809            output_tokens: 1000,
810            ..Default::default()
811        });
812
813        let util = tracker.context_utilization().unwrap();
814        assert!((util - 0.5).abs() < f64::EPSILON);
815    }
816
817    #[test]
818    fn test_usage_tracker_context_utilization_no_limit() {
819        let tracker = UsageTracker::new();
820        assert!(tracker.context_utilization().is_none());
821    }
822
823    #[test]
824    fn test_usage_tracker_context_utilization_zero_limit() {
825        let tracker = UsageTracker::with_context_limit(0);
826        assert!((tracker.context_utilization().unwrap()).abs() < f64::EPSILON);
827    }
828
829    #[test]
830    fn test_usage_tracker_is_near_limit() {
831        let mut tracker = UsageTracker::with_context_limit(100_000);
832        tracker.record(Usage {
833            input_tokens: 85_000,
834            output_tokens: 1000,
835            ..Default::default()
836        });
837
838        assert!(tracker.is_near_limit(0.8)); // 85% >= 80%
839        assert!(tracker.is_near_limit(0.85)); // 85% >= 85%
840        assert!(!tracker.is_near_limit(0.9)); // 85% < 90%
841    }
842
843    #[test]
844    fn test_usage_tracker_is_near_limit_no_limit() {
845        let tracker = UsageTracker::new();
846        assert!(!tracker.is_near_limit(0.8));
847    }
848
849    #[test]
850    fn test_usage_tracker_set_context_limit() {
851        let mut tracker = UsageTracker::new();
852        assert_eq!(tracker.context_limit(), None);
853
854        tracker.set_context_limit(200_000);
855        assert_eq!(tracker.context_limit(), Some(200_000));
856    }
857
858    #[test]
859    fn test_usage_tracker_reset() {
860        let mut tracker = UsageTracker::with_context_limit(100_000);
861        tracker.record(Usage {
862            input_tokens: 1000,
863            output_tokens: 500,
864            ..Default::default()
865        });
866        assert_eq!(tracker.call_count(), 1);
867        assert_eq!(tracker.total().input_tokens, 1000);
868
869        tracker.reset();
870        assert_eq!(tracker.call_count(), 0);
871        assert_eq!(tracker.total().input_tokens, 0);
872        // Context limit should be preserved
873        assert_eq!(tracker.context_limit(), Some(100_000));
874    }
875
876    #[test]
877    fn test_usage_tracker_clone() {
878        let mut tracker = UsageTracker::with_context_limit(50_000);
879        tracker.record(Usage {
880            input_tokens: 100,
881            output_tokens: 50,
882            ..Default::default()
883        });
884
885        let cloned = tracker.clone();
886        assert_eq!(cloned.total().input_tokens, 100);
887        assert_eq!(cloned.call_count(), 1);
888        assert_eq!(cloned.context_limit(), Some(50_000));
889    }
890
891    // --- ModelPricing ---
892
893    #[test]
894    fn test_model_pricing_compute_cost() {
895        let pricing = ModelPricing {
896            input_per_million: 3_000_000,   // $3 per MTok
897            output_per_million: 15_000_000, // $15 per MTok
898            cache_read_per_million: None,
899        };
900
901        let usage = Usage {
902            input_tokens: 1_000_000, // 1 MTok input
903            output_tokens: 100_000,  // 0.1 MTok output
904            ..Default::default()
905        };
906
907        let cost = pricing.compute_cost(&usage).unwrap();
908        assert_eq!(cost.input_microdollars(), 3_000_000); // $3
909        assert_eq!(cost.output_microdollars(), 1_500_000); // $1.50
910        assert_eq!(cost.total_microdollars(), 4_500_000); // $4.50
911    }
912
913    #[test]
914    fn test_model_pricing_with_cache_tokens() {
915        let pricing = ModelPricing {
916            input_per_million: 3_000_000,
917            output_per_million: 15_000_000,
918            cache_read_per_million: Some(300_000), // $0.30 per MTok
919        };
920
921        let usage = Usage {
922            input_tokens: 500_000,
923            output_tokens: 100_000,
924            cache_read_tokens: Some(500_000), // 0.5 MTok from cache
925            ..Default::default()
926        };
927
928        let cost = pricing.compute_cost(&usage).unwrap();
929        // Input: 500k * $3/MTok = $1.50 = 1_500_000
930        // Cache: 500k * $0.30/MTok = $0.15 = 150_000
931        // Total input side: $1.65 = 1_650_000
932        // Output: 100k * $15/MTok = $1.50 = 1_500_000
933        assert_eq!(cost.input_microdollars(), 1_650_000);
934        assert_eq!(cost.output_microdollars(), 1_500_000);
935    }
936
937    #[test]
938    fn test_model_pricing_zero_tokens() {
939        let pricing = ModelPricing {
940            input_per_million: 3_000_000,
941            output_per_million: 15_000_000,
942            cache_read_per_million: None,
943        };
944
945        let usage = Usage::default();
946        let cost = pricing.compute_cost(&usage).unwrap();
947        assert_eq!(cost.total_microdollars(), 0);
948    }
949
950    #[test]
951    fn test_model_pricing_cache_without_pricing() {
952        // Cache tokens present but no cache pricing — should ignore
953        let pricing = ModelPricing {
954            input_per_million: 3_000_000,
955            output_per_million: 15_000_000,
956            cache_read_per_million: None,
957        };
958
959        let usage = Usage {
960            input_tokens: 1_000_000,
961            output_tokens: 100_000,
962            cache_read_tokens: Some(500_000),
963            ..Default::default()
964        };
965
966        let cost = pricing.compute_cost(&usage).unwrap();
967        // Cache tokens ignored since no pricing
968        assert_eq!(cost.input_microdollars(), 3_000_000);
969    }
970
971    #[test]
972    fn test_usage_tracker_cost() {
973        let mut tracker = UsageTracker::new();
974        tracker.record(Usage {
975            input_tokens: 1_000_000,
976            output_tokens: 100_000,
977            ..Default::default()
978        });
979
980        let pricing = ModelPricing {
981            input_per_million: 3_000_000,
982            output_per_million: 15_000_000,
983            cache_read_per_million: None,
984        };
985
986        let cost = tracker.cost(&pricing).unwrap();
987        assert_eq!(cost.total_microdollars(), 4_500_000);
988    }
989
990    #[test]
991    fn test_model_pricing_clone_eq() {
992        let p1 = ModelPricing {
993            input_per_million: 100,
994            output_per_million: 200,
995            cache_read_per_million: Some(50),
996        };
997        let p2 = p1.clone();
998        assert_eq!(p1, p2);
999    }
1000
1001    #[test]
1002    fn test_compute_token_cost_large_values() {
1003        // Test with large but reasonable values
1004        let cost = compute_token_cost(10_000_000_000, 3_000_000);
1005        // 10B tokens * $3/MTok = $30,000 = 30_000_000_000 microdollars
1006        assert_eq!(cost, Some(30_000_000_000));
1007    }
1008}