1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
5pub struct ModelCost {
6 pub input_per_m: f64,
7 pub output_per_m: f64,
8 pub cache_write_per_m: f64,
9 pub cache_read_per_m: f64,
10}
11
12impl ModelCost {
13 pub fn estimate_usd(&self, input: u64, output: u64, cache_write: u64, cache_read: u64) -> f64 {
14 (input as f64 / 1_000_000.0 * self.input_per_m)
15 + (output as f64 / 1_000_000.0 * self.output_per_m)
16 + (cache_write as f64 / 1_000_000.0 * self.cache_write_per_m)
17 + (cache_read as f64 / 1_000_000.0 * self.cache_read_per_m)
18 }
19}
20
21#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
22pub enum PricingMatchKind {
23 Exact,
24 Alias,
25 Heuristic,
26 Fallback,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ModelQuote {
31 pub model_key: String,
32 pub cost: ModelCost,
33 pub match_kind: PricingMatchKind,
34}
35
36#[derive(Debug, Clone)]
37pub struct ModelPricing {
38 models: HashMap<String, ModelCost>,
39}
40
41impl ModelPricing {
42 pub fn load() -> Self {
43 let mut p = Self::embedded();
44 p.apply_env_override();
45 p
46 }
47
48 pub fn embedded() -> Self {
49 let mut models: HashMap<String, ModelCost> = HashMap::new();
50
51 models.insert(
53 "claude-3.5-sonnet".to_string(),
54 ModelCost {
55 input_per_m: 3.00,
56 output_per_m: 15.00,
57 cache_write_per_m: 3.75,
58 cache_read_per_m: 0.30,
59 },
60 );
61 models.insert(
62 "claude-3-opus".to_string(),
63 ModelCost {
64 input_per_m: 15.00,
65 output_per_m: 75.00,
66 cache_write_per_m: 18.75,
67 cache_read_per_m: 1.50,
68 },
69 );
70 models.insert(
71 "claude-3-haiku".to_string(),
72 ModelCost {
73 input_per_m: 0.25,
74 output_per_m: 1.25,
75 cache_write_per_m: 0.30,
76 cache_read_per_m: 0.03,
77 },
78 );
79
80 models.insert(
82 "gpt-5.4".to_string(),
83 ModelCost {
84 input_per_m: 2.50,
85 output_per_m: 15.00,
86 cache_write_per_m: 2.50,
87 cache_read_per_m: 0.25,
88 },
89 );
90 models.insert(
91 "gpt-5.4-mini".to_string(),
92 ModelCost {
93 input_per_m: 0.75,
94 output_per_m: 4.50,
95 cache_write_per_m: 0.75,
96 cache_read_per_m: 0.075,
97 },
98 );
99 models.insert(
100 "gpt-5.4-nano".to_string(),
101 ModelCost {
102 input_per_m: 0.20,
103 output_per_m: 1.25,
104 cache_write_per_m: 0.20,
105 cache_read_per_m: 0.02,
106 },
107 );
108
109 models.insert(
112 "gemini-2.5-pro".to_string(),
113 ModelCost {
114 input_per_m: 1.25,
115 output_per_m: 10.00,
116 cache_write_per_m: 1.25,
117 cache_read_per_m: 1.25,
118 },
119 );
120 models.insert(
121 "gemini-2.5-flash".to_string(),
122 ModelCost {
123 input_per_m: 0.30,
124 output_per_m: 2.50,
125 cache_write_per_m: 0.30,
126 cache_read_per_m: 0.30,
127 },
128 );
129 models.insert(
130 "gemini-2.5-flash-lite".to_string(),
131 ModelCost {
132 input_per_m: 0.10,
133 output_per_m: 0.40,
134 cache_write_per_m: 0.10,
135 cache_read_per_m: 0.10,
136 },
137 );
138
139 models.insert(
141 "fallback-blended".to_string(),
142 ModelCost {
143 input_per_m: 2.50,
144 output_per_m: 10.00,
145 cache_write_per_m: 2.50,
146 cache_read_per_m: 2.50,
147 },
148 );
149
150 Self { models }
151 }
152
153 pub fn quote(&self, model: Option<&str>) -> ModelQuote {
154 let raw = model.unwrap_or_default();
155 if let Some(k) = Self::infer_model_key(raw) {
156 if let Some(cost) = self.models.get(&k).copied() {
157 return ModelQuote {
158 model_key: k,
159 cost,
160 match_kind: PricingMatchKind::Exact,
161 };
162 }
163 }
164
165 if let Some((k, kind)) = Self::heuristic_key(raw) {
166 if let Some(cost) = self.models.get(&k).copied() {
167 return ModelQuote {
168 model_key: k,
169 cost,
170 match_kind: kind,
171 };
172 }
173 }
174
175 let cost = self
176 .models
177 .get("fallback-blended")
178 .copied()
179 .unwrap_or(ModelCost {
180 input_per_m: 2.50,
181 output_per_m: 10.00,
182 cache_write_per_m: 2.50,
183 cache_read_per_m: 2.50,
184 });
185 ModelQuote {
186 model_key: "fallback-blended".to_string(),
187 cost,
188 match_kind: PricingMatchKind::Fallback,
189 }
190 }
191
192 pub fn quote_from_env_or_agent_type(&self, agent_type: &str) -> ModelQuote {
193 let env_model = std::env::var("LEAN_CTX_MODEL")
194 .or_else(|_| std::env::var("LCTX_MODEL"))
195 .ok();
196 self.quote(env_model.as_deref().or(Some(agent_type)))
197 }
198
199 pub fn infer_model_key(model: &str) -> Option<String> {
200 let m = normalize(model);
201 if m.is_empty() {
202 return None;
203 }
204
205 let exact_keys = [
206 "claude-3.5-sonnet",
207 "claude-3-opus",
208 "claude-3-haiku",
209 "gpt-5.4",
210 "gpt-5.4-mini",
211 "gpt-5.4-nano",
212 "gemini-2.5-pro",
213 "gemini-2.5-flash",
214 "gemini-2.5-flash-lite",
215 "fallback-blended",
216 ];
217 for k in exact_keys {
218 if m == k {
219 return Some(k.to_string());
220 }
221 }
222 None
223 }
224
225 fn heuristic_key(model: &str) -> Option<(String, PricingMatchKind)> {
226 let m = normalize(model);
227 if m.is_empty() {
228 return None;
229 }
230
231 if m.contains("claude") {
233 if m.contains("sonnet") {
234 return Some(("claude-3.5-sonnet".to_string(), PricingMatchKind::Heuristic));
235 }
236 if m.contains("opus") {
237 return Some(("claude-3-opus".to_string(), PricingMatchKind::Heuristic));
238 }
239 if m.contains("haiku") {
240 return Some(("claude-3-haiku".to_string(), PricingMatchKind::Heuristic));
241 }
242 }
243
244 if m.contains("gemini") {
245 if m.contains("2.5") && m.contains("pro") {
246 return Some(("gemini-2.5-pro".to_string(), PricingMatchKind::Heuristic));
247 }
248 if m.contains("2.5") && m.contains("flash-lite") {
249 return Some((
250 "gemini-2.5-flash-lite".to_string(),
251 PricingMatchKind::Heuristic,
252 ));
253 }
254 if m.contains("2.5") && m.contains("flash") {
255 return Some(("gemini-2.5-flash".to_string(), PricingMatchKind::Heuristic));
256 }
257 }
258
259 if m.contains("gpt-5.4") && m.contains("mini") {
261 return Some(("gpt-5.4-mini".to_string(), PricingMatchKind::Alias));
262 }
263 if m.contains("gpt-5.4") && m.contains("nano") {
264 return Some(("gpt-5.4-nano".to_string(), PricingMatchKind::Alias));
265 }
266 if m.contains("gpt-5.4") {
267 return Some(("gpt-5.4".to_string(), PricingMatchKind::Alias));
268 }
269 if m.contains("gpt-4o") {
270 return Some(("fallback-blended".to_string(), PricingMatchKind::Heuristic));
271 }
272
273 None
274 }
275
276 fn apply_env_override(&mut self) {
277 let raw = std::env::var("LEAN_CTX_MODEL_PRICING_JSON")
278 .or_else(|_| std::env::var("LCTX_MODEL_PRICING_JSON"))
279 .ok();
280 let Some(raw) = raw else { return };
281
282 let Ok(v) = serde_json::from_str::<serde_json::Value>(&raw) else {
283 return;
284 };
285 let Some(models) = v.get("models").and_then(|m| m.as_object()) else {
286 return;
287 };
288 for (k, vv) in models {
289 let Some(obj) = vv.as_object() else { continue };
290 let input_per_m = obj.get("input_per_m").and_then(|x| x.as_f64());
291 let output_per_m = obj.get("output_per_m").and_then(|x| x.as_f64());
292 if input_per_m.is_none() && output_per_m.is_none() {
293 continue;
294 }
295
296 let key_norm = normalize(k);
297 let base = self.models.get(&key_norm).copied().unwrap_or_else(|| {
298 self.models
299 .get("fallback-blended")
300 .copied()
301 .unwrap_or(ModelCost {
302 input_per_m: 2.50,
303 output_per_m: 10.00,
304 cache_write_per_m: 2.50,
305 cache_read_per_m: 2.50,
306 })
307 });
308
309 let merged = ModelCost {
310 input_per_m: input_per_m.unwrap_or(base.input_per_m),
311 output_per_m: output_per_m.unwrap_or(base.output_per_m),
312 cache_write_per_m: obj
313 .get("cache_write_per_m")
314 .and_then(|x| x.as_f64())
315 .unwrap_or(base.cache_write_per_m),
316 cache_read_per_m: obj
317 .get("cache_read_per_m")
318 .and_then(|x| x.as_f64())
319 .unwrap_or(base.cache_read_per_m),
320 };
321 self.models.insert(key_norm, merged);
322 }
323 }
324}
325
326fn normalize(s: &str) -> String {
327 s.trim().to_lowercase().replace(' ', "-")
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn quote_falls_back() {
336 let p = ModelPricing::embedded();
337 let q = p.quote(Some("unknown-model"));
338 assert_eq!(q.match_kind, PricingMatchKind::Fallback);
339 }
340
341 #[test]
342 fn claude_sonnet_heuristic() {
343 let p = ModelPricing::embedded();
344 let q = p.quote(Some("claude-4.6-sonnet"));
345 assert!(matches!(
346 q.match_kind,
347 PricingMatchKind::Heuristic | PricingMatchKind::Alias
348 ));
349 assert_eq!(q.model_key, "claude-3.5-sonnet");
350 }
351}