Skip to main content

lean_ctx/core/gain/
model_pricing.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
5pub struct ModelCost {
6    pub input_per_m: f64,
7    pub output_per_m: f64,
8    pub cache_write_per_m: f64,
9    pub cache_read_per_m: f64,
10}
11
12impl ModelCost {
13    pub fn estimate_usd(&self, input: u64, output: u64, cache_write: u64, cache_read: u64) -> f64 {
14        (input as f64 / 1_000_000.0 * self.input_per_m)
15            + (output as f64 / 1_000_000.0 * self.output_per_m)
16            + (cache_write as f64 / 1_000_000.0 * self.cache_write_per_m)
17            + (cache_read as f64 / 1_000_000.0 * self.cache_read_per_m)
18    }
19}
20
21#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
22pub enum PricingMatchKind {
23    Exact,
24    Alias,
25    Heuristic,
26    Fallback,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ModelQuote {
31    pub model_key: String,
32    pub cost: ModelCost,
33    pub match_kind: PricingMatchKind,
34}
35
36#[derive(Debug, Clone)]
37pub struct ModelPricing {
38    models: HashMap<String, ModelCost>,
39}
40
41impl ModelPricing {
42    pub fn load() -> Self {
43        let mut p = Self::embedded();
44        p.apply_env_override();
45        p
46    }
47
48    pub fn embedded() -> Self {
49        let mut models: HashMap<String, ModelCost> = HashMap::new();
50
51        // Anthropic prompt caching pricing (public, GA) — source: https://anthropic.com/news/prompt-caching
52        models.insert(
53            "claude-3.5-sonnet".to_string(),
54            ModelCost {
55                input_per_m: 3.00,
56                output_per_m: 15.00,
57                cache_write_per_m: 3.75,
58                cache_read_per_m: 0.30,
59            },
60        );
61        models.insert(
62            "claude-3-opus".to_string(),
63            ModelCost {
64                input_per_m: 15.00,
65                output_per_m: 75.00,
66                cache_write_per_m: 18.75,
67                cache_read_per_m: 1.50,
68            },
69        );
70        models.insert(
71            "claude-3-haiku".to_string(),
72            ModelCost {
73                input_per_m: 0.25,
74                output_per_m: 1.25,
75                cache_write_per_m: 0.30,
76                cache_read_per_m: 0.03,
77            },
78        );
79
80        // OpenAI API pricing (Flagship) — source: https://openai.com/api/pricing/
81        models.insert(
82            "gpt-5.4".to_string(),
83            ModelCost {
84                input_per_m: 2.50,
85                output_per_m: 15.00,
86                cache_write_per_m: 2.50,
87                cache_read_per_m: 0.25,
88            },
89        );
90        models.insert(
91            "gpt-5.4-mini".to_string(),
92            ModelCost {
93                input_per_m: 0.75,
94                output_per_m: 4.50,
95                cache_write_per_m: 0.75,
96                cache_read_per_m: 0.075,
97            },
98        );
99        models.insert(
100            "gpt-5.4-nano".to_string(),
101            ModelCost {
102                input_per_m: 0.20,
103                output_per_m: 1.25,
104                cache_write_per_m: 0.20,
105                cache_read_per_m: 0.02,
106            },
107        );
108
109        // Google Gemini API pricing — source: https://ai.google.dev/pricing
110        // (No separate cache pricing published → treat cache read/write as input.)
111        models.insert(
112            "gemini-2.5-pro".to_string(),
113            ModelCost {
114                input_per_m: 1.25,
115                output_per_m: 10.00,
116                cache_write_per_m: 1.25,
117                cache_read_per_m: 1.25,
118            },
119        );
120        models.insert(
121            "gemini-2.5-flash".to_string(),
122            ModelCost {
123                input_per_m: 0.30,
124                output_per_m: 2.50,
125                cache_write_per_m: 0.30,
126                cache_read_per_m: 0.30,
127            },
128        );
129        models.insert(
130            "gemini-2.5-flash-lite".to_string(),
131            ModelCost {
132                input_per_m: 0.10,
133                output_per_m: 0.40,
134                cache_write_per_m: 0.10,
135                cache_read_per_m: 0.10,
136            },
137        );
138
139        // Conservative blended fallback (used by legacy stats output).
140        models.insert(
141            "fallback-blended".to_string(),
142            ModelCost {
143                input_per_m: 2.50,
144                output_per_m: 10.00,
145                cache_write_per_m: 2.50,
146                cache_read_per_m: 2.50,
147            },
148        );
149
150        Self { models }
151    }
152
153    pub fn quote(&self, model: Option<&str>) -> ModelQuote {
154        let raw = model.unwrap_or_default();
155        if let Some(k) = Self::infer_model_key(raw) {
156            if let Some(cost) = self.models.get(&k).copied() {
157                return ModelQuote {
158                    model_key: k,
159                    cost,
160                    match_kind: PricingMatchKind::Exact,
161                };
162            }
163        }
164
165        if let Some((k, kind)) = Self::heuristic_key(raw) {
166            if let Some(cost) = self.models.get(&k).copied() {
167                return ModelQuote {
168                    model_key: k,
169                    cost,
170                    match_kind: kind,
171                };
172            }
173        }
174
175        let cost = self
176            .models
177            .get("fallback-blended")
178            .copied()
179            .unwrap_or(ModelCost {
180                input_per_m: 2.50,
181                output_per_m: 10.00,
182                cache_write_per_m: 2.50,
183                cache_read_per_m: 2.50,
184            });
185        ModelQuote {
186            model_key: "fallback-blended".to_string(),
187            cost,
188            match_kind: PricingMatchKind::Fallback,
189        }
190    }
191
192    pub fn quote_from_env_or_agent_type(&self, agent_type: &str) -> ModelQuote {
193        let env_model = std::env::var("LEAN_CTX_MODEL")
194            .or_else(|_| std::env::var("LCTX_MODEL"))
195            .ok();
196        self.quote(env_model.as_deref().or(Some(agent_type)))
197    }
198
199    pub fn infer_model_key(model: &str) -> Option<String> {
200        let m = normalize(model);
201        if m.is_empty() {
202            return None;
203        }
204
205        let exact_keys = [
206            "claude-3.5-sonnet",
207            "claude-3-opus",
208            "claude-3-haiku",
209            "gpt-5.4",
210            "gpt-5.4-mini",
211            "gpt-5.4-nano",
212            "gemini-2.5-pro",
213            "gemini-2.5-flash",
214            "gemini-2.5-flash-lite",
215            "fallback-blended",
216        ];
217        for k in exact_keys {
218            if m == k {
219                return Some(k.to_string());
220            }
221        }
222        None
223    }
224
225    fn heuristic_key(model: &str) -> Option<(String, PricingMatchKind)> {
226        let m = normalize(model);
227        if m.is_empty() {
228            return None;
229        }
230
231        // Claude family: accept loose naming (e.g. "claude sonnet", "claude-4.6-sonnet").
232        if m.contains("claude") {
233            if m.contains("sonnet") {
234                return Some(("claude-3.5-sonnet".to_string(), PricingMatchKind::Heuristic));
235            }
236            if m.contains("opus") {
237                return Some(("claude-3-opus".to_string(), PricingMatchKind::Heuristic));
238            }
239            if m.contains("haiku") {
240                return Some(("claude-3-haiku".to_string(), PricingMatchKind::Heuristic));
241            }
242        }
243
244        if m.contains("gemini") {
245            if m.contains("2.5") && m.contains("pro") {
246                return Some(("gemini-2.5-pro".to_string(), PricingMatchKind::Heuristic));
247            }
248            if m.contains("2.5") && m.contains("flash-lite") {
249                return Some((
250                    "gemini-2.5-flash-lite".to_string(),
251                    PricingMatchKind::Heuristic,
252                ));
253            }
254            if m.contains("2.5") && m.contains("flash") {
255                return Some(("gemini-2.5-flash".to_string(), PricingMatchKind::Heuristic));
256            }
257        }
258
259        // OpenAI family: accept "gpt-5.4" variants and legacy "gpt-4o" as alias to blended fallback.
260        if m.contains("gpt-5.4") && m.contains("mini") {
261            return Some(("gpt-5.4-mini".to_string(), PricingMatchKind::Alias));
262        }
263        if m.contains("gpt-5.4") && m.contains("nano") {
264            return Some(("gpt-5.4-nano".to_string(), PricingMatchKind::Alias));
265        }
266        if m.contains("gpt-5.4") {
267            return Some(("gpt-5.4".to_string(), PricingMatchKind::Alias));
268        }
269        if m.contains("gpt-4o") {
270            return Some(("fallback-blended".to_string(), PricingMatchKind::Heuristic));
271        }
272
273        None
274    }
275
276    fn apply_env_override(&mut self) {
277        let raw = std::env::var("LEAN_CTX_MODEL_PRICING_JSON")
278            .or_else(|_| std::env::var("LCTX_MODEL_PRICING_JSON"))
279            .ok();
280        let Some(raw) = raw else { return };
281
282        let Ok(v) = serde_json::from_str::<serde_json::Value>(&raw) else {
283            return;
284        };
285        let Some(models) = v.get("models").and_then(|m| m.as_object()) else {
286            return;
287        };
288        for (k, vv) in models {
289            let Some(obj) = vv.as_object() else { continue };
290            let input_per_m = obj.get("input_per_m").and_then(|x| x.as_f64());
291            let output_per_m = obj.get("output_per_m").and_then(|x| x.as_f64());
292            if input_per_m.is_none() && output_per_m.is_none() {
293                continue;
294            }
295
296            let key_norm = normalize(k);
297            let base = self.models.get(&key_norm).copied().unwrap_or_else(|| {
298                self.models
299                    .get("fallback-blended")
300                    .copied()
301                    .unwrap_or(ModelCost {
302                        input_per_m: 2.50,
303                        output_per_m: 10.00,
304                        cache_write_per_m: 2.50,
305                        cache_read_per_m: 2.50,
306                    })
307            });
308
309            let merged = ModelCost {
310                input_per_m: input_per_m.unwrap_or(base.input_per_m),
311                output_per_m: output_per_m.unwrap_or(base.output_per_m),
312                cache_write_per_m: obj
313                    .get("cache_write_per_m")
314                    .and_then(|x| x.as_f64())
315                    .unwrap_or(base.cache_write_per_m),
316                cache_read_per_m: obj
317                    .get("cache_read_per_m")
318                    .and_then(|x| x.as_f64())
319                    .unwrap_or(base.cache_read_per_m),
320            };
321            self.models.insert(key_norm, merged);
322        }
323    }
324}
325
326fn normalize(s: &str) -> String {
327    s.trim().to_lowercase().replace(' ', "-")
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn quote_falls_back() {
336        let p = ModelPricing::embedded();
337        let q = p.quote(Some("unknown-model"));
338        assert_eq!(q.match_kind, PricingMatchKind::Fallback);
339    }
340
341    #[test]
342    fn claude_sonnet_heuristic() {
343        let p = ModelPricing::embedded();
344        let q = p.quote(Some("claude-4.6-sonnet"));
345        assert!(matches!(
346            q.match_kind,
347            PricingMatchKind::Heuristic | PricingMatchKind::Alias
348        ));
349        assert_eq!(q.model_key, "claude-3.5-sonnet");
350    }
351}