1use crate::api_provider::ApiProvider;
2use crate::error::Error;
3use lazy_static::lazy_static;
4use ragit_fs::join4;
5use ragit_pdl::Message;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::io::{Read, Write, stdin, stdout};
9
10#[derive(Clone, Debug, Eq, Hash, PartialEq)]
11pub struct Model {
12 pub name: String,
13 pub api_name: String,
14 pub can_read_images: bool,
15 pub api_provider: ApiProvider,
16 pub dollars_per_1b_input_tokens: u64,
17 pub dollars_per_1b_output_tokens: u64,
18 pub api_timeout: u64,
19 pub explanation: Option<String>,
20 pub api_key: Option<String>,
21 pub api_env_var: Option<String>,
22}
23
24impl Model {
25 pub fn dummy() -> Self {
27 Model {
28 name: String::from("dummy"),
29 api_name: String::from("test-model-dummy-v0"),
30 can_read_images: false,
31 api_provider: ApiProvider::Test(TestModel::Dummy),
32 dollars_per_1b_input_tokens: 0,
33 dollars_per_1b_output_tokens: 0,
34 api_timeout: 180,
35 explanation: None,
36 api_key: None,
37 api_env_var: None,
38 }
39 }
40
41 pub fn stdin() -> Self {
43 Model {
44 name: String::from("stdin"),
45 api_name: String::from("test-model-stdin-v0"),
46 can_read_images: false,
47 api_provider: ApiProvider::Test(TestModel::Stdin),
48 dollars_per_1b_input_tokens: 0,
49 dollars_per_1b_output_tokens: 0,
50 api_timeout: 180,
51 explanation: None,
52 api_key: None,
53 api_env_var: None,
54 }
55 }
56
57 pub fn error() -> Self {
59 Model {
60 name: String::from("error"),
61 api_name: String::from("test-model-error-v0"),
62 can_read_images: false,
63 api_provider: ApiProvider::Test(TestModel::Error),
64 dollars_per_1b_input_tokens: 0,
65 dollars_per_1b_output_tokens: 0,
66 api_timeout: 180,
67 explanation: None,
68 api_key: None,
69 api_env_var: None,
70 }
71 }
72
73 pub fn get_api_url(&self) -> Result<String, Error> {
74 let url = match &self.api_provider {
75 ApiProvider::Anthropic => String::from("https://api.anthropic.com/v1/messages"),
76 ApiProvider::Cohere => String::from("https://api.cohere.com/v2/chat"),
77 ApiProvider::OpenAi { url } => url.to_string(),
78 ApiProvider::Google => format!(
79 "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
80 self.api_name,
81 self.get_api_key()?,
82 ),
83 ApiProvider::Test(_) => String::new(),
84 };
85
86 Ok(url)
87 }
88
89 pub fn get_api_key(&self) -> Result<String, Error> {
90 if let Some(key) = &self.api_key {
92 return Ok(key.to_string());
93 }
94
95 if let Some(var) = &self.api_env_var {
97 if let Ok(key) = std::env::var(var) {
98 return Ok(key.to_string());
99 }
100
101 }
103
104 if let Some(key) = self.find_api_key_in_external_files()? {
106 return Ok(key);
107 }
108
109 if let Some(var) = &self.api_env_var {
111 return Err(Error::ApiKeyNotFound { env_var: Some(var.to_string()) });
112 }
113
114 Ok(String::new())
118 }
119
120 fn find_api_key_in_external_files(&self) -> Result<Option<String>, Error> {
121 if let Ok(file_path) = std::env::var("RAGIT_MODEL_FILE") {
123 if let Some(key) = self.find_api_key_in_file(&file_path)? {
124 return Ok(Some(key));
125 }
126 }
127
128 if let Ok(home_dir) = std::env::var("HOME") {
130 let config_path = join4(
131 &home_dir,
132 ".config",
133 "ragit",
134 "models.json",
135 )?;
136
137 if let Some(key) = self.find_api_key_in_file(&config_path)? {
138 return Ok(Some(key));
139 }
140 }
141
142 Ok(None)
143 }
144
145 fn find_api_key_in_file(&self, file_path: &str) -> Result<Option<String>, Error> {
146 use std::fs::File;
147 use std::io::Read;
148
149 let file = match File::open(file_path) {
151 Ok(file) => file,
152 Err(_) => return Ok(None), };
154
155 let mut content = String::new();
157 if let Err(_) = file.take(10_000_000).read_to_string(&mut content) {
158 return Ok(None); }
160
161 let models: Vec<ModelRaw> = match serde_json::from_str(&content) {
163 Ok(models) => models,
164 Err(_) => return Ok(None), };
166
167 for model in models {
169 if model.name == self.name {
170 if let Some(key) = model.api_key {
172 return Ok(Some(key));
173 }
174
175 if let Some(var) = model.api_env_var {
177 if let Ok(key) = std::env::var(&var) {
178 return Ok(Some(key));
179 }
180 }
181 }
182 }
183
184 Ok(None)
185 }
186
187 pub fn is_test_model(&self) -> bool {
188 matches!(self.api_provider, ApiProvider::Test(_))
189 }
190
191 pub fn default_models() -> Vec<Model> {
192 ModelRaw::default_models().iter().map(
193 |model| model.try_into().unwrap()
194 ).collect()
195 }
196}
197
198#[derive(Clone, Debug, Deserialize, Serialize)]
204pub struct ModelRaw {
205 pub name: String,
209
210 pub api_name: String,
212
213 pub can_read_images: bool,
214
215 pub api_provider: String,
220
221 pub api_url: Option<String>,
226
227 pub input_price: f64,
229
230 pub output_price: f64,
232
233 #[serde(default)]
240 pub api_timeout: Option<u64>,
241
242 pub explanation: Option<String>,
243
244 #[serde(default)]
247 pub api_key: Option<String>,
248
249 pub api_env_var: Option<String>,
255}
256
257lazy_static! {
258 static ref DEFAULT_MODELS: HashMap<String, ModelRaw> = {
259 let models_dot_json = include_str!("../models.json");
260 let models = serde_json::from_str::<Vec<ModelRaw>>(&models_dot_json).unwrap();
261 models.into_iter().map(
262 |model| (model.name.clone(), model)
263 ).collect()
264 };
265}
266
267impl ModelRaw {
268 pub fn llama_70b() -> Self {
269 DEFAULT_MODELS.get("llama3.3-70b-groq").unwrap().clone()
270 }
271
272 pub fn llama_8b() -> Self {
273 DEFAULT_MODELS.get("llama3.1-8b-groq").unwrap().clone()
274 }
275
276 pub fn gpt_4o() -> Self {
277 DEFAULT_MODELS.get("gpt-4o").unwrap().clone()
278 }
279
280 pub fn gpt_4o_mini() -> Self {
281 DEFAULT_MODELS.get("gpt-4o-mini").unwrap().clone()
282 }
283
284 pub fn gemini_2_flash() -> Self {
285 DEFAULT_MODELS.get("gemini-2.0-flash").unwrap().clone()
286 }
287
288 pub fn sonnet() -> Self {
289 DEFAULT_MODELS.get("claude-3.7-sonnet").unwrap().clone()
290 }
291
292 pub fn phi_4_14b() -> Self {
293 DEFAULT_MODELS.get("phi-4-14b-ollama").unwrap().clone()
294 }
295
296 pub fn command_r() -> Self {
297 DEFAULT_MODELS.get("command-r").unwrap().clone()
298 }
299
300 pub fn command_r_plus() -> Self {
301 DEFAULT_MODELS.get("command-r-plus").unwrap().clone()
302 }
303
304 pub fn default_models() -> Vec<ModelRaw> {
305 DEFAULT_MODELS.values().map(|model| model.clone()).collect()
306 }
307}
308
309pub fn get_model_by_name(models: &[Model], name: &str) -> Result<Model, Error> {
310 let mut partial_matches = vec![];
311
312 for model in models.iter() {
313 if model.name == name {
314 return Ok(model.clone());
315 }
316
317 if partial_match(&model.name, name) {
318 partial_matches.push(model);
319 }
320 }
321
322 if partial_matches.len() == 1 {
323 Ok(partial_matches[0].clone())
324 }
325
326 else if name == "dummy" {
327 Ok(Model::dummy())
328 }
329
330 else if name == "stdin" {
331 Ok(Model::stdin())
332 }
333
334 else if name == "error" {
335 Ok(Model::error())
336 }
337
338 else{
339 Err(Error::InvalidModelName {
340 name: name.to_string(),
341 candidates: partial_matches.iter().map(
342 |model| model.name.to_string()
343 ).collect(),
344 })
345 }
346}
347
348impl TryFrom<&ModelRaw> for Model {
349 type Error = Error;
350
351 fn try_from(m: &ModelRaw) -> Result<Model, Error> {
352 Ok(Model {
353 name: m.name.clone(),
354 api_name: m.api_name.clone(),
355 can_read_images: m.can_read_images,
356 api_provider: ApiProvider::parse(
357 &m.api_provider,
358 &m.api_url,
359 )?,
360 dollars_per_1b_input_tokens: (m.input_price * 1000.0).round() as u64,
361 dollars_per_1b_output_tokens: (m.output_price * 1000.0).round() as u64,
362 api_timeout: m.api_timeout.unwrap_or(180),
363 explanation: m.explanation.clone(),
364 api_key: m.api_key.clone(),
365 api_env_var: m.api_env_var.clone(),
366 })
367 }
368}
369
370impl From<&Model> for ModelRaw {
371 fn from(m: &Model) -> ModelRaw {
372 ModelRaw {
373 name: m.name.clone(),
374 api_name: m.api_name.clone(),
375 can_read_images: m.can_read_images,
376 api_provider: m.api_provider.to_string(),
377
378 api_url: m.get_api_url().ok(),
383
384 input_price: m.dollars_per_1b_input_tokens as f64 / 1000.0,
385 output_price: m.dollars_per_1b_output_tokens as f64 / 1000.0,
386 api_timeout: Some(m.api_timeout),
387 explanation: m.explanation.clone(),
388 api_key: m.api_key.clone(),
389 api_env_var: m.api_env_var.clone(),
390 }
391 }
392}
393
394#[derive(Clone, Debug, Eq, Hash, PartialEq)]
395pub enum TestModel {
396 Dummy, Stdin,
398 Error, }
400
401impl TestModel {
402 pub fn get_dummy_response(&self, messages: &[Message]) -> Result<String, Error> {
403 match self {
404 TestModel::Dummy => Ok(String::from("dummy")),
405 TestModel::Stdin => {
406 for message in messages.iter() {
407 println!(
408 "<|{:?}|>\n\n{}\n\n",
409 message.role,
410 message.content.iter().map(|c| c.to_string()).collect::<Vec<String>>().join(""),
411 );
412 }
413
414 print!("<|Assistant|>\n\n>>> ");
415 stdout().flush()?;
416
417 let mut s = String::new();
418 stdin().read_to_string(&mut s)?;
419 Ok(s)
420 },
421 TestModel::Error => Err(Error::TestModel),
422 }
423 }
424}
425
426fn partial_match(haystack: &str, needle: &str) -> bool {
427 let h_bytes = haystack.bytes().collect::<Vec<_>>();
428 let n_bytes = needle.bytes().collect::<Vec<_>>();
429 let mut h_cursor = 0;
430 let mut n_cursor = 0;
431
432 while h_cursor < h_bytes.len() && n_cursor < n_bytes.len() {
433 if h_bytes[h_cursor] == n_bytes[n_cursor] {
434 h_cursor += 1;
435 n_cursor += 1;
436 }
437
438 else {
439 h_cursor += 1;
440 }
441 }
442
443 n_cursor == n_bytes.len()
444}
445
446#[cfg(test)]
447mod tests {
448 use super::{DEFAULT_MODELS, Model};
449
450 #[test]
451 fn validate_models_dot_json() {
452 for model in DEFAULT_MODELS.values() {
453 Model::try_from(model).unwrap();
454 }
455 }
456}