1use serde::{Deserialize, Serialize};
2
3pub const QUANT_HIERARCHY: &[&str] = &["Q8_0", "Q6_K", "Q5_K_M", "Q4_K_M", "Q3_K_M", "Q2_K"];
6
7pub const MLX_QUANT_HIERARCHY: &[&str] = &["mlx-8bit", "mlx-4bit"];
9
10pub fn quant_bpp(quant: &str) -> f64 {
12 match quant {
13 "F32" => 4.0,
14 "F16" | "BF16" => 2.0,
15 "Q8_0" => 1.05,
16 "Q6_K" => 0.80,
17 "Q5_K_M" => 0.68,
18 "Q4_K_M" | "Q4_0" => 0.58,
19 "Q3_K_M" => 0.48,
20 "Q2_K" => 0.37,
21 "mlx-4bit" => 0.55,
22 "mlx-8bit" => 1.0,
23 "AWQ-4bit" => 0.5,
24 "AWQ-8bit" => 1.0,
25 "GPTQ-Int4" => 0.5,
26 "GPTQ-Int8" => 1.0,
27 _ => 0.58,
28 }
29}
30
31pub fn quant_speed_multiplier(quant: &str) -> f64 {
33 match quant {
34 "F16" | "BF16" => 0.6,
35 "Q8_0" => 0.8,
36 "Q6_K" => 0.95,
37 "Q5_K_M" => 1.0,
38 "Q4_K_M" | "Q4_0" => 1.15,
39 "Q3_K_M" => 1.25,
40 "Q2_K" => 1.35,
41 "mlx-4bit" => 1.15,
42 "mlx-8bit" => 0.85,
43 "AWQ-4bit" | "GPTQ-Int4" => 1.2,
44 "AWQ-8bit" | "GPTQ-Int8" => 0.85,
45 _ => 1.0,
46 }
47}
48
49pub fn quant_bytes_per_param(quant: &str) -> f64 {
52 match quant {
53 "F16" | "BF16" => 2.0,
54 "Q8_0" => 1.0,
55 "Q6_K" => 0.75,
56 "Q5_K_M" => 0.625,
57 "Q4_K_M" | "Q4_0" => 0.5,
58 "Q3_K_M" => 0.375,
59 "Q2_K" => 0.25,
60 "mlx-4bit" => 0.5,
61 "mlx-8bit" => 1.0,
62 "AWQ-4bit" | "GPTQ-Int4" => 0.5,
63 "AWQ-8bit" | "GPTQ-Int8" => 1.0,
64 _ => 0.5, }
66}
67
68pub fn quant_quality_penalty(quant: &str) -> f64 {
70 match quant {
71 "F16" | "BF16" => 0.0,
72 "Q8_0" => 0.0,
73 "Q6_K" => -1.0,
74 "Q5_K_M" => -2.0,
75 "Q4_K_M" | "Q4_0" => -5.0,
76 "Q3_K_M" => -8.0,
77 "Q2_K" => -12.0,
78 "mlx-4bit" => -4.0,
79 "mlx-8bit" => 0.0,
80 "AWQ-4bit" => -3.0,
81 "AWQ-8bit" => 0.0,
82 "GPTQ-Int4" => -3.0,
83 "GPTQ-Int8" => 0.0,
84 _ => -5.0,
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
90#[serde(rename_all = "snake_case")]
91pub enum Capability {
92 Vision,
93 ToolUse,
94}
95
96impl Capability {
97 pub fn label(&self) -> &'static str {
98 match self {
99 Capability::Vision => "Vision",
100 Capability::ToolUse => "Tool Use",
101 }
102 }
103
104 pub fn all() -> &'static [Capability] {
105 &[Capability::Vision, Capability::ToolUse]
106 }
107
108 pub fn infer(model: &LlmModel) -> Vec<Capability> {
110 let mut caps = model.capabilities.clone();
111 let name = model.name.to_lowercase();
112 let use_case = model.use_case.to_lowercase();
113
114 if !caps.contains(&Capability::Vision)
116 && (name.contains("vision")
117 || name.contains("-vl-")
118 || name.ends_with("-vl")
119 || name.contains("llava")
120 || name.contains("onevision")
121 || name.contains("pixtral")
122 || use_case.contains("vision")
123 || use_case.contains("multimodal"))
124 {
125 caps.push(Capability::Vision);
126 }
127
128 if !caps.contains(&Capability::ToolUse)
130 && (use_case.contains("tool")
131 || use_case.contains("function call")
132 || name.contains("qwen3")
133 || name.contains("qwen2.5")
134 || name.contains("command-r")
135 || (name.contains("llama-3") && name.contains("instruct"))
136 || (name.contains("mistral") && name.contains("instruct"))
137 || name.contains("hermes"))
138 {
139 caps.push(Capability::ToolUse);
140 }
141
142 caps
143 }
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
148#[serde(rename_all = "lowercase")]
149#[derive(Default)]
150pub enum ModelFormat {
151 #[default]
152 Gguf,
153 Awq,
154 Gptq,
155 Mlx,
156 Safetensors,
157}
158
159impl ModelFormat {
160 pub fn is_prequantized(&self) -> bool {
163 matches!(self, ModelFormat::Awq | ModelFormat::Gptq)
164 }
165}
166
167#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
169pub enum UseCase {
170 General,
171 Coding,
172 Reasoning,
173 Chat,
174 Multimodal,
175 Embedding,
176}
177
178impl UseCase {
179 pub fn label(&self) -> &'static str {
180 match self {
181 UseCase::General => "General",
182 UseCase::Coding => "Coding",
183 UseCase::Reasoning => "Reasoning",
184 UseCase::Chat => "Chat",
185 UseCase::Multimodal => "Multimodal",
186 UseCase::Embedding => "Embedding",
187 }
188 }
189
190 pub fn from_model(model: &LlmModel) -> Self {
192 let name = model.name.to_lowercase();
193 let use_case = model.use_case.to_lowercase();
194
195 if use_case.contains("embedding") || name.contains("embed") || name.contains("bge") {
196 UseCase::Embedding
197 } else if name.contains("code") || use_case.contains("code") {
198 UseCase::Coding
199 } else if use_case.contains("vision") || use_case.contains("multimodal") {
200 UseCase::Multimodal
201 } else if use_case.contains("reason")
202 || use_case.contains("chain-of-thought")
203 || name.contains("deepseek-r1")
204 {
205 UseCase::Reasoning
206 } else if use_case.contains("chat") || use_case.contains("instruction") {
207 UseCase::Chat
208 } else {
209 UseCase::General
210 }
211 }
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct LlmModel {
216 pub name: String,
217 pub provider: String,
218 pub parameter_count: String,
219 #[serde(default)]
220 pub parameters_raw: Option<u64>,
221 pub min_ram_gb: f64,
222 pub recommended_ram_gb: f64,
223 pub min_vram_gb: Option<f64>,
224 pub quantization: String,
225 pub context_length: u32,
226 pub use_case: String,
227 #[serde(default)]
228 pub is_moe: bool,
229 #[serde(default)]
230 pub num_experts: Option<u32>,
231 #[serde(default)]
232 pub active_experts: Option<u32>,
233 #[serde(default)]
234 pub active_parameters: Option<u64>,
235 #[serde(default)]
236 pub release_date: Option<String>,
237 #[serde(default)]
239 pub gguf_sources: Vec<GgufSource>,
240 #[serde(default)]
242 pub capabilities: Vec<Capability>,
243 #[serde(default)]
245 pub format: ModelFormat,
246 #[serde(default)]
248 pub num_attention_heads: Option<u32>,
249 #[serde(default)]
251 pub num_key_value_heads: Option<u32>,
252 #[serde(default)]
254 pub license: Option<String>,
255}
256
257pub fn matches_license_filter(license: &Option<String>, filter: &str) -> bool {
260 let allowed: Vec<String> = filter.split(',').map(|s| s.trim().to_lowercase()).collect();
261 license
262 .as_ref()
263 .map(|l| allowed.contains(&l.to_lowercase()))
264 .unwrap_or(false)
265}
266
267#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct GgufSource {
270 pub repo: String,
272 pub provider: String,
274}
275
276impl LlmModel {
277 pub fn is_mlx_model(&self) -> bool {
281 let name_lower = self.name.to_lowercase();
282 name_lower.contains("-mlx-") || name_lower.ends_with("-mlx")
283 }
284
285 pub fn is_prequantized(&self) -> bool {
288 self.format.is_prequantized()
289 }
290
291 pub fn supports_tp(&self, tp_size: u32) -> bool {
295 if tp_size <= 1 {
296 return true;
297 }
298 let (attn, kv) = self.infer_head_counts();
299 attn % tp_size == 0 && kv % tp_size == 0
300 }
301
302 pub fn valid_tp_sizes(&self) -> Vec<u32> {
304 (1..=8).filter(|&tp| self.supports_tp(tp)).collect()
305 }
306
307 fn infer_head_counts(&self) -> (u32, u32) {
309 if let (Some(attn), Some(kv)) = (self.num_attention_heads, self.num_key_value_heads) {
310 return (attn, kv);
311 }
312 if let Some(attn) = self.num_attention_heads {
313 return (attn, attn);
314 }
315 infer_heads_from_name(&self.name, self.params_b())
317 }
318
319 fn quant_bpp(&self) -> f64 {
321 quant_bpp(&self.quantization)
322 }
323
324 pub fn params_b(&self) -> f64 {
326 if let Some(raw) = self.parameters_raw {
327 raw as f64 / 1_000_000_000.0
328 } else {
329 let s = self.parameter_count.trim().to_uppercase();
331 if let Some(num_str) = s.strip_suffix('B') {
332 num_str.parse::<f64>().unwrap_or(7.0)
333 } else if let Some(num_str) = s.strip_suffix('M') {
334 num_str.parse::<f64>().unwrap_or(0.0) / 1000.0
335 } else {
336 7.0
337 }
338 }
339 }
340
341 pub fn estimate_memory_gb(&self, quant: &str, ctx: u32) -> f64 {
344 let bpp = quant_bpp(quant);
345 let params = self.params_b();
346 let model_mem = params * bpp;
347 let kv_cache = 0.000008 * params * ctx as f64;
349 let overhead = 0.5;
351 model_mem + kv_cache + overhead
352 }
353
354 pub fn best_quant_for_budget(&self, budget_gb: f64, ctx: u32) -> Option<(&'static str, f64)> {
357 self.best_quant_for_budget_with(budget_gb, ctx, QUANT_HIERARCHY)
358 }
359
360 pub fn best_quant_for_budget_with(
362 &self,
363 budget_gb: f64,
364 ctx: u32,
365 hierarchy: &[&'static str],
366 ) -> Option<(&'static str, f64)> {
367 for &q in hierarchy {
369 let mem = self.estimate_memory_gb(q, ctx);
370 if mem <= budget_gb {
371 return Some((q, mem));
372 }
373 }
374 let half_ctx = ctx / 2;
376 if half_ctx >= 1024 {
377 for &q in hierarchy {
378 let mem = self.estimate_memory_gb(q, half_ctx);
379 if mem <= budget_gb {
380 return Some((q, mem));
381 }
382 }
383 }
384 None
385 }
386
387 pub fn moe_active_vram_gb(&self) -> Option<f64> {
390 if !self.is_moe {
391 return None;
392 }
393 let active_params = self.active_parameters? as f64;
394 let bpp = self.quant_bpp();
395 let size_gb = (active_params * bpp) / (1024.0 * 1024.0 * 1024.0);
396 Some((size_gb * 1.1).max(0.5))
397 }
398
399 pub fn is_mlx_only(&self) -> bool {
402 self.name.to_uppercase().contains("-MLX")
403 }
404
405 pub fn moe_offloaded_ram_gb(&self) -> Option<f64> {
408 if !self.is_moe {
409 return None;
410 }
411 let active = self.active_parameters? as f64;
412 let total = self.parameters_raw? as f64;
413 let inactive = total - active;
414 if inactive <= 0.0 {
415 return Some(0.0);
416 }
417 let bpp = self.quant_bpp();
418 Some((inactive * bpp) / (1024.0 * 1024.0 * 1024.0))
419 }
420}
421
422#[derive(Debug, Clone, Deserialize)]
425struct HfModelEntry {
426 name: String,
427 provider: String,
428 parameter_count: String,
429 #[serde(default)]
430 parameters_raw: Option<u64>,
431 min_ram_gb: f64,
432 recommended_ram_gb: f64,
433 min_vram_gb: Option<f64>,
434 quantization: String,
435 context_length: u32,
436 use_case: String,
437 #[serde(default)]
438 is_moe: bool,
439 #[serde(default)]
440 num_experts: Option<u32>,
441 #[serde(default)]
442 active_experts: Option<u32>,
443 #[serde(default)]
444 active_parameters: Option<u64>,
445 #[serde(default)]
446 release_date: Option<String>,
447 #[serde(default)]
448 gguf_sources: Vec<GgufSource>,
449 #[serde(default)]
450 capabilities: Vec<Capability>,
451 #[serde(default)]
452 format: ModelFormat,
453 #[serde(default)]
454 hf_downloads: u64,
455 #[serde(default)]
456 hf_likes: u64,
457 #[serde(default)]
458 license: Option<String>,
459}
460
461const HF_MODELS_JSON: &str = include_str!("../data/hf_models.json");
462
463pub struct ModelDatabase {
464 models: Vec<LlmModel>,
465}
466
467impl Default for ModelDatabase {
468 fn default() -> Self {
469 Self::new()
470 }
471}
472
473pub(crate) fn canonical_slug(name: &str) -> String {
478 let slug = name.split('/').next_back().unwrap_or(name);
479 slug.to_lowercase().replace(['-', '_', '.'], "")
480}
481
482fn load_embedded() -> Vec<LlmModel> {
484 let entries: Vec<HfModelEntry> =
485 serde_json::from_str(HF_MODELS_JSON).expect("Failed to parse embedded hf_models.json");
486 entries
487 .into_iter()
488 .map(|e| {
489 let mut model = LlmModel {
490 name: e.name,
491 provider: e.provider,
492 parameter_count: e.parameter_count,
493 parameters_raw: e.parameters_raw,
494 min_ram_gb: e.min_ram_gb,
495 recommended_ram_gb: e.recommended_ram_gb,
496 min_vram_gb: e.min_vram_gb,
497 quantization: e.quantization,
498 context_length: e.context_length,
499 use_case: e.use_case,
500 is_moe: e.is_moe,
501 num_experts: e.num_experts,
502 active_experts: e.active_experts,
503 active_parameters: e.active_parameters,
504 release_date: e.release_date,
505 gguf_sources: e.gguf_sources,
506 capabilities: e.capabilities,
507 format: e.format,
508 num_attention_heads: None,
509 num_key_value_heads: None,
510 license: e.license,
511 };
512 model.capabilities = Capability::infer(&model);
513 model
514 })
515 .collect()
516}
517
518impl ModelDatabase {
519 pub fn embedded() -> Self {
522 ModelDatabase {
523 models: load_embedded(),
524 }
525 }
526
527 pub fn new() -> Self {
533 let mut models = load_embedded();
534
535 let embedded_keys: std::collections::HashSet<String> =
540 models.iter().map(|m| canonical_slug(&m.name)).collect();
541
542 for cached in crate::update::load_cache() {
543 if !embedded_keys.contains(&canonical_slug(&cached.name)) {
544 models.push(cached);
545 }
546 }
547
548 ModelDatabase { models }
549 }
550
551 pub fn get_all_models(&self) -> &Vec<LlmModel> {
552 &self.models
553 }
554
555 pub fn find_model(&self, query: &str) -> Vec<&LlmModel> {
556 let query_lower = query.to_lowercase();
557 self.models
558 .iter()
559 .filter(|m| {
560 m.name.to_lowercase().contains(&query_lower)
561 || m.provider.to_lowercase().contains(&query_lower)
562 || m.parameter_count.to_lowercase().contains(&query_lower)
563 })
564 .collect()
565 }
566
567 pub fn models_fitting_system(
568 &self,
569 available_ram_gb: f64,
570 has_gpu: bool,
571 vram_gb: Option<f64>,
572 ) -> Vec<&LlmModel> {
573 self.models
574 .iter()
575 .filter(|m| {
576 let ram_ok = m.min_ram_gb <= available_ram_gb;
578
579 if let Some(min_vram) = m.min_vram_gb {
581 if has_gpu {
582 if let Some(system_vram) = vram_gb {
583 ram_ok && min_vram <= system_vram
584 } else {
585 ram_ok
587 }
588 } else {
589 ram_ok && available_ram_gb >= m.recommended_ram_gb
591 }
592 } else {
593 ram_ok
594 }
595 })
596 .collect()
597 }
598}
599
600fn infer_heads_from_name(name: &str, params_b: f64) -> (u32, u32) {
603 let name_lower = name.to_lowercase();
604
605 if name_lower.contains("qwen") {
607 if params_b > 100.0 {
608 return (128, 16);
609 } else if params_b > 50.0 {
610 return (64, 8);
611 } else if params_b > 25.0 {
612 return (40, 8);
613 } else if params_b > 10.0 {
614 return (40, 8);
615 } else if params_b > 5.0 {
616 return (32, 8);
617 } else {
618 return (16, 4);
619 }
620 }
621
622 if name_lower.contains("llama") {
624 if name_lower.contains("scout") || name_lower.contains("maverick") {
625 return (64, 8);
626 } else if params_b > 60.0 {
627 return (64, 8);
628 } else if params_b > 20.0 {
629 return (48, 8);
630 } else if params_b > 5.0 {
631 return (32, 8);
632 } else {
633 return (16, 8);
634 }
635 }
636
637 if name_lower.contains("deepseek") {
639 if params_b > 200.0 {
640 return (128, 16);
641 } else if params_b > 50.0 {
642 return (64, 8);
643 } else if params_b > 25.0 {
644 return (40, 8);
645 } else if params_b > 10.0 {
646 return (40, 8);
647 } else {
648 return (32, 8);
649 }
650 }
651
652 if name_lower.contains("mistral") || name_lower.contains("mixtral") {
654 if params_b > 100.0 {
655 return (96, 8);
656 } else if params_b > 20.0 {
657 return (32, 8);
658 } else {
659 return (32, 8);
660 }
661 }
662
663 if name_lower.contains("gemma") {
665 if params_b > 20.0 {
666 return (32, 16);
667 } else if params_b > 5.0 {
668 return (16, 8);
669 } else {
670 return (8, 4);
671 }
672 }
673
674 if name_lower.contains("phi") {
676 if params_b > 10.0 {
677 return (40, 10);
678 } else {
679 return (32, 8);
680 }
681 }
682
683 if name_lower.contains("minimax") {
685 return (48, 8);
686 }
687
688 if params_b > 100.0 {
690 (128, 16)
691 } else if params_b > 50.0 {
692 (64, 8)
693 } else if params_b > 20.0 {
694 (32, 8)
695 } else if params_b > 5.0 {
696 (32, 8)
697 } else {
698 (16, 4)
699 }
700}
701
702#[cfg(test)]
703mod tests {
704 use super::*;
705
706 #[test]
711 fn test_mlx_quant_bpp_values() {
712 assert_eq!(quant_bpp("mlx-4bit"), 0.55);
713 assert_eq!(quant_bpp("mlx-8bit"), 1.0);
714 assert_eq!(quant_speed_multiplier("mlx-4bit"), 1.15);
715 assert_eq!(quant_speed_multiplier("mlx-8bit"), 0.85);
716 assert_eq!(quant_quality_penalty("mlx-4bit"), -4.0);
717 assert_eq!(quant_quality_penalty("mlx-8bit"), 0.0);
718 }
719
720 #[test]
721 fn test_best_quant_with_mlx_hierarchy() {
722 let model = LlmModel {
723 name: "Test Model".to_string(),
724 provider: "Test".to_string(),
725 parameter_count: "7B".to_string(),
726 parameters_raw: Some(7_000_000_000),
727 min_ram_gb: 4.0,
728 recommended_ram_gb: 8.0,
729 min_vram_gb: Some(4.0),
730 quantization: "Q4_K_M".to_string(),
731 context_length: 4096,
732 use_case: "General".to_string(),
733 is_moe: false,
734 num_experts: None,
735 active_experts: None,
736 active_parameters: None,
737 release_date: None,
738 gguf_sources: vec![],
739 capabilities: vec![],
740 format: ModelFormat::default(),
741 num_attention_heads: None,
742 num_key_value_heads: None,
743 license: None,
744 };
745
746 let result = model.best_quant_for_budget_with(10.0, 4096, MLX_QUANT_HIERARCHY);
748 assert!(result.is_some());
749 let (quant, _) = result.unwrap();
750 assert_eq!(quant, "mlx-8bit");
751
752 let result = model.best_quant_for_budget_with(5.0, 4096, MLX_QUANT_HIERARCHY);
754 assert!(result.is_some());
755 let (quant, _) = result.unwrap();
756 assert_eq!(quant, "mlx-4bit");
757 }
758
759 #[test]
760 fn test_quant_bpp() {
761 assert_eq!(quant_bpp("F32"), 4.0);
762 assert_eq!(quant_bpp("F16"), 2.0);
763 assert_eq!(quant_bpp("Q8_0"), 1.05);
764 assert_eq!(quant_bpp("Q4_K_M"), 0.58);
765 assert_eq!(quant_bpp("Q2_K"), 0.37);
766 assert_eq!(quant_bpp("UNKNOWN"), 0.58);
768 }
769
770 #[test]
771 fn test_quant_speed_multiplier() {
772 assert_eq!(quant_speed_multiplier("F16"), 0.6);
773 assert_eq!(quant_speed_multiplier("Q5_K_M"), 1.0);
774 assert_eq!(quant_speed_multiplier("Q4_K_M"), 1.15);
775 assert_eq!(quant_speed_multiplier("Q2_K"), 1.35);
776 assert!(quant_speed_multiplier("Q2_K") > quant_speed_multiplier("Q8_0"));
778 }
779
780 #[test]
781 fn test_quant_quality_penalty() {
782 assert_eq!(quant_quality_penalty("F16"), 0.0);
783 assert_eq!(quant_quality_penalty("Q8_0"), 0.0);
784 assert_eq!(quant_quality_penalty("Q4_K_M"), -5.0);
785 assert_eq!(quant_quality_penalty("Q2_K"), -12.0);
786 assert!(quant_quality_penalty("Q2_K") < quant_quality_penalty("Q8_0"));
788 }
789
790 #[test]
795 fn test_params_b_from_raw() {
796 let model = LlmModel {
797 name: "Test Model".to_string(),
798 provider: "Test".to_string(),
799 parameter_count: "7B".to_string(),
800 parameters_raw: Some(7_000_000_000),
801 min_ram_gb: 4.0,
802 recommended_ram_gb: 8.0,
803 min_vram_gb: Some(4.0),
804 quantization: "Q4_K_M".to_string(),
805 context_length: 4096,
806 use_case: "General".to_string(),
807 is_moe: false,
808 num_experts: None,
809 active_experts: None,
810 active_parameters: None,
811 release_date: None,
812 gguf_sources: vec![],
813 capabilities: vec![],
814 format: ModelFormat::default(),
815 num_attention_heads: None,
816 num_key_value_heads: None,
817 license: None,
818 };
819 assert_eq!(model.params_b(), 7.0);
820 }
821
822 #[test]
823 fn test_params_b_from_string() {
824 let model = LlmModel {
825 name: "Test Model".to_string(),
826 provider: "Test".to_string(),
827 parameter_count: "13B".to_string(),
828 parameters_raw: None,
829 min_ram_gb: 8.0,
830 recommended_ram_gb: 16.0,
831 min_vram_gb: Some(8.0),
832 quantization: "Q4_K_M".to_string(),
833 context_length: 4096,
834 use_case: "General".to_string(),
835 is_moe: false,
836 num_experts: None,
837 active_experts: None,
838 active_parameters: None,
839 release_date: None,
840 gguf_sources: vec![],
841 capabilities: vec![],
842 format: ModelFormat::default(),
843 num_attention_heads: None,
844 num_key_value_heads: None,
845 license: None,
846 };
847 assert_eq!(model.params_b(), 13.0);
848 }
849
850 #[test]
851 fn test_params_b_from_millions() {
852 let model = LlmModel {
853 name: "Test Model".to_string(),
854 provider: "Test".to_string(),
855 parameter_count: "500M".to_string(),
856 parameters_raw: None,
857 min_ram_gb: 1.0,
858 recommended_ram_gb: 2.0,
859 min_vram_gb: Some(1.0),
860 quantization: "Q4_K_M".to_string(),
861 context_length: 2048,
862 use_case: "General".to_string(),
863 is_moe: false,
864 num_experts: None,
865 active_experts: None,
866 active_parameters: None,
867 release_date: None,
868 gguf_sources: vec![],
869 capabilities: vec![],
870 format: ModelFormat::default(),
871 num_attention_heads: None,
872 num_key_value_heads: None,
873 license: None,
874 };
875 assert_eq!(model.params_b(), 0.5);
876 }
877
878 #[test]
879 fn test_estimate_memory_gb() {
880 let model = LlmModel {
881 name: "Test Model".to_string(),
882 provider: "Test".to_string(),
883 parameter_count: "7B".to_string(),
884 parameters_raw: Some(7_000_000_000),
885 min_ram_gb: 4.0,
886 recommended_ram_gb: 8.0,
887 min_vram_gb: Some(4.0),
888 quantization: "Q4_K_M".to_string(),
889 context_length: 4096,
890 use_case: "General".to_string(),
891 is_moe: false,
892 num_experts: None,
893 active_experts: None,
894 active_parameters: None,
895 release_date: None,
896 gguf_sources: vec![],
897 capabilities: vec![],
898 format: ModelFormat::default(),
899 num_attention_heads: None,
900 num_key_value_heads: None,
901 license: None,
902 };
903
904 let mem = model.estimate_memory_gb("Q4_K_M", 4096);
905 assert!(mem > 4.0);
907 assert!(mem < 6.0);
908
909 let mem_q8 = model.estimate_memory_gb("Q8_0", 4096);
911 assert!(mem_q8 > mem);
912 }
913
914 #[test]
915 fn test_best_quant_for_budget() {
916 let model = LlmModel {
917 name: "Test Model".to_string(),
918 provider: "Test".to_string(),
919 parameter_count: "7B".to_string(),
920 parameters_raw: Some(7_000_000_000),
921 min_ram_gb: 4.0,
922 recommended_ram_gb: 8.0,
923 min_vram_gb: Some(4.0),
924 quantization: "Q4_K_M".to_string(),
925 context_length: 4096,
926 use_case: "General".to_string(),
927 is_moe: false,
928 num_experts: None,
929 active_experts: None,
930 active_parameters: None,
931 release_date: None,
932 gguf_sources: vec![],
933 capabilities: vec![],
934 format: ModelFormat::default(),
935 num_attention_heads: None,
936 num_key_value_heads: None,
937 license: None,
938 };
939
940 let result = model.best_quant_for_budget(10.0, 4096);
942 assert!(result.is_some());
943 let (quant, _) = result.unwrap();
944 assert_eq!(quant, "Q8_0");
945
946 let result = model.best_quant_for_budget(5.0, 4096);
948 assert!(result.is_some());
949
950 let result = model.best_quant_for_budget(1.0, 4096);
952 assert!(result.is_none());
953 }
954
955 #[test]
956 fn test_moe_active_vram_gb() {
957 let dense_model = LlmModel {
959 name: "Dense Model".to_string(),
960 provider: "Test".to_string(),
961 parameter_count: "7B".to_string(),
962 parameters_raw: Some(7_000_000_000),
963 min_ram_gb: 4.0,
964 recommended_ram_gb: 8.0,
965 min_vram_gb: Some(4.0),
966 quantization: "Q4_K_M".to_string(),
967 context_length: 4096,
968 use_case: "General".to_string(),
969 is_moe: false,
970 num_experts: None,
971 active_experts: None,
972 active_parameters: None,
973 release_date: None,
974 gguf_sources: vec![],
975 capabilities: vec![],
976 format: ModelFormat::default(),
977 num_attention_heads: None,
978 num_key_value_heads: None,
979 license: None,
980 };
981 assert!(dense_model.moe_active_vram_gb().is_none());
982
983 let moe_model = LlmModel {
985 name: "MoE Model".to_string(),
986 provider: "Test".to_string(),
987 parameter_count: "8x7B".to_string(),
988 parameters_raw: Some(46_700_000_000),
989 min_ram_gb: 25.0,
990 recommended_ram_gb: 50.0,
991 min_vram_gb: Some(25.0),
992 quantization: "Q4_K_M".to_string(),
993 context_length: 32768,
994 use_case: "General".to_string(),
995 is_moe: true,
996 num_experts: Some(8),
997 active_experts: Some(2),
998 active_parameters: Some(12_900_000_000),
999 release_date: None,
1000 gguf_sources: vec![],
1001 capabilities: vec![],
1002 format: ModelFormat::default(),
1003 num_attention_heads: None,
1004 num_key_value_heads: None,
1005 license: None,
1006 };
1007 let vram = moe_model.moe_active_vram_gb();
1008 assert!(vram.is_some());
1009 let vram_val = vram.unwrap();
1010 assert!(vram_val > 0.0);
1012 assert!(vram_val < 15.0);
1013 }
1014
1015 #[test]
1016 fn test_moe_offloaded_ram_gb() {
1017 let dense_model = LlmModel {
1019 name: "Dense Model".to_string(),
1020 provider: "Test".to_string(),
1021 parameter_count: "7B".to_string(),
1022 parameters_raw: Some(7_000_000_000),
1023 min_ram_gb: 4.0,
1024 recommended_ram_gb: 8.0,
1025 min_vram_gb: Some(4.0),
1026 quantization: "Q4_K_M".to_string(),
1027 context_length: 4096,
1028 use_case: "General".to_string(),
1029 is_moe: false,
1030 num_experts: None,
1031 active_experts: None,
1032 active_parameters: None,
1033 release_date: None,
1034 gguf_sources: vec![],
1035 capabilities: vec![],
1036 format: ModelFormat::default(),
1037 num_attention_heads: None,
1038 num_key_value_heads: None,
1039 license: None,
1040 };
1041 assert!(dense_model.moe_offloaded_ram_gb().is_none());
1042
1043 let moe_model = LlmModel {
1045 name: "MoE Model".to_string(),
1046 provider: "Test".to_string(),
1047 parameter_count: "8x7B".to_string(),
1048 parameters_raw: Some(46_700_000_000),
1049 min_ram_gb: 25.0,
1050 recommended_ram_gb: 50.0,
1051 min_vram_gb: Some(25.0),
1052 quantization: "Q4_K_M".to_string(),
1053 context_length: 32768,
1054 use_case: "General".to_string(),
1055 is_moe: true,
1056 num_experts: Some(8),
1057 active_experts: Some(2),
1058 active_parameters: Some(12_900_000_000),
1059 release_date: None,
1060 gguf_sources: vec![],
1061 capabilities: vec![],
1062 format: ModelFormat::default(),
1063 num_attention_heads: None,
1064 num_key_value_heads: None,
1065 license: None,
1066 };
1067 let offloaded = moe_model.moe_offloaded_ram_gb();
1068 assert!(offloaded.is_some());
1069 let offloaded_val = offloaded.unwrap();
1070 assert!(offloaded_val > 10.0);
1072 }
1073
1074 #[test]
1079 fn test_use_case_from_model_coding() {
1080 let model = LlmModel {
1081 name: "codellama-7b".to_string(),
1082 provider: "Meta".to_string(),
1083 parameter_count: "7B".to_string(),
1084 parameters_raw: Some(7_000_000_000),
1085 min_ram_gb: 4.0,
1086 recommended_ram_gb: 8.0,
1087 min_vram_gb: Some(4.0),
1088 quantization: "Q4_K_M".to_string(),
1089 context_length: 4096,
1090 use_case: "Coding".to_string(),
1091 is_moe: false,
1092 num_experts: None,
1093 active_experts: None,
1094 active_parameters: None,
1095 release_date: None,
1096 gguf_sources: vec![],
1097 capabilities: vec![],
1098 format: ModelFormat::default(),
1099 num_attention_heads: None,
1100 num_key_value_heads: None,
1101 license: None,
1102 };
1103 assert_eq!(UseCase::from_model(&model), UseCase::Coding);
1104 }
1105
1106 #[test]
1107 fn test_use_case_from_model_embedding() {
1108 let model = LlmModel {
1109 name: "bge-large".to_string(),
1110 provider: "BAAI".to_string(),
1111 parameter_count: "335M".to_string(),
1112 parameters_raw: Some(335_000_000),
1113 min_ram_gb: 1.0,
1114 recommended_ram_gb: 2.0,
1115 min_vram_gb: Some(1.0),
1116 quantization: "F16".to_string(),
1117 context_length: 512,
1118 use_case: "Embedding".to_string(),
1119 is_moe: false,
1120 num_experts: None,
1121 active_experts: None,
1122 active_parameters: None,
1123 release_date: None,
1124 gguf_sources: vec![],
1125 capabilities: vec![],
1126 format: ModelFormat::default(),
1127 num_attention_heads: None,
1128 num_key_value_heads: None,
1129 license: None,
1130 };
1131 assert_eq!(UseCase::from_model(&model), UseCase::Embedding);
1132 }
1133
1134 #[test]
1135 fn test_use_case_from_model_reasoning() {
1136 let model = LlmModel {
1137 name: "deepseek-r1-7b".to_string(),
1138 provider: "DeepSeek".to_string(),
1139 parameter_count: "7B".to_string(),
1140 parameters_raw: Some(7_000_000_000),
1141 min_ram_gb: 4.0,
1142 recommended_ram_gb: 8.0,
1143 min_vram_gb: Some(4.0),
1144 quantization: "Q4_K_M".to_string(),
1145 context_length: 8192,
1146 use_case: "Reasoning".to_string(),
1147 is_moe: false,
1148 num_experts: None,
1149 active_experts: None,
1150 active_parameters: None,
1151 release_date: None,
1152 gguf_sources: vec![],
1153 capabilities: vec![],
1154 format: ModelFormat::default(),
1155 num_attention_heads: None,
1156 num_key_value_heads: None,
1157 license: None,
1158 };
1159 assert_eq!(UseCase::from_model(&model), UseCase::Reasoning);
1160 }
1161
1162 #[test]
1167 fn test_model_database_new() {
1168 let db = ModelDatabase::new();
1169 let models = db.get_all_models();
1170 assert!(!models.is_empty());
1172 }
1173
1174 #[test]
1175 fn test_find_model() {
1176 let db = ModelDatabase::new();
1177
1178 let results = db.find_model("llama");
1180 assert!(!results.is_empty());
1181 assert!(
1182 results
1183 .iter()
1184 .any(|m| m.name.to_lowercase().contains("llama"))
1185 );
1186
1187 let results_upper = db.find_model("LLAMA");
1189 assert_eq!(results.len(), results_upper.len());
1190 }
1191
1192 #[test]
1193 fn test_models_fitting_system() {
1194 let db = ModelDatabase::new();
1195
1196 let fitting = db.models_fitting_system(32.0, true, Some(24.0));
1198 assert!(!fitting.is_empty());
1199
1200 let fitting_small = db.models_fitting_system(2.0, false, None);
1202 assert!(fitting_small.len() < fitting.len());
1203
1204 for model in fitting_small {
1206 assert!(model.min_ram_gb <= 2.0);
1207 }
1208 }
1209
1210 #[test]
1215 fn test_capability_infer_vision() {
1216 let model = LlmModel {
1217 name: "meta-llama/Llama-3.2-11B-Vision-Instruct".to_string(),
1218 provider: "Meta".to_string(),
1219 parameter_count: "11B".to_string(),
1220 parameters_raw: Some(11_000_000_000),
1221 min_ram_gb: 6.0,
1222 recommended_ram_gb: 10.0,
1223 min_vram_gb: Some(6.0),
1224 quantization: "Q4_K_M".to_string(),
1225 context_length: 131072,
1226 use_case: "Multimodal, vision and text".to_string(),
1227 is_moe: false,
1228 num_experts: None,
1229 active_experts: None,
1230 active_parameters: None,
1231 release_date: None,
1232 gguf_sources: vec![],
1233 capabilities: vec![],
1234 format: ModelFormat::default(),
1235 num_attention_heads: None,
1236 num_key_value_heads: None,
1237 license: None,
1238 };
1239 let caps = Capability::infer(&model);
1240 assert!(caps.contains(&Capability::Vision));
1241 assert!(caps.contains(&Capability::ToolUse));
1243 }
1244
1245 #[test]
1246 fn test_capability_infer_tool_use() {
1247 let model = LlmModel {
1248 name: "Qwen/Qwen3-8B".to_string(),
1249 provider: "Qwen".to_string(),
1250 parameter_count: "8B".to_string(),
1251 parameters_raw: Some(8_000_000_000),
1252 min_ram_gb: 4.5,
1253 recommended_ram_gb: 8.0,
1254 min_vram_gb: Some(4.0),
1255 quantization: "Q4_K_M".to_string(),
1256 context_length: 32768,
1257 use_case: "General purpose text generation".to_string(),
1258 is_moe: false,
1259 num_experts: None,
1260 active_experts: None,
1261 active_parameters: None,
1262 release_date: None,
1263 gguf_sources: vec![],
1264 capabilities: vec![],
1265 format: ModelFormat::default(),
1266 num_attention_heads: None,
1267 num_key_value_heads: None,
1268 license: None,
1269 };
1270 let caps = Capability::infer(&model);
1271 assert!(caps.contains(&Capability::ToolUse));
1272 assert!(!caps.contains(&Capability::Vision));
1273 }
1274
1275 #[test]
1276 fn test_capability_infer_none() {
1277 let model = LlmModel {
1278 name: "BAAI/bge-large-en-v1.5".to_string(),
1279 provider: "BAAI".to_string(),
1280 parameter_count: "335M".to_string(),
1281 parameters_raw: Some(335_000_000),
1282 min_ram_gb: 1.0,
1283 recommended_ram_gb: 2.0,
1284 min_vram_gb: Some(1.0),
1285 quantization: "F16".to_string(),
1286 context_length: 512,
1287 use_case: "Text embeddings for RAG".to_string(),
1288 is_moe: false,
1289 num_experts: None,
1290 active_experts: None,
1291 active_parameters: None,
1292 release_date: None,
1293 gguf_sources: vec![],
1294 capabilities: vec![],
1295 format: ModelFormat::default(),
1296 num_attention_heads: None,
1297 num_key_value_heads: None,
1298 license: None,
1299 };
1300 let caps = Capability::infer(&model);
1301 assert!(caps.is_empty());
1302 }
1303
1304 #[test]
1305 fn test_capability_preserves_explicit() {
1306 let model = LlmModel {
1307 name: "some-model".to_string(),
1308 provider: "Test".to_string(),
1309 parameter_count: "7B".to_string(),
1310 parameters_raw: Some(7_000_000_000),
1311 min_ram_gb: 4.0,
1312 recommended_ram_gb: 8.0,
1313 min_vram_gb: Some(4.0),
1314 quantization: "Q4_K_M".to_string(),
1315 context_length: 4096,
1316 use_case: "General".to_string(),
1317 is_moe: false,
1318 num_experts: None,
1319 active_experts: None,
1320 active_parameters: None,
1321 release_date: None,
1322 gguf_sources: vec![],
1323 capabilities: vec![Capability::Vision],
1324 format: ModelFormat::default(),
1325 num_attention_heads: None,
1326 num_key_value_heads: None,
1327 license: None,
1328 };
1329 let caps = Capability::infer(&model);
1330 assert_eq!(caps.iter().filter(|c| **c == Capability::Vision).count(), 1);
1332 }
1333
1334 #[test]
1335 fn test_awq_gptq_quant_values() {
1336 assert_eq!(quant_bpp("AWQ-4bit"), 0.5);
1338 assert_eq!(quant_bpp("AWQ-8bit"), 1.0);
1339 assert_eq!(quant_speed_multiplier("AWQ-4bit"), 1.2);
1340 assert_eq!(quant_speed_multiplier("AWQ-8bit"), 0.85);
1341 assert_eq!(quant_quality_penalty("AWQ-4bit"), -3.0);
1342 assert_eq!(quant_quality_penalty("AWQ-8bit"), 0.0);
1343 assert_eq!(quant_bpp("GPTQ-Int4"), 0.5);
1345 assert_eq!(quant_bpp("GPTQ-Int8"), 1.0);
1346 assert_eq!(quant_speed_multiplier("GPTQ-Int4"), 1.2);
1347 assert_eq!(quant_speed_multiplier("GPTQ-Int8"), 0.85);
1348 assert_eq!(quant_quality_penalty("GPTQ-Int4"), -3.0);
1349 assert_eq!(quant_quality_penalty("GPTQ-Int8"), 0.0);
1350 }
1351
1352 #[test]
1353 fn test_model_format_prequantized() {
1354 assert!(ModelFormat::Awq.is_prequantized());
1355 assert!(ModelFormat::Gptq.is_prequantized());
1356 assert!(!ModelFormat::Gguf.is_prequantized());
1357 assert!(!ModelFormat::Mlx.is_prequantized());
1358 assert!(!ModelFormat::Safetensors.is_prequantized());
1359 }
1360
1361 #[test]
1366 fn test_gguf_source_deserialization() {
1367 let json = r#"{"repo": "unsloth/Llama-3.1-8B-Instruct-GGUF", "provider": "unsloth"}"#;
1368 let source: GgufSource = serde_json::from_str(json).unwrap();
1369 assert_eq!(source.repo, "unsloth/Llama-3.1-8B-Instruct-GGUF");
1370 assert_eq!(source.provider, "unsloth");
1371 }
1372
1373 #[test]
1374 fn test_gguf_sources_default_to_empty() {
1375 let json = r#"{
1376 "name": "test/model",
1377 "provider": "Test",
1378 "parameter_count": "7B",
1379 "parameters_raw": 7000000000,
1380 "min_ram_gb": 4.0,
1381 "recommended_ram_gb": 8.0,
1382 "quantization": "Q4_K_M",
1383 "context_length": 4096,
1384 "use_case": "General"
1385 }"#;
1386 let entry: HfModelEntry = serde_json::from_str(json).unwrap();
1387 assert!(entry.gguf_sources.is_empty());
1388 }
1389
1390 #[test]
1391 fn test_catalog_popular_models_have_gguf_sources() {
1392 let db = ModelDatabase::new();
1393 let expected_with_gguf = [
1395 "meta-llama/Llama-3.3-70B-Instruct",
1396 "Qwen/Qwen2.5-7B-Instruct",
1397 "Qwen/Qwen2.5-Coder-7B-Instruct",
1398 "meta-llama/Meta-Llama-3-8B-Instruct",
1399 "mistralai/Mistral-7B-Instruct-v0.3",
1400 ];
1401 for name in &expected_with_gguf {
1402 let model = db.get_all_models().iter().find(|m| m.name == *name);
1403 assert!(model.is_some(), "Model {} should exist in catalog", name);
1404 let model = model.unwrap();
1405 assert!(
1406 !model.gguf_sources.is_empty(),
1407 "Model {} should have gguf_sources but has none",
1408 name
1409 );
1410 }
1411 }
1412
1413 #[test]
1414 fn test_catalog_gguf_sources_have_valid_repos() {
1415 let db = ModelDatabase::new();
1416 for model in db.get_all_models() {
1417 for source in &model.gguf_sources {
1418 assert!(
1419 source.repo.contains('/'),
1420 "GGUF source repo '{}' for model '{}' should be owner/repo format",
1421 source.repo,
1422 model.name
1423 );
1424 assert!(
1425 !source.provider.is_empty(),
1426 "GGUF source provider for model '{}' should not be empty",
1427 model.name
1428 );
1429 assert!(
1430 source.repo.to_uppercase().contains("GGUF"),
1431 "GGUF source repo '{}' for model '{}' should contain 'GGUF'",
1432 source.repo,
1433 model.name
1434 );
1435 }
1436 }
1437 }
1438
1439 #[test]
1440 #[ignore] fn test_catalog_has_significant_gguf_coverage() {
1442 let db = ModelDatabase::new();
1443 let total = db.get_all_models().len();
1444 let with_gguf = db
1445 .get_all_models()
1446 .iter()
1447 .filter(|m| !m.gguf_sources.is_empty())
1448 .count();
1449 let coverage_pct = (with_gguf as f64 / total as f64) * 100.0;
1451 assert!(
1452 coverage_pct >= 25.0,
1453 "GGUF source coverage is only {:.1}% ({}/{}), expected at least 25%",
1454 coverage_pct,
1455 with_gguf,
1456 total
1457 );
1458 }
1459
1460 fn tp_test_model(
1465 name: &str,
1466 params_b: f64,
1467 attn_heads: Option<u32>,
1468 kv_heads: Option<u32>,
1469 ) -> LlmModel {
1470 LlmModel {
1471 name: name.to_string(),
1472 provider: "Test".to_string(),
1473 parameter_count: format!("{:.0}B", params_b),
1474 parameters_raw: Some((params_b * 1_000_000_000.0) as u64),
1475 min_ram_gb: 4.0,
1476 recommended_ram_gb: 8.0,
1477 min_vram_gb: Some(4.0),
1478 quantization: "Q4_K_M".to_string(),
1479 context_length: 4096,
1480 use_case: "General".to_string(),
1481 is_moe: false,
1482 num_experts: None,
1483 active_experts: None,
1484 active_parameters: None,
1485 release_date: None,
1486 gguf_sources: vec![],
1487 capabilities: vec![],
1488 format: ModelFormat::default(),
1489 num_attention_heads: attn_heads,
1490 num_key_value_heads: kv_heads,
1491 license: None,
1492 }
1493 }
1494
1495 #[test]
1496 fn test_supports_tp_with_explicit_heads() {
1497 let model = tp_test_model("Test-8B", 8.0, Some(32), Some(8));
1498 assert!(model.supports_tp(1));
1499 assert!(model.supports_tp(2));
1500 assert!(model.supports_tp(4));
1501 assert!(model.supports_tp(8));
1502 assert!(!model.supports_tp(3)); assert!(!model.supports_tp(5));
1504 }
1505
1506 #[test]
1507 fn test_supports_tp_always_true_for_1() {
1508 let model = tp_test_model("Tiny", 1.0, None, None);
1509 assert!(model.supports_tp(1));
1510 }
1511
1512 #[test]
1513 fn test_valid_tp_sizes_32_8() {
1514 let model = tp_test_model("Test", 8.0, Some(32), Some(8));
1515 let sizes = model.valid_tp_sizes();
1516 assert!(sizes.contains(&1));
1517 assert!(sizes.contains(&2));
1518 assert!(sizes.contains(&4));
1519 assert!(sizes.contains(&8));
1520 assert!(!sizes.contains(&3));
1521 }
1522
1523 #[test]
1524 fn test_valid_tp_sizes_48_heads() {
1525 let model = tp_test_model("Llama-32B", 32.0, Some(48), Some(8));
1527 assert!(model.supports_tp(2)); assert!(!model.supports_tp(3)); assert!(model.supports_tp(4)); assert!(model.supports_tp(8)); }
1532
1533 #[test]
1534 fn test_infer_heads_from_name_qwen() {
1535 let (attn, kv) = infer_heads_from_name("Qwen2.5-72B-Instruct", 72.0);
1536 assert_eq!(attn, 64);
1537 assert_eq!(kv, 8);
1538 }
1539
1540 #[test]
1541 fn test_infer_heads_from_name_llama() {
1542 let (attn, kv) = infer_heads_from_name("Llama-3.1-8B", 8.0);
1543 assert_eq!(attn, 32);
1544 assert_eq!(kv, 8);
1545 }
1546
1547 #[test]
1548 fn test_infer_heads_from_name_deepseek() {
1549 let (attn, kv) = infer_heads_from_name("DeepSeek-V3", 671.0);
1550 assert_eq!(attn, 128);
1551 assert_eq!(kv, 16);
1552 }
1553
1554 #[test]
1555 fn test_supports_tp_with_inferred_heads() {
1556 let model = tp_test_model("Llama-3.1-70B", 70.0, None, None);
1558 assert!(model.supports_tp(2));
1559 assert!(model.supports_tp(4));
1560 assert!(model.supports_tp(8));
1561 }
1562}