use serde::{Deserialize, Serialize};
use crate::{PricingEstimate, Usage};
use super::ModelPricing;
#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct ModelPricingDetails {
pub input_micros_per_million_tokens: u64,
pub output_micros_per_million_tokens: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_write_micros_per_million_tokens: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_read_micros_per_million_tokens: Option<u64>,
}
impl ModelPricingDetails {
#[must_use]
pub const fn new(input_micros: u64, output_micros: u64) -> Self {
Self {
input_micros_per_million_tokens: input_micros,
output_micros_per_million_tokens: output_micros,
cache_write_micros_per_million_tokens: None,
cache_read_micros_per_million_tokens: None,
}
}
#[must_use]
pub const fn standard_pricing(&self) -> ModelPricing {
ModelPricing::new(
self.input_micros_per_million_tokens,
self.output_micros_per_million_tokens,
)
}
#[must_use]
pub const fn with_cache_write_micros_per_million_tokens(mut self, micros: u64) -> Self {
self.cache_write_micros_per_million_tokens = Some(micros);
self
}
#[must_use]
pub const fn with_cache_read_micros_per_million_tokens(mut self, micros: u64) -> Self {
self.cache_read_micros_per_million_tokens = Some(micros);
self
}
#[must_use]
pub const fn estimate_micros(&self, usage: &Usage) -> u64 {
let cache_write_rate = match self.cache_write_micros_per_million_tokens {
Some(rate) => rate,
None => self.input_micros_per_million_tokens,
};
let cache_read_rate = match self.cache_read_micros_per_million_tokens {
Some(rate) => rate,
None => self.input_micros_per_million_tokens,
};
let standard_input_tokens = usage
.input_tokens
.saturating_sub(usage.cache_write_tokens)
.saturating_sub(usage.cache_read_tokens);
cost_for_tokens(standard_input_tokens, self.input_micros_per_million_tokens)
.saturating_add(cost_for_tokens(usage.cache_write_tokens, cache_write_rate))
.saturating_add(cost_for_tokens(usage.cache_read_tokens, cache_read_rate))
.saturating_add(cost_for_tokens(
usage.output_tokens,
self.output_micros_per_million_tokens,
))
}
#[must_use]
pub const fn estimate(&self, usage: &Usage) -> PricingEstimate {
PricingEstimate::from_micros_usd(self.estimate_micros(usage))
}
}
#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct ModelPricingTier {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_input_tokens: Option<u64>,
pub pricing: ModelPricingDetails,
}
impl ModelPricingTier {
#[must_use]
pub const fn new(max_input_tokens: Option<u64>, pricing: ModelPricingDetails) -> Self {
Self {
max_input_tokens,
pricing,
}
}
#[must_use]
pub const fn matches(&self, input_tokens: u64) -> bool {
match self.max_input_tokens {
Some(max_input_tokens) => input_tokens <= max_input_tokens,
None => true,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum ModelPricingProfileKind {
Single(ModelPricingDetails),
Tiered(&'static [ModelPricingTier]),
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct ModelPricingProfile {
kind: ModelPricingProfileKind,
}
impl ModelPricingProfile {
#[must_use]
pub const fn from_details(pricing: ModelPricingDetails) -> Self {
Self {
kind: ModelPricingProfileKind::Single(pricing),
}
}
#[must_use]
pub const fn from_tiers(tiers: &'static [ModelPricingTier]) -> Self {
Self {
kind: ModelPricingProfileKind::Tiered(tiers),
}
}
#[must_use]
pub fn default_details(&self) -> ModelPricingDetails {
match self.kind {
ModelPricingProfileKind::Single(pricing) => pricing,
ModelPricingProfileKind::Tiered(tiers) => tiers
.first()
.map_or_else(ModelPricingDetails::default, |tier| tier.pricing),
}
}
#[must_use]
pub const fn is_tiered(&self) -> bool {
matches!(self.kind, ModelPricingProfileKind::Tiered(_))
}
#[must_use]
pub const fn tiers(&self) -> Option<&'static [ModelPricingTier]> {
match self.kind {
ModelPricingProfileKind::Single(_) => None,
ModelPricingProfileKind::Tiered(tiers) => Some(tiers),
}
}
#[must_use]
pub fn standard_pricing(&self) -> ModelPricing {
self.default_details().standard_pricing()
}
#[must_use]
pub fn details_for_input_tokens(&self, input_tokens: u64) -> ModelPricingDetails {
match self.kind {
ModelPricingProfileKind::Single(pricing) => pricing,
ModelPricingProfileKind::Tiered(tiers) => {
let mut fallback = ModelPricingDetails::default();
for tier in tiers {
fallback = tier.pricing;
if tier.matches(input_tokens) {
return tier.pricing;
}
}
fallback
}
}
}
#[must_use]
pub fn details_for_usage(&self, usage: &Usage) -> ModelPricingDetails {
self.details_for_input_tokens(usage.input_tokens)
}
#[must_use]
pub fn estimate_micros(&self, usage: &Usage) -> u64 {
self.details_for_usage(usage).estimate_micros(usage)
}
#[must_use]
pub fn estimate(&self, usage: &Usage) -> PricingEstimate {
PricingEstimate::from_micros_usd(self.estimate_micros(usage))
}
}
#[allow(clippy::cast_possible_truncation)]
pub(crate) const fn cost_for_tokens(tokens: u64, micros_per_million_tokens: u64) -> u64 {
let cost = (tokens as u128)
.saturating_mul(micros_per_million_tokens as u128)
.saturating_add(999_999)
/ 1_000_000;
if cost > u64::MAX as u128 {
u64::MAX
} else {
cost as u64
}
}
const CNY_PER_USD_MILLIS: u64 = 7_200;
#[allow(clippy::cast_possible_truncation)]
pub(crate) const fn cny_millis_to_usd_micros(cny_millis: u64) -> u64 {
let micros = (cny_millis as u128)
.saturating_mul(1_000_000)
.saturating_add((CNY_PER_USD_MILLIS / 2) as u128)
/ CNY_PER_USD_MILLIS as u128;
if micros > u64::MAX as u128 {
u64::MAX
} else {
micros as u64
}
}
#[allow(clippy::cast_possible_truncation)]
pub(crate) const fn scale_rate(rate: u64, numerator: u64, denominator: u64) -> u64 {
let scaled = (rate as u128)
.saturating_mul(numerator as u128)
.saturating_add((denominator - 1) as u128)
/ denominator as u128;
if scaled > u64::MAX as u128 {
u64::MAX
} else {
scaled as u64
}
}