Skip to main content

llmfit_core/
models.rs

1use serde::{Deserialize, Serialize};
2
3/// Quantization levels ordered from best quality to most compressed.
4/// Used for dynamic quantization selection: try the best that fits.
5pub const QUANT_HIERARCHY: &[&str] = &["Q8_0", "Q6_K", "Q5_K_M", "Q4_K_M", "Q3_K_M", "Q2_K"];
6
7/// MLX-native quantization hierarchy (best quality to most compressed).
8pub const MLX_QUANT_HIERARCHY: &[&str] = &["mlx-8bit", "mlx-4bit"];
9
10/// Bytes per parameter for each quantization level.
11pub fn quant_bpp(quant: &str) -> f64 {
12    match quant {
13        "F32" => 4.0,
14        "F16" | "BF16" => 2.0,
15        "Q8_0" => 1.05,
16        "Q6_K" => 0.80,
17        "Q5_K_M" => 0.68,
18        "Q4_K_M" | "Q4_0" => 0.58,
19        "Q3_K_M" => 0.48,
20        "Q2_K" => 0.37,
21        "mlx-4bit" => 0.55,
22        "mlx-8bit" => 1.0,
23        "AWQ-4bit" => 0.5,
24        "AWQ-8bit" => 1.0,
25        "GPTQ-Int4" => 0.5,
26        "GPTQ-Int8" => 1.0,
27        _ => 0.58,
28    }
29}
30
31/// Speed multiplier for quantization (lower quant = faster inference).
32pub fn quant_speed_multiplier(quant: &str) -> f64 {
33    match quant {
34        "F16" | "BF16" => 0.6,
35        "Q8_0" => 0.8,
36        "Q6_K" => 0.95,
37        "Q5_K_M" => 1.0,
38        "Q4_K_M" | "Q4_0" => 1.15,
39        "Q3_K_M" => 1.25,
40        "Q2_K" => 1.35,
41        "mlx-4bit" => 1.15,
42        "mlx-8bit" => 0.85,
43        "AWQ-4bit" | "GPTQ-Int4" => 1.2,
44        "AWQ-8bit" | "GPTQ-Int8" => 0.85,
45        _ => 1.0,
46    }
47}
48
49/// Bytes per parameter for a given quantization format.
50/// Used by the bandwidth-based tok/s estimator to compute model size in GB.
51pub fn quant_bytes_per_param(quant: &str) -> f64 {
52    match quant {
53        "F16" | "BF16" => 2.0,
54        "Q8_0" => 1.0,
55        "Q6_K" => 0.75,
56        "Q5_K_M" => 0.625,
57        "Q4_K_M" | "Q4_0" => 0.5,
58        "Q3_K_M" => 0.375,
59        "Q2_K" => 0.25,
60        "mlx-4bit" => 0.5,
61        "mlx-8bit" => 1.0,
62        "AWQ-4bit" | "GPTQ-Int4" => 0.5,
63        "AWQ-8bit" | "GPTQ-Int8" => 1.0,
64        _ => 0.5, // default to ~4-bit
65    }
66}
67
68/// Quality penalty for quantization (lower quant = lower quality).
69pub fn quant_quality_penalty(quant: &str) -> f64 {
70    match quant {
71        "F16" | "BF16" => 0.0,
72        "Q8_0" => 0.0,
73        "Q6_K" => -1.0,
74        "Q5_K_M" => -2.0,
75        "Q4_K_M" | "Q4_0" => -5.0,
76        "Q3_K_M" => -8.0,
77        "Q2_K" => -12.0,
78        "mlx-4bit" => -4.0,
79        "mlx-8bit" => 0.0,
80        "AWQ-4bit" => -3.0,
81        "AWQ-8bit" => 0.0,
82        "GPTQ-Int4" => -3.0,
83        "GPTQ-Int8" => 0.0,
84        _ => -5.0,
85    }
86}
87
88/// Model capability flags (orthogonal to UseCase).
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
90#[serde(rename_all = "snake_case")]
91pub enum Capability {
92    Vision,
93    ToolUse,
94}
95
96impl Capability {
97    pub fn label(&self) -> &'static str {
98        match self {
99            Capability::Vision => "Vision",
100            Capability::ToolUse => "Tool Use",
101        }
102    }
103
104    pub fn all() -> &'static [Capability] {
105        &[Capability::Vision, Capability::ToolUse]
106    }
107
108    /// Infer capabilities from model metadata when not explicitly set in JSON.
109    pub fn infer(model: &LlmModel) -> Vec<Capability> {
110        let mut caps = model.capabilities.clone();
111        let name = model.name.to_lowercase();
112        let use_case = model.use_case.to_lowercase();
113
114        // Vision detection
115        if !caps.contains(&Capability::Vision)
116            && (name.contains("vision")
117                || name.contains("-vl-")
118                || name.ends_with("-vl")
119                || name.contains("llava")
120                || name.contains("onevision")
121                || name.contains("pixtral")
122                || use_case.contains("vision")
123                || use_case.contains("multimodal"))
124        {
125            caps.push(Capability::Vision);
126        }
127
128        // Tool use detection (known model families)
129        if !caps.contains(&Capability::ToolUse)
130            && (use_case.contains("tool")
131                || use_case.contains("function call")
132                || name.contains("qwen3")
133                || name.contains("qwen2.5")
134                || name.contains("command-r")
135                || (name.contains("llama-3") && name.contains("instruct"))
136                || (name.contains("mistral") && name.contains("instruct"))
137                || name.contains("hermes")
138                || (name.contains("gemma-3") && name.ends_with("-it"))
139                || (name.contains("gemma-4") && name.ends_with("-it")))
140        {
141            caps.push(Capability::ToolUse);
142        }
143
144        caps
145    }
146}
147
148/// Model weight format — determines which inference runtime to use.
149#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
150#[serde(rename_all = "lowercase")]
151#[derive(Default)]
152pub enum ModelFormat {
153    #[default]
154    Gguf,
155    Awq,
156    Gptq,
157    Mlx,
158    Safetensors,
159}
160
161impl ModelFormat {
162    /// Returns true for formats that are pre-quantized at a fixed bit width
163    /// and cannot be dynamically re-quantized (AWQ, GPTQ).
164    pub fn is_prequantized(&self) -> bool {
165        matches!(self, ModelFormat::Awq | ModelFormat::Gptq)
166    }
167}
168
169/// Use-case category for scoring weights.
170#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
171pub enum UseCase {
172    General,
173    Coding,
174    Reasoning,
175    Chat,
176    Multimodal,
177    Embedding,
178}
179
180impl UseCase {
181    pub fn label(&self) -> &'static str {
182        match self {
183            UseCase::General => "General",
184            UseCase::Coding => "Coding",
185            UseCase::Reasoning => "Reasoning",
186            UseCase::Chat => "Chat",
187            UseCase::Multimodal => "Multimodal",
188            UseCase::Embedding => "Embedding",
189        }
190    }
191
192    /// Infer use-case from the model's use_case field and name.
193    pub fn from_model(model: &LlmModel) -> Self {
194        let name = model.name.to_lowercase();
195        let use_case = model.use_case.to_lowercase();
196
197        if use_case.contains("embedding") || name.contains("embed") || name.contains("bge") {
198            UseCase::Embedding
199        } else if name.contains("code") || use_case.contains("code") {
200            UseCase::Coding
201        } else if use_case.contains("vision") || use_case.contains("multimodal") {
202            UseCase::Multimodal
203        } else if use_case.contains("reason")
204            || use_case.contains("chain-of-thought")
205            || name.contains("deepseek-r1")
206        {
207            UseCase::Reasoning
208        } else if use_case.contains("chat") || use_case.contains("instruction") {
209            UseCase::Chat
210        } else {
211            UseCase::General
212        }
213    }
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct LlmModel {
218    pub name: String,
219    pub provider: String,
220    pub parameter_count: String,
221    #[serde(default)]
222    pub parameters_raw: Option<u64>,
223    pub min_ram_gb: f64,
224    pub recommended_ram_gb: f64,
225    pub min_vram_gb: Option<f64>,
226    pub quantization: String,
227    pub context_length: u32,
228    pub use_case: String,
229    #[serde(default)]
230    pub is_moe: bool,
231    #[serde(default)]
232    pub num_experts: Option<u32>,
233    #[serde(default)]
234    pub active_experts: Option<u32>,
235    #[serde(default)]
236    pub active_parameters: Option<u64>,
237    #[serde(default)]
238    pub release_date: Option<String>,
239    /// Known GGUF download sources (e.g. unsloth, bartowski repos on HuggingFace)
240    #[serde(default)]
241    pub gguf_sources: Vec<GgufSource>,
242    /// Model capabilities (vision, tool use, etc.)
243    #[serde(default)]
244    pub capabilities: Vec<Capability>,
245    /// Model weight format (gguf, awq, gptq, mlx, safetensors)
246    #[serde(default)]
247    pub format: ModelFormat,
248    /// Number of attention heads (for tensor-parallelism compatibility checks).
249    #[serde(default)]
250    pub num_attention_heads: Option<u32>,
251    /// Number of key-value heads for GQA (defaults to num_attention_heads if None).
252    #[serde(default)]
253    pub num_key_value_heads: Option<u32>,
254    /// Model license (e.g. "apache-2.0", "mit", "llama3.1")
255    #[serde(default)]
256    pub license: Option<String>,
257}
258
259/// Returns true if a model's license matches any in the comma-separated filter string.
260/// Models without a license never match.
261pub fn matches_license_filter(license: &Option<String>, filter: &str) -> bool {
262    let allowed: Vec<String> = filter.split(',').map(|s| s.trim().to_lowercase()).collect();
263    license
264        .as_ref()
265        .map(|l| allowed.contains(&l.to_lowercase()))
266        .unwrap_or(false)
267}
268
269/// A known GGUF download source for a model on HuggingFace.
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct GgufSource {
272    /// HuggingFace repo ID (e.g. "unsloth/Llama-3.1-8B-Instruct-GGUF")
273    pub repo: String,
274    /// Provider who published the GGUF (e.g. "unsloth", "bartowski")
275    pub provider: String,
276}
277
278impl LlmModel {
279    /// MLX models are Apple-only — they won't run on NVIDIA/AMD/Intel hardware.
280    /// We detect them by the `-MLX-` suffix that's standard on HuggingFace
281    /// (e.g. `Qwen3-8B-MLX-4bit`, `LFM2-1.2B-MLX-8bit`).
282    pub fn is_mlx_model(&self) -> bool {
283        let name_lower = self.name.to_lowercase();
284        name_lower.contains("-mlx-") || name_lower.ends_with("-mlx")
285    }
286
287    /// Returns true if this model uses a pre-quantized format (AWQ/GPTQ)
288    /// that cannot be dynamically re-quantized.
289    pub fn is_prequantized(&self) -> bool {
290        self.format.is_prequantized()
291    }
292
293    /// Returns true if the model's attention/KV heads are evenly divisible
294    /// by `tp_size`, meaning it can be split across that many devices.
295    /// TP=1 always returns true.
296    pub fn supports_tp(&self, tp_size: u32) -> bool {
297        if tp_size <= 1 {
298            return true;
299        }
300        let (attn, kv) = self.infer_head_counts();
301        attn % tp_size == 0 && kv % tp_size == 0
302    }
303
304    /// Returns all valid TP degrees in [1..=8] for this model.
305    pub fn valid_tp_sizes(&self) -> Vec<u32> {
306        (1..=8).filter(|&tp| self.supports_tp(tp)).collect()
307    }
308
309    /// Infer attention and KV head counts from metadata or model name heuristics.
310    fn infer_head_counts(&self) -> (u32, u32) {
311        if let (Some(attn), Some(kv)) = (self.num_attention_heads, self.num_key_value_heads) {
312            return (attn, kv);
313        }
314        if let Some(attn) = self.num_attention_heads {
315            return (attn, attn);
316        }
317        // Heuristic: infer from model name
318        infer_heads_from_name(&self.name, self.params_b())
319    }
320
321    /// Bytes-per-parameter for the model's quantization level.
322    fn quant_bpp(&self) -> f64 {
323        quant_bpp(&self.quantization)
324    }
325
326    /// Parameter count in billions, extracted from parameters_raw or parameter_count.
327    pub fn params_b(&self) -> f64 {
328        if let Some(raw) = self.parameters_raw {
329            raw as f64 / 1_000_000_000.0
330        } else {
331            // Parse from string like "7B", "1.1B", "137M"
332            let s = self.parameter_count.trim().to_uppercase();
333            if let Some(num_str) = s.strip_suffix('B') {
334                num_str.parse::<f64>().unwrap_or(7.0)
335            } else if let Some(num_str) = s.strip_suffix('M') {
336                num_str.parse::<f64>().unwrap_or(0.0) / 1000.0
337            } else {
338                7.0
339            }
340        }
341    }
342
343    /// Estimate memory required (GB) at a given quantization and context length.
344    /// Formula: model_weights + KV_cache + runtime_overhead
345    pub fn estimate_memory_gb(&self, quant: &str, ctx: u32) -> f64 {
346        let bpp = quant_bpp(quant);
347        let params = self.params_b();
348        let model_mem = params * bpp;
349        // KV cache: ~0.000008 GB per billion params per context token
350        let kv_cache = 0.000008 * params * ctx as f64;
351        // Runtime overhead (CUDA/Metal context, buffers)
352        let overhead = 0.5;
353        model_mem + kv_cache + overhead
354    }
355
356    /// Select the best quantization level that fits within a memory budget.
357    /// Returns the quant name and estimated memory in GB, or None if nothing fits.
358    pub fn best_quant_for_budget(&self, budget_gb: f64, ctx: u32) -> Option<(&'static str, f64)> {
359        self.best_quant_for_budget_with(budget_gb, ctx, QUANT_HIERARCHY)
360    }
361
362    /// Select the best quantization from a custom hierarchy that fits within a memory budget.
363    pub fn best_quant_for_budget_with(
364        &self,
365        budget_gb: f64,
366        ctx: u32,
367        hierarchy: &[&'static str],
368    ) -> Option<(&'static str, f64)> {
369        // Try best quality first
370        for &q in hierarchy {
371            let mem = self.estimate_memory_gb(q, ctx);
372            if mem <= budget_gb {
373                return Some((q, mem));
374            }
375        }
376        // Try halving context once
377        let half_ctx = ctx / 2;
378        if half_ctx >= 1024 {
379            for &q in hierarchy {
380                let mem = self.estimate_memory_gb(q, half_ctx);
381                if mem <= budget_gb {
382                    return Some((q, mem));
383                }
384            }
385        }
386        None
387    }
388
389    /// For MoE models, compute estimated VRAM for active experts only.
390    /// Returns None for dense models.
391    pub fn moe_active_vram_gb(&self) -> Option<f64> {
392        if !self.is_moe {
393            return None;
394        }
395        let active_params = self.active_parameters? as f64;
396        let bpp = self.quant_bpp();
397        let size_gb = (active_params * bpp) / (1024.0 * 1024.0 * 1024.0);
398        Some((size_gb * 1.1).max(0.5))
399    }
400
401    /// Returns true if this model is MLX-specific (Apple Silicon only).
402    /// MLX models are identified by having "-MLX" in their name.
403    pub fn is_mlx_only(&self) -> bool {
404        self.name.to_uppercase().contains("-MLX")
405    }
406
407    /// For MoE models, compute RAM needed for offloaded (inactive) experts.
408    /// Returns None for dense models.
409    pub fn moe_offloaded_ram_gb(&self) -> Option<f64> {
410        if !self.is_moe {
411            return None;
412        }
413        let active = self.active_parameters? as f64;
414        let total = self.parameters_raw? as f64;
415        let inactive = total - active;
416        if inactive <= 0.0 {
417            return Some(0.0);
418        }
419        let bpp = self.quant_bpp();
420        Some((inactive * bpp) / (1024.0 * 1024.0 * 1024.0))
421    }
422}
423
424/// Intermediate struct matching the JSON schema from the scraper.
425/// Extra fields are ignored when mapping to LlmModel.
426#[derive(Debug, Clone, Deserialize)]
427struct HfModelEntry {
428    name: String,
429    provider: String,
430    parameter_count: String,
431    #[serde(default)]
432    parameters_raw: Option<u64>,
433    min_ram_gb: f64,
434    recommended_ram_gb: f64,
435    min_vram_gb: Option<f64>,
436    quantization: String,
437    context_length: u32,
438    use_case: String,
439    #[serde(default)]
440    is_moe: bool,
441    #[serde(default)]
442    num_experts: Option<u32>,
443    #[serde(default)]
444    active_experts: Option<u32>,
445    #[serde(default)]
446    active_parameters: Option<u64>,
447    #[serde(default)]
448    release_date: Option<String>,
449    #[serde(default)]
450    gguf_sources: Vec<GgufSource>,
451    #[serde(default)]
452    capabilities: Vec<Capability>,
453    #[serde(default)]
454    format: ModelFormat,
455    #[serde(default)]
456    hf_downloads: u64,
457    #[serde(default)]
458    hf_likes: u64,
459    #[serde(default)]
460    license: Option<String>,
461}
462
463const HF_MODELS_JSON: &str = include_str!("../data/hf_models.json");
464
465pub struct ModelDatabase {
466    models: Vec<LlmModel>,
467}
468
469impl Default for ModelDatabase {
470    fn default() -> Self {
471        Self::new()
472    }
473}
474
475/// Normalize a model name/ID to a canonical slug for deduplication.
476///
477/// Strips the `org/` prefix, lowercases, and collapses `-`/`_`/`.` so that
478/// `meta-llama/Llama-3.1-8B` and `meta-llama/llama-3.1-8b` compare equal.
479pub(crate) fn canonical_slug(name: &str) -> String {
480    let slug = name.split('/').next_back().unwrap_or(name);
481    slug.to_lowercase().replace(['-', '_', '.'], "")
482}
483
484/// Parse the compile-time embedded JSON into a flat `Vec<LlmModel>`.
485fn load_embedded() -> Vec<LlmModel> {
486    let entries: Vec<HfModelEntry> =
487        serde_json::from_str(HF_MODELS_JSON).expect("Failed to parse embedded hf_models.json");
488    entries
489        .into_iter()
490        .map(|e| {
491            let mut model = LlmModel {
492                name: e.name,
493                provider: e.provider,
494                parameter_count: e.parameter_count,
495                parameters_raw: e.parameters_raw,
496                min_ram_gb: e.min_ram_gb,
497                recommended_ram_gb: e.recommended_ram_gb,
498                min_vram_gb: e.min_vram_gb,
499                quantization: e.quantization,
500                context_length: e.context_length,
501                use_case: e.use_case,
502                is_moe: e.is_moe,
503                num_experts: e.num_experts,
504                active_experts: e.active_experts,
505                active_parameters: e.active_parameters,
506                release_date: e.release_date,
507                gguf_sources: e.gguf_sources,
508                capabilities: e.capabilities,
509                format: e.format,
510                num_attention_heads: None,
511                num_key_value_heads: None,
512                license: e.license,
513            };
514            model.capabilities = Capability::infer(&model);
515            model
516        })
517        .collect()
518}
519
520impl ModelDatabase {
521    /// Load only the compile-time embedded model list (no cache).
522    /// Used internally by the updater to determine which models are already known.
523    pub fn embedded() -> Self {
524        ModelDatabase {
525            models: load_embedded(),
526        }
527    }
528
529    /// Load the embedded model list **and** merge any locally cached models.
530    ///
531    /// Cached models are appended after the embedded ones; if an ID already
532    /// exists in the embedded list it is skipped to avoid duplication.
533    /// Silently ignores a missing or corrupt cache file.
534    pub fn new() -> Self {
535        let mut models = load_embedded();
536
537        // Merge cached models (from `llmfit update`) without duplicating.
538        // canonical_slug normalizes org/ prefix, case, and separators so that
539        // e.g. `meta-llama/Llama-3.1-8B` and `meta-llama/llama-3.1-8b` are
540        // treated as the same model.
541        let embedded_keys: std::collections::HashSet<String> =
542            models.iter().map(|m| canonical_slug(&m.name)).collect();
543
544        for cached in crate::update::load_cache() {
545            if !embedded_keys.contains(&canonical_slug(&cached.name)) {
546                models.push(cached);
547            }
548        }
549
550        ModelDatabase { models }
551    }
552
553    pub fn get_all_models(&self) -> &Vec<LlmModel> {
554        &self.models
555    }
556
557    pub fn find_model(&self, query: &str) -> Vec<&LlmModel> {
558        let query_lower = query.to_lowercase();
559        self.models
560            .iter()
561            .filter(|m| {
562                m.name.to_lowercase().contains(&query_lower)
563                    || m.provider.to_lowercase().contains(&query_lower)
564                    || m.parameter_count.to_lowercase().contains(&query_lower)
565            })
566            .collect()
567    }
568
569    pub fn models_fitting_system(
570        &self,
571        available_ram_gb: f64,
572        has_gpu: bool,
573        vram_gb: Option<f64>,
574    ) -> Vec<&LlmModel> {
575        self.models
576            .iter()
577            .filter(|m| {
578                // Check RAM requirement
579                let ram_ok = m.min_ram_gb <= available_ram_gb;
580
581                // If model requires GPU and system has GPU, check VRAM
582                if let Some(min_vram) = m.min_vram_gb {
583                    if has_gpu {
584                        if let Some(system_vram) = vram_gb {
585                            ram_ok && min_vram <= system_vram
586                        } else {
587                            // GPU detected but VRAM unknown, allow but warn
588                            ram_ok
589                        }
590                    } else {
591                        // Model prefers GPU but can run on CPU with enough RAM
592                        ram_ok && available_ram_gb >= m.recommended_ram_gb
593                    }
594                } else {
595                    ram_ok
596                }
597            })
598            .collect()
599    }
600}
601
602/// Infer attention and KV head counts from the model name and parameter count.
603/// Used as a fallback when explicit head counts are not available in the model metadata.
604fn infer_heads_from_name(name: &str, params_b: f64) -> (u32, u32) {
605    let name_lower = name.to_lowercase();
606
607    // Qwen family
608    if name_lower.contains("qwen") {
609        if params_b > 100.0 {
610            return (128, 16);
611        } else if params_b > 50.0 {
612            return (64, 8);
613        } else if params_b > 25.0 {
614            return (40, 8);
615        } else if params_b > 10.0 {
616            return (40, 8);
617        } else if params_b > 5.0 {
618            return (32, 8);
619        } else {
620            return (16, 4);
621        }
622    }
623
624    // Llama family
625    if name_lower.contains("llama") {
626        if name_lower.contains("scout") || name_lower.contains("maverick") {
627            return (64, 8);
628        } else if params_b > 60.0 {
629            return (64, 8);
630        } else if params_b > 20.0 {
631            return (48, 8);
632        } else if params_b > 5.0 {
633            return (32, 8);
634        } else {
635            return (16, 8);
636        }
637    }
638
639    // DeepSeek family
640    if name_lower.contains("deepseek") {
641        if params_b > 200.0 {
642            return (128, 16);
643        } else if params_b > 50.0 {
644            return (64, 8);
645        } else if params_b > 25.0 {
646            return (40, 8);
647        } else if params_b > 10.0 {
648            return (40, 8);
649        } else {
650            return (32, 8);
651        }
652    }
653
654    // Mistral/Mixtral
655    if name_lower.contains("mistral") || name_lower.contains("mixtral") {
656        if params_b > 100.0 {
657            return (96, 8);
658        } else if params_b > 20.0 {
659            return (32, 8);
660        } else {
661            return (32, 8);
662        }
663    }
664
665    // Gemma
666    if name_lower.contains("gemma") {
667        if params_b > 20.0 {
668            return (32, 16);
669        } else if params_b > 5.0 {
670            return (16, 8);
671        } else {
672            return (8, 4);
673        }
674    }
675
676    // Phi
677    if name_lower.contains("phi") {
678        if params_b > 10.0 {
679            return (40, 10);
680        } else {
681            return (32, 8);
682        }
683    }
684
685    // MiniMax
686    if name_lower.contains("minimax") {
687        return (48, 8);
688    }
689
690    // Default: common pattern based on param count
691    if params_b > 100.0 {
692        (128, 16)
693    } else if params_b > 50.0 {
694        (64, 8)
695    } else if params_b > 20.0 {
696        (32, 8)
697    } else if params_b > 5.0 {
698        (32, 8)
699    } else {
700        (16, 4)
701    }
702}
703
704#[cfg(test)]
705mod tests {
706    use super::*;
707
708    // ────────────────────────────────────────────────────────────────────
709    // Quantization function tests
710    // ────────────────────────────────────────────────────────────────────
711
712    #[test]
713    fn test_mlx_quant_bpp_values() {
714        assert_eq!(quant_bpp("mlx-4bit"), 0.55);
715        assert_eq!(quant_bpp("mlx-8bit"), 1.0);
716        assert_eq!(quant_speed_multiplier("mlx-4bit"), 1.15);
717        assert_eq!(quant_speed_multiplier("mlx-8bit"), 0.85);
718        assert_eq!(quant_quality_penalty("mlx-4bit"), -4.0);
719        assert_eq!(quant_quality_penalty("mlx-8bit"), 0.0);
720    }
721
722    #[test]
723    fn test_best_quant_with_mlx_hierarchy() {
724        let model = LlmModel {
725            name: "Test Model".to_string(),
726            provider: "Test".to_string(),
727            parameter_count: "7B".to_string(),
728            parameters_raw: Some(7_000_000_000),
729            min_ram_gb: 4.0,
730            recommended_ram_gb: 8.0,
731            min_vram_gb: Some(4.0),
732            quantization: "Q4_K_M".to_string(),
733            context_length: 4096,
734            use_case: "General".to_string(),
735            is_moe: false,
736            num_experts: None,
737            active_experts: None,
738            active_parameters: None,
739            release_date: None,
740            gguf_sources: vec![],
741            capabilities: vec![],
742            format: ModelFormat::default(),
743            num_attention_heads: None,
744            num_key_value_heads: None,
745            license: None,
746        };
747
748        // Large budget should return mlx-8bit (best in MLX hierarchy)
749        let result = model.best_quant_for_budget_with(10.0, 4096, MLX_QUANT_HIERARCHY);
750        assert!(result.is_some());
751        let (quant, _) = result.unwrap();
752        assert_eq!(quant, "mlx-8bit");
753
754        // Tighter budget should fall to mlx-4bit
755        let result = model.best_quant_for_budget_with(5.0, 4096, MLX_QUANT_HIERARCHY);
756        assert!(result.is_some());
757        let (quant, _) = result.unwrap();
758        assert_eq!(quant, "mlx-4bit");
759    }
760
761    #[test]
762    fn test_quant_bpp() {
763        assert_eq!(quant_bpp("F32"), 4.0);
764        assert_eq!(quant_bpp("F16"), 2.0);
765        assert_eq!(quant_bpp("Q8_0"), 1.05);
766        assert_eq!(quant_bpp("Q4_K_M"), 0.58);
767        assert_eq!(quant_bpp("Q2_K"), 0.37);
768        // Unknown quant defaults to Q4_K_M
769        assert_eq!(quant_bpp("UNKNOWN"), 0.58);
770    }
771
772    #[test]
773    fn test_quant_speed_multiplier() {
774        assert_eq!(quant_speed_multiplier("F16"), 0.6);
775        assert_eq!(quant_speed_multiplier("Q5_K_M"), 1.0);
776        assert_eq!(quant_speed_multiplier("Q4_K_M"), 1.15);
777        assert_eq!(quant_speed_multiplier("Q2_K"), 1.35);
778        // Lower quant = faster inference
779        assert!(quant_speed_multiplier("Q2_K") > quant_speed_multiplier("Q8_0"));
780    }
781
782    #[test]
783    fn test_quant_quality_penalty() {
784        assert_eq!(quant_quality_penalty("F16"), 0.0);
785        assert_eq!(quant_quality_penalty("Q8_0"), 0.0);
786        assert_eq!(quant_quality_penalty("Q4_K_M"), -5.0);
787        assert_eq!(quant_quality_penalty("Q2_K"), -12.0);
788        // Lower quant = higher quality penalty
789        assert!(quant_quality_penalty("Q2_K") < quant_quality_penalty("Q8_0"));
790    }
791
792    // ────────────────────────────────────────────────────────────────────
793    // LlmModel tests
794    // ────────────────────────────────────────────────────────────────────
795
796    #[test]
797    fn test_params_b_from_raw() {
798        let model = LlmModel {
799            name: "Test Model".to_string(),
800            provider: "Test".to_string(),
801            parameter_count: "7B".to_string(),
802            parameters_raw: Some(7_000_000_000),
803            min_ram_gb: 4.0,
804            recommended_ram_gb: 8.0,
805            min_vram_gb: Some(4.0),
806            quantization: "Q4_K_M".to_string(),
807            context_length: 4096,
808            use_case: "General".to_string(),
809            is_moe: false,
810            num_experts: None,
811            active_experts: None,
812            active_parameters: None,
813            release_date: None,
814            gguf_sources: vec![],
815            capabilities: vec![],
816            format: ModelFormat::default(),
817            num_attention_heads: None,
818            num_key_value_heads: None,
819            license: None,
820        };
821        assert_eq!(model.params_b(), 7.0);
822    }
823
824    #[test]
825    fn test_params_b_from_string() {
826        let model = LlmModel {
827            name: "Test Model".to_string(),
828            provider: "Test".to_string(),
829            parameter_count: "13B".to_string(),
830            parameters_raw: None,
831            min_ram_gb: 8.0,
832            recommended_ram_gb: 16.0,
833            min_vram_gb: Some(8.0),
834            quantization: "Q4_K_M".to_string(),
835            context_length: 4096,
836            use_case: "General".to_string(),
837            is_moe: false,
838            num_experts: None,
839            active_experts: None,
840            active_parameters: None,
841            release_date: None,
842            gguf_sources: vec![],
843            capabilities: vec![],
844            format: ModelFormat::default(),
845            num_attention_heads: None,
846            num_key_value_heads: None,
847            license: None,
848        };
849        assert_eq!(model.params_b(), 13.0);
850    }
851
852    #[test]
853    fn test_params_b_from_millions() {
854        let model = LlmModel {
855            name: "Test Model".to_string(),
856            provider: "Test".to_string(),
857            parameter_count: "500M".to_string(),
858            parameters_raw: None,
859            min_ram_gb: 1.0,
860            recommended_ram_gb: 2.0,
861            min_vram_gb: Some(1.0),
862            quantization: "Q4_K_M".to_string(),
863            context_length: 2048,
864            use_case: "General".to_string(),
865            is_moe: false,
866            num_experts: None,
867            active_experts: None,
868            active_parameters: None,
869            release_date: None,
870            gguf_sources: vec![],
871            capabilities: vec![],
872            format: ModelFormat::default(),
873            num_attention_heads: None,
874            num_key_value_heads: None,
875            license: None,
876        };
877        assert_eq!(model.params_b(), 0.5);
878    }
879
880    #[test]
881    fn test_estimate_memory_gb() {
882        let model = LlmModel {
883            name: "Test Model".to_string(),
884            provider: "Test".to_string(),
885            parameter_count: "7B".to_string(),
886            parameters_raw: Some(7_000_000_000),
887            min_ram_gb: 4.0,
888            recommended_ram_gb: 8.0,
889            min_vram_gb: Some(4.0),
890            quantization: "Q4_K_M".to_string(),
891            context_length: 4096,
892            use_case: "General".to_string(),
893            is_moe: false,
894            num_experts: None,
895            active_experts: None,
896            active_parameters: None,
897            release_date: None,
898            gguf_sources: vec![],
899            capabilities: vec![],
900            format: ModelFormat::default(),
901            num_attention_heads: None,
902            num_key_value_heads: None,
903            license: None,
904        };
905
906        let mem = model.estimate_memory_gb("Q4_K_M", 4096);
907        // 7B params * 0.58 bytes = 4.06 GB + KV cache + overhead
908        assert!(mem > 4.0);
909        assert!(mem < 6.0);
910
911        // Q8_0 should require more memory
912        let mem_q8 = model.estimate_memory_gb("Q8_0", 4096);
913        assert!(mem_q8 > mem);
914    }
915
916    #[test]
917    fn test_best_quant_for_budget() {
918        let model = LlmModel {
919            name: "Test Model".to_string(),
920            provider: "Test".to_string(),
921            parameter_count: "7B".to_string(),
922            parameters_raw: Some(7_000_000_000),
923            min_ram_gb: 4.0,
924            recommended_ram_gb: 8.0,
925            min_vram_gb: Some(4.0),
926            quantization: "Q4_K_M".to_string(),
927            context_length: 4096,
928            use_case: "General".to_string(),
929            is_moe: false,
930            num_experts: None,
931            active_experts: None,
932            active_parameters: None,
933            release_date: None,
934            gguf_sources: vec![],
935            capabilities: vec![],
936            format: ModelFormat::default(),
937            num_attention_heads: None,
938            num_key_value_heads: None,
939            license: None,
940        };
941
942        // Large budget should return best quant
943        let result = model.best_quant_for_budget(10.0, 4096);
944        assert!(result.is_some());
945        let (quant, _) = result.unwrap();
946        assert_eq!(quant, "Q8_0");
947
948        // Medium budget should find acceptable quant
949        let result = model.best_quant_for_budget(5.0, 4096);
950        assert!(result.is_some());
951
952        // Tiny budget should return None
953        let result = model.best_quant_for_budget(1.0, 4096);
954        assert!(result.is_none());
955    }
956
957    #[test]
958    fn test_moe_active_vram_gb() {
959        // Dense model should return None
960        let dense_model = LlmModel {
961            name: "Dense Model".to_string(),
962            provider: "Test".to_string(),
963            parameter_count: "7B".to_string(),
964            parameters_raw: Some(7_000_000_000),
965            min_ram_gb: 4.0,
966            recommended_ram_gb: 8.0,
967            min_vram_gb: Some(4.0),
968            quantization: "Q4_K_M".to_string(),
969            context_length: 4096,
970            use_case: "General".to_string(),
971            is_moe: false,
972            num_experts: None,
973            active_experts: None,
974            active_parameters: None,
975            release_date: None,
976            gguf_sources: vec![],
977            capabilities: vec![],
978            format: ModelFormat::default(),
979            num_attention_heads: None,
980            num_key_value_heads: None,
981            license: None,
982        };
983        assert!(dense_model.moe_active_vram_gb().is_none());
984
985        // MoE model should calculate active VRAM
986        let moe_model = LlmModel {
987            name: "MoE Model".to_string(),
988            provider: "Test".to_string(),
989            parameter_count: "8x7B".to_string(),
990            parameters_raw: Some(46_700_000_000),
991            min_ram_gb: 25.0,
992            recommended_ram_gb: 50.0,
993            min_vram_gb: Some(25.0),
994            quantization: "Q4_K_M".to_string(),
995            context_length: 32768,
996            use_case: "General".to_string(),
997            is_moe: true,
998            num_experts: Some(8),
999            active_experts: Some(2),
1000            active_parameters: Some(12_900_000_000),
1001            release_date: None,
1002            gguf_sources: vec![],
1003            capabilities: vec![],
1004            format: ModelFormat::default(),
1005            num_attention_heads: None,
1006            num_key_value_heads: None,
1007            license: None,
1008        };
1009        let vram = moe_model.moe_active_vram_gb();
1010        assert!(vram.is_some());
1011        let vram_val = vram.unwrap();
1012        // Should be significantly less than full model
1013        assert!(vram_val > 0.0);
1014        assert!(vram_val < 15.0);
1015    }
1016
1017    #[test]
1018    fn test_moe_offloaded_ram_gb() {
1019        // Dense model should return None
1020        let dense_model = LlmModel {
1021            name: "Dense Model".to_string(),
1022            provider: "Test".to_string(),
1023            parameter_count: "7B".to_string(),
1024            parameters_raw: Some(7_000_000_000),
1025            min_ram_gb: 4.0,
1026            recommended_ram_gb: 8.0,
1027            min_vram_gb: Some(4.0),
1028            quantization: "Q4_K_M".to_string(),
1029            context_length: 4096,
1030            use_case: "General".to_string(),
1031            is_moe: false,
1032            num_experts: None,
1033            active_experts: None,
1034            active_parameters: None,
1035            release_date: None,
1036            gguf_sources: vec![],
1037            capabilities: vec![],
1038            format: ModelFormat::default(),
1039            num_attention_heads: None,
1040            num_key_value_heads: None,
1041            license: None,
1042        };
1043        assert!(dense_model.moe_offloaded_ram_gb().is_none());
1044
1045        // MoE model should calculate offloaded RAM
1046        let moe_model = LlmModel {
1047            name: "MoE Model".to_string(),
1048            provider: "Test".to_string(),
1049            parameter_count: "8x7B".to_string(),
1050            parameters_raw: Some(46_700_000_000),
1051            min_ram_gb: 25.0,
1052            recommended_ram_gb: 50.0,
1053            min_vram_gb: Some(25.0),
1054            quantization: "Q4_K_M".to_string(),
1055            context_length: 32768,
1056            use_case: "General".to_string(),
1057            is_moe: true,
1058            num_experts: Some(8),
1059            active_experts: Some(2),
1060            active_parameters: Some(12_900_000_000),
1061            release_date: None,
1062            gguf_sources: vec![],
1063            capabilities: vec![],
1064            format: ModelFormat::default(),
1065            num_attention_heads: None,
1066            num_key_value_heads: None,
1067            license: None,
1068        };
1069        let offloaded = moe_model.moe_offloaded_ram_gb();
1070        assert!(offloaded.is_some());
1071        let offloaded_val = offloaded.unwrap();
1072        // Should be substantial
1073        assert!(offloaded_val > 10.0);
1074    }
1075
1076    // ────────────────────────────────────────────────────────────────────
1077    // UseCase tests
1078    // ────────────────────────────────────────────────────────────────────
1079
1080    #[test]
1081    fn test_use_case_from_model_coding() {
1082        let model = LlmModel {
1083            name: "codellama-7b".to_string(),
1084            provider: "Meta".to_string(),
1085            parameter_count: "7B".to_string(),
1086            parameters_raw: Some(7_000_000_000),
1087            min_ram_gb: 4.0,
1088            recommended_ram_gb: 8.0,
1089            min_vram_gb: Some(4.0),
1090            quantization: "Q4_K_M".to_string(),
1091            context_length: 4096,
1092            use_case: "Coding".to_string(),
1093            is_moe: false,
1094            num_experts: None,
1095            active_experts: None,
1096            active_parameters: None,
1097            release_date: None,
1098            gguf_sources: vec![],
1099            capabilities: vec![],
1100            format: ModelFormat::default(),
1101            num_attention_heads: None,
1102            num_key_value_heads: None,
1103            license: None,
1104        };
1105        assert_eq!(UseCase::from_model(&model), UseCase::Coding);
1106    }
1107
1108    #[test]
1109    fn test_use_case_from_model_embedding() {
1110        let model = LlmModel {
1111            name: "bge-large".to_string(),
1112            provider: "BAAI".to_string(),
1113            parameter_count: "335M".to_string(),
1114            parameters_raw: Some(335_000_000),
1115            min_ram_gb: 1.0,
1116            recommended_ram_gb: 2.0,
1117            min_vram_gb: Some(1.0),
1118            quantization: "F16".to_string(),
1119            context_length: 512,
1120            use_case: "Embedding".to_string(),
1121            is_moe: false,
1122            num_experts: None,
1123            active_experts: None,
1124            active_parameters: None,
1125            release_date: None,
1126            gguf_sources: vec![],
1127            capabilities: vec![],
1128            format: ModelFormat::default(),
1129            num_attention_heads: None,
1130            num_key_value_heads: None,
1131            license: None,
1132        };
1133        assert_eq!(UseCase::from_model(&model), UseCase::Embedding);
1134    }
1135
1136    #[test]
1137    fn test_use_case_from_model_reasoning() {
1138        let model = LlmModel {
1139            name: "deepseek-r1-7b".to_string(),
1140            provider: "DeepSeek".to_string(),
1141            parameter_count: "7B".to_string(),
1142            parameters_raw: Some(7_000_000_000),
1143            min_ram_gb: 4.0,
1144            recommended_ram_gb: 8.0,
1145            min_vram_gb: Some(4.0),
1146            quantization: "Q4_K_M".to_string(),
1147            context_length: 8192,
1148            use_case: "Reasoning".to_string(),
1149            is_moe: false,
1150            num_experts: None,
1151            active_experts: None,
1152            active_parameters: None,
1153            release_date: None,
1154            gguf_sources: vec![],
1155            capabilities: vec![],
1156            format: ModelFormat::default(),
1157            num_attention_heads: None,
1158            num_key_value_heads: None,
1159            license: None,
1160        };
1161        assert_eq!(UseCase::from_model(&model), UseCase::Reasoning);
1162    }
1163
1164    // ────────────────────────────────────────────────────────────────────
1165    // ModelDatabase tests
1166    // ────────────────────────────────────────────────────────────────────
1167
1168    #[test]
1169    fn test_model_database_new() {
1170        let db = ModelDatabase::new();
1171        let models = db.get_all_models();
1172        // Should have loaded models from embedded JSON
1173        assert!(!models.is_empty());
1174    }
1175
1176    #[test]
1177    fn test_find_model() {
1178        let db = ModelDatabase::new();
1179
1180        // Search by name substring (case insensitive)
1181        let results = db.find_model("llama");
1182        assert!(!results.is_empty());
1183        assert!(
1184            results
1185                .iter()
1186                .any(|m| m.name.to_lowercase().contains("llama"))
1187        );
1188
1189        // Search should be case insensitive
1190        let results_upper = db.find_model("LLAMA");
1191        assert_eq!(results.len(), results_upper.len());
1192    }
1193
1194    #[test]
1195    fn test_models_fitting_system() {
1196        let db = ModelDatabase::new();
1197
1198        // Large system should fit many models
1199        let fitting = db.models_fitting_system(32.0, true, Some(24.0));
1200        assert!(!fitting.is_empty());
1201
1202        // Very small system should fit fewer or no models
1203        let fitting_small = db.models_fitting_system(2.0, false, None);
1204        assert!(fitting_small.len() < fitting.len());
1205
1206        // All fitting models should meet RAM requirements
1207        for model in fitting_small {
1208            assert!(model.min_ram_gb <= 2.0);
1209        }
1210    }
1211
1212    // ────────────────────────────────────────────────────────────────────
1213    // Capability tests
1214    // ────────────────────────────────────────────────────────────────────
1215
1216    #[test]
1217    fn test_capability_infer_vision() {
1218        let model = LlmModel {
1219            name: "meta-llama/Llama-3.2-11B-Vision-Instruct".to_string(),
1220            provider: "Meta".to_string(),
1221            parameter_count: "11B".to_string(),
1222            parameters_raw: Some(11_000_000_000),
1223            min_ram_gb: 6.0,
1224            recommended_ram_gb: 10.0,
1225            min_vram_gb: Some(6.0),
1226            quantization: "Q4_K_M".to_string(),
1227            context_length: 131072,
1228            use_case: "Multimodal, vision and text".to_string(),
1229            is_moe: false,
1230            num_experts: None,
1231            active_experts: None,
1232            active_parameters: None,
1233            release_date: None,
1234            gguf_sources: vec![],
1235            capabilities: vec![],
1236            format: ModelFormat::default(),
1237            num_attention_heads: None,
1238            num_key_value_heads: None,
1239            license: None,
1240        };
1241        let caps = Capability::infer(&model);
1242        assert!(caps.contains(&Capability::Vision));
1243        // Also gets ToolUse because "llama-3" + "instruct"
1244        assert!(caps.contains(&Capability::ToolUse));
1245    }
1246
1247    #[test]
1248    fn test_capability_infer_tool_use() {
1249        let model = LlmModel {
1250            name: "Qwen/Qwen3-8B".to_string(),
1251            provider: "Qwen".to_string(),
1252            parameter_count: "8B".to_string(),
1253            parameters_raw: Some(8_000_000_000),
1254            min_ram_gb: 4.5,
1255            recommended_ram_gb: 8.0,
1256            min_vram_gb: Some(4.0),
1257            quantization: "Q4_K_M".to_string(),
1258            context_length: 32768,
1259            use_case: "General purpose text generation".to_string(),
1260            is_moe: false,
1261            num_experts: None,
1262            active_experts: None,
1263            active_parameters: None,
1264            release_date: None,
1265            gguf_sources: vec![],
1266            capabilities: vec![],
1267            format: ModelFormat::default(),
1268            num_attention_heads: None,
1269            num_key_value_heads: None,
1270            license: None,
1271        };
1272        let caps = Capability::infer(&model);
1273        assert!(caps.contains(&Capability::ToolUse));
1274        assert!(!caps.contains(&Capability::Vision));
1275    }
1276
1277    #[test]
1278    fn test_capability_infer_none() {
1279        let model = LlmModel {
1280            name: "BAAI/bge-large-en-v1.5".to_string(),
1281            provider: "BAAI".to_string(),
1282            parameter_count: "335M".to_string(),
1283            parameters_raw: Some(335_000_000),
1284            min_ram_gb: 1.0,
1285            recommended_ram_gb: 2.0,
1286            min_vram_gb: Some(1.0),
1287            quantization: "F16".to_string(),
1288            context_length: 512,
1289            use_case: "Text embeddings for RAG".to_string(),
1290            is_moe: false,
1291            num_experts: None,
1292            active_experts: None,
1293            active_parameters: None,
1294            release_date: None,
1295            gguf_sources: vec![],
1296            capabilities: vec![],
1297            format: ModelFormat::default(),
1298            num_attention_heads: None,
1299            num_key_value_heads: None,
1300            license: None,
1301        };
1302        let caps = Capability::infer(&model);
1303        assert!(caps.is_empty());
1304    }
1305
1306    #[test]
1307    fn test_capability_preserves_explicit() {
1308        let model = LlmModel {
1309            name: "some-model".to_string(),
1310            provider: "Test".to_string(),
1311            parameter_count: "7B".to_string(),
1312            parameters_raw: Some(7_000_000_000),
1313            min_ram_gb: 4.0,
1314            recommended_ram_gb: 8.0,
1315            min_vram_gb: Some(4.0),
1316            quantization: "Q4_K_M".to_string(),
1317            context_length: 4096,
1318            use_case: "General".to_string(),
1319            is_moe: false,
1320            num_experts: None,
1321            active_experts: None,
1322            active_parameters: None,
1323            release_date: None,
1324            gguf_sources: vec![],
1325            capabilities: vec![Capability::Vision],
1326            format: ModelFormat::default(),
1327            num_attention_heads: None,
1328            num_key_value_heads: None,
1329            license: None,
1330        };
1331        let caps = Capability::infer(&model);
1332        // Should keep the explicit Vision and not duplicate it
1333        assert_eq!(caps.iter().filter(|c| **c == Capability::Vision).count(), 1);
1334    }
1335
1336    #[test]
1337    fn test_awq_gptq_quant_values() {
1338        // AWQ
1339        assert_eq!(quant_bpp("AWQ-4bit"), 0.5);
1340        assert_eq!(quant_bpp("AWQ-8bit"), 1.0);
1341        assert_eq!(quant_speed_multiplier("AWQ-4bit"), 1.2);
1342        assert_eq!(quant_speed_multiplier("AWQ-8bit"), 0.85);
1343        assert_eq!(quant_quality_penalty("AWQ-4bit"), -3.0);
1344        assert_eq!(quant_quality_penalty("AWQ-8bit"), 0.0);
1345        // GPTQ
1346        assert_eq!(quant_bpp("GPTQ-Int4"), 0.5);
1347        assert_eq!(quant_bpp("GPTQ-Int8"), 1.0);
1348        assert_eq!(quant_speed_multiplier("GPTQ-Int4"), 1.2);
1349        assert_eq!(quant_speed_multiplier("GPTQ-Int8"), 0.85);
1350        assert_eq!(quant_quality_penalty("GPTQ-Int4"), -3.0);
1351        assert_eq!(quant_quality_penalty("GPTQ-Int8"), 0.0);
1352    }
1353
1354    #[test]
1355    fn test_model_format_prequantized() {
1356        assert!(ModelFormat::Awq.is_prequantized());
1357        assert!(ModelFormat::Gptq.is_prequantized());
1358        assert!(!ModelFormat::Gguf.is_prequantized());
1359        assert!(!ModelFormat::Mlx.is_prequantized());
1360        assert!(!ModelFormat::Safetensors.is_prequantized());
1361    }
1362
1363    // ────────────────────────────────────────────────────────────────────
1364    // GGUF source catalog tests
1365    // ────────────────────────────────────────────────────────────────────
1366
1367    #[test]
1368    fn test_gguf_source_deserialization() {
1369        let json = r#"{"repo": "unsloth/Llama-3.1-8B-Instruct-GGUF", "provider": "unsloth"}"#;
1370        let source: GgufSource = serde_json::from_str(json).unwrap();
1371        assert_eq!(source.repo, "unsloth/Llama-3.1-8B-Instruct-GGUF");
1372        assert_eq!(source.provider, "unsloth");
1373    }
1374
1375    #[test]
1376    fn test_gguf_sources_default_to_empty() {
1377        let json = r#"{
1378            "name": "test/model",
1379            "provider": "Test",
1380            "parameter_count": "7B",
1381            "parameters_raw": 7000000000,
1382            "min_ram_gb": 4.0,
1383            "recommended_ram_gb": 8.0,
1384            "quantization": "Q4_K_M",
1385            "context_length": 4096,
1386            "use_case": "General"
1387        }"#;
1388        let entry: HfModelEntry = serde_json::from_str(json).unwrap();
1389        assert!(entry.gguf_sources.is_empty());
1390    }
1391
1392    #[test]
1393    fn test_catalog_popular_models_have_gguf_sources() {
1394        let db = ModelDatabase::new();
1395        // These popular models should have gguf_sources populated in the catalog
1396        let expected_with_gguf = [
1397            "meta-llama/Llama-3.3-70B-Instruct",
1398            "Qwen/Qwen2.5-7B-Instruct",
1399            "Qwen/Qwen2.5-Coder-7B-Instruct",
1400            "meta-llama/Meta-Llama-3-8B-Instruct",
1401            "mistralai/Mistral-7B-Instruct-v0.3",
1402        ];
1403        for name in &expected_with_gguf {
1404            let model = db.get_all_models().iter().find(|m| m.name == *name);
1405            assert!(model.is_some(), "Model {} should exist in catalog", name);
1406            let model = model.unwrap();
1407            assert!(
1408                !model.gguf_sources.is_empty(),
1409                "Model {} should have gguf_sources but has none",
1410                name
1411            );
1412        }
1413    }
1414
1415    #[test]
1416    fn test_catalog_gguf_sources_have_valid_repos() {
1417        let db = ModelDatabase::new();
1418        for model in db.get_all_models() {
1419            for source in &model.gguf_sources {
1420                assert!(
1421                    source.repo.contains('/'),
1422                    "GGUF source repo '{}' for model '{}' should be owner/repo format",
1423                    source.repo,
1424                    model.name
1425                );
1426                assert!(
1427                    !source.provider.is_empty(),
1428                    "GGUF source provider for model '{}' should not be empty",
1429                    model.name
1430                );
1431                assert!(
1432                    source.repo.to_uppercase().contains("GGUF"),
1433                    "GGUF source repo '{}' for model '{}' should contain 'GGUF'",
1434                    source.repo,
1435                    model.name
1436                );
1437            }
1438        }
1439    }
1440
1441    #[test]
1442    #[ignore] // Requires network access to populate GGUF sources at build time
1443    fn test_catalog_has_significant_gguf_coverage() {
1444        let db = ModelDatabase::new();
1445        let total = db.get_all_models().len();
1446        let with_gguf = db
1447            .get_all_models()
1448            .iter()
1449            .filter(|m| !m.gguf_sources.is_empty())
1450            .count();
1451        // We should have at least 25% coverage after enrichment
1452        let coverage_pct = (with_gguf as f64 / total as f64) * 100.0;
1453        assert!(
1454            coverage_pct >= 25.0,
1455            "GGUF source coverage is only {:.1}% ({}/{}), expected at least 25%",
1456            coverage_pct,
1457            with_gguf,
1458            total
1459        );
1460    }
1461
1462    // ────────────────────────────────────────────────────────────────────
1463    // Tensor parallelism tests
1464    // ────────────────────────────────────────────────────────────────────
1465
1466    fn tp_test_model(
1467        name: &str,
1468        params_b: f64,
1469        attn_heads: Option<u32>,
1470        kv_heads: Option<u32>,
1471    ) -> LlmModel {
1472        LlmModel {
1473            name: name.to_string(),
1474            provider: "Test".to_string(),
1475            parameter_count: format!("{:.0}B", params_b),
1476            parameters_raw: Some((params_b * 1_000_000_000.0) as u64),
1477            min_ram_gb: 4.0,
1478            recommended_ram_gb: 8.0,
1479            min_vram_gb: Some(4.0),
1480            quantization: "Q4_K_M".to_string(),
1481            context_length: 4096,
1482            use_case: "General".to_string(),
1483            is_moe: false,
1484            num_experts: None,
1485            active_experts: None,
1486            active_parameters: None,
1487            release_date: None,
1488            gguf_sources: vec![],
1489            capabilities: vec![],
1490            format: ModelFormat::default(),
1491            num_attention_heads: attn_heads,
1492            num_key_value_heads: kv_heads,
1493            license: None,
1494        }
1495    }
1496
1497    #[test]
1498    fn test_supports_tp_with_explicit_heads() {
1499        let model = tp_test_model("Test-8B", 8.0, Some(32), Some(8));
1500        assert!(model.supports_tp(1));
1501        assert!(model.supports_tp(2));
1502        assert!(model.supports_tp(4));
1503        assert!(model.supports_tp(8));
1504        assert!(!model.supports_tp(3)); // 32 % 3 != 0
1505        assert!(!model.supports_tp(5));
1506    }
1507
1508    #[test]
1509    fn test_supports_tp_always_true_for_1() {
1510        let model = tp_test_model("Tiny", 1.0, None, None);
1511        assert!(model.supports_tp(1));
1512    }
1513
1514    #[test]
1515    fn test_valid_tp_sizes_32_8() {
1516        let model = tp_test_model("Test", 8.0, Some(32), Some(8));
1517        let sizes = model.valid_tp_sizes();
1518        assert!(sizes.contains(&1));
1519        assert!(sizes.contains(&2));
1520        assert!(sizes.contains(&4));
1521        assert!(sizes.contains(&8));
1522        assert!(!sizes.contains(&3));
1523    }
1524
1525    #[test]
1526    fn test_valid_tp_sizes_48_heads() {
1527        // 48 attn heads, 8 kv heads — TP must divide both
1528        let model = tp_test_model("Llama-32B", 32.0, Some(48), Some(8));
1529        assert!(model.supports_tp(2)); // 48%2==0, 8%2==0
1530        assert!(!model.supports_tp(3)); // 48%3==0 but 8%3!=0
1531        assert!(model.supports_tp(4)); // 48%4==0, 8%4==0
1532        assert!(model.supports_tp(8)); // 48%8==0, 8%8==0
1533    }
1534
1535    #[test]
1536    fn test_infer_heads_from_name_qwen() {
1537        let (attn, kv) = infer_heads_from_name("Qwen2.5-72B-Instruct", 72.0);
1538        assert_eq!(attn, 64);
1539        assert_eq!(kv, 8);
1540    }
1541
1542    #[test]
1543    fn test_infer_heads_from_name_llama() {
1544        let (attn, kv) = infer_heads_from_name("Llama-3.1-8B", 8.0);
1545        assert_eq!(attn, 32);
1546        assert_eq!(kv, 8);
1547    }
1548
1549    #[test]
1550    fn test_infer_heads_from_name_deepseek() {
1551        let (attn, kv) = infer_heads_from_name("DeepSeek-V3", 671.0);
1552        assert_eq!(attn, 128);
1553        assert_eq!(kv, 16);
1554    }
1555
1556    #[test]
1557    fn test_supports_tp_with_inferred_heads() {
1558        // No explicit heads — should infer from name
1559        let model = tp_test_model("Llama-3.1-70B", 70.0, None, None);
1560        assert!(model.supports_tp(2));
1561        assert!(model.supports_tp(4));
1562        assert!(model.supports_tp(8));
1563    }
1564}