1mod generated;
49
50pub use generated::{VISION_MODELS, TEXT_ONLY_MODELS, AUDIO_MODELS, ALL_MODELS};
51pub use generated::{ModelInfoEntry, MODEL_INFO};
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub struct ModelCapabilities {
56 pub vision: bool,
58 pub audio: bool,
60 pub video: bool,
62 pub file: bool,
64}
65
66#[derive(Debug, Clone, Copy, PartialEq)]
68pub struct ModelRanks {
69 pub overall: Option<f32>,
71 pub coding: Option<f32>,
73 pub math: Option<f32>,
75 pub hard_prompts: Option<f32>,
77 pub instruction_following: Option<f32>,
79 pub vision_rank: Option<f32>,
81 pub style_control: Option<f32>,
83}
84
85#[derive(Debug, Clone, Copy, PartialEq)]
87pub struct ModelPricing {
88 pub input_cost_per_m_tokens: Option<f32>,
90 pub output_cost_per_m_tokens: Option<f32>,
92}
93
94#[derive(Debug, Clone, PartialEq)]
96pub struct ModelProfile {
97 pub capabilities: ModelCapabilities,
98 pub ranks: ModelRanks,
99 pub pricing: ModelPricing,
100 pub max_input_tokens: u32,
101 pub max_output_tokens: u32,
102}
103
104impl ModelCapabilities {
105 pub const fn vision_only() -> Self {
107 Self {
108 vision: true,
109 audio: false,
110 video: false,
111 file: false,
112 }
113 }
114
115 pub const fn text_only() -> Self {
117 Self {
118 vision: false,
119 audio: false,
120 video: false,
121 file: false,
122 }
123 }
124
125 pub const fn full_multimodal() -> Self {
127 Self {
128 vision: true,
129 audio: true,
130 video: true,
131 file: true,
132 }
133 }
134
135 pub fn lookup(model: &str) -> Option<Self> {
140 let lower = model.to_lowercase();
141
142 let info = lookup_model_info(&lower);
144
145 let vision = info.map_or(false, |i| i.supports_vision)
147 || is_in_list(&lower, VISION_MODELS)
148 || supports_vision_by_pattern(&lower);
149
150 let audio = info.map_or(false, |i| i.supports_audio)
152 || is_in_list(&lower, AUDIO_MODELS);
153
154 let video = info.map_or(false, |i| i.supports_video);
156 let file = info.map_or(false, |i| i.supports_pdf);
157
158 if info.is_some() || vision || audio || is_in_list(&lower, TEXT_ONLY_MODELS) {
159 Some(Self { vision, audio, video, file })
160 } else {
161 None
162 }
163 }
164}
165
166pub fn arena_rank(model: &str) -> Option<f32> {
170 let lower = model.to_lowercase();
171 let info = lookup_model_info(&lower)?;
172 if info.arena_overall == 0 {
173 None
174 } else {
175 Some(info.arena_overall as f32 / 100.0)
176 }
177}
178
179pub fn model_profile(model: &str) -> Option<ModelProfile> {
181 let lower = model.to_lowercase();
182 let info = lookup_model_info(&lower)?;
183
184 let capabilities = ModelCapabilities {
185 vision: info.supports_vision,
186 audio: info.supports_audio,
187 video: info.supports_video,
188 file: info.supports_pdf,
189 };
190
191 let ranks = ModelRanks {
192 overall: if info.arena_overall > 0 {
193 Some(info.arena_overall as f32 / 100.0)
194 } else {
195 None
196 },
197 coding: None,
198 math: None,
199 hard_prompts: None,
200 instruction_following: None,
201 vision_rank: None,
202 style_control: None,
203 };
204
205 let pricing = ModelPricing {
206 input_cost_per_m_tokens: if info.cost_input_x1000 > 0 {
207 Some(info.cost_input_x1000 as f32 / 1000.0)
208 } else {
209 None
210 },
211 output_cost_per_m_tokens: if info.cost_output_x1000 > 0 {
212 Some(info.cost_output_x1000 as f32 / 1000.0)
213 } else {
214 None
215 },
216 };
217
218 Some(ModelProfile {
219 capabilities,
220 ranks,
221 pricing,
222 max_input_tokens: info.max_input_tokens,
223 max_output_tokens: info.max_output_tokens,
224 })
225}
226
227pub fn supports_video(model: &str) -> bool {
229 let lower = model.to_lowercase();
230 if let Some(info) = lookup_model_info(&lower) {
231 return info.supports_video;
232 }
233 false
234}
235
236pub fn supports_pdf(model: &str) -> bool {
238 let lower = model.to_lowercase();
239 if let Some(info) = lookup_model_info(&lower) {
240 return info.supports_pdf;
241 }
242 false
243}
244
245pub fn supports_vision(model: &str) -> bool {
261 let lower = model.to_lowercase();
262
263 if let Some(info) = lookup_model_info(&lower) {
265 if info.supports_vision {
266 return true;
267 }
268 }
270
271 if is_in_list(&lower, VISION_MODELS) {
273 return true;
274 }
275
276 supports_vision_by_pattern(&lower)
278}
279
280pub fn supports_audio(model: &str) -> bool {
282 let lower = model.to_lowercase();
283
284 if let Some(info) = lookup_model_info(&lower) {
285 if info.supports_audio {
286 return true;
287 }
288 }
290
291 is_in_list(&lower, AUDIO_MODELS)
292}
293
294pub fn is_text_only(model: &str) -> bool {
296 !supports_vision(model) && !supports_audio(model)
297}
298
299fn lookup_model_info(model: &str) -> Option<&'static ModelInfoEntry> {
301 if let Ok(idx) = MODEL_INFO.binary_search_by(|entry| entry.name.cmp(model)) {
303 return Some(&MODEL_INFO[idx]);
304 }
305
306 if let Some(pos) = model.rfind('/') {
308 let short = &model[pos + 1..];
309 if let Ok(idx) = MODEL_INFO.binary_search_by(|entry| entry.name.cmp(short)) {
310 return Some(&MODEL_INFO[idx]);
311 }
312 }
313
314 let mut best: Option<&'static ModelInfoEntry> = None;
316 let mut best_len = 0;
317
318 for entry in MODEL_INFO {
319 if model.contains(entry.name) && entry.name.len() > best_len {
321 best = Some(entry);
322 best_len = entry.name.len();
323 }
324 if entry.name.contains(model) && model.len() >= 4 {
326 return Some(entry);
327 }
328 }
329
330 best
331}
332
333fn is_in_list(model: &str, list: &[&str]) -> bool {
337 for entry in list {
338 if model == *entry {
340 return true;
341 }
342 if model.contains(entry) {
344 return true;
345 }
346 if entry.contains(model) && model.len() >= 4 {
348 return true;
349 }
350 }
351 false
352}
353
354fn supports_vision_by_pattern(model: &str) -> bool {
358 const VISION_PATTERNS: &[&str] = &[
359 "gpt-4o",
361 "gpt-4-turbo",
362 "gpt-4-vision",
363 "o1",
364 "o3",
365 "o4",
366 "claude-3",
368 "claude-4",
369 "gemini-1.5",
371 "gemini-2",
372 "gemini-flash",
373 "gemini-pro-vision",
374 "qwen2-vl",
376 "qwen2.5-vl",
377 "qwen-vl",
378 "qwq",
379 "llama-3.2-vision",
381 "-vision",
383 "-vl-",
384 "-vl:",
385 "/vl-",
386 ];
387
388 for pattern in VISION_PATTERNS {
389 if model.contains(pattern) {
390 return true;
391 }
392 }
393
394 model.ends_with("-vl") || model.ends_with(":vl") || model.ends_with("/vl")
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 #[test]
403 fn test_supports_vision_openai() {
404 assert!(supports_vision("gpt-4o"));
405 assert!(supports_vision("gpt-4o-mini"));
406 assert!(supports_vision("openai/gpt-4o"));
407 assert!(!supports_vision("gpt-3.5-turbo"));
408 }
409
410 #[test]
411 fn test_supports_vision_anthropic() {
412 assert!(supports_vision("claude-3-sonnet"));
413 assert!(supports_vision("anthropic/claude-3-opus"));
414 assert!(!supports_vision("claude-2"));
415 }
416
417 #[test]
418 fn test_supports_vision_google() {
419 assert!(supports_vision("gemini-2.0-flash"));
420 assert!(supports_vision("google/gemini-1.5-pro"));
421 }
422
423 #[test]
424 fn test_is_text_only() {
425 assert!(is_text_only("gpt-3.5-turbo"));
426 assert!(is_text_only("claude-2"));
427 assert!(!is_text_only("gpt-4o"));
428 }
429
430 #[test]
431 fn test_model_capabilities_lookup() {
432 let caps = ModelCapabilities::lookup("gpt-4o");
433 assert!(caps.is_some());
434 assert!(caps.unwrap().vision);
435 }
436
437 #[test]
438 fn test_arena_rank_known_models() {
439 if let Some(rank) = arena_rank("gpt-4o") {
441 assert!(rank > 50.0 && rank <= 100.0);
442 }
443 }
444
445 #[test]
446 fn test_model_profile_has_pricing() {
447 if let Some(profile) = model_profile("gpt-4o") {
448 assert!(profile.pricing.input_cost_per_m_tokens.is_some());
449 assert!(profile.pricing.output_cost_per_m_tokens.is_some());
450 }
451 }
452
453 #[test]
454 fn test_model_profile_has_context_window() {
455 if let Some(profile) = model_profile("gpt-4o") {
456 assert!(profile.max_input_tokens > 0);
457 }
458 }
459
460 #[test]
461 fn test_supports_video_and_pdf() {
462 assert!(supports_video("gemini-2.5-pro") || supports_video("gemini-2.0-flash"));
464 }
465
466 #[test]
467 fn test_lookup_model_info_binary_search() {
468 let info = lookup_model_info("gpt-4o");
470 assert!(info.is_some());
471 }
472
473 #[test]
474 fn test_lookup_model_info_substring() {
475 let info = lookup_model_info("openai/gpt-4o");
477 assert!(info.is_some());
478 }
479
480 #[test]
481 fn test_model_info_sorted() {
482 for window in MODEL_INFO.windows(2) {
484 assert!(
485 window[0].name <= window[1].name,
486 "MODEL_INFO not sorted: {:?} > {:?}",
487 window[0].name,
488 window[1].name
489 );
490 }
491 }
492
493 #[test]
494 fn test_supports_vision_qwen_vl() {
495 assert!(supports_vision("qwen2-vl-72b"));
497 assert!(supports_vision("qwen2.5-vl-7b"));
498 assert!(supports_vision("qwen-vl-max"));
499 assert!(supports_vision("QWEN2-VL"));
500 }
501
502 #[test]
503 fn test_supports_vision_case_insensitive() {
504 assert!(supports_vision("GPT-4O"));
505 assert!(supports_vision("Claude-3-Sonnet"));
506 assert!(supports_vision("Gemini-2.0-Flash"));
507 }
508
509 #[test]
510 fn test_capabilities_merge_sources() {
511 let caps = ModelCapabilities::lookup("qwen2-vl-72b");
513 assert!(caps.is_some());
514 assert!(caps.unwrap().vision);
515 }
516}