1use serde_json::Value;
2use std::collections::HashMap;
3
4use crate::error::{LiteLLMError, Result};
5
6#[derive(Debug, Clone)]
7pub struct ModelPricing {
8 pub input_cost_per_1k: Option<f64>,
9 pub output_cost_per_1k: Option<f64>,
10 pub max_input_tokens: Option<u32>,
11 pub max_output_tokens: Option<u32>,
12 pub mode: Option<String>,
13 pub provider: Option<String>,
14}
15
16#[derive(Debug, Clone)]
17pub struct Registry {
18 pub models: HashMap<String, ModelPricing>,
19}
20
21impl Registry {
22 pub fn load_embedded() -> Result<Self> {
23 let raw = include_str!("../data/model_prices_and_context_window.json");
24 let json: Value = serde_json::from_str(raw)
25 .map_err(|e| LiteLLMError::Parse(format!("model registry: {e}")))?;
26 let mut models = HashMap::new();
27 let map = json
28 .as_object()
29 .ok_or_else(|| LiteLLMError::Parse("model registry not an object".into()))?;
30 for (name, entry) in map {
31 if name == "sample_spec" {
32 continue;
33 }
34 if let Some(obj) = entry.as_object() {
35 let input = obj
36 .get("input_cost_per_token")
37 .and_then(|v| v.as_f64())
38 .map(|v| v * 1000.0);
39 let output = obj
40 .get("output_cost_per_token")
41 .and_then(|v| v.as_f64())
42 .map(|v| v * 1000.0);
43 let max_input = obj
44 .get("max_input_tokens")
45 .or_else(|| obj.get("max_tokens"))
46 .and_then(|v| v.as_u64())
47 .map(|v| v as u32);
48 let max_output = obj
49 .get("max_output_tokens")
50 .and_then(|v| v.as_u64())
51 .map(|v| v as u32);
52 let mode = obj
53 .get("mode")
54 .and_then(|v| v.as_str())
55 .map(|s| s.to_string());
56 let provider = obj
57 .get("litellm_provider")
58 .and_then(|v| v.as_str())
59 .map(|s| s.to_string());
60 models.insert(
61 name.to_string(),
62 ModelPricing {
63 input_cost_per_1k: input,
64 output_cost_per_1k: output,
65 max_input_tokens: max_input,
66 max_output_tokens: max_output,
67 mode,
68 provider,
69 },
70 );
71 }
72 }
73 Ok(Self { models })
74 }
75
76 pub fn get(&self, model: &str) -> Option<&ModelPricing> {
77 self.models.get(model)
78 }
79
80 pub fn estimate_cost(&self, model: &str, input_tokens: u32, output_tokens: u32) -> Option<f64> {
81 let pricing = self.models.get(model)?;
82 let input = pricing
83 .input_cost_per_1k
84 .map(|v| v * input_tokens as f64 / 1000.0)?;
85 let output = pricing
86 .output_cost_per_1k
87 .map(|v| v * output_tokens as f64 / 1000.0)?;
88 Some(input + output)
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 #[test]
97 fn loads_registry() {
98 let registry = Registry::load_embedded().unwrap();
99 assert!(!registry.models.is_empty());
100 }
101}