use crate::identity::{CorrelationIds, UsageEventId};
use crate::pricing::{ModelRef, ProviderRef};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageObservation {
pub event_id: UsageEventId,
pub subject: crate::identity::BillingSubject,
pub meter_set: MeterSet,
pub model_ref: ModelRef,
pub provider_ref: Option<ProviderRef>,
pub source: UsageSource,
pub outcome: UsageOutcome,
pub timing: UsageTiming,
pub correlation: CorrelationIds,
pub attributes: Attributes,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MeterSet {
pub meters: HashMap<MeterKind, u64>,
}
impl MeterSet {
pub fn new() -> Self {
Self {
meters: HashMap::new(),
}
}
pub fn accumulate(&mut self, kind: MeterKind, quantity: u64) -> Result<(), MeterSetError> {
use std::collections::hash_map::Entry;
match self.meters.entry(kind) {
Entry::Occupied(mut e) => {
let new_val = e
.get()
.checked_add(quantity)
.ok_or_else(|| MeterSetError::Overflow(e.key().clone()))?;
e.insert(new_val);
}
Entry::Vacant(e) => {
e.insert(quantity);
}
}
Ok(())
}
pub fn get(&self, kind: &MeterKind) -> u64 {
self.meters.get(kind).copied().unwrap_or(0)
}
}
impl Default for MeterSet {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum MeterSetError {
Overflow(MeterKind),
}
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
pub enum MeterKind {
InputTokens,
OutputTokens,
CachedInputTokens,
CachedWriteTokens,
ReasoningTokens,
AudioInputTokens,
AudioOutputTokens,
ImageCount,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum UsageSource {
ProviderReported,
StreamAccumulated,
Estimated,
Corrected { correction_of: UsageEventId },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum UsageOutcome {
Success,
Error { code: String },
Timeout,
Unknown,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Attributes {
inner: HashMap<String, String>,
}
impl Attributes {
pub const MAX_KEY_LEN: usize = 64;
pub const MAX_VALUE_LEN: usize = 256;
pub fn new() -> Self {
Self {
inner: HashMap::new(),
}
}
pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) -> Result<(), AttributeError> {
let key = key.into();
let value = value.into();
if key.starts_with("sys.") {
return Err(AttributeError::ReservedPrefix(key));
}
if key.len() > Self::MAX_KEY_LEN {
let len = key.len();
return Err(AttributeError::KeyTooLong { key, len });
}
if value.len() > Self::MAX_VALUE_LEN {
let len = value.len();
return Err(AttributeError::ValueTooLong { key, len });
}
self.inner.insert(key, value);
Ok(())
}
pub fn get(&self, key: &str) -> Option<&str> {
self.inner.get(key).map(|s| s.as_str())
}
pub fn iter(&self) -> impl Iterator<Item = (&String, &String)> {
self.inner.iter()
}
}
#[derive(Debug, Clone)]
pub enum AttributeError {
ReservedPrefix(String),
KeyTooLong { key: String, len: usize },
ValueTooLong { key: String, len: usize },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageTiming {
pub observed_at: DateTime<Utc>,
pub completed_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
pub struct CurrencyCode(pub String);
impl CurrencyCode {
pub fn usd() -> Self {
CurrencyCode("USD".to_string())
}
pub fn cny() -> Self {
CurrencyCode("CNY".to_string())
}
pub fn eur() -> Self {
CurrencyCode("EUR".to_string())
}
}
impl std::str::FromStr for CurrencyCode {
type Err = CurrencyCodeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.len() == 3 && s.chars().all(|c| c.is_ascii_uppercase()) {
Ok(CurrencyCode(s.to_string()))
} else {
Err(CurrencyCodeError::Invalid(s.to_string()))
}
}
}
#[derive(Debug, Clone)]
pub enum CurrencyCodeError {
Invalid(String),
}
impl std::fmt::Display for CurrencyCodeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CurrencyCodeError::Invalid(s) => write!(f, "Invalid currency code: {s}"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn meter_set_accumulate_accumulates() {
let mut ms = MeterSet::new();
ms.accumulate(MeterKind::InputTokens, 100).unwrap();
ms.accumulate(MeterKind::InputTokens, 50).unwrap();
assert_eq!(ms.get(&MeterKind::InputTokens), 150);
}
#[test]
fn meter_set_overflow_returns_error() {
let mut ms = MeterSet::new();
ms.accumulate(MeterKind::InputTokens, u64::MAX).unwrap();
let result = ms.accumulate(MeterKind::InputTokens, 1);
assert!(matches!(result, Err(MeterSetError::Overflow(_))));
}
#[test]
fn meter_set_get_missing_returns_zero() {
let ms = MeterSet::new();
assert_eq!(ms.get(&MeterKind::OutputTokens), 0);
}
#[test]
fn meter_kind_custom_hash() {
let mut map = HashMap::new();
map.insert(MeterKind::Custom("test".to_string()), 1);
assert_eq!(map.get(&MeterKind::Custom("test".to_string())), Some(&1));
}
#[test]
fn meter_kind_enum_variant_not_confused_with_custom() {
let mut map = HashMap::new();
map.insert(MeterKind::InputTokens, 1);
map.insert(MeterKind::Custom("InputTokens".to_string()), 2);
assert_eq!(map.len(), 2);
}
#[test]
fn attributes_insert_valid() {
let mut attrs = Attributes::new();
assert!(attrs.insert("key1", "value1").is_ok());
assert_eq!(attrs.get("key1"), Some("value1"));
}
#[test]
fn attributes_rejects_reserved_prefix() {
let mut attrs = Attributes::new();
let result = attrs.insert("sys.test", "value");
assert!(matches!(result, Err(AttributeError::ReservedPrefix(_))));
}
#[test]
fn attributes_rejects_too_long_key() {
let mut attrs = Attributes::new();
let long_key = "a".repeat(65);
let result = attrs.insert(long_key, "value");
assert!(matches!(result, Err(AttributeError::KeyTooLong { .. })));
}
}