Skip to main content

llmfit_core/
models.rs

1use serde::{Deserialize, Serialize};
2
3/// Quantization levels ordered from best quality to most compressed.
4/// Used for dynamic quantization selection: try the best that fits.
5pub const QUANT_HIERARCHY: &[&str] = &["Q8_0", "Q6_K", "Q5_K_M", "Q4_K_M", "Q3_K_M", "Q2_K"];
6
7/// MLX-native quantization hierarchy (best quality to most compressed).
8pub const MLX_QUANT_HIERARCHY: &[&str] = &["mlx-8bit", "mlx-4bit"];
9
10/// Bytes per parameter for each quantization level.
11pub fn quant_bpp(quant: &str) -> f64 {
12    match quant {
13        "F32" => 4.0,
14        "F16" | "BF16" => 2.0,
15        "Q8_0" => 1.05,
16        "Q6_K" => 0.80,
17        "Q5_K_M" => 0.68,
18        "Q4_K_M" | "Q4_0" => 0.58,
19        "Q3_K_M" => 0.48,
20        "Q2_K" => 0.37,
21        "mlx-4bit" => 0.55,
22        "mlx-8bit" => 1.0,
23        "AWQ-4bit" => 0.5,
24        "AWQ-8bit" => 1.0,
25        "GPTQ-Int4" => 0.5,
26        "GPTQ-Int8" => 1.0,
27        _ => 0.58,
28    }
29}
30
31/// Speed multiplier for quantization (lower quant = faster inference).
32pub fn quant_speed_multiplier(quant: &str) -> f64 {
33    match quant {
34        "F16" | "BF16" => 0.6,
35        "Q8_0" => 0.8,
36        "Q6_K" => 0.95,
37        "Q5_K_M" => 1.0,
38        "Q4_K_M" | "Q4_0" => 1.15,
39        "Q3_K_M" => 1.25,
40        "Q2_K" => 1.35,
41        "mlx-4bit" => 1.15,
42        "mlx-8bit" => 0.85,
43        "AWQ-4bit" | "GPTQ-Int4" => 1.2,
44        "AWQ-8bit" | "GPTQ-Int8" => 0.85,
45        _ => 1.0,
46    }
47}
48
49/// Bytes per parameter for a given quantization format.
50/// Used by the bandwidth-based tok/s estimator to compute model size in GB.
51pub fn quant_bytes_per_param(quant: &str) -> f64 {
52    match quant {
53        "F16" | "BF16" => 2.0,
54        "Q8_0" => 1.0,
55        "Q6_K" => 0.75,
56        "Q5_K_M" => 0.625,
57        "Q4_K_M" | "Q4_0" => 0.5,
58        "Q3_K_M" => 0.375,
59        "Q2_K" => 0.25,
60        "mlx-4bit" => 0.5,
61        "mlx-8bit" => 1.0,
62        "AWQ-4bit" | "GPTQ-Int4" => 0.5,
63        "AWQ-8bit" | "GPTQ-Int8" => 1.0,
64        _ => 0.5, // default to ~4-bit
65    }
66}
67
68/// Quality penalty for quantization (lower quant = lower quality).
69pub fn quant_quality_penalty(quant: &str) -> f64 {
70    match quant {
71        "F16" | "BF16" => 0.0,
72        "Q8_0" => 0.0,
73        "Q6_K" => -1.0,
74        "Q5_K_M" => -2.0,
75        "Q4_K_M" | "Q4_0" => -5.0,
76        "Q3_K_M" => -8.0,
77        "Q2_K" => -12.0,
78        "mlx-4bit" => -4.0,
79        "mlx-8bit" => 0.0,
80        "AWQ-4bit" => -3.0,
81        "AWQ-8bit" => 0.0,
82        "GPTQ-Int4" => -3.0,
83        "GPTQ-Int8" => 0.0,
84        _ => -5.0,
85    }
86}
87
88/// Model capability flags (orthogonal to UseCase).
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
90#[serde(rename_all = "snake_case")]
91pub enum Capability {
92    Vision,
93    ToolUse,
94}
95
96impl Capability {
97    pub fn label(&self) -> &'static str {
98        match self {
99            Capability::Vision => "Vision",
100            Capability::ToolUse => "Tool Use",
101        }
102    }
103
104    pub fn all() -> &'static [Capability] {
105        &[Capability::Vision, Capability::ToolUse]
106    }
107
108    /// Infer capabilities from model metadata when not explicitly set in JSON.
109    pub fn infer(model: &LlmModel) -> Vec<Capability> {
110        let mut caps = model.capabilities.clone();
111        let name = model.name.to_lowercase();
112        let use_case = model.use_case.to_lowercase();
113
114        // Vision detection
115        if !caps.contains(&Capability::Vision)
116            && (name.contains("vision")
117                || name.contains("-vl-")
118                || name.ends_with("-vl")
119                || name.contains("llava")
120                || name.contains("onevision")
121                || name.contains("pixtral")
122                || use_case.contains("vision")
123                || use_case.contains("multimodal"))
124        {
125            caps.push(Capability::Vision);
126        }
127
128        // Tool use detection (known model families)
129        if !caps.contains(&Capability::ToolUse)
130            && (use_case.contains("tool")
131                || use_case.contains("function call")
132                || name.contains("qwen3")
133                || name.contains("qwen2.5")
134                || name.contains("command-r")
135                || (name.contains("llama-3") && name.contains("instruct"))
136                || (name.contains("mistral") && name.contains("instruct"))
137                || name.contains("hermes"))
138        {
139            caps.push(Capability::ToolUse);
140        }
141
142        caps
143    }
144}
145
146/// Model weight format — determines which inference runtime to use.
147#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
148#[serde(rename_all = "lowercase")]
149#[derive(Default)]
150pub enum ModelFormat {
151    #[default]
152    Gguf,
153    Awq,
154    Gptq,
155    Mlx,
156    Safetensors,
157}
158
159impl ModelFormat {
160    /// Returns true for formats that are pre-quantized at a fixed bit width
161    /// and cannot be dynamically re-quantized (AWQ, GPTQ).
162    pub fn is_prequantized(&self) -> bool {
163        matches!(self, ModelFormat::Awq | ModelFormat::Gptq)
164    }
165}
166
167/// Use-case category for scoring weights.
168#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
169pub enum UseCase {
170    General,
171    Coding,
172    Reasoning,
173    Chat,
174    Multimodal,
175    Embedding,
176}
177
178impl UseCase {
179    pub fn label(&self) -> &'static str {
180        match self {
181            UseCase::General => "General",
182            UseCase::Coding => "Coding",
183            UseCase::Reasoning => "Reasoning",
184            UseCase::Chat => "Chat",
185            UseCase::Multimodal => "Multimodal",
186            UseCase::Embedding => "Embedding",
187        }
188    }
189
190    /// Infer use-case from the model's use_case field and name.
191    pub fn from_model(model: &LlmModel) -> Self {
192        let name = model.name.to_lowercase();
193        let use_case = model.use_case.to_lowercase();
194
195        if use_case.contains("embedding") || name.contains("embed") || name.contains("bge") {
196            UseCase::Embedding
197        } else if name.contains("code") || use_case.contains("code") {
198            UseCase::Coding
199        } else if use_case.contains("vision") || use_case.contains("multimodal") {
200            UseCase::Multimodal
201        } else if use_case.contains("reason")
202            || use_case.contains("chain-of-thought")
203            || name.contains("deepseek-r1")
204        {
205            UseCase::Reasoning
206        } else if use_case.contains("chat") || use_case.contains("instruction") {
207            UseCase::Chat
208        } else {
209            UseCase::General
210        }
211    }
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct LlmModel {
216    pub name: String,
217    pub provider: String,
218    pub parameter_count: String,
219    #[serde(default)]
220    pub parameters_raw: Option<u64>,
221    pub min_ram_gb: f64,
222    pub recommended_ram_gb: f64,
223    pub min_vram_gb: Option<f64>,
224    pub quantization: String,
225    pub context_length: u32,
226    pub use_case: String,
227    #[serde(default)]
228    pub is_moe: bool,
229    #[serde(default)]
230    pub num_experts: Option<u32>,
231    #[serde(default)]
232    pub active_experts: Option<u32>,
233    #[serde(default)]
234    pub active_parameters: Option<u64>,
235    #[serde(default)]
236    pub release_date: Option<String>,
237    /// Known GGUF download sources (e.g. unsloth, bartowski repos on HuggingFace)
238    #[serde(default)]
239    pub gguf_sources: Vec<GgufSource>,
240    /// Model capabilities (vision, tool use, etc.)
241    #[serde(default)]
242    pub capabilities: Vec<Capability>,
243    /// Model weight format (gguf, awq, gptq, mlx, safetensors)
244    #[serde(default)]
245    pub format: ModelFormat,
246    /// Number of attention heads (for tensor-parallelism compatibility checks).
247    #[serde(default)]
248    pub num_attention_heads: Option<u32>,
249    /// Number of key-value heads for GQA (defaults to num_attention_heads if None).
250    #[serde(default)]
251    pub num_key_value_heads: Option<u32>,
252    /// Model license (e.g. "apache-2.0", "mit", "llama3.1")
253    #[serde(default)]
254    pub license: Option<String>,
255}
256
257/// Returns true if a model's license matches any in the comma-separated filter string.
258/// Models without a license never match.
259pub fn matches_license_filter(license: &Option<String>, filter: &str) -> bool {
260    let allowed: Vec<String> = filter.split(',').map(|s| s.trim().to_lowercase()).collect();
261    license
262        .as_ref()
263        .map(|l| allowed.contains(&l.to_lowercase()))
264        .unwrap_or(false)
265}
266
267/// A known GGUF download source for a model on HuggingFace.
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct GgufSource {
270    /// HuggingFace repo ID (e.g. "unsloth/Llama-3.1-8B-Instruct-GGUF")
271    pub repo: String,
272    /// Provider who published the GGUF (e.g. "unsloth", "bartowski")
273    pub provider: String,
274}
275
276impl LlmModel {
277    /// MLX models are Apple-only — they won't run on NVIDIA/AMD/Intel hardware.
278    /// We detect them by the `-MLX-` suffix that's standard on HuggingFace
279    /// (e.g. `Qwen3-8B-MLX-4bit`, `LFM2-1.2B-MLX-8bit`).
280    pub fn is_mlx_model(&self) -> bool {
281        let name_lower = self.name.to_lowercase();
282        name_lower.contains("-mlx-") || name_lower.ends_with("-mlx")
283    }
284
285    /// Returns true if this model uses a pre-quantized format (AWQ/GPTQ)
286    /// that cannot be dynamically re-quantized.
287    pub fn is_prequantized(&self) -> bool {
288        self.format.is_prequantized()
289    }
290
291    /// Returns true if the model's attention/KV heads are evenly divisible
292    /// by `tp_size`, meaning it can be split across that many devices.
293    /// TP=1 always returns true.
294    pub fn supports_tp(&self, tp_size: u32) -> bool {
295        if tp_size <= 1 {
296            return true;
297        }
298        let (attn, kv) = self.infer_head_counts();
299        attn % tp_size == 0 && kv % tp_size == 0
300    }
301
302    /// Returns all valid TP degrees in [1..=8] for this model.
303    pub fn valid_tp_sizes(&self) -> Vec<u32> {
304        (1..=8).filter(|&tp| self.supports_tp(tp)).collect()
305    }
306
307    /// Infer attention and KV head counts from metadata or model name heuristics.
308    fn infer_head_counts(&self) -> (u32, u32) {
309        if let (Some(attn), Some(kv)) = (self.num_attention_heads, self.num_key_value_heads) {
310            return (attn, kv);
311        }
312        if let Some(attn) = self.num_attention_heads {
313            return (attn, attn);
314        }
315        // Heuristic: infer from model name
316        infer_heads_from_name(&self.name, self.params_b())
317    }
318
319    /// Bytes-per-parameter for the model's quantization level.
320    fn quant_bpp(&self) -> f64 {
321        quant_bpp(&self.quantization)
322    }
323
324    /// Parameter count in billions, extracted from parameters_raw or parameter_count.
325    pub fn params_b(&self) -> f64 {
326        if let Some(raw) = self.parameters_raw {
327            raw as f64 / 1_000_000_000.0
328        } else {
329            // Parse from string like "7B", "1.1B", "137M"
330            let s = self.parameter_count.trim().to_uppercase();
331            if let Some(num_str) = s.strip_suffix('B') {
332                num_str.parse::<f64>().unwrap_or(7.0)
333            } else if let Some(num_str) = s.strip_suffix('M') {
334                num_str.parse::<f64>().unwrap_or(0.0) / 1000.0
335            } else {
336                7.0
337            }
338        }
339    }
340
341    /// Estimate memory required (GB) at a given quantization and context length.
342    /// Formula: model_weights + KV_cache + runtime_overhead
343    pub fn estimate_memory_gb(&self, quant: &str, ctx: u32) -> f64 {
344        let bpp = quant_bpp(quant);
345        let params = self.params_b();
346        let model_mem = params * bpp;
347        // KV cache: ~0.000008 GB per billion params per context token
348        let kv_cache = 0.000008 * params * ctx as f64;
349        // Runtime overhead (CUDA/Metal context, buffers)
350        let overhead = 0.5;
351        model_mem + kv_cache + overhead
352    }
353
354    /// Select the best quantization level that fits within a memory budget.
355    /// Returns the quant name and estimated memory in GB, or None if nothing fits.
356    pub fn best_quant_for_budget(&self, budget_gb: f64, ctx: u32) -> Option<(&'static str, f64)> {
357        self.best_quant_for_budget_with(budget_gb, ctx, QUANT_HIERARCHY)
358    }
359
360    /// Select the best quantization from a custom hierarchy that fits within a memory budget.
361    pub fn best_quant_for_budget_with(
362        &self,
363        budget_gb: f64,
364        ctx: u32,
365        hierarchy: &[&'static str],
366    ) -> Option<(&'static str, f64)> {
367        // Try best quality first
368        for &q in hierarchy {
369            let mem = self.estimate_memory_gb(q, ctx);
370            if mem <= budget_gb {
371                return Some((q, mem));
372            }
373        }
374        // Try halving context once
375        let half_ctx = ctx / 2;
376        if half_ctx >= 1024 {
377            for &q in hierarchy {
378                let mem = self.estimate_memory_gb(q, half_ctx);
379                if mem <= budget_gb {
380                    return Some((q, mem));
381                }
382            }
383        }
384        None
385    }
386
387    /// For MoE models, compute estimated VRAM for active experts only.
388    /// Returns None for dense models.
389    pub fn moe_active_vram_gb(&self) -> Option<f64> {
390        if !self.is_moe {
391            return None;
392        }
393        let active_params = self.active_parameters? as f64;
394        let bpp = self.quant_bpp();
395        let size_gb = (active_params * bpp) / (1024.0 * 1024.0 * 1024.0);
396        Some((size_gb * 1.1).max(0.5))
397    }
398
399    /// Returns true if this model is MLX-specific (Apple Silicon only).
400    /// MLX models are identified by having "-MLX" in their name.
401    pub fn is_mlx_only(&self) -> bool {
402        self.name.to_uppercase().contains("-MLX")
403    }
404
405    /// For MoE models, compute RAM needed for offloaded (inactive) experts.
406    /// Returns None for dense models.
407    pub fn moe_offloaded_ram_gb(&self) -> Option<f64> {
408        if !self.is_moe {
409            return None;
410        }
411        let active = self.active_parameters? as f64;
412        let total = self.parameters_raw? as f64;
413        let inactive = total - active;
414        if inactive <= 0.0 {
415            return Some(0.0);
416        }
417        let bpp = self.quant_bpp();
418        Some((inactive * bpp) / (1024.0 * 1024.0 * 1024.0))
419    }
420}
421
422/// Intermediate struct matching the JSON schema from the scraper.
423/// Extra fields are ignored when mapping to LlmModel.
424#[derive(Debug, Clone, Deserialize)]
425struct HfModelEntry {
426    name: String,
427    provider: String,
428    parameter_count: String,
429    #[serde(default)]
430    parameters_raw: Option<u64>,
431    min_ram_gb: f64,
432    recommended_ram_gb: f64,
433    min_vram_gb: Option<f64>,
434    quantization: String,
435    context_length: u32,
436    use_case: String,
437    #[serde(default)]
438    is_moe: bool,
439    #[serde(default)]
440    num_experts: Option<u32>,
441    #[serde(default)]
442    active_experts: Option<u32>,
443    #[serde(default)]
444    active_parameters: Option<u64>,
445    #[serde(default)]
446    release_date: Option<String>,
447    #[serde(default)]
448    gguf_sources: Vec<GgufSource>,
449    #[serde(default)]
450    capabilities: Vec<Capability>,
451    #[serde(default)]
452    format: ModelFormat,
453    #[serde(default)]
454    hf_downloads: u64,
455    #[serde(default)]
456    hf_likes: u64,
457    #[serde(default)]
458    license: Option<String>,
459}
460
461const HF_MODELS_JSON: &str = include_str!("../data/hf_models.json");
462
463pub struct ModelDatabase {
464    models: Vec<LlmModel>,
465}
466
467impl Default for ModelDatabase {
468    fn default() -> Self {
469        Self::new()
470    }
471}
472
473/// Normalize a model name/ID to a canonical slug for deduplication.
474///
475/// Strips the `org/` prefix, lowercases, and collapses `-`/`_`/`.` so that
476/// `meta-llama/Llama-3.1-8B` and `meta-llama/llama-3.1-8b` compare equal.
477pub(crate) fn canonical_slug(name: &str) -> String {
478    let slug = name.split('/').next_back().unwrap_or(name);
479    slug.to_lowercase().replace(['-', '_', '.'], "")
480}
481
482/// Parse the compile-time embedded JSON into a flat `Vec<LlmModel>`.
483fn load_embedded() -> Vec<LlmModel> {
484    let entries: Vec<HfModelEntry> =
485        serde_json::from_str(HF_MODELS_JSON).expect("Failed to parse embedded hf_models.json");
486    entries
487        .into_iter()
488        .map(|e| {
489            let mut model = LlmModel {
490                name: e.name,
491                provider: e.provider,
492                parameter_count: e.parameter_count,
493                parameters_raw: e.parameters_raw,
494                min_ram_gb: e.min_ram_gb,
495                recommended_ram_gb: e.recommended_ram_gb,
496                min_vram_gb: e.min_vram_gb,
497                quantization: e.quantization,
498                context_length: e.context_length,
499                use_case: e.use_case,
500                is_moe: e.is_moe,
501                num_experts: e.num_experts,
502                active_experts: e.active_experts,
503                active_parameters: e.active_parameters,
504                release_date: e.release_date,
505                gguf_sources: e.gguf_sources,
506                capabilities: e.capabilities,
507                format: e.format,
508                num_attention_heads: None,
509                num_key_value_heads: None,
510                license: e.license,
511            };
512            model.capabilities = Capability::infer(&model);
513            model
514        })
515        .collect()
516}
517
518impl ModelDatabase {
519    /// Load only the compile-time embedded model list (no cache).
520    /// Used internally by the updater to determine which models are already known.
521    pub fn embedded() -> Self {
522        ModelDatabase {
523            models: load_embedded(),
524        }
525    }
526
527    /// Load the embedded model list **and** merge any locally cached models.
528    ///
529    /// Cached models are appended after the embedded ones; if an ID already
530    /// exists in the embedded list it is skipped to avoid duplication.
531    /// Silently ignores a missing or corrupt cache file.
532    pub fn new() -> Self {
533        let mut models = load_embedded();
534
535        // Merge cached models (from `llmfit update`) without duplicating.
536        // canonical_slug normalizes org/ prefix, case, and separators so that
537        // e.g. `meta-llama/Llama-3.1-8B` and `meta-llama/llama-3.1-8b` are
538        // treated as the same model.
539        let embedded_keys: std::collections::HashSet<String> =
540            models.iter().map(|m| canonical_slug(&m.name)).collect();
541
542        for cached in crate::update::load_cache() {
543            if !embedded_keys.contains(&canonical_slug(&cached.name)) {
544                models.push(cached);
545            }
546        }
547
548        ModelDatabase { models }
549    }
550
551    pub fn get_all_models(&self) -> &Vec<LlmModel> {
552        &self.models
553    }
554
555    pub fn find_model(&self, query: &str) -> Vec<&LlmModel> {
556        let query_lower = query.to_lowercase();
557        self.models
558            .iter()
559            .filter(|m| {
560                m.name.to_lowercase().contains(&query_lower)
561                    || m.provider.to_lowercase().contains(&query_lower)
562                    || m.parameter_count.to_lowercase().contains(&query_lower)
563            })
564            .collect()
565    }
566
567    pub fn models_fitting_system(
568        &self,
569        available_ram_gb: f64,
570        has_gpu: bool,
571        vram_gb: Option<f64>,
572    ) -> Vec<&LlmModel> {
573        self.models
574            .iter()
575            .filter(|m| {
576                // Check RAM requirement
577                let ram_ok = m.min_ram_gb <= available_ram_gb;
578
579                // If model requires GPU and system has GPU, check VRAM
580                if let Some(min_vram) = m.min_vram_gb {
581                    if has_gpu {
582                        if let Some(system_vram) = vram_gb {
583                            ram_ok && min_vram <= system_vram
584                        } else {
585                            // GPU detected but VRAM unknown, allow but warn
586                            ram_ok
587                        }
588                    } else {
589                        // Model prefers GPU but can run on CPU with enough RAM
590                        ram_ok && available_ram_gb >= m.recommended_ram_gb
591                    }
592                } else {
593                    ram_ok
594                }
595            })
596            .collect()
597    }
598}
599
600/// Infer attention and KV head counts from the model name and parameter count.
601/// Used as a fallback when explicit head counts are not available in the model metadata.
602fn infer_heads_from_name(name: &str, params_b: f64) -> (u32, u32) {
603    let name_lower = name.to_lowercase();
604
605    // Qwen family
606    if name_lower.contains("qwen") {
607        if params_b > 100.0 {
608            return (128, 16);
609        } else if params_b > 50.0 {
610            return (64, 8);
611        } else if params_b > 25.0 {
612            return (40, 8);
613        } else if params_b > 10.0 {
614            return (40, 8);
615        } else if params_b > 5.0 {
616            return (32, 8);
617        } else {
618            return (16, 4);
619        }
620    }
621
622    // Llama family
623    if name_lower.contains("llama") {
624        if name_lower.contains("scout") || name_lower.contains("maverick") {
625            return (64, 8);
626        } else if params_b > 60.0 {
627            return (64, 8);
628        } else if params_b > 20.0 {
629            return (48, 8);
630        } else if params_b > 5.0 {
631            return (32, 8);
632        } else {
633            return (16, 8);
634        }
635    }
636
637    // DeepSeek family
638    if name_lower.contains("deepseek") {
639        if params_b > 200.0 {
640            return (128, 16);
641        } else if params_b > 50.0 {
642            return (64, 8);
643        } else if params_b > 25.0 {
644            return (40, 8);
645        } else if params_b > 10.0 {
646            return (40, 8);
647        } else {
648            return (32, 8);
649        }
650    }
651
652    // Mistral/Mixtral
653    if name_lower.contains("mistral") || name_lower.contains("mixtral") {
654        if params_b > 100.0 {
655            return (96, 8);
656        } else if params_b > 20.0 {
657            return (32, 8);
658        } else {
659            return (32, 8);
660        }
661    }
662
663    // Gemma
664    if name_lower.contains("gemma") {
665        if params_b > 20.0 {
666            return (32, 16);
667        } else if params_b > 5.0 {
668            return (16, 8);
669        } else {
670            return (8, 4);
671        }
672    }
673
674    // Phi
675    if name_lower.contains("phi") {
676        if params_b > 10.0 {
677            return (40, 10);
678        } else {
679            return (32, 8);
680        }
681    }
682
683    // MiniMax
684    if name_lower.contains("minimax") {
685        return (48, 8);
686    }
687
688    // Default: common pattern based on param count
689    if params_b > 100.0 {
690        (128, 16)
691    } else if params_b > 50.0 {
692        (64, 8)
693    } else if params_b > 20.0 {
694        (32, 8)
695    } else if params_b > 5.0 {
696        (32, 8)
697    } else {
698        (16, 4)
699    }
700}
701
702#[cfg(test)]
703mod tests {
704    use super::*;
705
706    // ────────────────────────────────────────────────────────────────────
707    // Quantization function tests
708    // ────────────────────────────────────────────────────────────────────
709
710    #[test]
711    fn test_mlx_quant_bpp_values() {
712        assert_eq!(quant_bpp("mlx-4bit"), 0.55);
713        assert_eq!(quant_bpp("mlx-8bit"), 1.0);
714        assert_eq!(quant_speed_multiplier("mlx-4bit"), 1.15);
715        assert_eq!(quant_speed_multiplier("mlx-8bit"), 0.85);
716        assert_eq!(quant_quality_penalty("mlx-4bit"), -4.0);
717        assert_eq!(quant_quality_penalty("mlx-8bit"), 0.0);
718    }
719
720    #[test]
721    fn test_best_quant_with_mlx_hierarchy() {
722        let model = LlmModel {
723            name: "Test Model".to_string(),
724            provider: "Test".to_string(),
725            parameter_count: "7B".to_string(),
726            parameters_raw: Some(7_000_000_000),
727            min_ram_gb: 4.0,
728            recommended_ram_gb: 8.0,
729            min_vram_gb: Some(4.0),
730            quantization: "Q4_K_M".to_string(),
731            context_length: 4096,
732            use_case: "General".to_string(),
733            is_moe: false,
734            num_experts: None,
735            active_experts: None,
736            active_parameters: None,
737            release_date: None,
738            gguf_sources: vec![],
739            capabilities: vec![],
740            format: ModelFormat::default(),
741            num_attention_heads: None,
742            num_key_value_heads: None,
743            license: None,
744        };
745
746        // Large budget should return mlx-8bit (best in MLX hierarchy)
747        let result = model.best_quant_for_budget_with(10.0, 4096, MLX_QUANT_HIERARCHY);
748        assert!(result.is_some());
749        let (quant, _) = result.unwrap();
750        assert_eq!(quant, "mlx-8bit");
751
752        // Tighter budget should fall to mlx-4bit
753        let result = model.best_quant_for_budget_with(5.0, 4096, MLX_QUANT_HIERARCHY);
754        assert!(result.is_some());
755        let (quant, _) = result.unwrap();
756        assert_eq!(quant, "mlx-4bit");
757    }
758
759    #[test]
760    fn test_quant_bpp() {
761        assert_eq!(quant_bpp("F32"), 4.0);
762        assert_eq!(quant_bpp("F16"), 2.0);
763        assert_eq!(quant_bpp("Q8_0"), 1.05);
764        assert_eq!(quant_bpp("Q4_K_M"), 0.58);
765        assert_eq!(quant_bpp("Q2_K"), 0.37);
766        // Unknown quant defaults to Q4_K_M
767        assert_eq!(quant_bpp("UNKNOWN"), 0.58);
768    }
769
770    #[test]
771    fn test_quant_speed_multiplier() {
772        assert_eq!(quant_speed_multiplier("F16"), 0.6);
773        assert_eq!(quant_speed_multiplier("Q5_K_M"), 1.0);
774        assert_eq!(quant_speed_multiplier("Q4_K_M"), 1.15);
775        assert_eq!(quant_speed_multiplier("Q2_K"), 1.35);
776        // Lower quant = faster inference
777        assert!(quant_speed_multiplier("Q2_K") > quant_speed_multiplier("Q8_0"));
778    }
779
780    #[test]
781    fn test_quant_quality_penalty() {
782        assert_eq!(quant_quality_penalty("F16"), 0.0);
783        assert_eq!(quant_quality_penalty("Q8_0"), 0.0);
784        assert_eq!(quant_quality_penalty("Q4_K_M"), -5.0);
785        assert_eq!(quant_quality_penalty("Q2_K"), -12.0);
786        // Lower quant = higher quality penalty
787        assert!(quant_quality_penalty("Q2_K") < quant_quality_penalty("Q8_0"));
788    }
789
790    // ────────────────────────────────────────────────────────────────────
791    // LlmModel tests
792    // ────────────────────────────────────────────────────────────────────
793
794    #[test]
795    fn test_params_b_from_raw() {
796        let model = LlmModel {
797            name: "Test Model".to_string(),
798            provider: "Test".to_string(),
799            parameter_count: "7B".to_string(),
800            parameters_raw: Some(7_000_000_000),
801            min_ram_gb: 4.0,
802            recommended_ram_gb: 8.0,
803            min_vram_gb: Some(4.0),
804            quantization: "Q4_K_M".to_string(),
805            context_length: 4096,
806            use_case: "General".to_string(),
807            is_moe: false,
808            num_experts: None,
809            active_experts: None,
810            active_parameters: None,
811            release_date: None,
812            gguf_sources: vec![],
813            capabilities: vec![],
814            format: ModelFormat::default(),
815            num_attention_heads: None,
816            num_key_value_heads: None,
817            license: None,
818        };
819        assert_eq!(model.params_b(), 7.0);
820    }
821
822    #[test]
823    fn test_params_b_from_string() {
824        let model = LlmModel {
825            name: "Test Model".to_string(),
826            provider: "Test".to_string(),
827            parameter_count: "13B".to_string(),
828            parameters_raw: None,
829            min_ram_gb: 8.0,
830            recommended_ram_gb: 16.0,
831            min_vram_gb: Some(8.0),
832            quantization: "Q4_K_M".to_string(),
833            context_length: 4096,
834            use_case: "General".to_string(),
835            is_moe: false,
836            num_experts: None,
837            active_experts: None,
838            active_parameters: None,
839            release_date: None,
840            gguf_sources: vec![],
841            capabilities: vec![],
842            format: ModelFormat::default(),
843            num_attention_heads: None,
844            num_key_value_heads: None,
845            license: None,
846        };
847        assert_eq!(model.params_b(), 13.0);
848    }
849
850    #[test]
851    fn test_params_b_from_millions() {
852        let model = LlmModel {
853            name: "Test Model".to_string(),
854            provider: "Test".to_string(),
855            parameter_count: "500M".to_string(),
856            parameters_raw: None,
857            min_ram_gb: 1.0,
858            recommended_ram_gb: 2.0,
859            min_vram_gb: Some(1.0),
860            quantization: "Q4_K_M".to_string(),
861            context_length: 2048,
862            use_case: "General".to_string(),
863            is_moe: false,
864            num_experts: None,
865            active_experts: None,
866            active_parameters: None,
867            release_date: None,
868            gguf_sources: vec![],
869            capabilities: vec![],
870            format: ModelFormat::default(),
871            num_attention_heads: None,
872            num_key_value_heads: None,
873            license: None,
874        };
875        assert_eq!(model.params_b(), 0.5);
876    }
877
878    #[test]
879    fn test_estimate_memory_gb() {
880        let model = LlmModel {
881            name: "Test Model".to_string(),
882            provider: "Test".to_string(),
883            parameter_count: "7B".to_string(),
884            parameters_raw: Some(7_000_000_000),
885            min_ram_gb: 4.0,
886            recommended_ram_gb: 8.0,
887            min_vram_gb: Some(4.0),
888            quantization: "Q4_K_M".to_string(),
889            context_length: 4096,
890            use_case: "General".to_string(),
891            is_moe: false,
892            num_experts: None,
893            active_experts: None,
894            active_parameters: None,
895            release_date: None,
896            gguf_sources: vec![],
897            capabilities: vec![],
898            format: ModelFormat::default(),
899            num_attention_heads: None,
900            num_key_value_heads: None,
901            license: None,
902        };
903
904        let mem = model.estimate_memory_gb("Q4_K_M", 4096);
905        // 7B params * 0.58 bytes = 4.06 GB + KV cache + overhead
906        assert!(mem > 4.0);
907        assert!(mem < 6.0);
908
909        // Q8_0 should require more memory
910        let mem_q8 = model.estimate_memory_gb("Q8_0", 4096);
911        assert!(mem_q8 > mem);
912    }
913
914    #[test]
915    fn test_best_quant_for_budget() {
916        let model = LlmModel {
917            name: "Test Model".to_string(),
918            provider: "Test".to_string(),
919            parameter_count: "7B".to_string(),
920            parameters_raw: Some(7_000_000_000),
921            min_ram_gb: 4.0,
922            recommended_ram_gb: 8.0,
923            min_vram_gb: Some(4.0),
924            quantization: "Q4_K_M".to_string(),
925            context_length: 4096,
926            use_case: "General".to_string(),
927            is_moe: false,
928            num_experts: None,
929            active_experts: None,
930            active_parameters: None,
931            release_date: None,
932            gguf_sources: vec![],
933            capabilities: vec![],
934            format: ModelFormat::default(),
935            num_attention_heads: None,
936            num_key_value_heads: None,
937            license: None,
938        };
939
940        // Large budget should return best quant
941        let result = model.best_quant_for_budget(10.0, 4096);
942        assert!(result.is_some());
943        let (quant, _) = result.unwrap();
944        assert_eq!(quant, "Q8_0");
945
946        // Medium budget should find acceptable quant
947        let result = model.best_quant_for_budget(5.0, 4096);
948        assert!(result.is_some());
949
950        // Tiny budget should return None
951        let result = model.best_quant_for_budget(1.0, 4096);
952        assert!(result.is_none());
953    }
954
955    #[test]
956    fn test_moe_active_vram_gb() {
957        // Dense model should return None
958        let dense_model = LlmModel {
959            name: "Dense Model".to_string(),
960            provider: "Test".to_string(),
961            parameter_count: "7B".to_string(),
962            parameters_raw: Some(7_000_000_000),
963            min_ram_gb: 4.0,
964            recommended_ram_gb: 8.0,
965            min_vram_gb: Some(4.0),
966            quantization: "Q4_K_M".to_string(),
967            context_length: 4096,
968            use_case: "General".to_string(),
969            is_moe: false,
970            num_experts: None,
971            active_experts: None,
972            active_parameters: None,
973            release_date: None,
974            gguf_sources: vec![],
975            capabilities: vec![],
976            format: ModelFormat::default(),
977            num_attention_heads: None,
978            num_key_value_heads: None,
979            license: None,
980        };
981        assert!(dense_model.moe_active_vram_gb().is_none());
982
983        // MoE model should calculate active VRAM
984        let moe_model = LlmModel {
985            name: "MoE Model".to_string(),
986            provider: "Test".to_string(),
987            parameter_count: "8x7B".to_string(),
988            parameters_raw: Some(46_700_000_000),
989            min_ram_gb: 25.0,
990            recommended_ram_gb: 50.0,
991            min_vram_gb: Some(25.0),
992            quantization: "Q4_K_M".to_string(),
993            context_length: 32768,
994            use_case: "General".to_string(),
995            is_moe: true,
996            num_experts: Some(8),
997            active_experts: Some(2),
998            active_parameters: Some(12_900_000_000),
999            release_date: None,
1000            gguf_sources: vec![],
1001            capabilities: vec![],
1002            format: ModelFormat::default(),
1003            num_attention_heads: None,
1004            num_key_value_heads: None,
1005            license: None,
1006        };
1007        let vram = moe_model.moe_active_vram_gb();
1008        assert!(vram.is_some());
1009        let vram_val = vram.unwrap();
1010        // Should be significantly less than full model
1011        assert!(vram_val > 0.0);
1012        assert!(vram_val < 15.0);
1013    }
1014
1015    #[test]
1016    fn test_moe_offloaded_ram_gb() {
1017        // Dense model should return None
1018        let dense_model = LlmModel {
1019            name: "Dense Model".to_string(),
1020            provider: "Test".to_string(),
1021            parameter_count: "7B".to_string(),
1022            parameters_raw: Some(7_000_000_000),
1023            min_ram_gb: 4.0,
1024            recommended_ram_gb: 8.0,
1025            min_vram_gb: Some(4.0),
1026            quantization: "Q4_K_M".to_string(),
1027            context_length: 4096,
1028            use_case: "General".to_string(),
1029            is_moe: false,
1030            num_experts: None,
1031            active_experts: None,
1032            active_parameters: None,
1033            release_date: None,
1034            gguf_sources: vec![],
1035            capabilities: vec![],
1036            format: ModelFormat::default(),
1037            num_attention_heads: None,
1038            num_key_value_heads: None,
1039            license: None,
1040        };
1041        assert!(dense_model.moe_offloaded_ram_gb().is_none());
1042
1043        // MoE model should calculate offloaded RAM
1044        let moe_model = LlmModel {
1045            name: "MoE Model".to_string(),
1046            provider: "Test".to_string(),
1047            parameter_count: "8x7B".to_string(),
1048            parameters_raw: Some(46_700_000_000),
1049            min_ram_gb: 25.0,
1050            recommended_ram_gb: 50.0,
1051            min_vram_gb: Some(25.0),
1052            quantization: "Q4_K_M".to_string(),
1053            context_length: 32768,
1054            use_case: "General".to_string(),
1055            is_moe: true,
1056            num_experts: Some(8),
1057            active_experts: Some(2),
1058            active_parameters: Some(12_900_000_000),
1059            release_date: None,
1060            gguf_sources: vec![],
1061            capabilities: vec![],
1062            format: ModelFormat::default(),
1063            num_attention_heads: None,
1064            num_key_value_heads: None,
1065            license: None,
1066        };
1067        let offloaded = moe_model.moe_offloaded_ram_gb();
1068        assert!(offloaded.is_some());
1069        let offloaded_val = offloaded.unwrap();
1070        // Should be substantial
1071        assert!(offloaded_val > 10.0);
1072    }
1073
1074    // ────────────────────────────────────────────────────────────────────
1075    // UseCase tests
1076    // ────────────────────────────────────────────────────────────────────
1077
1078    #[test]
1079    fn test_use_case_from_model_coding() {
1080        let model = LlmModel {
1081            name: "codellama-7b".to_string(),
1082            provider: "Meta".to_string(),
1083            parameter_count: "7B".to_string(),
1084            parameters_raw: Some(7_000_000_000),
1085            min_ram_gb: 4.0,
1086            recommended_ram_gb: 8.0,
1087            min_vram_gb: Some(4.0),
1088            quantization: "Q4_K_M".to_string(),
1089            context_length: 4096,
1090            use_case: "Coding".to_string(),
1091            is_moe: false,
1092            num_experts: None,
1093            active_experts: None,
1094            active_parameters: None,
1095            release_date: None,
1096            gguf_sources: vec![],
1097            capabilities: vec![],
1098            format: ModelFormat::default(),
1099            num_attention_heads: None,
1100            num_key_value_heads: None,
1101            license: None,
1102        };
1103        assert_eq!(UseCase::from_model(&model), UseCase::Coding);
1104    }
1105
1106    #[test]
1107    fn test_use_case_from_model_embedding() {
1108        let model = LlmModel {
1109            name: "bge-large".to_string(),
1110            provider: "BAAI".to_string(),
1111            parameter_count: "335M".to_string(),
1112            parameters_raw: Some(335_000_000),
1113            min_ram_gb: 1.0,
1114            recommended_ram_gb: 2.0,
1115            min_vram_gb: Some(1.0),
1116            quantization: "F16".to_string(),
1117            context_length: 512,
1118            use_case: "Embedding".to_string(),
1119            is_moe: false,
1120            num_experts: None,
1121            active_experts: None,
1122            active_parameters: None,
1123            release_date: None,
1124            gguf_sources: vec![],
1125            capabilities: vec![],
1126            format: ModelFormat::default(),
1127            num_attention_heads: None,
1128            num_key_value_heads: None,
1129            license: None,
1130        };
1131        assert_eq!(UseCase::from_model(&model), UseCase::Embedding);
1132    }
1133
1134    #[test]
1135    fn test_use_case_from_model_reasoning() {
1136        let model = LlmModel {
1137            name: "deepseek-r1-7b".to_string(),
1138            provider: "DeepSeek".to_string(),
1139            parameter_count: "7B".to_string(),
1140            parameters_raw: Some(7_000_000_000),
1141            min_ram_gb: 4.0,
1142            recommended_ram_gb: 8.0,
1143            min_vram_gb: Some(4.0),
1144            quantization: "Q4_K_M".to_string(),
1145            context_length: 8192,
1146            use_case: "Reasoning".to_string(),
1147            is_moe: false,
1148            num_experts: None,
1149            active_experts: None,
1150            active_parameters: None,
1151            release_date: None,
1152            gguf_sources: vec![],
1153            capabilities: vec![],
1154            format: ModelFormat::default(),
1155            num_attention_heads: None,
1156            num_key_value_heads: None,
1157            license: None,
1158        };
1159        assert_eq!(UseCase::from_model(&model), UseCase::Reasoning);
1160    }
1161
1162    // ────────────────────────────────────────────────────────────────────
1163    // ModelDatabase tests
1164    // ────────────────────────────────────────────────────────────────────
1165
1166    #[test]
1167    fn test_model_database_new() {
1168        let db = ModelDatabase::new();
1169        let models = db.get_all_models();
1170        // Should have loaded models from embedded JSON
1171        assert!(!models.is_empty());
1172    }
1173
1174    #[test]
1175    fn test_find_model() {
1176        let db = ModelDatabase::new();
1177
1178        // Search by name substring (case insensitive)
1179        let results = db.find_model("llama");
1180        assert!(!results.is_empty());
1181        assert!(
1182            results
1183                .iter()
1184                .any(|m| m.name.to_lowercase().contains("llama"))
1185        );
1186
1187        // Search should be case insensitive
1188        let results_upper = db.find_model("LLAMA");
1189        assert_eq!(results.len(), results_upper.len());
1190    }
1191
1192    #[test]
1193    fn test_models_fitting_system() {
1194        let db = ModelDatabase::new();
1195
1196        // Large system should fit many models
1197        let fitting = db.models_fitting_system(32.0, true, Some(24.0));
1198        assert!(!fitting.is_empty());
1199
1200        // Very small system should fit fewer or no models
1201        let fitting_small = db.models_fitting_system(2.0, false, None);
1202        assert!(fitting_small.len() < fitting.len());
1203
1204        // All fitting models should meet RAM requirements
1205        for model in fitting_small {
1206            assert!(model.min_ram_gb <= 2.0);
1207        }
1208    }
1209
1210    // ────────────────────────────────────────────────────────────────────
1211    // Capability tests
1212    // ────────────────────────────────────────────────────────────────────
1213
1214    #[test]
1215    fn test_capability_infer_vision() {
1216        let model = LlmModel {
1217            name: "meta-llama/Llama-3.2-11B-Vision-Instruct".to_string(),
1218            provider: "Meta".to_string(),
1219            parameter_count: "11B".to_string(),
1220            parameters_raw: Some(11_000_000_000),
1221            min_ram_gb: 6.0,
1222            recommended_ram_gb: 10.0,
1223            min_vram_gb: Some(6.0),
1224            quantization: "Q4_K_M".to_string(),
1225            context_length: 131072,
1226            use_case: "Multimodal, vision and text".to_string(),
1227            is_moe: false,
1228            num_experts: None,
1229            active_experts: None,
1230            active_parameters: None,
1231            release_date: None,
1232            gguf_sources: vec![],
1233            capabilities: vec![],
1234            format: ModelFormat::default(),
1235            num_attention_heads: None,
1236            num_key_value_heads: None,
1237            license: None,
1238        };
1239        let caps = Capability::infer(&model);
1240        assert!(caps.contains(&Capability::Vision));
1241        // Also gets ToolUse because "llama-3" + "instruct"
1242        assert!(caps.contains(&Capability::ToolUse));
1243    }
1244
1245    #[test]
1246    fn test_capability_infer_tool_use() {
1247        let model = LlmModel {
1248            name: "Qwen/Qwen3-8B".to_string(),
1249            provider: "Qwen".to_string(),
1250            parameter_count: "8B".to_string(),
1251            parameters_raw: Some(8_000_000_000),
1252            min_ram_gb: 4.5,
1253            recommended_ram_gb: 8.0,
1254            min_vram_gb: Some(4.0),
1255            quantization: "Q4_K_M".to_string(),
1256            context_length: 32768,
1257            use_case: "General purpose text generation".to_string(),
1258            is_moe: false,
1259            num_experts: None,
1260            active_experts: None,
1261            active_parameters: None,
1262            release_date: None,
1263            gguf_sources: vec![],
1264            capabilities: vec![],
1265            format: ModelFormat::default(),
1266            num_attention_heads: None,
1267            num_key_value_heads: None,
1268            license: None,
1269        };
1270        let caps = Capability::infer(&model);
1271        assert!(caps.contains(&Capability::ToolUse));
1272        assert!(!caps.contains(&Capability::Vision));
1273    }
1274
1275    #[test]
1276    fn test_capability_infer_none() {
1277        let model = LlmModel {
1278            name: "BAAI/bge-large-en-v1.5".to_string(),
1279            provider: "BAAI".to_string(),
1280            parameter_count: "335M".to_string(),
1281            parameters_raw: Some(335_000_000),
1282            min_ram_gb: 1.0,
1283            recommended_ram_gb: 2.0,
1284            min_vram_gb: Some(1.0),
1285            quantization: "F16".to_string(),
1286            context_length: 512,
1287            use_case: "Text embeddings for RAG".to_string(),
1288            is_moe: false,
1289            num_experts: None,
1290            active_experts: None,
1291            active_parameters: None,
1292            release_date: None,
1293            gguf_sources: vec![],
1294            capabilities: vec![],
1295            format: ModelFormat::default(),
1296            num_attention_heads: None,
1297            num_key_value_heads: None,
1298            license: None,
1299        };
1300        let caps = Capability::infer(&model);
1301        assert!(caps.is_empty());
1302    }
1303
1304    #[test]
1305    fn test_capability_preserves_explicit() {
1306        let model = LlmModel {
1307            name: "some-model".to_string(),
1308            provider: "Test".to_string(),
1309            parameter_count: "7B".to_string(),
1310            parameters_raw: Some(7_000_000_000),
1311            min_ram_gb: 4.0,
1312            recommended_ram_gb: 8.0,
1313            min_vram_gb: Some(4.0),
1314            quantization: "Q4_K_M".to_string(),
1315            context_length: 4096,
1316            use_case: "General".to_string(),
1317            is_moe: false,
1318            num_experts: None,
1319            active_experts: None,
1320            active_parameters: None,
1321            release_date: None,
1322            gguf_sources: vec![],
1323            capabilities: vec![Capability::Vision],
1324            format: ModelFormat::default(),
1325            num_attention_heads: None,
1326            num_key_value_heads: None,
1327            license: None,
1328        };
1329        let caps = Capability::infer(&model);
1330        // Should keep the explicit Vision and not duplicate it
1331        assert_eq!(caps.iter().filter(|c| **c == Capability::Vision).count(), 1);
1332    }
1333
1334    #[test]
1335    fn test_awq_gptq_quant_values() {
1336        // AWQ
1337        assert_eq!(quant_bpp("AWQ-4bit"), 0.5);
1338        assert_eq!(quant_bpp("AWQ-8bit"), 1.0);
1339        assert_eq!(quant_speed_multiplier("AWQ-4bit"), 1.2);
1340        assert_eq!(quant_speed_multiplier("AWQ-8bit"), 0.85);
1341        assert_eq!(quant_quality_penalty("AWQ-4bit"), -3.0);
1342        assert_eq!(quant_quality_penalty("AWQ-8bit"), 0.0);
1343        // GPTQ
1344        assert_eq!(quant_bpp("GPTQ-Int4"), 0.5);
1345        assert_eq!(quant_bpp("GPTQ-Int8"), 1.0);
1346        assert_eq!(quant_speed_multiplier("GPTQ-Int4"), 1.2);
1347        assert_eq!(quant_speed_multiplier("GPTQ-Int8"), 0.85);
1348        assert_eq!(quant_quality_penalty("GPTQ-Int4"), -3.0);
1349        assert_eq!(quant_quality_penalty("GPTQ-Int8"), 0.0);
1350    }
1351
1352    #[test]
1353    fn test_model_format_prequantized() {
1354        assert!(ModelFormat::Awq.is_prequantized());
1355        assert!(ModelFormat::Gptq.is_prequantized());
1356        assert!(!ModelFormat::Gguf.is_prequantized());
1357        assert!(!ModelFormat::Mlx.is_prequantized());
1358        assert!(!ModelFormat::Safetensors.is_prequantized());
1359    }
1360
1361    // ────────────────────────────────────────────────────────────────────
1362    // GGUF source catalog tests
1363    // ────────────────────────────────────────────────────────────────────
1364
1365    #[test]
1366    fn test_gguf_source_deserialization() {
1367        let json = r#"{"repo": "unsloth/Llama-3.1-8B-Instruct-GGUF", "provider": "unsloth"}"#;
1368        let source: GgufSource = serde_json::from_str(json).unwrap();
1369        assert_eq!(source.repo, "unsloth/Llama-3.1-8B-Instruct-GGUF");
1370        assert_eq!(source.provider, "unsloth");
1371    }
1372
1373    #[test]
1374    fn test_gguf_sources_default_to_empty() {
1375        let json = r#"{
1376            "name": "test/model",
1377            "provider": "Test",
1378            "parameter_count": "7B",
1379            "parameters_raw": 7000000000,
1380            "min_ram_gb": 4.0,
1381            "recommended_ram_gb": 8.0,
1382            "quantization": "Q4_K_M",
1383            "context_length": 4096,
1384            "use_case": "General"
1385        }"#;
1386        let entry: HfModelEntry = serde_json::from_str(json).unwrap();
1387        assert!(entry.gguf_sources.is_empty());
1388    }
1389
1390    #[test]
1391    fn test_catalog_popular_models_have_gguf_sources() {
1392        let db = ModelDatabase::new();
1393        // These popular models should have gguf_sources populated in the catalog
1394        let expected_with_gguf = [
1395            "meta-llama/Llama-3.3-70B-Instruct",
1396            "Qwen/Qwen2.5-7B-Instruct",
1397            "Qwen/Qwen2.5-Coder-7B-Instruct",
1398            "meta-llama/Meta-Llama-3-8B-Instruct",
1399            "mistralai/Mistral-7B-Instruct-v0.3",
1400        ];
1401        for name in &expected_with_gguf {
1402            let model = db.get_all_models().iter().find(|m| m.name == *name);
1403            assert!(model.is_some(), "Model {} should exist in catalog", name);
1404            let model = model.unwrap();
1405            assert!(
1406                !model.gguf_sources.is_empty(),
1407                "Model {} should have gguf_sources but has none",
1408                name
1409            );
1410        }
1411    }
1412
1413    #[test]
1414    fn test_catalog_gguf_sources_have_valid_repos() {
1415        let db = ModelDatabase::new();
1416        for model in db.get_all_models() {
1417            for source in &model.gguf_sources {
1418                assert!(
1419                    source.repo.contains('/'),
1420                    "GGUF source repo '{}' for model '{}' should be owner/repo format",
1421                    source.repo,
1422                    model.name
1423                );
1424                assert!(
1425                    !source.provider.is_empty(),
1426                    "GGUF source provider for model '{}' should not be empty",
1427                    model.name
1428                );
1429                assert!(
1430                    source.repo.to_uppercase().contains("GGUF"),
1431                    "GGUF source repo '{}' for model '{}' should contain 'GGUF'",
1432                    source.repo,
1433                    model.name
1434                );
1435            }
1436        }
1437    }
1438
1439    #[test]
1440    #[ignore] // Requires network access to populate GGUF sources at build time
1441    fn test_catalog_has_significant_gguf_coverage() {
1442        let db = ModelDatabase::new();
1443        let total = db.get_all_models().len();
1444        let with_gguf = db
1445            .get_all_models()
1446            .iter()
1447            .filter(|m| !m.gguf_sources.is_empty())
1448            .count();
1449        // We should have at least 25% coverage after enrichment
1450        let coverage_pct = (with_gguf as f64 / total as f64) * 100.0;
1451        assert!(
1452            coverage_pct >= 25.0,
1453            "GGUF source coverage is only {:.1}% ({}/{}), expected at least 25%",
1454            coverage_pct,
1455            with_gguf,
1456            total
1457        );
1458    }
1459
1460    // ────────────────────────────────────────────────────────────────────
1461    // Tensor parallelism tests
1462    // ────────────────────────────────────────────────────────────────────
1463
1464    fn tp_test_model(
1465        name: &str,
1466        params_b: f64,
1467        attn_heads: Option<u32>,
1468        kv_heads: Option<u32>,
1469    ) -> LlmModel {
1470        LlmModel {
1471            name: name.to_string(),
1472            provider: "Test".to_string(),
1473            parameter_count: format!("{:.0}B", params_b),
1474            parameters_raw: Some((params_b * 1_000_000_000.0) as u64),
1475            min_ram_gb: 4.0,
1476            recommended_ram_gb: 8.0,
1477            min_vram_gb: Some(4.0),
1478            quantization: "Q4_K_M".to_string(),
1479            context_length: 4096,
1480            use_case: "General".to_string(),
1481            is_moe: false,
1482            num_experts: None,
1483            active_experts: None,
1484            active_parameters: None,
1485            release_date: None,
1486            gguf_sources: vec![],
1487            capabilities: vec![],
1488            format: ModelFormat::default(),
1489            num_attention_heads: attn_heads,
1490            num_key_value_heads: kv_heads,
1491            license: None,
1492        }
1493    }
1494
1495    #[test]
1496    fn test_supports_tp_with_explicit_heads() {
1497        let model = tp_test_model("Test-8B", 8.0, Some(32), Some(8));
1498        assert!(model.supports_tp(1));
1499        assert!(model.supports_tp(2));
1500        assert!(model.supports_tp(4));
1501        assert!(model.supports_tp(8));
1502        assert!(!model.supports_tp(3)); // 32 % 3 != 0
1503        assert!(!model.supports_tp(5));
1504    }
1505
1506    #[test]
1507    fn test_supports_tp_always_true_for_1() {
1508        let model = tp_test_model("Tiny", 1.0, None, None);
1509        assert!(model.supports_tp(1));
1510    }
1511
1512    #[test]
1513    fn test_valid_tp_sizes_32_8() {
1514        let model = tp_test_model("Test", 8.0, Some(32), Some(8));
1515        let sizes = model.valid_tp_sizes();
1516        assert!(sizes.contains(&1));
1517        assert!(sizes.contains(&2));
1518        assert!(sizes.contains(&4));
1519        assert!(sizes.contains(&8));
1520        assert!(!sizes.contains(&3));
1521    }
1522
1523    #[test]
1524    fn test_valid_tp_sizes_48_heads() {
1525        // 48 attn heads, 8 kv heads — TP must divide both
1526        let model = tp_test_model("Llama-32B", 32.0, Some(48), Some(8));
1527        assert!(model.supports_tp(2)); // 48%2==0, 8%2==0
1528        assert!(!model.supports_tp(3)); // 48%3==0 but 8%3!=0
1529        assert!(model.supports_tp(4)); // 48%4==0, 8%4==0
1530        assert!(model.supports_tp(8)); // 48%8==0, 8%8==0
1531    }
1532
1533    #[test]
1534    fn test_infer_heads_from_name_qwen() {
1535        let (attn, kv) = infer_heads_from_name("Qwen2.5-72B-Instruct", 72.0);
1536        assert_eq!(attn, 64);
1537        assert_eq!(kv, 8);
1538    }
1539
1540    #[test]
1541    fn test_infer_heads_from_name_llama() {
1542        let (attn, kv) = infer_heads_from_name("Llama-3.1-8B", 8.0);
1543        assert_eq!(attn, 32);
1544        assert_eq!(kv, 8);
1545    }
1546
1547    #[test]
1548    fn test_infer_heads_from_name_deepseek() {
1549        let (attn, kv) = infer_heads_from_name("DeepSeek-V3", 671.0);
1550        assert_eq!(attn, 128);
1551        assert_eq!(kv, 16);
1552    }
1553
1554    #[test]
1555    fn test_supports_tp_with_inferred_heads() {
1556        // No explicit heads — should infer from name
1557        let model = tp_test_model("Llama-3.1-70B", 70.0, None, None);
1558        assert!(model.supports_tp(2));
1559        assert!(model.supports_tp(4));
1560        assert!(model.supports_tp(8));
1561    }
1562}