1#![deny(missing_docs)]
38
39#[cfg(feature = "serde")]
40use serde::{Deserialize, Serialize};
41
42const INFERENCE_PROFILE_PREFIXES: &[&str] = &["us.", "eu.", "apac."];
43
44#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
47#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
48#[cfg_attr(feature = "serde", serde(default))]
49pub struct Usage {
50 pub input_tokens: u64,
52 pub output_tokens: u64,
54}
55
56impl Usage {
57 pub fn total_tokens(&self) -> u64 {
59 self.input_tokens + self.output_tokens
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq)]
65pub struct Pricing {
66 pub input_per_mtok: f64,
68 pub output_per_mtok: f64,
70}
71
72impl Pricing {
73 pub fn cost_for(&self, usage: &Usage) -> f64 {
75 (usage.input_tokens as f64 * self.input_per_mtok
76 + usage.output_tokens as f64 * self.output_per_mtok)
77 / 1_000_000.0
78 }
79}
80
81pub fn normalize_model_id(id: &str) -> &str {
86 let mut s = id;
87
88 if s.starts_with("arn:aws:bedrock:") {
90 if let Some(slash) = s.rfind('/') {
91 s = &s[slash + 1..];
92 }
93 }
94
95 for prefix in INFERENCE_PROFILE_PREFIXES {
97 if let Some(rest) = s.strip_prefix(prefix) {
98 s = rest;
99 break;
100 }
101 }
102
103 if let Some(idx) = s.rfind("-v") {
105 let tail = &s[idx + 2..];
106 if tail
107 .splitn(2, ':')
108 .all(|part| !part.is_empty() && part.chars().all(|c| c.is_ascii_digit()))
109 && tail.contains(':')
110 {
111 s = &s[..idx];
112 }
113 }
114
115 s
116}
117
118pub const DEFAULT_PRICING_TABLE: &[(&str, Pricing)] = &[
122 (
124 "meta.llama3-8b-instruct",
125 Pricing { input_per_mtok: 0.30, output_per_mtok: 0.60 },
126 ),
127 (
128 "meta.llama3-70b-instruct",
129 Pricing { input_per_mtok: 2.65, output_per_mtok: 3.50 },
130 ),
131 (
132 "meta.llama3-1-8b-instruct",
133 Pricing { input_per_mtok: 0.22, output_per_mtok: 0.22 },
134 ),
135 (
136 "meta.llama3-1-70b-instruct",
137 Pricing { input_per_mtok: 0.72, output_per_mtok: 0.72 },
138 ),
139 (
140 "meta.llama3-1-405b-instruct",
141 Pricing { input_per_mtok: 5.32, output_per_mtok: 16.00 },
142 ),
143 (
145 "mistral.mistral-large-2407",
146 Pricing { input_per_mtok: 2.00, output_per_mtok: 6.00 },
147 ),
148 (
149 "mistral.mistral-small-2402",
150 Pricing { input_per_mtok: 1.00, output_per_mtok: 3.00 },
151 ),
152 (
154 "cohere.command-r-plus",
155 Pricing { input_per_mtok: 3.00, output_per_mtok: 15.00 },
156 ),
157 (
158 "cohere.command-r",
159 Pricing { input_per_mtok: 0.50, output_per_mtok: 1.50 },
160 ),
161 (
163 "amazon.titan-text-premier",
164 Pricing { input_per_mtok: 0.50, output_per_mtok: 1.50 },
165 ),
166 (
167 "amazon.titan-text-express",
168 Pricing { input_per_mtok: 0.20, output_per_mtok: 0.60 },
169 ),
170 (
171 "amazon.titan-text-lite",
172 Pricing { input_per_mtok: 0.15, output_per_mtok: 0.20 },
173 ),
174 (
176 "ai21.jamba-1-5-large",
177 Pricing { input_per_mtok: 2.00, output_per_mtok: 8.00 },
178 ),
179 (
180 "ai21.jamba-1-5-mini",
181 Pricing { input_per_mtok: 0.20, output_per_mtok: 0.40 },
182 ),
183];
184
185pub fn default_pricing(model_id: &str) -> Option<Pricing> {
187 let key = normalize_model_id(model_id);
188 DEFAULT_PRICING_TABLE
189 .iter()
190 .find(|(k, _)| *k == key)
191 .map(|(_, p)| *p)
192}