Skip to main content

starweaver_usage/
lib.rs

1//! Usage accounting, limits, and optional pricing primitives for Starweaver.
2
3use std::collections::BTreeMap;
4
5use serde::{Deserialize, Serialize};
6use thiserror::Error;
7
8#[cfg(feature = "pricing")]
9pub mod pricing;
10
11/// Token and request usage accumulated by model and runtime layers.
12#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
13pub struct Usage {
14    /// Number of provider requests.
15    pub requests: u64,
16    /// Input or prompt tokens.
17    ///
18    /// Provider adapters should normalize this as total provider-billed input
19    /// tokens for the request, including cache-write and cache-read tokens when
20    /// those subtotals are present. Pricing helpers subtract the cache subtotals
21    /// before applying cache-specific rates.
22    pub input_tokens: u64,
23    /// Tokens written to a provider prompt cache.
24    #[serde(default)]
25    pub cache_write_tokens: u64,
26    /// Tokens read from a provider prompt cache.
27    #[serde(default)]
28    pub cache_read_tokens: u64,
29    /// Output or completion tokens.
30    pub output_tokens: u64,
31    /// Total tokens.
32    pub total_tokens: u64,
33    /// Number of successful function tool calls executed by the runtime.
34    #[serde(default)]
35    pub tool_calls: u64,
36}
37
38impl Usage {
39    /// Add another usage value into this one.
40    pub const fn add_assign(&mut self, other: &Self) {
41        self.requests = self.requests.saturating_add(other.requests);
42        self.input_tokens = self.input_tokens.saturating_add(other.input_tokens);
43        self.cache_write_tokens = self
44            .cache_write_tokens
45            .saturating_add(other.cache_write_tokens);
46        self.cache_read_tokens = self
47            .cache_read_tokens
48            .saturating_add(other.cache_read_tokens);
49        self.output_tokens = self.output_tokens.saturating_add(other.output_tokens);
50        self.total_tokens = self.total_tokens.saturating_add(other.total_tokens);
51        self.tool_calls = self.tool_calls.saturating_add(other.tool_calls);
52    }
53
54    /// Return whether no usage has been recorded.
55    #[must_use]
56    pub const fn is_empty(&self) -> bool {
57        self.requests == 0
58            && self.input_tokens == 0
59            && self.cache_write_tokens == 0
60            && self.cache_read_tokens == 0
61            && self.output_tokens == 0
62            && self.total_tokens == 0
63            && self.tool_calls == 0
64    }
65
66    /// Return a copy with additional successful tool calls applied.
67    #[must_use]
68    pub const fn with_additional_tool_calls(mut self, tool_calls: u64) -> Self {
69        self.tool_calls = self.tool_calls.saturating_add(tool_calls);
70        self
71    }
72}
73
74/// Estimated USD pricing for usage.
75#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Serialize)]
76pub struct PricingEstimate {
77    /// Estimated cost in micro USD units.
78    #[serde(default)]
79    pub amount_micros_usd: u64,
80}
81
82impl PricingEstimate {
83    /// Create an estimate from micro USD units.
84    #[must_use]
85    pub const fn from_micros_usd(amount_micros_usd: u64) -> Self {
86        Self { amount_micros_usd }
87    }
88
89    /// Add another estimate into this one.
90    pub const fn add_assign(&mut self, other: &Self) {
91        self.amount_micros_usd = self
92            .amount_micros_usd
93            .saturating_add(other.amount_micros_usd);
94    }
95
96    /// Return whether the estimate is zero.
97    #[must_use]
98    pub const fn is_zero(&self) -> bool {
99        self.amount_micros_usd == 0
100    }
101}
102
103/// Cumulative usage for one agent or usage source in the current run.
104#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
105pub struct UsageSnapshotEntry {
106    /// Agent or source instance that generated this usage.
107    pub agent_id: String,
108    /// Human-readable agent or source name.
109    pub agent_name: String,
110    /// Model identifier that generated this usage.
111    pub model_id: String,
112    /// Cumulative token usage for this agent/source in the run.
113    pub usage: Usage,
114    /// Estimated cumulative pricing for this entry, in USD.
115    #[serde(default, skip_serializing_if = "Option::is_none")]
116    pub estimate_pricing: Option<PricingEstimate>,
117    /// Stable usage record id for idempotent updates.
118    #[serde(default, skip_serializing_if = "Option::is_none")]
119    pub usage_id: Option<String>,
120    /// Component that reported this usage.
121    #[serde(default = "default_usage_source")]
122    pub source: String,
123}
124
125/// Cumulative usage grouped by agent/source.
126#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
127pub struct UsageAgentTotal {
128    /// Human-readable agent or source name.
129    pub agent_name: String,
130    /// Model identifier, or `multiple` when a source used more than one model.
131    pub model_id: String,
132    /// Cumulative token usage for this agent/source.
133    pub usage: Usage,
134    /// Estimated cumulative pricing for this agent/source, in USD.
135    #[serde(default, skip_serializing_if = "Option::is_none")]
136    pub estimate_pricing: Option<PricingEstimate>,
137    /// Stable usage record id when all grouped entries share one id.
138    #[serde(default, skip_serializing_if = "Option::is_none")]
139    pub usage_id: Option<String>,
140    /// Component that reported this usage.
141    #[serde(default = "default_usage_source")]
142    pub source: String,
143}
144
145/// Cumulative usage snapshot for one run.
146///
147/// Realtime consumers should treat each snapshot as a replacement for the
148/// previous snapshot with the same run id.
149#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
150pub struct UsageSnapshot {
151    /// Run identifier for the snapshot.
152    pub run_id: String,
153    /// Usage reported by the latest provider request that produced this snapshot.
154    ///
155    /// This is intentionally separate from `total_usage`: realtime UI surfaces may use
156    /// the latest request total tokens as the current context-window estimate,
157    /// while `total_usage` remains the cumulative run ledger.
158    #[serde(default, skip_serializing_if = "Option::is_none")]
159    pub latest_usage: Option<Usage>,
160    /// Cumulative usage across all known entries in this run.
161    #[serde(default)]
162    pub total_usage: Usage,
163    /// Estimated cumulative pricing across all known entries in this run, in USD.
164    #[serde(default, skip_serializing_if = "Option::is_none")]
165    pub estimate_pricing: Option<PricingEstimate>,
166    /// Per-agent/source cumulative usage entries.
167    #[serde(default, skip_serializing_if = "Vec::is_empty")]
168    pub entries: Vec<UsageSnapshotEntry>,
169    /// Cumulative usage grouped by agent id.
170    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
171    pub agent_usages: BTreeMap<String, UsageAgentTotal>,
172    /// Cumulative usage grouped by model id.
173    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
174    pub model_usages: BTreeMap<String, Usage>,
175    /// Estimated cumulative pricing grouped by model id, in USD.
176    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
177    pub model_estimate_pricing: BTreeMap<String, PricingEstimate>,
178}
179
180fn default_usage_source() -> String {
181    "model_request".to_string()
182}
183
184/// Runtime usage limits for one agent run.
185#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
186pub struct UsageLimits {
187    /// Maximum provider requests allowed in one run.
188    #[serde(default, skip_serializing_if = "Option::is_none")]
189    pub request_limit: Option<u64>,
190    /// Maximum input tokens allowed.
191    #[serde(default, skip_serializing_if = "Option::is_none")]
192    pub input_tokens_limit: Option<u64>,
193    /// Maximum output tokens allowed.
194    #[serde(default, skip_serializing_if = "Option::is_none")]
195    pub output_tokens_limit: Option<u64>,
196    /// Maximum total tokens allowed.
197    #[serde(default, skip_serializing_if = "Option::is_none")]
198    pub total_tokens_limit: Option<u64>,
199    /// Maximum successful function tool calls allowed.
200    #[serde(default, skip_serializing_if = "Option::is_none")]
201    pub tool_calls_limit: Option<u64>,
202    /// Optional USD pricing budget based on accumulated usage.
203    #[cfg(feature = "pricing")]
204    #[serde(default, skip_serializing_if = "Option::is_none")]
205    pub cost_budget: Option<pricing::CostBudget>,
206}
207
208impl UsageLimits {
209    /// Create empty limits.
210    #[must_use]
211    pub const fn new() -> Self {
212        Self {
213            request_limit: None,
214            input_tokens_limit: None,
215            output_tokens_limit: None,
216            total_tokens_limit: None,
217            tool_calls_limit: None,
218            #[cfg(feature = "pricing")]
219            cost_budget: None,
220        }
221    }
222
223    /// Set request limit.
224    #[must_use]
225    pub const fn with_request_limit(mut self, limit: u64) -> Self {
226        self.request_limit = Some(limit);
227        self
228    }
229
230    /// Set input token limit.
231    #[must_use]
232    pub const fn with_input_tokens_limit(mut self, limit: u64) -> Self {
233        self.input_tokens_limit = Some(limit);
234        self
235    }
236
237    /// Set output token limit.
238    #[must_use]
239    pub const fn with_output_tokens_limit(mut self, limit: u64) -> Self {
240        self.output_tokens_limit = Some(limit);
241        self
242    }
243
244    /// Set total token limit.
245    #[must_use]
246    pub const fn with_total_tokens_limit(mut self, limit: u64) -> Self {
247        self.total_tokens_limit = Some(limit);
248        self
249    }
250
251    /// Set successful tool-call limit.
252    #[must_use]
253    pub const fn with_tool_calls_limit(mut self, limit: u64) -> Self {
254        self.tool_calls_limit = Some(limit);
255        self
256    }
257
258    /// Set USD cost budget.
259    #[cfg(feature = "pricing")]
260    #[must_use]
261    pub const fn with_cost_budget(mut self, budget: pricing::CostBudget) -> Self {
262        self.cost_budget = Some(budget);
263        self
264    }
265
266    /// Estimate current USD cost in micro-units when a cost budget is configured.
267    #[cfg(feature = "pricing")]
268    #[must_use]
269    pub fn estimate_cost_micros(&self, usage: &Usage) -> Option<u64> {
270        self.cost_budget
271            .as_ref()
272            .map(|budget| budget.estimate_micros(usage))
273    }
274
275    /// Estimate current USD pricing when a cost budget is configured.
276    #[cfg(feature = "pricing")]
277    #[must_use]
278    pub fn estimate_pricing(&self, usage: &Usage) -> Option<PricingEstimate> {
279        self.cost_budget
280            .as_ref()
281            .map(|budget| budget.estimate_pricing(usage))
282    }
283
284    /// Check whether the next model request would exceed the request limit.
285    ///
286    /// # Errors
287    ///
288    /// Returns an error when another request would exceed the configured request limit.
289    pub const fn check_before_request(&self, current: &Usage) -> Result<(), UsageLimitError> {
290        if let Some(limit) = self.request_limit {
291            let next = current.requests.saturating_add(1);
292            if next > limit {
293                return Err(UsageLimitError::NextRequest {
294                    limit,
295                    next_requests: next,
296                });
297            }
298        }
299        Ok(())
300    }
301
302    /// Check whether projected tool calls would exceed the tool-call limit.
303    ///
304    /// # Errors
305    ///
306    /// Returns an error when executing the next successful tool calls would exceed the configured limit.
307    pub const fn check_tool_calls(&self, projected: &Usage) -> Result<(), UsageLimitError> {
308        if let Some(limit) = self.tool_calls_limit {
309            if projected.tool_calls > limit {
310                return Err(UsageLimitError::ToolCalls {
311                    limit,
312                    tool_calls: projected.tool_calls,
313                });
314            }
315        }
316        Ok(())
317    }
318
319    /// Check whether accumulated usage exceeds configured token or pricing limits.
320    ///
321    /// # Errors
322    ///
323    /// Returns an error when accumulated usage exceeds any configured limit.
324    pub fn check_usage(&self, usage: &Usage) -> Result<(), UsageLimitError> {
325        check_limit(
326            UsageTokenKind::InputTokens,
327            self.input_tokens_limit,
328            usage.input_tokens,
329        )?;
330        check_limit(
331            UsageTokenKind::OutputTokens,
332            self.output_tokens_limit,
333            usage.output_tokens,
334        )?;
335        check_limit(
336            UsageTokenKind::TotalTokens,
337            self.total_tokens_limit,
338            usage.total_tokens,
339        )?;
340        #[cfg(feature = "pricing")]
341        if let Some(budget) = &self.cost_budget {
342            budget.check_usage(usage)?;
343        }
344        Ok(())
345    }
346}
347
348/// Token counter checked by a usage limit.
349#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
350#[serde(rename_all = "snake_case")]
351pub enum UsageTokenKind {
352    /// Input or prompt tokens.
353    InputTokens,
354    /// Output or completion tokens.
355    OutputTokens,
356    /// Total tokens.
357    TotalTokens,
358}
359
360impl std::fmt::Display for UsageTokenKind {
361    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362        let value = match self {
363            Self::InputTokens => "input_tokens",
364            Self::OutputTokens => "output_tokens",
365            Self::TotalTokens => "total_tokens",
366        };
367        formatter.write_str(value)
368    }
369}
370
371/// Usage limit error.
372#[derive(Clone, Debug, Error, Deserialize, Eq, PartialEq, Serialize)]
373pub enum UsageLimitError {
374    /// The next request would exceed request budget.
375    #[error("the next request would exceed the request_limit of {limit} (next_requests={next_requests})")]
376    NextRequest {
377        /// Configured limit.
378        limit: u64,
379        /// Requests after the next request.
380        next_requests: u64,
381    },
382    /// Accumulated usage exceeded a token budget.
383    #[error("exceeded the {kind}_limit of {limit} ({kind}={actual})")]
384    Token {
385        /// Usage kind.
386        kind: UsageTokenKind,
387        /// Configured limit.
388        limit: u64,
389        /// Actual usage value.
390        actual: u64,
391    },
392    /// Accumulated usage exceeded a USD pricing budget.
393    #[cfg(feature = "pricing")]
394    #[error(
395        "exceeded the total_cost_limit_micros of {limit_micros} (cost_micros={actual_micros})"
396    )]
397    Cost {
398        /// Configured cost limit in micro USD units.
399        limit_micros: u64,
400        /// Actual estimated cost in micro USD units.
401        actual_micros: u64,
402    },
403    /// Projected successful function tool calls would exceed the configured budget.
404    #[error("the next tool call(s) would exceed the tool_calls_limit of {limit} (tool_calls={tool_calls})")]
405    ToolCalls {
406        /// Configured tool-call limit.
407        limit: u64,
408        /// Projected successful tool calls.
409        tool_calls: u64,
410    },
411}
412
413const fn check_limit(
414    kind: UsageTokenKind,
415    limit: Option<u64>,
416    actual: u64,
417) -> Result<(), UsageLimitError> {
418    if let Some(limit) = limit {
419        if actual > limit {
420            return Err(UsageLimitError::Token {
421                kind,
422                limit,
423                actual,
424            });
425        }
426    }
427    Ok(())
428}
429
430/// Aggregate an optional pricing estimate into a running total.
431pub fn add_optional_pricing(
432    total: &mut Option<PricingEstimate>,
433    estimate: Option<&PricingEstimate>,
434) {
435    if let Some(estimate) = estimate {
436        match total {
437            Some(total) => total.add_assign(estimate),
438            None => *total = Some(estimate.clone()),
439        }
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    #[test]
448    fn usage_add_assign_and_empty_work() {
449        let mut usage = Usage {
450            requests: 1,
451            input_tokens: 2,
452            cache_write_tokens: 7,
453            cache_read_tokens: 11,
454            output_tokens: 3,
455            total_tokens: 5,
456            tool_calls: 1,
457        };
458        usage.add_assign(&Usage {
459            requests: 2,
460            input_tokens: 4,
461            cache_write_tokens: 13,
462            cache_read_tokens: 17,
463            output_tokens: 6,
464            total_tokens: 10,
465            tool_calls: 3,
466        });
467        assert_eq!(usage.requests, 3);
468        assert_eq!(usage.input_tokens, 6);
469        assert_eq!(usage.cache_write_tokens, 20);
470        assert_eq!(usage.cache_read_tokens, 28);
471        assert_eq!(usage.output_tokens, 9);
472        assert_eq!(usage.total_tokens, 15);
473        assert_eq!(usage.tool_calls, 4);
474        assert_eq!(usage.clone().with_additional_tool_calls(2).tool_calls, 6);
475        assert!(Usage::default().is_empty());
476        assert!(!usage.is_empty());
477    }
478
479    #[test]
480    fn usage_add_assign_saturates() {
481        let mut usage = Usage {
482            requests: u64::MAX,
483            input_tokens: u64::MAX,
484            cache_write_tokens: u64::MAX,
485            cache_read_tokens: u64::MAX,
486            output_tokens: u64::MAX,
487            total_tokens: u64::MAX,
488            tool_calls: u64::MAX,
489        };
490        usage.add_assign(&Usage {
491            requests: 1,
492            input_tokens: 1,
493            cache_write_tokens: 1,
494            cache_read_tokens: 1,
495            output_tokens: 1,
496            total_tokens: 1,
497            tool_calls: 1,
498        });
499        assert_eq!(usage.requests, u64::MAX);
500        assert_eq!(usage.input_tokens, u64::MAX);
501        assert_eq!(usage.cache_write_tokens, u64::MAX);
502        assert_eq!(usage.cache_read_tokens, u64::MAX);
503        assert_eq!(usage.output_tokens, u64::MAX);
504        assert_eq!(usage.total_tokens, u64::MAX);
505        assert_eq!(usage.tool_calls, u64::MAX);
506    }
507
508    #[test]
509    fn usage_limit_error_token_kind_is_owned_ser_de_contract() {
510        let error = UsageLimitError::Token {
511            kind: UsageTokenKind::TotalTokens,
512            limit: 5,
513            actual: 6,
514        };
515        let value = match serde_json::to_value(&error) {
516            Ok(value) => value,
517            Err(err) => panic!("usage limit error should serialize: {err}"),
518        };
519        let restored: UsageLimitError = match serde_json::from_value(value) {
520            Ok(restored) => restored,
521            Err(err) => panic!("usage limit error should deserialize: {err}"),
522        };
523        assert_eq!(restored, error);
524    }
525
526    #[test]
527    fn usage_snapshot_accepts_missing_pricing_fields() {
528        let snapshot: UsageSnapshot = match serde_json::from_value(serde_json::json!({
529            "run_id": "run_1",
530            "total_usage": {
531                "requests": 1,
532                "input_tokens": 2,
533                "output_tokens": 3,
534                "total_tokens": 5
535            }
536        })) {
537            Ok(snapshot) => snapshot,
538            Err(err) => panic!("usage snapshot should deserialize: {err}"),
539        };
540        assert_eq!(snapshot.run_id, "run_1");
541        assert!(snapshot.estimate_pricing.is_none());
542        assert!(snapshot.model_estimate_pricing.is_empty());
543    }
544}