use crate::observation::MeterKind;
use crate::CurrencyCode;
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelRef {
pub billable_model: String,
pub vendor: Option<String>,
pub region: Option<String>,
pub tier: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderRef {
pub provider_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PriceSnapshot {
pub snapshot_id: String,
pub model_ref: ModelRef,
pub currency: CurrencyCode,
pub prices: HashMap<MeterKind, MeterPrice>,
pub tiers: Option<TierConfig>,
pub effective_from: DateTime<Utc>,
pub effective_until: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MeterPrice {
pub unit_price: Decimal,
pub unit_display: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TierConfig {
pub baseline_meter: TierBaseline,
pub accumulation_scope: AccumulationScope,
pub boundaries: Vec<TierBoundary>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TierBaseline {
Meter(MeterKind),
TotalTokens,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AccumulationScope {
CallerProvided,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TierBoundary {
pub up_to_mtok: u64,
pub price_multiplier: Option<Decimal>,
pub absolute_price_per_mtok: Option<Decimal>,
}
pub trait PricingSource: Send + Sync {
fn resolve_snapshot(
&self,
model_ref: &ModelRef,
provider_ref: Option<&ProviderRef>,
) -> Result<PriceSnapshot, PricingError>;
}
#[derive(Debug, Clone)]
pub enum PricingError {
NotFound {
model: String,
provider: Option<String>,
},
InvalidConfig(String),
Other(String),
}
impl std::fmt::Display for PricingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PricingError::NotFound { model, provider } => {
write!(f, "Price not found for model={model}, provider={provider:?}")
}
PricingError::InvalidConfig(s) => write!(f, "Invalid pricing config: {s}"),
PricingError::Other(s) => write!(f, "Pricing error: {s}"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn model_ref_equality() {
let m1 = ModelRef {
billable_model: "gpt-4o".to_string(),
vendor: Some("openai".to_string()),
region: None,
tier: None,
};
let m2 = ModelRef {
billable_model: "gpt-4o".to_string(),
vendor: Some("openai".to_string()),
region: None,
tier: None,
};
assert_eq!(m1.billable_model, m2.billable_model);
}
#[test]
fn meter_price_creation() {
let price = MeterPrice {
unit_price: Decimal::new(30, 2), unit_display: "1M tokens".to_string(),
};
assert_eq!(price.unit_price, Decimal::new(30, 2));
}
}