use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[cfg(feature = "pricing")]
pub mod pricing;
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct Usage {
pub requests: u64,
pub input_tokens: u64,
#[serde(default)]
pub cache_write_tokens: u64,
#[serde(default)]
pub cache_read_tokens: u64,
pub output_tokens: u64,
pub total_tokens: u64,
#[serde(default)]
pub tool_calls: u64,
}
impl Usage {
pub fn add_assign(&mut self, other: &Self) {
self.requests = self.requests.saturating_add(other.requests);
self.input_tokens = self.input_tokens.saturating_add(other.input_tokens);
self.cache_write_tokens = self
.cache_write_tokens
.saturating_add(other.cache_write_tokens);
self.cache_read_tokens = self
.cache_read_tokens
.saturating_add(other.cache_read_tokens);
self.output_tokens = self.output_tokens.saturating_add(other.output_tokens);
self.total_tokens = self.total_tokens.saturating_add(other.total_tokens);
self.tool_calls = self.tool_calls.saturating_add(other.tool_calls);
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.requests == 0
&& self.input_tokens == 0
&& self.cache_write_tokens == 0
&& self.cache_read_tokens == 0
&& self.output_tokens == 0
&& self.total_tokens == 0
&& self.tool_calls == 0
}
#[must_use]
pub const fn with_additional_tool_calls(mut self, tool_calls: u64) -> Self {
self.tool_calls = self.tool_calls.saturating_add(tool_calls);
self
}
}
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Serialize)]
pub struct PricingEstimate {
#[serde(default)]
pub amount_micros_usd: u64,
}
impl PricingEstimate {
#[must_use]
pub const fn from_micros_usd(amount_micros_usd: u64) -> Self {
Self { amount_micros_usd }
}
pub fn add_assign(&mut self, other: &Self) {
self.amount_micros_usd = self
.amount_micros_usd
.saturating_add(other.amount_micros_usd);
}
#[must_use]
pub const fn is_zero(&self) -> bool {
self.amount_micros_usd == 0
}
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct UsageSnapshotEntry {
pub agent_id: String,
pub agent_name: String,
pub model_id: String,
pub usage: Usage,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub estimate_pricing: Option<PricingEstimate>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub usage_id: Option<String>,
#[serde(default = "default_usage_source")]
pub source: String,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct UsageAgentTotal {
pub agent_name: String,
pub model_id: String,
pub usage: Usage,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub estimate_pricing: Option<PricingEstimate>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub usage_id: Option<String>,
#[serde(default = "default_usage_source")]
pub source: String,
}
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct UsageSnapshot {
pub run_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub latest_usage: Option<Usage>,
#[serde(default)]
pub total_usage: Usage,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub estimate_pricing: Option<PricingEstimate>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub entries: Vec<UsageSnapshotEntry>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub agent_usages: BTreeMap<String, UsageAgentTotal>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub model_usages: BTreeMap<String, Usage>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub model_estimate_pricing: BTreeMap<String, PricingEstimate>,
}
fn default_usage_source() -> String {
"model_request".to_string()
}
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct UsageLimits {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub request_limit: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub input_tokens_limit: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_tokens_limit: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub total_tokens_limit: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls_limit: Option<u64>,
#[cfg(feature = "pricing")]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cost_budget: Option<pricing::CostBudget>,
}
impl UsageLimits {
#[must_use]
pub const fn new() -> Self {
Self {
request_limit: None,
input_tokens_limit: None,
output_tokens_limit: None,
total_tokens_limit: None,
tool_calls_limit: None,
#[cfg(feature = "pricing")]
cost_budget: None,
}
}
#[must_use]
pub const fn with_request_limit(mut self, limit: u64) -> Self {
self.request_limit = Some(limit);
self
}
#[must_use]
pub const fn with_input_tokens_limit(mut self, limit: u64) -> Self {
self.input_tokens_limit = Some(limit);
self
}
#[must_use]
pub const fn with_output_tokens_limit(mut self, limit: u64) -> Self {
self.output_tokens_limit = Some(limit);
self
}
#[must_use]
pub const fn with_total_tokens_limit(mut self, limit: u64) -> Self {
self.total_tokens_limit = Some(limit);
self
}
#[must_use]
pub const fn with_tool_calls_limit(mut self, limit: u64) -> Self {
self.tool_calls_limit = Some(limit);
self
}
#[cfg(feature = "pricing")]
#[must_use]
pub const fn with_cost_budget(mut self, budget: pricing::CostBudget) -> Self {
self.cost_budget = Some(budget);
self
}
#[cfg(feature = "pricing")]
#[must_use]
pub fn estimate_cost_micros(&self, usage: &Usage) -> Option<u64> {
self.cost_budget
.as_ref()
.map(|budget| budget.estimate_micros(usage))
}
#[cfg(feature = "pricing")]
#[must_use]
pub fn estimate_pricing(&self, usage: &Usage) -> Option<PricingEstimate> {
self.cost_budget
.as_ref()
.map(|budget| budget.estimate_pricing(usage))
}
pub const fn check_before_request(&self, current: &Usage) -> Result<(), UsageLimitError> {
if let Some(limit) = self.request_limit {
let next = current.requests.saturating_add(1);
if next > limit {
return Err(UsageLimitError::NextRequest {
limit,
next_requests: next,
});
}
}
Ok(())
}
pub const fn check_tool_calls(&self, projected: &Usage) -> Result<(), UsageLimitError> {
if let Some(limit) = self.tool_calls_limit {
if projected.tool_calls > limit {
return Err(UsageLimitError::ToolCalls {
limit,
tool_calls: projected.tool_calls,
});
}
}
Ok(())
}
pub fn check_usage(&self, usage: &Usage) -> Result<(), UsageLimitError> {
check_limit(
UsageTokenKind::InputTokens,
self.input_tokens_limit,
usage.input_tokens,
)?;
check_limit(
UsageTokenKind::OutputTokens,
self.output_tokens_limit,
usage.output_tokens,
)?;
check_limit(
UsageTokenKind::TotalTokens,
self.total_tokens_limit,
usage.total_tokens,
)?;
#[cfg(feature = "pricing")]
if let Some(budget) = &self.cost_budget {
budget.check_usage(usage)?;
}
Ok(())
}
}
#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum UsageTokenKind {
InputTokens,
OutputTokens,
TotalTokens,
}
impl std::fmt::Display for UsageTokenKind {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let value = match self {
Self::InputTokens => "input_tokens",
Self::OutputTokens => "output_tokens",
Self::TotalTokens => "total_tokens",
};
formatter.write_str(value)
}
}
#[derive(Clone, Debug, Error, Deserialize, Eq, PartialEq, Serialize)]
pub enum UsageLimitError {
#[error("the next request would exceed the request_limit of {limit} (next_requests={next_requests})")]
NextRequest {
limit: u64,
next_requests: u64,
},
#[error("exceeded the {kind}_limit of {limit} ({kind}={actual})")]
Token {
kind: UsageTokenKind,
limit: u64,
actual: u64,
},
#[cfg(feature = "pricing")]
#[error(
"exceeded the total_cost_limit_micros of {limit_micros} (cost_micros={actual_micros})"
)]
Cost {
limit_micros: u64,
actual_micros: u64,
},
#[error("the next tool call(s) would exceed the tool_calls_limit of {limit} (tool_calls={tool_calls})")]
ToolCalls {
limit: u64,
tool_calls: u64,
},
}
const fn check_limit(
kind: UsageTokenKind,
limit: Option<u64>,
actual: u64,
) -> Result<(), UsageLimitError> {
if let Some(limit) = limit {
if actual > limit {
return Err(UsageLimitError::Token {
kind,
limit,
actual,
});
}
}
Ok(())
}
pub fn add_optional_pricing(
total: &mut Option<PricingEstimate>,
estimate: Option<&PricingEstimate>,
) {
if let Some(estimate) = estimate {
match total {
Some(total) => total.add_assign(estimate),
None => *total = Some(estimate.clone()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn usage_add_assign_and_empty_work() {
let mut usage = Usage {
requests: 1,
input_tokens: 2,
cache_write_tokens: 7,
cache_read_tokens: 11,
output_tokens: 3,
total_tokens: 5,
tool_calls: 1,
};
usage.add_assign(&Usage {
requests: 2,
input_tokens: 4,
cache_write_tokens: 13,
cache_read_tokens: 17,
output_tokens: 6,
total_tokens: 10,
tool_calls: 3,
});
assert_eq!(usage.requests, 3);
assert_eq!(usage.input_tokens, 6);
assert_eq!(usage.cache_write_tokens, 20);
assert_eq!(usage.cache_read_tokens, 28);
assert_eq!(usage.output_tokens, 9);
assert_eq!(usage.total_tokens, 15);
assert_eq!(usage.tool_calls, 4);
assert_eq!(usage.clone().with_additional_tool_calls(2).tool_calls, 6);
assert!(Usage::default().is_empty());
assert!(!usage.is_empty());
}
#[test]
fn usage_add_assign_saturates() {
let mut usage = Usage {
requests: u64::MAX,
input_tokens: u64::MAX,
cache_write_tokens: u64::MAX,
cache_read_tokens: u64::MAX,
output_tokens: u64::MAX,
total_tokens: u64::MAX,
tool_calls: u64::MAX,
};
usage.add_assign(&Usage {
requests: 1,
input_tokens: 1,
cache_write_tokens: 1,
cache_read_tokens: 1,
output_tokens: 1,
total_tokens: 1,
tool_calls: 1,
});
assert_eq!(usage.requests, u64::MAX);
assert_eq!(usage.input_tokens, u64::MAX);
assert_eq!(usage.cache_write_tokens, u64::MAX);
assert_eq!(usage.cache_read_tokens, u64::MAX);
assert_eq!(usage.output_tokens, u64::MAX);
assert_eq!(usage.total_tokens, u64::MAX);
assert_eq!(usage.tool_calls, u64::MAX);
}
#[test]
fn usage_limit_error_token_kind_is_owned_ser_de_contract() {
let error = UsageLimitError::Token {
kind: UsageTokenKind::TotalTokens,
limit: 5,
actual: 6,
};
let value = match serde_json::to_value(&error) {
Ok(value) => value,
Err(err) => panic!("usage limit error should serialize: {err}"),
};
let restored: UsageLimitError = match serde_json::from_value(value) {
Ok(restored) => restored,
Err(err) => panic!("usage limit error should deserialize: {err}"),
};
assert_eq!(restored, error);
}
#[test]
fn usage_snapshot_accepts_missing_pricing_fields() {
let snapshot: UsageSnapshot = match serde_json::from_value(serde_json::json!({
"run_id": "run_1",
"total_usage": {
"requests": 1,
"input_tokens": 2,
"output_tokens": 3,
"total_tokens": 5
}
})) {
Ok(snapshot) => snapshot,
Err(err) => panic!("usage snapshot should deserialize: {err}"),
};
assert_eq!(snapshot.run_id, "run_1");
assert!(snapshot.estimate_pricing.is_none());
assert!(snapshot.model_estimate_pricing.is_empty());
}
}