Skip to main content

litellm_rust/
registry.rs

1use serde_json::Value;
2use std::collections::HashMap;
3
4use crate::error::{LiteLLMError, Result};
5
6#[derive(Debug, Clone)]
7pub struct ModelPricing {
8    pub input_cost_per_1k: Option<f64>,
9    pub output_cost_per_1k: Option<f64>,
10    pub max_input_tokens: Option<u32>,
11    pub max_output_tokens: Option<u32>,
12    pub mode: Option<String>,
13    pub provider: Option<String>,
14}
15
16#[derive(Debug, Clone)]
17pub struct Registry {
18    pub models: HashMap<String, ModelPricing>,
19}
20
21impl Registry {
22    pub fn load_embedded() -> Result<Self> {
23        let raw = include_str!("../data/model_prices_and_context_window.json");
24        let json: Value = serde_json::from_str(raw)
25            .map_err(|e| LiteLLMError::Parse(format!("model registry: {e}")))?;
26        let mut models = HashMap::new();
27        let map = json
28            .as_object()
29            .ok_or_else(|| LiteLLMError::Parse("model registry not an object".into()))?;
30        for (name, entry) in map {
31            if name == "sample_spec" {
32                continue;
33            }
34            if let Some(obj) = entry.as_object() {
35                let input = obj
36                    .get("input_cost_per_token")
37                    .and_then(|v| v.as_f64())
38                    .map(|v| v * 1000.0);
39                let output = obj
40                    .get("output_cost_per_token")
41                    .and_then(|v| v.as_f64())
42                    .map(|v| v * 1000.0);
43                let max_input = obj
44                    .get("max_input_tokens")
45                    .or_else(|| obj.get("max_tokens"))
46                    .and_then(|v| v.as_u64())
47                    .map(|v| v as u32);
48                let max_output = obj
49                    .get("max_output_tokens")
50                    .and_then(|v| v.as_u64())
51                    .map(|v| v as u32);
52                let mode = obj
53                    .get("mode")
54                    .and_then(|v| v.as_str())
55                    .map(|s| s.to_string());
56                let provider = obj
57                    .get("litellm_provider")
58                    .and_then(|v| v.as_str())
59                    .map(|s| s.to_string());
60                models.insert(
61                    name.to_string(),
62                    ModelPricing {
63                        input_cost_per_1k: input,
64                        output_cost_per_1k: output,
65                        max_input_tokens: max_input,
66                        max_output_tokens: max_output,
67                        mode,
68                        provider,
69                    },
70                );
71            }
72        }
73        Ok(Self { models })
74    }
75
76    pub fn get(&self, model: &str) -> Option<&ModelPricing> {
77        self.models.get(model)
78    }
79
80    pub fn estimate_cost(&self, model: &str, input_tokens: u32, output_tokens: u32) -> Option<f64> {
81        let pricing = self.models.get(model)?;
82        let input = pricing
83            .input_cost_per_1k
84            .map(|v| v * input_tokens as f64 / 1000.0)?;
85        let output = pricing
86            .output_cost_per_1k
87            .map(|v| v * output_tokens as f64 / 1000.0)?;
88        Some(input + output)
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn loads_registry() {
98        let registry = Registry::load_embedded().unwrap();
99        assert!(!registry.models.is_empty());
100    }
101}