use crate::CurrencyCode;
use crate::observation::{MeterKind, UsageObservation};
use crate::pricing::{PriceSnapshot, TierConfig};
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RatedUsageRecord {
pub rated_record_id: String,
pub observation: UsageObservation,
pub rating: RatingResult,
pub supersedes: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RatingResult {
pub line_items: Vec<RatedLineItem>,
pub total_cost: Decimal,
pub currency: CurrencyCode,
pub price_snapshot_id: String,
pub rated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Default)]
pub struct RatingContext {
pub cumulative_baseline_usage_mtok: u64,
pub billing_period: Option<String>,
pub tenant_scope: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RatedLineItem {
pub meter_kind: MeterKind,
pub quantity: u64,
pub unit_price: Decimal,
pub subtotal: Decimal,
}
pub trait RatingEngine: Send + Sync {
fn rate(
&self,
observation: &UsageObservation,
snapshot: &PriceSnapshot,
context: &RatingContext,
) -> Result<RatingResult, RatingError>;
}
#[derive(Debug, Clone)]
pub enum RatingError {
NoPriceForMeter {
meter_kind: MeterKind,
snapshot_id: String,
},
DecimalOverflow {
meter_kind: MeterKind,
quantity: u64,
unit_price: Decimal,
},
InvalidTierConfig(String),
NoMatchingTier {
usage_mtok: u64,
},
Other(String),
}
pub struct DefaultRatingEngine;
impl DefaultRatingEngine {
pub fn new() -> Self {
Self
}
}
impl RatingEngine for DefaultRatingEngine {
fn rate(
&self,
observation: &UsageObservation,
snapshot: &PriceSnapshot,
context: &RatingContext,
) -> Result<RatingResult, RatingError> {
let mut line_items = Vec::new();
let mut total_cost: Decimal = 0.into();
for (meter_kind, &quantity) in &observation.meter_set.meters {
let price = snapshot.prices.get(meter_kind).ok_or_else(|| {
RatingError::NoPriceForMeter {
meter_kind: meter_kind.clone(),
snapshot_id: snapshot.snapshot_id.clone(),
}
})?;
let quantity_dec: Decimal = quantity.into();
let unit_price = price.unit_price;
let subtotal = unit_price
.checked_mul(quantity_dec)
.ok_or_else(|| RatingError::DecimalOverflow {
meter_kind: meter_kind.clone(),
quantity,
unit_price,
})?;
let final_subtotal = if let Some(ref tier_config) = snapshot.tiers {
let multiplier = calculate_tier_multiplier(tier_config, context)?;
if let Some(m) = multiplier {
subtotal
.checked_mul(m)
.ok_or_else(|| RatingError::DecimalOverflow {
meter_kind: meter_kind.clone(),
quantity,
unit_price,
})?
} else {
subtotal
}
} else {
subtotal
};
total_cost = total_cost
.checked_add(final_subtotal)
.ok_or_else(|| RatingError::DecimalOverflow {
meter_kind: meter_kind.clone(),
quantity,
unit_price,
})?;
line_items.push(RatedLineItem {
meter_kind: meter_kind.clone(),
quantity,
unit_price,
subtotal: final_subtotal,
});
}
Ok(RatingResult {
line_items,
total_cost,
currency: snapshot.currency.clone(),
price_snapshot_id: snapshot.snapshot_id.clone(),
rated_at: Utc::now(),
})
}
}
fn calculate_tier_multiplier(
tier_config: &TierConfig,
context: &RatingContext,
) -> Result<Option<Decimal>, RatingError> {
let cumulative = context.cumulative_baseline_usage_mtok;
let mut matched_multiplier: Option<Decimal> = None;
for boundary in &tier_config.boundaries {
if cumulative <= boundary.up_to_mtok {
if let Some(mp) = boundary.price_multiplier {
matched_multiplier = Some(mp);
}
}
}
Ok(matched_multiplier)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::observation::MeterSet;
use crate::pricing::{MeterPrice, TierBoundary, TierConfig, TierBaseline, AccumulationScope};
use rust_decimal_macros::dec;
#[test]
fn rated_line_item_calculation() {
let item = RatedLineItem {
meter_kind: MeterKind::InputTokens,
quantity: 1000,
unit_price: dec!(0.0003),
subtotal: dec!(0.30),
};
assert_eq!(item.subtotal, dec!(0.30));
}
#[test]
fn rating_result_total_cost() {
let result = RatingResult {
line_items: vec![],
total_cost: dec!(1.50),
currency: CurrencyCode::usd(),
price_snapshot_id: "snap-123".to_string(),
rated_at: Utc::now(),
};
assert_eq!(result.total_cost, dec!(1.50));
}
#[test]
fn default_rating_engine_basic() {
let engine = DefaultRatingEngine::new();
let mut meter_set = MeterSet::new();
meter_set.accumulate(MeterKind::InputTokens, 1000).unwrap();
let observation = UsageObservation {
event_id: crate::identity::UsageEventId::from_raw("test-1"),
subject: crate::identity::BillingSubject::default(),
meter_set,
model_ref: crate::pricing::ModelRef {
billable_model: "test".to_string(),
vendor: None,
region: None,
tier: None,
},
provider_ref: None,
source: crate::observation::UsageSource::Estimated,
outcome: crate::observation::UsageOutcome::Success,
timing: crate::observation::UsageTiming {
observed_at: Utc::now(),
completed_at: None,
},
correlation: crate::identity::CorrelationIds::default(),
attributes: crate::observation::Attributes::new(),
};
let mut prices = std::collections::HashMap::new();
prices.insert(
MeterKind::InputTokens,
MeterPrice {
unit_price: dec!(0.0003),
unit_display: "1M tokens".to_string(),
},
);
let snapshot = PriceSnapshot {
snapshot_id: "test-snap".to_string(),
model_ref: crate::pricing::ModelRef {
billable_model: "test".to_string(),
vendor: None,
region: None,
tier: None,
},
currency: CurrencyCode::usd(),
prices,
tiers: None,
effective_from: Utc::now(),
effective_until: None,
};
let context = RatingContext::default();
let result = engine.rate(&observation, &snapshot, &context).unwrap();
assert_eq!(result.total_cost, dec!(0.30)); }
}