#![deny(missing_docs)]
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
const INFERENCE_PROFILE_PREFIXES: &[&str] = &["us.", "eu.", "apac."];
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(default))]
pub struct Usage {
pub input_tokens: u64,
pub output_tokens: u64,
}
impl Usage {
pub fn total_tokens(&self) -> u64 {
self.input_tokens + self.output_tokens
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Pricing {
pub input_per_mtok: f64,
pub output_per_mtok: f64,
}
impl Pricing {
pub fn cost_for(&self, usage: &Usage) -> f64 {
(usage.input_tokens as f64 * self.input_per_mtok
+ usage.output_tokens as f64 * self.output_per_mtok)
/ 1_000_000.0
}
}
pub fn normalize_model_id(id: &str) -> &str {
let mut s = id;
if s.starts_with("arn:aws:bedrock:") {
if let Some(slash) = s.rfind('/') {
s = &s[slash + 1..];
}
}
for prefix in INFERENCE_PROFILE_PREFIXES {
if let Some(rest) = s.strip_prefix(prefix) {
s = rest;
break;
}
}
if let Some(idx) = s.rfind("-v") {
let tail = &s[idx + 2..];
if tail
.splitn(2, ':')
.all(|part| !part.is_empty() && part.chars().all(|c| c.is_ascii_digit()))
&& tail.contains(':')
{
s = &s[..idx];
}
}
s
}
pub const DEFAULT_PRICING_TABLE: &[(&str, Pricing)] = &[
(
"meta.llama3-8b-instruct",
Pricing { input_per_mtok: 0.30, output_per_mtok: 0.60 },
),
(
"meta.llama3-70b-instruct",
Pricing { input_per_mtok: 2.65, output_per_mtok: 3.50 },
),
(
"meta.llama3-1-8b-instruct",
Pricing { input_per_mtok: 0.22, output_per_mtok: 0.22 },
),
(
"meta.llama3-1-70b-instruct",
Pricing { input_per_mtok: 0.72, output_per_mtok: 0.72 },
),
(
"meta.llama3-1-405b-instruct",
Pricing { input_per_mtok: 5.32, output_per_mtok: 16.00 },
),
(
"mistral.mistral-large-2407",
Pricing { input_per_mtok: 2.00, output_per_mtok: 6.00 },
),
(
"mistral.mistral-small-2402",
Pricing { input_per_mtok: 1.00, output_per_mtok: 3.00 },
),
(
"cohere.command-r-plus",
Pricing { input_per_mtok: 3.00, output_per_mtok: 15.00 },
),
(
"cohere.command-r",
Pricing { input_per_mtok: 0.50, output_per_mtok: 1.50 },
),
(
"amazon.titan-text-premier",
Pricing { input_per_mtok: 0.50, output_per_mtok: 1.50 },
),
(
"amazon.titan-text-express",
Pricing { input_per_mtok: 0.20, output_per_mtok: 0.60 },
),
(
"amazon.titan-text-lite",
Pricing { input_per_mtok: 0.15, output_per_mtok: 0.20 },
),
(
"ai21.jamba-1-5-large",
Pricing { input_per_mtok: 2.00, output_per_mtok: 8.00 },
),
(
"ai21.jamba-1-5-mini",
Pricing { input_per_mtok: 0.20, output_per_mtok: 0.40 },
),
];
pub fn default_pricing(model_id: &str) -> Option<Pricing> {
let key = normalize_model_id(model_id);
DEFAULT_PRICING_TABLE
.iter()
.find(|(k, _)| *k == key)
.map(|(_, p)| *p)
}