Skip to main content

latch_billing/
pricing.rs

1//! Pricing module - defines pricing models and the `PricingSource` trait.
2//!
3//! The key design decision here is "push mode" for pricing:
4//! - `PricingSource` trait is provided for in-memory/file-based pricing (Phase 3).
5//! - For DB/remote pricing, the caller (e.g., xrouter adapter) should
6//!   asynchronously fetch the `PriceSnapshot` and pass it to `RatingEngine::rate()`.
7
8use crate::observation::MeterKind;
9use crate::CurrencyCode;
10use chrono::{DateTime, Utc};
11use rust_decimal::Decimal;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15/// Reference to a model for pricing purposes.
16///
17/// Pricing is keyed by `(billable_model, vendor, region, tier)`.
18/// The optional fields allow for flexible pricing strategies:
19/// - Global pricing: only `billable_model` is set
20/// - Vendor-specific: add `vendor`
21/// - Regional pricing: add `region`
22/// - Tier-based: add `tier` (e.g., "enterprise", "standard")
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ModelRef {
25    /// The billable model name (e.g., "gpt-4o", "claude-3-opus").
26    pub billable_model: String,
27
28    /// Vendor/provider name (e.g., "openai", "anthropic").
29    pub vendor: Option<String>,
30
31    /// Region for regional pricing (e.g., "us", "eu").
32    pub region: Option<String>,
33
34    /// Pricing tier (e.g., "standard", "enterprise", "batch").
35    pub tier: Option<String>,
36}
37
38/// Reference to a provider instance.
39///
40/// Used to distinguish between different deployments of the same provider.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ProviderRef {
43    /// Provider identifier (e.g., "openai-west", "self-hosted-1").
44    pub provider_id: String,
45}
46
47/// A snapshot of prices for a model at a point in time.
48///
49/// This is the key structure passed to `RatingEngine::rate()`.
50/// It is immutable once created - price changes create a new snapshot.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct PriceSnapshot {
53    /// Unique identifier for this snapshot (for audit/traceability).
54    pub snapshot_id: String,
55
56    /// Which model this snapshot applies to.
57    pub model_ref: ModelRef,
58
59    /// Currency for all prices in this snapshot.
60    pub currency: CurrencyCode,
61
62    /// Price per meter kind.
63    /// Only meter kinds present in this map will be billed.
64    pub prices: HashMap<MeterKind, MeterPrice>,
65
66    /// Optional tier configuration for volume-based pricing.
67    pub tiers: Option<TierConfig>,
68
69    /// When this snapshot becomes effective.
70    pub effective_from: DateTime<Utc>,
71
72    /// When this snapshot expires (None = no expiration).
73    pub effective_until: Option<DateTime<Utc>>,
74}
75
76/// Price for a single meter kind.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct MeterPrice {
79    /// Price per 1 unit of meter (e.g., per 1 token, not per 1M).
80    pub unit_price: Decimal,
81
82    /// Display string for the unit (e.g., "1M tokens", "per image").
83    pub unit_display: String,
84}
85
86/// Tier configuration for volume-based pricing.
87///
88/// Switches prices based on cumulative usage (in MTok) of a baseline meter.
89/// The cumulative value is provided by the caller when calling `RatingEngine::rate()`.
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct TierConfig {
92    /// Which meter to accumulate for tier calculation.
93    pub baseline_meter: TierBaseline,
94
95    /// Who provides the cumulative value.
96    pub accumulation_scope: AccumulationScope,
97
98    /// Sorted list of tier boundaries.
99    ///
100    /// Each boundary defines a pricing threshold.
101    /// The list should be sorted by `up_to_mtok` ascending.
102    pub boundaries: Vec<TierBoundary>,
103}
104
105/// Which meter to use as baseline for tier calculation.
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub enum TierBaseline {
108    /// Accumulate by a single meter kind.
109    Meter(MeterKind),
110    /// Accumulate by total tokens (all meter kinds).
111    TotalTokens,
112}
113
114/// Who provides the cumulative usage value.
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub enum AccumulationScope {
117    /// Caller provides cumulative value via `RatingContext` (only executable path currently).
118    CallerProvided,
119}
120
121/// A single tier boundary.
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct TierBoundary {
124    /// Threshold (inclusive) in MTok (millions of tokens).
125    /// `0` means "starting from 0" (first tier).
126    pub up_to_mtok: u64,
127
128    /// Price multiplier relative to `base_price`.
129    /// If set, final price = `base_price * price_multiplier`.
130    pub price_multiplier: Option<Decimal>,
131
132    /// OR: absolute price per MTok.
133    /// If set, this overrides `price_multiplier`.
134    pub absolute_price_per_mtok: Option<Decimal>,
135}
136
137/// Trait for resolving price snapshots.
138///
139/// **Design note**: This trait is sync and intended for in-memory
140/// or file-based pricing sources (e.g., `TomlPricingSource` in Phase 3).
141///
142/// For DB/remote pricing, use **push mode** instead:
143/// 1. Caller (e.g., xrouter adapter) asynchronously fetches pricing data
144/// 2. Caller constructs `PriceSnapshot`
145/// 3. Caller passes `PriceSnapshot` to `RatingEngine::rate()`
146///
147/// This keeps `tokenbill-core` sync and avoids forcing async on all users.
148pub trait PricingSource: Send + Sync {
149    /// Resolve a price snapshot for the given model and provider.
150    fn resolve_snapshot(
151        &self,
152        model_ref: &ModelRef,
153        provider_ref: Option<&ProviderRef>,
154    ) -> Result<PriceSnapshot, PricingError>;
155}
156
157/// Error type for pricing operations.
158#[derive(Debug, Clone)]
159pub enum PricingError {
160    /// No price found for the given model/provider.
161    NotFound {
162        model: String,
163        provider: Option<String>,
164    },
165    /// Invalid pricing configuration.
166    InvalidConfig(String),
167    /// Generic error.
168    Other(String),
169}
170
171impl std::fmt::Display for PricingError {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        match self {
174            PricingError::NotFound { model, provider } => {
175                write!(f, "Price not found for model={model}, provider={provider:?}")
176            }
177            PricingError::InvalidConfig(s) => write!(f, "Invalid pricing config: {s}"),
178            PricingError::Other(s) => write!(f, "Pricing error: {s}"),
179        }
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn model_ref_equality() {
189        let m1 = ModelRef {
190            billable_model: "gpt-4o".to_string(),
191            vendor: Some("openai".to_string()),
192            region: None,
193            tier: None,
194        };
195        let m2 = ModelRef {
196            billable_model: "gpt-4o".to_string(),
197            vendor: Some("openai".to_string()),
198            region: None,
199            tier: None,
200        };
201        // Note: ModelRef doesn't derive Eq/PartialEq in this design.
202        // That's intentional - we compare by snapshot_id or use custom comparison.
203        assert_eq!(m1.billable_model, m2.billable_model);
204    }
205
206    #[test]
207    fn meter_price_creation() {
208        let price = MeterPrice {
209            unit_price: Decimal::new(30, 2), // 0.30
210            unit_display: "1M tokens".to_string(),
211        };
212        assert_eq!(price.unit_price, Decimal::new(30, 2));
213    }
214}