ragit_api/
model.rs

1use crate::api_provider::ApiProvider;
2use crate::error::Error;
3use lazy_static::lazy_static;
4use ragit_fs::join4;
5use ragit_pdl::Message;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::io::{Read, Write, stdin, stdout};
9
10#[derive(Clone, Debug, Eq, Hash, PartialEq)]
11pub struct Model {
12    pub name: String,
13    pub api_name: String,
14    pub can_read_images: bool,
15    pub api_provider: ApiProvider,
16    pub dollars_per_1b_input_tokens: u64,
17    pub dollars_per_1b_output_tokens: u64,
18    pub api_timeout: u64,
19    pub explanation: Option<String>,
20    pub api_key: Option<String>,
21    pub api_env_var: Option<String>,
22}
23
24impl Model {
25    /// This is a test model. It always returns a string `"dummy"`.
26    pub fn dummy() -> Self {
27        Model {
28            name: String::from("dummy"),
29            api_name: String::from("test-model-dummy-v0"),
30            can_read_images: false,
31            api_provider: ApiProvider::Test(TestModel::Dummy),
32            dollars_per_1b_input_tokens: 0,
33            dollars_per_1b_output_tokens: 0,
34            api_timeout: 180,
35            explanation: None,
36            api_key: None,
37            api_env_var: None,
38        }
39    }
40
41    /// This is a test model. It takes a response from you.
42    pub fn stdin() -> Self {
43        Model {
44            name: String::from("stdin"),
45            api_name: String::from("test-model-stdin-v0"),
46            can_read_images: false,
47            api_provider: ApiProvider::Test(TestModel::Stdin),
48            dollars_per_1b_input_tokens: 0,
49            dollars_per_1b_output_tokens: 0,
50            api_timeout: 180,
51            explanation: None,
52            api_key: None,
53            api_env_var: None,
54        }
55    }
56
57    /// This is a test model. It always throws an error.
58    pub fn error() -> Self {
59        Model {
60            name: String::from("error"),
61            api_name: String::from("test-model-error-v0"),
62            can_read_images: false,
63            api_provider: ApiProvider::Test(TestModel::Error),
64            dollars_per_1b_input_tokens: 0,
65            dollars_per_1b_output_tokens: 0,
66            api_timeout: 180,
67            explanation: None,
68            api_key: None,
69            api_env_var: None,
70        }
71    }
72
73    pub fn get_api_url(&self) -> Result<String, Error> {
74        let url = match &self.api_provider {
75            ApiProvider::Anthropic => String::from("https://api.anthropic.com/v1/messages"),
76            ApiProvider::Cohere => String::from("https://api.cohere.com/v2/chat"),
77            ApiProvider::OpenAi { url } => url.to_string(),
78            ApiProvider::Google => format!(
79                "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
80                self.api_name,
81                self.get_api_key()?,
82            ),
83            ApiProvider::Test(_) => String::new(),
84        };
85
86        Ok(url)
87    }
88
89    pub fn get_api_key(&self) -> Result<String, Error> {
90        // First, check if the API key is directly set in the model
91        if let Some(key) = &self.api_key {
92            return Ok(key.to_string());
93        }
94
95        // Next, check if an environment variable is specified and try to get the API key from it
96        if let Some(var) = &self.api_env_var {
97            if let Ok(key) = std::env::var(var) {
98                return Ok(key.to_string());
99            }
100
101            // Don't return an error yet, try the other methods first
102        }
103
104        // If we get here, try to find the API key in external model files
105        if let Some(key) = self.find_api_key_in_external_files()? {
106            return Ok(key);
107        }
108
109        // If we have an api_env_var but couldn't find the key anywhere, return an error
110        if let Some(var) = &self.api_env_var {
111            return Err(Error::ApiKeyNotFound { env_var: Some(var.to_string()) });
112        }
113
114        // If both `api_key` and `api_env_var` are not set,
115        // it assumes that the model does not require an
116        // api key.
117        Ok(String::new())
118    }
119
120    fn find_api_key_in_external_files(&self) -> Result<Option<String>, Error> {
121        // Try to find the API key in the file indicated by RAGIT_MODEL_FILE
122        if let Ok(file_path) = std::env::var("RAGIT_MODEL_FILE") {
123            if let Some(key) = self.find_api_key_in_file(&file_path)? {
124                return Ok(Some(key));
125            }
126        }
127
128        // Try to find the API key in ~/.config/ragit/models.json
129        if let Ok(home_dir) = std::env::var("HOME") {
130            let config_path = join4(
131                &home_dir,
132                ".config",
133                "ragit",
134                "models.json",
135            )?;
136
137            if let Some(key) = self.find_api_key_in_file(&config_path)? {
138                return Ok(Some(key));
139            }
140        }
141
142        Ok(None)
143    }
144
145    fn find_api_key_in_file(&self, file_path: &str) -> Result<Option<String>, Error> {
146        use std::fs::File;
147        use std::io::Read;
148
149        // Check if the file exists
150        let file = match File::open(file_path) {
151            Ok(file) => file,
152            Err(_) => return Ok(None), // File doesn't exist or can't be opened
153        };
154
155        // Read the file content
156        let mut content = String::new();
157        if let Err(_) = file.take(10_000_000).read_to_string(&mut content) {
158            return Ok(None); // Can't read the file
159        }
160
161        // Parse the JSON
162        let models: Vec<ModelRaw> = match serde_json::from_str(&content) {
163            Ok(models) => models,
164            Err(_) => return Ok(None), // Can't parse the JSON
165        };
166
167        // Find the model with the same name
168        for model in models {
169            if model.name == self.name {
170                // If the model has an API key, return it
171                if let Some(key) = model.api_key {
172                    return Ok(Some(key));
173                }
174
175                // If the model has an environment variable, try to get the API key from it
176                if let Some(var) = model.api_env_var {
177                    if let Ok(key) = std::env::var(&var) {
178                        return Ok(Some(key));
179                    }
180                }
181            }
182        }
183
184        Ok(None)
185    }
186
187    pub fn is_test_model(&self) -> bool {
188        matches!(self.api_provider, ApiProvider::Test(_))
189    }
190
191    pub fn default_models() -> Vec<Model> {
192        ModelRaw::default_models().iter().map(
193            |model| model.try_into().unwrap()
194        ).collect()
195    }
196}
197
198/// There are 2 types for models: `Model` and `ModelRaw`. I know it's confusing, I'm sorry.
199/// `Model` is the type ragit internally uses and `ModelRaw` is only for json serialization.
200/// Long time ago, there was only `Model` type. But then I implemented `models.json` interface.
201/// I wanted people to directly edit the json file and found that `Model` isn't intuitive to
202/// edit directly. So I added this struct.
203#[derive(Clone, Debug, Deserialize, Serialize)]
204pub struct ModelRaw {
205    /// Model name shown to user.
206    /// `rag config --set model` also
207    /// uses this name.
208    pub name: String,
209
210    /// Model name used for api requests.
211    pub api_name: String,
212
213    pub can_read_images: bool,
214
215    /// `openai | cohere | anthropic | google`
216    ///
217    /// If you're using an openai-compatible
218    /// api, set this to `openai`.
219    pub api_provider: String,
220
221    /// It's necessary if you're using an
222    /// openai-compatible api. If it's not
223    /// set, ragit uses the default url of
224    /// each api provider.
225    pub api_url: Option<String>,
226
227    /// Dollars per 1 million input tokens.
228    pub input_price: f64,
229
230    /// Dollars per 1 million output tokens.
231    pub output_price: f64,
232
233    // FIXME: I set the default value to 180 seconds long ago.
234    //        At that time, it's very common for LLMs to take
235    //        1 ~ 2 minutes to respond. But now, nobody would
236    //        wait 180 seconds. Do I have to reduce it?
237    /// The number is in seconds.
238    /// If not set, it's default to 180 seconds.
239    #[serde(default)]
240    pub api_timeout: Option<u64>,
241
242    pub explanation: Option<String>,
243
244    /// If you don't want to use an env var, you
245    /// can hard-code your api key in this field.
246    #[serde(default)]
247    pub api_key: Option<String>,
248
249    /// If you've hard-coded your api key,
250    /// you don't have to set this. If neither
251    /// `api_key`, nor `api_env_var` is set,
252    /// it assumes that the model doesn't require
253    /// an api key.
254    pub api_env_var: Option<String>,
255}
256
257lazy_static! {
258    static ref DEFAULT_MODELS: HashMap<String, ModelRaw> = {
259        let models_dot_json = include_str!("../models.json");
260        let models = serde_json::from_str::<Vec<ModelRaw>>(&models_dot_json).unwrap();
261        models.into_iter().map(
262            |model| (model.name.clone(), model)
263        ).collect()
264    };
265}
266
267impl ModelRaw {
268    pub fn llama_70b() -> Self {
269        DEFAULT_MODELS.get("llama3.3-70b-groq").unwrap().clone()
270    }
271
272    pub fn llama_8b() -> Self {
273        DEFAULT_MODELS.get("llama3.1-8b-groq").unwrap().clone()
274    }
275
276    pub fn gpt_4o() -> Self {
277        DEFAULT_MODELS.get("gpt-4o").unwrap().clone()
278    }
279
280    pub fn gpt_4o_mini() -> Self {
281        DEFAULT_MODELS.get("gpt-4o-mini").unwrap().clone()
282    }
283
284    pub fn gemini_2_flash() -> Self {
285        DEFAULT_MODELS.get("gemini-2.0-flash").unwrap().clone()
286    }
287
288    pub fn sonnet() -> Self {
289        DEFAULT_MODELS.get("claude-3.7-sonnet").unwrap().clone()
290    }
291
292    pub fn phi_4_14b() -> Self {
293        DEFAULT_MODELS.get("phi-4-14b-ollama").unwrap().clone()
294    }
295
296    pub fn command_r() -> Self {
297        DEFAULT_MODELS.get("command-r").unwrap().clone()
298    }
299
300    pub fn command_r_plus() -> Self {
301        DEFAULT_MODELS.get("command-r-plus").unwrap().clone()
302    }
303
304    pub fn default_models() -> Vec<ModelRaw> {
305        DEFAULT_MODELS.values().map(|model| model.clone()).collect()
306    }
307}
308
309pub fn get_model_by_name(models: &[Model], name: &str) -> Result<Model, Error> {
310    let mut partial_matches = vec![];
311
312    for model in models.iter() {
313        if model.name == name {
314            return Ok(model.clone());
315        }
316
317        if partial_match(&model.name, name) {
318            partial_matches.push(model);
319        }
320    }
321
322    if partial_matches.len() == 1 {
323        Ok(partial_matches[0].clone())
324    }
325
326    else if name == "dummy" {
327        Ok(Model::dummy())
328    }
329
330    else if name == "stdin" {
331        Ok(Model::stdin())
332    }
333
334    else if name == "error" {
335        Ok(Model::error())
336    }
337
338    else{
339        Err(Error::InvalidModelName {
340            name: name.to_string(),
341            candidates: partial_matches.iter().map(
342                |model| model.name.to_string()
343            ).collect(),
344        })
345    }
346}
347
348impl TryFrom<&ModelRaw> for Model {
349    type Error = Error;
350
351    fn try_from(m: &ModelRaw) -> Result<Model, Error> {
352        Ok(Model {
353            name: m.name.clone(),
354            api_name: m.api_name.clone(),
355            can_read_images: m.can_read_images,
356            api_provider: ApiProvider::parse(
357                &m.api_provider,
358                &m.api_url,
359            )?,
360            dollars_per_1b_input_tokens: (m.input_price * 1000.0).round() as u64,
361            dollars_per_1b_output_tokens: (m.output_price * 1000.0).round() as u64,
362            api_timeout: m.api_timeout.unwrap_or(180),
363            explanation: m.explanation.clone(),
364            api_key: m.api_key.clone(),
365            api_env_var: m.api_env_var.clone(),
366        })
367    }
368}
369
370impl From<&Model> for ModelRaw {
371    fn from(m: &Model) -> ModelRaw {
372        ModelRaw {
373            name: m.name.clone(),
374            api_name: m.api_name.clone(),
375            can_read_images: m.can_read_images,
376            api_provider: m.api_provider.to_string(),
377
378            // This field is for openai-compatible apis. The other api
379            // providers do not need this field. The problem is that
380            // `m.get_api_url()` may fail if api provider is google.
381            // So it just ignores errors.
382            api_url: m.get_api_url().ok(),
383
384            input_price: m.dollars_per_1b_input_tokens as f64 / 1000.0,
385            output_price: m.dollars_per_1b_output_tokens as f64 / 1000.0,
386            api_timeout: Some(m.api_timeout),
387            explanation: m.explanation.clone(),
388            api_key: m.api_key.clone(),
389            api_env_var: m.api_env_var.clone(),
390        }
391    }
392}
393
394#[derive(Clone, Debug, Eq, Hash, PartialEq)]
395pub enum TestModel {
396    Dummy,  // it always returns `"dummy"`
397    Stdin,
398    Error,  // it always raises an error
399}
400
401impl TestModel {
402    pub fn get_dummy_response(&self, messages: &[Message]) -> Result<String, Error> {
403        match self {
404            TestModel::Dummy => Ok(String::from("dummy")),
405            TestModel::Stdin => {
406                for message in messages.iter() {
407                    println!(
408                        "<|{:?}|>\n\n{}\n\n",
409                        message.role,
410                        message.content.iter().map(|c| c.to_string()).collect::<Vec<String>>().join(""),
411                    );
412                }
413
414                print!("<|Assistant|>\n\n>>> ");
415                stdout().flush()?;
416
417                let mut s = String::new();
418                stdin().read_to_string(&mut s)?;
419                Ok(s)
420            },
421            TestModel::Error => Err(Error::TestModel),
422        }
423    }
424}
425
426fn partial_match(haystack: &str, needle: &str) -> bool {
427    let h_bytes = haystack.bytes().collect::<Vec<_>>();
428    let n_bytes = needle.bytes().collect::<Vec<_>>();
429    let mut h_cursor = 0;
430    let mut n_cursor = 0;
431
432    while h_cursor < h_bytes.len() && n_cursor < n_bytes.len() {
433        if h_bytes[h_cursor] == n_bytes[n_cursor] {
434            h_cursor += 1;
435            n_cursor += 1;
436        }
437
438        else {
439            h_cursor += 1;
440        }
441    }
442
443    n_cursor == n_bytes.len()
444}
445
446#[cfg(test)]
447mod tests {
448    use super::{DEFAULT_MODELS, Model};
449
450    #[test]
451    fn validate_models_dot_json() {
452        for model in DEFAULT_MODELS.values() {
453            Model::try_from(model).unwrap();
454        }
455    }
456}