Skip to main content

llmfit_core/
models.rs

1use serde::{Deserialize, Serialize};
2
3/// Quantization levels ordered from best quality to most compressed.
4/// Used for dynamic quantization selection: try the best that fits.
5pub const QUANT_HIERARCHY: &[&str] = &["Q8_0", "Q6_K", "Q5_K_M", "Q4_K_M", "Q3_K_M", "Q2_K"];
6
7/// MLX-native quantization hierarchy (best quality to most compressed).
8pub const MLX_QUANT_HIERARCHY: &[&str] = &["mlx-8bit", "mlx-4bit"];
9
10/// Bytes per parameter for each quantization level.
11pub fn quant_bpp(quant: &str) -> f64 {
12    match quant {
13        "F32" => 4.0,
14        "F16" | "BF16" => 2.0,
15        "Q8_0" => 1.05,
16        "Q6_K" => 0.80,
17        "Q5_K_M" => 0.68,
18        "Q4_K_M" | "Q4_0" => 0.58,
19        "Q3_K_M" => 0.48,
20        "Q2_K" => 0.37,
21        "UD-Q2_K_XL" | "UD-Q2_K_L" | "UD-Q2_K_M" | "UD-Q2_K_S" => 0.37,
22        "UD-Q3_K_XL" | "UD-Q3_K_L" | "UD-Q3_K_M" | "UD-Q3_K_S" => 0.48,
23        "UD-Q4_K_XL" | "UD-Q4_K_L" | "UD-Q4_K_M" | "UD-Q4_K_S" => 0.58,
24        "UD-Q5_K_XL" | "UD-Q5_K_L" | "UD-Q5_K_M" | "UD-Q5_K_S" => 0.68,
25        "UD-Q6_K_XL" | "UD-Q6_K_L" | "UD-Q6_K_M" | "UD-Q6_K_S" => 0.80,
26        "UD-Q8_K_XL" | "UD-Q8_K_L" | "UD-Q8_K_M" | "UD-Q8_K_S" => 1.05,
27        "mlx-4bit" => 0.55,
28        "mlx-8bit" => 1.0,
29        "AWQ-4bit" => 0.5,
30        "AWQ-8bit" => 1.0,
31        "GPTQ-Int4" => 0.5,
32        "GPTQ-Int8" => 1.0,
33        _ => 0.58,
34    }
35}
36
37/// Speed multiplier for quantization (lower quant = faster inference).
38pub fn quant_speed_multiplier(quant: &str) -> f64 {
39    match quant {
40        "F16" | "BF16" => 0.6,
41        "Q8_0" => 0.8,
42        "Q6_K" => 0.95,
43        "Q5_K_M" => 1.0,
44        "Q4_K_M" | "Q4_0" => 1.15,
45        "Q3_K_M" => 1.25,
46        "Q2_K" => 1.35,
47        "UD-Q2_K_XL" | "UD-Q2_K_L" | "UD-Q2_K_M" | "UD-Q2_K_S" => 1.35,
48        "UD-Q3_K_XL" | "UD-Q3_K_L" | "UD-Q3_K_M" | "UD-Q3_K_S" => 1.25,
49        "UD-Q4_K_XL" | "UD-Q4_K_L" | "UD-Q4_K_M" | "UD-Q4_K_S" => 1.15,
50        "UD-Q5_K_XL" | "UD-Q5_K_L" | "UD-Q5_K_M" | "UD-Q5_K_S" => 1.0,
51        "UD-Q6_K_XL" | "UD-Q6_K_L" | "UD-Q6_K_M" | "UD-Q6_K_S" => 0.95,
52        "UD-Q8_K_XL" | "UD-Q8_K_L" | "UD-Q8_K_M" | "UD-Q8_K_S" => 0.8,
53        "mlx-4bit" => 1.15,
54        "mlx-8bit" => 0.85,
55        "AWQ-4bit" | "GPTQ-Int4" => 1.2,
56        "AWQ-8bit" | "GPTQ-Int8" => 0.85,
57        _ => 1.0,
58    }
59}
60
61/// Bytes per parameter for a given quantization format.
62/// Used by the bandwidth-based tok/s estimator to compute model size in GB.
63pub fn quant_bytes_per_param(quant: &str) -> f64 {
64    match quant {
65        "F16" | "BF16" => 2.0,
66        "Q8_0" => 1.0,
67        "Q6_K" => 0.75,
68        "Q5_K_M" => 0.625,
69        "Q4_K_M" | "Q4_0" => 0.5,
70        "Q3_K_M" => 0.375,
71        "Q2_K" => 0.25,
72        "UD-Q2_K_XL" | "UD-Q2_K_L" | "UD-Q2_K_M" | "UD-Q2_K_S" => 0.25,
73        "UD-Q3_K_XL" | "UD-Q3_K_L" | "UD-Q3_K_M" | "UD-Q3_K_S" => 0.375,
74        "UD-Q4_K_XL" | "UD-Q4_K_L" | "UD-Q4_K_M" | "UD-Q4_K_S" => 0.5,
75        "UD-Q5_K_XL" | "UD-Q5_K_L" | "UD-Q5_K_M" | "UD-Q5_K_S" => 0.625,
76        "UD-Q6_K_XL" | "UD-Q6_K_L" | "UD-Q6_K_M" | "UD-Q6_K_S" => 0.75,
77        "UD-Q8_K_XL" | "UD-Q8_K_L" | "UD-Q8_K_M" | "UD-Q8_K_S" => 1.0,
78        "mlx-4bit" => 0.5,
79        "mlx-8bit" => 1.0,
80        "AWQ-4bit" | "GPTQ-Int4" => 0.5,
81        "AWQ-8bit" | "GPTQ-Int8" => 1.0,
82        _ => 0.5, // default to ~4-bit
83    }
84}
85
86/// Quality penalty for quantization (lower quant = lower quality).
87pub fn quant_quality_penalty(quant: &str) -> f64 {
88    match quant {
89        "F16" | "BF16" => 0.0,
90        "Q8_0" => 0.0,
91        "Q6_K" => -1.0,
92        "Q5_K_M" => -2.0,
93        "Q4_K_M" | "Q4_0" => -5.0,
94        "Q3_K_M" => -8.0,
95        "Q2_K" => -12.0,
96        "UD-Q2_K_XL" | "UD-Q2_K_L" | "UD-Q2_K_M" | "UD-Q2_K_S" => -12.0,
97        "UD-Q3_K_XL" | "UD-Q3_K_L" | "UD-Q3_K_M" | "UD-Q3_K_S" => -8.0,
98        "UD-Q4_K_XL" | "UD-Q4_K_L" | "UD-Q4_K_M" | "UD-Q4_K_S" => -5.0,
99        "UD-Q5_K_XL" | "UD-Q5_K_L" | "UD-Q5_K_M" | "UD-Q5_K_S" => -2.0,
100        "UD-Q6_K_XL" | "UD-Q6_K_L" | "UD-Q6_K_M" | "UD-Q6_K_S" => -1.0,
101        "UD-Q8_K_XL" | "UD-Q8_K_L" | "UD-Q8_K_M" | "UD-Q8_K_S" => 0.0,
102        "mlx-4bit" => -4.0,
103        "mlx-8bit" => 0.0,
104        "AWQ-4bit" => -3.0,
105        "AWQ-8bit" => 0.0,
106        "GPTQ-Int4" => -3.0,
107        "GPTQ-Int8" => 0.0,
108        _ => -5.0,
109    }
110}
111
112/// Model capability flags (orthogonal to UseCase).
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
114#[serde(rename_all = "snake_case")]
115pub enum Capability {
116    Vision,
117    ToolUse,
118}
119
120impl Capability {
121    pub fn label(&self) -> &'static str {
122        match self {
123            Capability::Vision => "Vision",
124            Capability::ToolUse => "Tool Use",
125        }
126    }
127
128    pub fn all() -> &'static [Capability] {
129        &[Capability::Vision, Capability::ToolUse]
130    }
131
132    /// Infer capabilities from model metadata when not explicitly set in JSON.
133    pub fn infer(model: &LlmModel) -> Vec<Capability> {
134        let mut caps = model.capabilities.clone();
135        let name = model.name.to_lowercase();
136        let use_case = model.use_case.to_lowercase();
137
138        // Vision detection
139        if !caps.contains(&Capability::Vision)
140            && (name.contains("vision")
141                || name.contains("-vl-")
142                || name.ends_with("-vl")
143                || name.contains("llava")
144                || name.contains("onevision")
145                || name.contains("pixtral")
146                || use_case.contains("vision")
147                || use_case.contains("multimodal"))
148        {
149            caps.push(Capability::Vision);
150        }
151
152        // Tool use detection (known model families)
153        if !caps.contains(&Capability::ToolUse)
154            && (use_case.contains("tool")
155                || use_case.contains("function call")
156                || name.contains("qwen3")
157                || name.contains("qwen2.5")
158                || name.contains("command-r")
159                || (name.contains("llama-3") && name.contains("instruct"))
160                || (name.contains("mistral") && name.contains("instruct"))
161                || name.contains("hermes")
162                || (name.contains("gemma-3") && name.ends_with("-it"))
163                || (name.contains("gemma-4") && name.ends_with("-it")))
164        {
165            caps.push(Capability::ToolUse);
166        }
167
168        caps
169    }
170}
171
172/// Model weight format — determines which inference runtime to use.
173#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
174#[serde(rename_all = "lowercase")]
175#[derive(Default)]
176pub enum ModelFormat {
177    #[default]
178    Gguf,
179    Awq,
180    Gptq,
181    Mlx,
182    Safetensors,
183}
184
185impl ModelFormat {
186    /// Returns true for formats that are pre-quantized at a fixed bit width
187    /// and cannot be dynamically re-quantized (AWQ, GPTQ).
188    pub fn is_prequantized(&self) -> bool {
189        matches!(self, ModelFormat::Awq | ModelFormat::Gptq)
190    }
191}
192
193/// Use-case category for scoring weights.
194#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
195pub enum UseCase {
196    General,
197    Coding,
198    Reasoning,
199    Chat,
200    Multimodal,
201    Embedding,
202}
203
204impl UseCase {
205    pub fn label(&self) -> &'static str {
206        match self {
207            UseCase::General => "General",
208            UseCase::Coding => "Coding",
209            UseCase::Reasoning => "Reasoning",
210            UseCase::Chat => "Chat",
211            UseCase::Multimodal => "Multimodal",
212            UseCase::Embedding => "Embedding",
213        }
214    }
215
216    /// Infer use-case from the model's use_case field and name.
217    pub fn from_model(model: &LlmModel) -> Self {
218        let name = model.name.to_lowercase();
219        let use_case = model.use_case.to_lowercase();
220
221        if use_case.contains("embedding") || name.contains("embed") || name.contains("bge") {
222            UseCase::Embedding
223        } else if name.contains("code") || use_case.contains("code") {
224            UseCase::Coding
225        } else if use_case.contains("vision") || use_case.contains("multimodal") {
226            UseCase::Multimodal
227        } else if use_case.contains("reason")
228            || use_case.contains("chain-of-thought")
229            || name.contains("deepseek-r1")
230        {
231            UseCase::Reasoning
232        } else if use_case.contains("chat") || use_case.contains("instruction") {
233            UseCase::Chat
234        } else {
235            UseCase::General
236        }
237    }
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct LlmModel {
242    pub name: String,
243    pub provider: String,
244    pub parameter_count: String,
245    #[serde(default)]
246    pub parameters_raw: Option<u64>,
247    pub min_ram_gb: f64,
248    pub recommended_ram_gb: f64,
249    pub min_vram_gb: Option<f64>,
250    pub quantization: String,
251    pub context_length: u32,
252    pub use_case: String,
253    #[serde(default)]
254    pub is_moe: bool,
255    #[serde(default)]
256    pub num_experts: Option<u32>,
257    #[serde(default)]
258    pub active_experts: Option<u32>,
259    #[serde(default)]
260    pub active_parameters: Option<u64>,
261    #[serde(default)]
262    pub release_date: Option<String>,
263    /// Known GGUF download sources (e.g. unsloth, bartowski repos on HuggingFace)
264    #[serde(default)]
265    pub gguf_sources: Vec<GgufSource>,
266    /// Model capabilities (vision, tool use, etc.)
267    #[serde(default)]
268    pub capabilities: Vec<Capability>,
269    /// Model weight format (gguf, awq, gptq, mlx, safetensors)
270    #[serde(default)]
271    pub format: ModelFormat,
272    /// Number of attention heads (for tensor-parallelism compatibility checks).
273    #[serde(default)]
274    pub num_attention_heads: Option<u32>,
275    /// Number of key-value heads for GQA (defaults to num_attention_heads if None).
276    #[serde(default)]
277    pub num_key_value_heads: Option<u32>,
278    /// Total number of transformer layers. Used by the precise KV cache formula.
279    #[serde(default)]
280    pub num_hidden_layers: Option<u32>,
281    /// Per-head dimension. Used by the precise KV cache formula. When absent,
282    /// derived as `hidden_size / num_attention_heads` if both are known, or
283    /// a name based heuristic otherwise.
284    #[serde(default)]
285    pub head_dim: Option<u32>,
286    /// Attention layer composition for hybrid models (full attention + linear /
287    /// Mamba style layers). When None, all layers are assumed to be full
288    /// attention. Used by KV cache compression schemes (e.g. TurboQuant) that
289    /// only apply to full attention layers.
290    #[serde(default)]
291    pub attention_layout: Option<AttentionLayout>,
292    /// Model license (e.g. "apache-2.0", "mit", "llama3.1")
293    #[serde(default)]
294    pub license: Option<String>,
295    /// Hidden dimension size (d_model). Used for MoE bandwidth decomposition.
296    #[serde(default)]
297    pub hidden_size: Option<u32>,
298    /// Per-expert FFN intermediate size. Used for MoE bandwidth decomposition.
299    #[serde(default)]
300    pub moe_intermediate_size: Option<u32>,
301    /// Vocabulary size. Used for lm_head + embedding bandwidth estimation.
302    #[serde(default)]
303    pub vocab_size: Option<u32>,
304    /// Shared expert FFN intermediate size (0 if no shared experts).
305    /// Present in Qwen1.5-MoE, DeepSeek-V2, Qwen3.5-MoE.
306    #[serde(default)]
307    pub shared_expert_intermediate_size: Option<u32>,
308}
309
310/// Composition of attention layers in a hybrid model.
311///
312/// Some recent architectures (Qwen3-Next, Jamba, Mamba style hybrids) mix
313/// full attention layers with cheaper linear / state space layers. KV cache
314/// compression schemes like TurboQuant only apply to the full attention
315/// fraction, so we track the split here to compute honest savings.
316#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
317pub struct AttentionLayout {
318    /// Number of full self attention layers (compressible).
319    pub full: u32,
320    /// Number of linear / state space layers (not compressible by KV quant).
321    pub linear: u32,
322}
323
324impl AttentionLayout {
325    pub fn total(&self) -> u32 {
326        self.full + self.linear
327    }
328
329    /// Fraction of layers that are full attention (and therefore compressible
330    /// by KV quant schemes). Returns 1.0 for an all-full model.
331    pub fn compressible_fraction(&self) -> f64 {
332        let total = self.total();
333        if total == 0 {
334            1.0
335        } else {
336            self.full as f64 / total as f64
337        }
338    }
339}
340
341/// KV cache element representation. Controls bytes per element for the
342/// precise KV cache formula and (for TurboQuant) gates on runtime support.
343#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
344pub enum KvQuant {
345    /// fp16 / bf16, the inference default for most runtimes.
346    #[default]
347    #[serde(rename = "fp16")]
348    Fp16,
349    /// fp8 KV cache (vLLM, llama.cpp via --cache-type-k fp8 on supported builds).
350    #[serde(rename = "fp8")]
351    Fp8,
352    /// 8 bit integer KV cache (llama.cpp `q8_0`, vLLM int8).
353    #[serde(rename = "q8_0")]
354    Q8_0,
355    /// 4 bit integer KV cache (llama.cpp `q4_0`, vLLM int4).
356    #[serde(rename = "q4_0")]
357    Q4_0,
358    /// TurboQuant (3 bit keys + 2 bit values + Pi/S overhead). Research
359    /// integration, vLLM + CUDA only, not in upstream vLLM yet. Compression
360    /// only applies to full attention layers, so hybrid models see less.
361    /// See https://github.com/0xSero/turboquant
362    #[serde(rename = "tq")]
363    TurboQuant,
364}
365
366impl KvQuant {
367    pub fn label(&self) -> &'static str {
368        match self {
369            KvQuant::Fp16 => "fp16",
370            KvQuant::Fp8 => "fp8",
371            KvQuant::Q8_0 => "q8_0",
372            KvQuant::Q4_0 => "q4_0",
373            KvQuant::TurboQuant => "tq",
374        }
375    }
376
377    /// Bytes per KV element for non-TurboQuant variants. TurboQuant is handled
378    /// per layer because it only affects the full attention slice.
379    pub fn bytes_per_element(&self) -> f64 {
380        match self {
381            KvQuant::Fp16 => 2.0,
382            KvQuant::Fp8 => 1.0,
383            KvQuant::Q8_0 => 1.0,
384            KvQuant::Q4_0 => 0.5,
385            // For the bookkeeping path that doesn't know about layout, assume
386            // ~2.7 bits per element on the compressible slice. The real
387            // computation in `precise_kv_cache_gb` handles the layout split.
388            KvQuant::TurboQuant => 0.34,
389        }
390    }
391
392    pub fn parse(s: &str) -> Option<Self> {
393        match s.trim().to_lowercase().as_str() {
394            "fp16" | "f16" | "bf16" | "default" => Some(KvQuant::Fp16),
395            "fp8" | "f8" => Some(KvQuant::Fp8),
396            "q8" | "q8_0" | "int8" => Some(KvQuant::Q8_0),
397            "q4" | "q4_0" | "int4" => Some(KvQuant::Q4_0),
398            "tq" | "turboquant" => Some(KvQuant::TurboQuant),
399            _ => None,
400        }
401    }
402
403    /// All KV quant options llmfit knows how to estimate. Order is best
404    /// quality (fp16) to most compressed.
405    pub fn all() -> &'static [KvQuant] {
406        &[
407            KvQuant::Fp16,
408            KvQuant::Fp8,
409            KvQuant::Q8_0,
410            KvQuant::Q4_0,
411            KvQuant::TurboQuant,
412        ]
413    }
414}
415
416impl std::fmt::Display for KvQuant {
417    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418        f.write_str(self.label())
419    }
420}
421
422/// Returns true if a model's license matches any in the comma-separated filter string.
423/// Models without a license never match.
424pub fn matches_license_filter(license: &Option<String>, filter: &str) -> bool {
425    let allowed: Vec<String> = filter.split(',').map(|s| s.trim().to_lowercase()).collect();
426    license
427        .as_ref()
428        .map(|l| allowed.contains(&l.to_lowercase()))
429        .unwrap_or(false)
430}
431
432/// A known GGUF download source for a model on HuggingFace.
433#[derive(Debug, Clone, Serialize, Deserialize)]
434pub struct GgufSource {
435    /// HuggingFace repo ID (e.g. "unsloth/Llama-3.1-8B-Instruct-GGUF")
436    pub repo: String,
437    /// Provider who published the GGUF (e.g. "unsloth", "bartowski")
438    pub provider: String,
439}
440
441impl LlmModel {
442    /// MLX models are Apple-only — they won't run on NVIDIA/AMD/Intel hardware.
443    /// We detect them by the `-MLX-` suffix that's standard on HuggingFace
444    /// (e.g. `Qwen3-8B-MLX-4bit`, `LFM2-1.2B-MLX-8bit`).
445    pub fn is_mlx_model(&self) -> bool {
446        let name_lower = self.name.to_lowercase();
447        name_lower.contains("-mlx-") || name_lower.ends_with("-mlx")
448    }
449
450    /// Returns true if this model uses a pre-quantized format (AWQ/GPTQ)
451    /// that cannot be dynamically re-quantized.
452    pub fn is_prequantized(&self) -> bool {
453        self.format.is_prequantized()
454    }
455
456    /// Returns true if the model's attention/KV heads are evenly divisible
457    /// by `tp_size`, meaning it can be split across that many devices.
458    /// TP=1 always returns true.
459    pub fn supports_tp(&self, tp_size: u32) -> bool {
460        if tp_size <= 1 {
461            return true;
462        }
463        let (attn, kv) = self.infer_head_counts();
464        attn % tp_size == 0 && kv % tp_size == 0
465    }
466
467    /// Returns all valid TP degrees in [1..=8] for this model.
468    pub fn valid_tp_sizes(&self) -> Vec<u32> {
469        (1..=8).filter(|&tp| self.supports_tp(tp)).collect()
470    }
471
472    /// Infer attention and KV head counts from metadata or model name heuristics.
473    fn infer_head_counts(&self) -> (u32, u32) {
474        if let (Some(attn), Some(kv)) = (self.num_attention_heads, self.num_key_value_heads) {
475            return (attn, kv);
476        }
477        if let Some(attn) = self.num_attention_heads {
478            return (attn, attn);
479        }
480        // Heuristic: infer from model name
481        infer_heads_from_name(&self.name, self.params_b())
482    }
483
484    /// Bytes-per-parameter for the model's quantization level.
485    fn quant_bpp(&self) -> f64 {
486        quant_bpp(&self.quantization)
487    }
488
489    /// Parameter count in billions, extracted from parameters_raw or parameter_count.
490    pub fn params_b(&self) -> f64 {
491        if let Some(raw) = self.parameters_raw {
492            raw as f64 / 1_000_000_000.0
493        } else {
494            // Parse from string like "7B", "1.1B", "137M"
495            let s = self.parameter_count.trim().to_uppercase();
496            if let Some(num_str) = s.strip_suffix('B') {
497                num_str.parse::<f64>().unwrap_or(7.0)
498            } else if let Some(num_str) = s.strip_suffix('M') {
499                num_str.parse::<f64>().unwrap_or(0.0) / 1000.0
500            } else {
501                7.0
502            }
503        }
504    }
505
506    /// Approximate on-disk size (GB) for a given quantization level.
507    /// This is just the model weights: params_b * bytes_per_param.
508    pub fn estimate_disk_gb(&self, quant: &str) -> f64 {
509        self.params_b() * quant_bpp(quant)
510    }
511
512    /// Effective bytes-per-param for the compute-bound fixed component of MoE
513    /// per-token bandwidth. Captures the ratio of compute time to weight-read
514    /// time for attention-sized matrix operations.
515    /// Calibrated to K=3.2 from RX 6900 XT benchmarks across Q2_K, Q4_K_M, Q8_0.
516    pub const MOE_FIXED_EFFECTIVE_BPP: f64 = 3.2;
517
518    /// Decompose MoE per-token bandwidth into scalable (FFN) and fixed components.
519    ///
520    /// Returns (active_ffn_params_billions, fixed_params_billions) or None if
521    /// insufficient architecture metadata is available.
522    ///
523    /// The fixed component includes: attention layers (Q,K,V,O), MoE router,
524    /// shared experts (if any), output head (lm_head), and embedding table.
525    /// These are compute-bound and don't scale with quantization, so we use
526    /// MOE_FIXED_EFFECTIVE_BPP to convert them to bandwidth-equivalent bytes.
527    pub fn moe_bandwidth_decomposition(&self) -> Option<(f64, f64)> {
528        if !self.is_moe {
529            return None;
530        }
531
532        let hidden = self.hidden_size? as f64;
533        let layers = self.num_hidden_layers? as f64;
534        let active_exp = self.active_experts? as f64;
535        let expert_inter = self.moe_intermediate_size? as f64;
536        let vocab = self.vocab_size? as f64;
537        let n_experts = self.num_experts.unwrap_or(8) as f64;
538
539        // Head dimensions: prefer explicit head_dim, derive from hidden/heads
540        let n_heads = self.num_attention_heads.unwrap_or(1) as f64;
541        let n_kv = self
542            .num_key_value_heads
543            .unwrap_or(self.num_attention_heads.unwrap_or(1)) as f64;
544        let hd = self
545            .head_dim
546            .map(|h| h as f64)
547            .unwrap_or_else(|| hidden / n_heads);
548
549        // Active routed expert FFN params (SwiGLU: 3 projections per expert)
550        let active_ffn = layers * active_exp * 3.0 * hidden * expert_inter;
551
552        // Attention params per layer: Q + K + V + O
553        let attn_per_layer = 2.0 * n_heads * hd * hidden + 2.0 * n_kv * hd * hidden;
554        let attn_total = layers * attn_per_layer;
555
556        // Shared expert FFN (Qwen1.5-MoE, DeepSeek-V2, Qwen3.5)
557        let shared_inter = self.shared_expert_intermediate_size.unwrap_or(0) as f64;
558        let shared_ffn = layers * 3.0 * hidden * shared_inter;
559
560        // Router: one gate projection per layer
561        let router = layers * n_experts * hidden;
562
563        // Output head + embedding (both are hidden × vocab)
564        let lm_head = vocab * hidden;
565        let embedding = vocab * hidden;
566
567        let fixed = attn_total + shared_ffn + router + lm_head + embedding;
568
569        Some((active_ffn / 1_000_000_000.0, fixed / 1_000_000_000.0))
570    }
571
572    /// Estimate memory required (GB) at a given quantization and context length.
573    /// Defaults to fp16 KV cache. Use `estimate_memory_gb_with_kv` to override.
574    pub fn estimate_memory_gb(&self, quant: &str, ctx: u32) -> f64 {
575        self.estimate_memory_gb_with_kv(quant, ctx, KvQuant::Fp16)
576    }
577
578    /// Estimate memory required (GB) with an explicit KV cache quantization.
579    /// Formula: model_weights + KV_cache + runtime_overhead
580    pub fn estimate_memory_gb_with_kv(&self, quant: &str, ctx: u32, kv: KvQuant) -> f64 {
581        let bpp = quant_bpp(quant);
582        let params = self.params_b();
583        let model_mem = params * bpp;
584        let kv_cache = self.kv_cache_gb(ctx, kv);
585        // Runtime overhead (CUDA/Metal context, buffers)
586        let overhead = 0.5;
587        model_mem + kv_cache + overhead
588    }
589
590    /// KV cache size in GB at the given context length and KV quant.
591    ///
592    /// Uses the precise per layer formula when `num_hidden_layers`,
593    /// `num_key_value_heads`, and `head_dim` are known:
594    ///
595    /// `kv_bytes = 2 * n_layers * n_kv_heads * head_dim * ctx * dtype_bytes`
596    ///
597    /// Falls back to a coarse `params * ctx` approximation when the metadata
598    /// is missing so older catalog entries don't regress.
599    ///
600    /// For TurboQuant, only the full attention slice (per `attention_layout`)
601    /// is compressed. Linear / state space layers stay at fp16.
602    pub fn kv_cache_gb(&self, ctx: u32, kv: KvQuant) -> f64 {
603        let params = self.params_b();
604        let layout = self.effective_attention_layout();
605
606        // Precise path: requires layer count, KV head count, head dim.
607        if let (Some(n_layers), Some(head_dim)) = (self.num_hidden_layers, self.head_dim) {
608            let n_kv_heads = self
609                .num_key_value_heads
610                .or(self.num_attention_heads)
611                .unwrap_or(8);
612
613            let bytes_per_layer =
614                |bpe: f64| -> f64 { 2.0 * n_kv_heads as f64 * head_dim as f64 * ctx as f64 * bpe };
615
616            let total_bytes = match kv {
617                KvQuant::TurboQuant => {
618                    // Compressed slice (full attention) at TQ rate, rest stay fp16.
619                    let full_layers = match layout {
620                        Some(l) => l.full.min(n_layers),
621                        None => n_layers,
622                    };
623                    let linear_layers = n_layers.saturating_sub(full_layers);
624                    bytes_per_layer(KvQuant::TurboQuant.bytes_per_element()) * full_layers as f64
625                        + bytes_per_layer(KvQuant::Fp16.bytes_per_element()) * linear_layers as f64
626                }
627                _ => bytes_per_layer(kv.bytes_per_element()) * n_layers as f64,
628            };
629
630            return total_bytes / 1_073_741_824.0;
631        }
632
633        // Fallback: coarse linear approximation, scaled by KV quant ratio.
634        // Historical formula was 0.000008 * params_b * ctx (assumes fp16).
635        let baseline_fp16 = 0.000008 * params * ctx as f64;
636        let scale = match kv {
637            KvQuant::Fp16 => 1.0,
638            KvQuant::Fp8 | KvQuant::Q8_0 => 0.5,
639            KvQuant::Q4_0 => 0.25,
640            KvQuant::TurboQuant => {
641                // Without layer counts we can't separate full vs linear, so
642                // weight the savings by the layout if available, otherwise
643                // assume an all-full dense transformer.
644                let frac = layout.map(|l| l.compressible_fraction()).unwrap_or(1.0);
645                let tq_ratio = KvQuant::TurboQuant.bytes_per_element() / 2.0;
646                frac * tq_ratio + (1.0 - frac)
647            }
648        };
649        baseline_fp16 * scale
650    }
651
652    /// Select the best quantization level that fits within a memory budget.
653    /// Returns the quant name and estimated memory in GB, or None if nothing fits.
654    pub fn best_quant_for_budget(&self, budget_gb: f64, ctx: u32) -> Option<(&'static str, f64)> {
655        self.best_quant_for_budget_with(budget_gb, ctx, QUANT_HIERARCHY)
656    }
657
658    /// Select the best quantization from a custom hierarchy that fits within a memory budget.
659    pub fn best_quant_for_budget_with(
660        &self,
661        budget_gb: f64,
662        ctx: u32,
663        hierarchy: &[&'static str],
664    ) -> Option<(&'static str, f64)> {
665        // Try best quality first
666        for &q in hierarchy {
667            let mem = self.estimate_memory_gb(q, ctx);
668            if mem <= budget_gb {
669                return Some((q, mem));
670            }
671        }
672        // Try halving context once
673        let half_ctx = ctx / 2;
674        if half_ctx >= 1024 {
675            for &q in hierarchy {
676                let mem = self.estimate_memory_gb(q, half_ctx);
677                if mem <= budget_gb {
678                    return Some((q, mem));
679                }
680            }
681        }
682        None
683    }
684
685    /// Resolved attention layout: explicit metadata if present, otherwise a
686    /// best effort heuristic based on the model name. Returns `None` for
687    /// plain dense transformers (which the KV estimator should treat as
688    /// "all layers compressible").
689    pub fn effective_attention_layout(&self) -> Option<AttentionLayout> {
690        self.attention_layout
691            .or_else(|| infer_attention_layout_from_name(&self.name))
692    }
693
694    /// For MoE models, compute estimated VRAM for active experts only.
695    /// Returns None for dense models.
696    pub fn moe_active_vram_gb(&self) -> Option<f64> {
697        if !self.is_moe {
698            return None;
699        }
700        let active_params = self.active_parameters? as f64;
701        let bpp = self.quant_bpp();
702        let size_gb = (active_params * bpp) / (1024.0 * 1024.0 * 1024.0);
703        Some((size_gb * 1.1).max(0.5))
704    }
705
706    /// Returns true if this model is MLX-specific (Apple Silicon only).
707    /// MLX models are identified by having "-MLX" in their name.
708    pub fn is_mlx_only(&self) -> bool {
709        self.name.to_uppercase().contains("-MLX")
710    }
711
712    /// For MoE models, compute RAM needed for offloaded (inactive) experts.
713    /// Returns None for dense models.
714    pub fn moe_offloaded_ram_gb(&self) -> Option<f64> {
715        if !self.is_moe {
716            return None;
717        }
718        let active = self.active_parameters? as f64;
719        let total = self.parameters_raw? as f64;
720        let inactive = total - active;
721        if inactive <= 0.0 {
722            return Some(0.0);
723        }
724        let bpp = self.quant_bpp();
725        Some((inactive * bpp) / (1024.0 * 1024.0 * 1024.0))
726    }
727}
728
729/// Intermediate struct matching the JSON schema from the scraper.
730/// Extra fields are ignored when mapping to LlmModel.
731#[derive(Debug, Clone, Deserialize)]
732struct HfModelEntry {
733    name: String,
734    provider: String,
735    parameter_count: String,
736    #[serde(default)]
737    parameters_raw: Option<u64>,
738    min_ram_gb: f64,
739    recommended_ram_gb: f64,
740    min_vram_gb: Option<f64>,
741    quantization: String,
742    context_length: u32,
743    use_case: String,
744    #[serde(default)]
745    is_moe: bool,
746    #[serde(default)]
747    num_experts: Option<u32>,
748    #[serde(default)]
749    active_experts: Option<u32>,
750    #[serde(default)]
751    active_parameters: Option<u64>,
752    #[serde(default)]
753    release_date: Option<String>,
754    #[serde(default)]
755    gguf_sources: Vec<GgufSource>,
756    #[serde(default)]
757    capabilities: Vec<Capability>,
758    #[serde(default)]
759    format: ModelFormat,
760    #[serde(default)]
761    hf_downloads: u64,
762    #[serde(default)]
763    hf_likes: u64,
764    #[serde(default)]
765    num_attention_heads: Option<u32>,
766    #[serde(default)]
767    num_key_value_heads: Option<u32>,
768    #[serde(default)]
769    num_hidden_layers: Option<u32>,
770    #[serde(default)]
771    head_dim: Option<u32>,
772    #[serde(default)]
773    license: Option<String>,
774}
775
776const HF_MODELS_JSON: &str = include_str!("../data/hf_models.json");
777
778pub struct ModelDatabase {
779    models: Vec<LlmModel>,
780}
781
782impl Default for ModelDatabase {
783    fn default() -> Self {
784        Self::new()
785    }
786}
787
788/// Normalize a model name/ID to a canonical slug for deduplication.
789///
790/// Strips the `org/` prefix, lowercases, and collapses `-`/`_`/`.` so that
791/// `meta-llama/Llama-3.1-8B` and `meta-llama/llama-3.1-8b` compare equal.
792pub(crate) fn canonical_slug(name: &str) -> String {
793    let slug = name.split('/').next_back().unwrap_or(name);
794    slug.to_lowercase().replace(['-', '_', '.'], "")
795}
796
797/// Parse the compile-time embedded JSON into a flat `Vec<LlmModel>`.
798fn load_embedded() -> Vec<LlmModel> {
799    let entries: Vec<HfModelEntry> =
800        serde_json::from_str(HF_MODELS_JSON).expect("Failed to parse embedded hf_models.json");
801    entries
802        .into_iter()
803        .map(|e| {
804            let mut model = LlmModel {
805                name: e.name,
806                provider: e.provider,
807                parameter_count: e.parameter_count,
808                parameters_raw: e.parameters_raw,
809                min_ram_gb: e.min_ram_gb,
810                recommended_ram_gb: e.recommended_ram_gb,
811                min_vram_gb: e.min_vram_gb,
812                quantization: e.quantization,
813                context_length: e.context_length,
814                use_case: e.use_case,
815                is_moe: e.is_moe,
816                num_experts: e.num_experts,
817                active_experts: e.active_experts,
818                active_parameters: e.active_parameters,
819                release_date: e.release_date,
820                gguf_sources: e.gguf_sources,
821                capabilities: e.capabilities,
822                format: e.format,
823                num_attention_heads: e.num_attention_heads,
824                num_key_value_heads: e.num_key_value_heads,
825                num_hidden_layers: e.num_hidden_layers,
826                head_dim: e.head_dim,
827                attention_layout: None,
828                hidden_size: None,
829                moe_intermediate_size: None,
830                vocab_size: None,
831                shared_expert_intermediate_size: None,
832                license: e.license,
833            };
834            model.capabilities = Capability::infer(&model);
835            // Auto-populate attention_layout from name heuristic for known
836            // hybrid families. Explicit metadata still wins (model.attention_layout
837            // stays None until the scraper is taught to read it from config.json).
838            if model.attention_layout.is_none() {
839                model.attention_layout = infer_attention_layout_from_name(&model.name);
840            }
841            model
842        })
843        .collect()
844}
845
846impl ModelDatabase {
847    /// Load only the compile-time embedded model list (no cache).
848    /// Used internally by the updater to determine which models are already known.
849    pub fn embedded() -> Self {
850        ModelDatabase {
851            models: load_embedded(),
852        }
853    }
854
855    /// Load the embedded model list **and** merge any locally cached models.
856    ///
857    /// Cached models are appended after the embedded ones; if an ID already
858    /// exists in the embedded list it is skipped to avoid duplication.
859    /// Silently ignores a missing or corrupt cache file.
860    pub fn new() -> Self {
861        let mut models = load_embedded();
862
863        // Merge cached models (from `llmfit update`) without duplicating.
864        // canonical_slug normalizes org/ prefix, case, and separators so that
865        // e.g. `meta-llama/Llama-3.1-8B` and `meta-llama/llama-3.1-8b` are
866        // treated as the same model.
867        let embedded_keys: std::collections::HashSet<String> =
868            models.iter().map(|m| canonical_slug(&m.name)).collect();
869
870        for cached in crate::update::load_cache() {
871            if !embedded_keys.contains(&canonical_slug(&cached.name)) {
872                models.push(cached);
873            }
874        }
875
876        ModelDatabase { models }
877    }
878
879    pub fn get_all_models(&self) -> &Vec<LlmModel> {
880        &self.models
881    }
882
883    pub fn find_model(&self, query: &str) -> Vec<&LlmModel> {
884        let query_lower = query.to_lowercase();
885        self.models
886            .iter()
887            .filter(|m| {
888                m.name.to_lowercase().contains(&query_lower)
889                    || m.provider.to_lowercase().contains(&query_lower)
890                    || m.parameter_count.to_lowercase().contains(&query_lower)
891            })
892            .collect()
893    }
894
895    pub fn models_fitting_system(
896        &self,
897        available_ram_gb: f64,
898        has_gpu: bool,
899        vram_gb: Option<f64>,
900    ) -> Vec<&LlmModel> {
901        self.models
902            .iter()
903            .filter(|m| {
904                // Check RAM requirement
905                let ram_ok = m.min_ram_gb <= available_ram_gb;
906
907                // If model requires GPU and system has GPU, check VRAM
908                if let Some(min_vram) = m.min_vram_gb {
909                    if has_gpu {
910                        if let Some(system_vram) = vram_gb {
911                            ram_ok && min_vram <= system_vram
912                        } else {
913                            // GPU detected but VRAM unknown, allow but warn
914                            ram_ok
915                        }
916                    } else {
917                        // Model prefers GPU but can run on CPU with enough RAM
918                        ram_ok && available_ram_gb >= m.recommended_ram_gb
919                    }
920                } else {
921                    ram_ok
922                }
923            })
924            .collect()
925    }
926}
927
928/// Infer an attention layout from the model name for known hybrid families.
929/// Returns `None` for plain dense / all-full transformers (which is the safe
930/// default for the KV cache estimator: assume all layers are compressible).
931///
932/// The numbers here come from the published configs of each family as of
933/// 2026 Q1. They're a best effort starting point and should be replaced
934/// with values scraped from `config.json` whenever the metadata is available.
935pub fn infer_attention_layout_from_name(name: &str) -> Option<AttentionLayout> {
936    let lower = name.to_lowercase();
937
938    // Qwen3-Next series: roughly 1 full attention layer per 4 layers,
939    // remainder are linear / gated DeltaNet style. The A3B (35B total)
940    // variant ships with 10 full out of 40 according to the TurboQuant
941    // benchmark in 0xSero/turboquant.
942    if lower.contains("qwen3-next") || lower.contains("qwen3.5-next") {
943        return Some(AttentionLayout {
944            full: 10,
945            linear: 30,
946        });
947    }
948
949    // Qwen3.5 / Qwen3.6 hybrid models use 1 full attention per 4 layers.
950    // The dense 27B variants have 64 layers → 16 full + 48 linear.
951    // The MoE A3B variants have 40 layers → 10 full + 30 linear.
952    if lower.contains("qwen3.5-") || lower.contains("qwen3.6-") {
953        if lower.contains("-a3b") || lower.contains("-a10b") || lower.contains("-a17b") {
954            return Some(AttentionLayout {
955                full: 10,
956                linear: 30,
957            });
958        }
959        // Dense variants (27B) use 64 layers with same 1:3 ratio
960        return Some(AttentionLayout {
961            full: 16,
962            linear: 48,
963        });
964    }
965
966    // Jamba (Mamba + Transformer hybrid). Jamba 1.5 Mini and Large both
967    // use a 1:7 attention to mamba ratio in their 32 layer blocks.
968    if lower.contains("jamba") {
969        return Some(AttentionLayout {
970            full: 4,
971            linear: 28,
972        });
973    }
974
975    // Zamba2 (Mamba2 + shared attention). Zamba2-7B has 2 shared attention
976    // blocks and 54 mamba layers per the model card.
977    if lower.contains("zamba") {
978        return Some(AttentionLayout {
979            full: 2,
980            linear: 54,
981        });
982    }
983
984    // RWKV / Mamba pure SSM models: no full attention at all. We still
985    // report them so the KV estimator can short circuit. Compressible
986    // fraction is 0, so KV quant savings will correctly show as zero.
987    if lower.contains("mamba") || lower.contains("rwkv") {
988        return Some(AttentionLayout { full: 0, linear: 1 });
989    }
990
991    None
992}
993
994/// Infer attention and KV head counts from the model name and parameter count.
995/// Used as a fallback when explicit head counts are not available in the model metadata.
996fn infer_heads_from_name(name: &str, params_b: f64) -> (u32, u32) {
997    let name_lower = name.to_lowercase();
998
999    // Qwen family
1000    if name_lower.contains("qwen") {
1001        if params_b > 100.0 {
1002            return (128, 16);
1003        } else if params_b > 50.0 {
1004            return (64, 8);
1005        } else if params_b > 25.0 {
1006            return (40, 8);
1007        } else if params_b > 10.0 {
1008            return (40, 8);
1009        } else if params_b > 5.0 {
1010            return (32, 8);
1011        } else {
1012            return (16, 4);
1013        }
1014    }
1015
1016    // Llama family
1017    if name_lower.contains("llama") {
1018        if name_lower.contains("scout") || name_lower.contains("maverick") {
1019            return (64, 8);
1020        } else if params_b > 60.0 {
1021            return (64, 8);
1022        } else if params_b > 20.0 {
1023            return (48, 8);
1024        } else if params_b > 5.0 {
1025            return (32, 8);
1026        } else {
1027            return (16, 8);
1028        }
1029    }
1030
1031    // DeepSeek family
1032    if name_lower.contains("deepseek") {
1033        if params_b > 200.0 {
1034            return (128, 16);
1035        } else if params_b > 50.0 {
1036            return (64, 8);
1037        } else if params_b > 25.0 {
1038            return (40, 8);
1039        } else if params_b > 10.0 {
1040            return (40, 8);
1041        } else {
1042            return (32, 8);
1043        }
1044    }
1045
1046    // Mistral/Mixtral
1047    if name_lower.contains("mistral") || name_lower.contains("mixtral") {
1048        if params_b > 100.0 {
1049            return (96, 8);
1050        } else if params_b > 20.0 {
1051            return (32, 8);
1052        } else {
1053            return (32, 8);
1054        }
1055    }
1056
1057    // Gemma
1058    if name_lower.contains("gemma") {
1059        if params_b > 20.0 {
1060            return (32, 16);
1061        } else if params_b > 5.0 {
1062            return (16, 8);
1063        } else {
1064            return (8, 4);
1065        }
1066    }
1067
1068    // Phi
1069    if name_lower.contains("phi") {
1070        if params_b > 10.0 {
1071            return (40, 10);
1072        } else {
1073            return (32, 8);
1074        }
1075    }
1076
1077    // MiniMax
1078    if name_lower.contains("minimax") {
1079        return (48, 8);
1080    }
1081
1082    // Default: common pattern based on param count
1083    if params_b > 100.0 {
1084        (128, 16)
1085    } else if params_b > 50.0 {
1086        (64, 8)
1087    } else if params_b > 20.0 {
1088        (32, 8)
1089    } else if params_b > 5.0 {
1090        (32, 8)
1091    } else {
1092        (16, 4)
1093    }
1094}
1095
1096#[cfg(test)]
1097mod tests {
1098    use super::*;
1099
1100    // ────────────────────────────────────────────────────────────────────
1101    // Quantization function tests
1102    // ────────────────────────────────────────────────────────────────────
1103
1104    #[test]
1105    fn test_mlx_quant_bpp_values() {
1106        assert_eq!(quant_bpp("mlx-4bit"), 0.55);
1107        assert_eq!(quant_bpp("mlx-8bit"), 1.0);
1108        assert_eq!(quant_speed_multiplier("mlx-4bit"), 1.15);
1109        assert_eq!(quant_speed_multiplier("mlx-8bit"), 0.85);
1110        assert_eq!(quant_quality_penalty("mlx-4bit"), -4.0);
1111        assert_eq!(quant_quality_penalty("mlx-8bit"), 0.0);
1112    }
1113
1114    #[test]
1115    fn test_ud_quant_mappings() {
1116        // UD-Q2_K_XL should match Q2_K values (not hit the default fallback)
1117        assert_eq!(quant_bpp("UD-Q2_K_XL"), quant_bpp("Q2_K"));
1118        assert_eq!(
1119            quant_bytes_per_param("UD-Q2_K_XL"),
1120            quant_bytes_per_param("Q2_K")
1121        );
1122        assert_eq!(
1123            quant_speed_multiplier("UD-Q2_K_XL"),
1124            quant_speed_multiplier("Q2_K")
1125        );
1126        assert_eq!(
1127            quant_quality_penalty("UD-Q2_K_XL"),
1128            quant_quality_penalty("Q2_K")
1129        );
1130
1131        // UD-Q4_K_M should match Q4_K_M values
1132        assert_eq!(quant_bpp("UD-Q4_K_M"), quant_bpp("Q4_K_M"));
1133        assert_eq!(
1134            quant_bytes_per_param("UD-Q4_K_M"),
1135            quant_bytes_per_param("Q4_K_M")
1136        );
1137
1138        // UD-Q8_K_S should match Q8_0 values (bpp table)
1139        assert_eq!(quant_bpp("UD-Q8_K_S"), quant_bpp("Q8_0"));
1140        assert_eq!(
1141            quant_bytes_per_param("UD-Q8_K_S"),
1142            quant_bytes_per_param("Q8_0")
1143        );
1144
1145        // Verify no longer hitting defaults
1146        assert!(
1147            quant_bpp("UD-Q2_K_XL") < 0.5,
1148            "UD-Q2_K_XL bpp should be 0.37, not default 0.58"
1149        );
1150        assert!(
1151            quant_bytes_per_param("UD-Q2_K_XL") < 0.4,
1152            "UD-Q2_K_XL bytes should be 0.25, not default 0.5"
1153        );
1154    }
1155
1156    #[test]
1157    fn test_best_quant_with_mlx_hierarchy() {
1158        let model = LlmModel {
1159            name: "Test Model".to_string(),
1160            provider: "Test".to_string(),
1161            parameter_count: "7B".to_string(),
1162            parameters_raw: Some(7_000_000_000),
1163            min_ram_gb: 4.0,
1164            recommended_ram_gb: 8.0,
1165            min_vram_gb: Some(4.0),
1166            quantization: "Q4_K_M".to_string(),
1167            context_length: 4096,
1168            use_case: "General".to_string(),
1169            is_moe: false,
1170            num_experts: None,
1171            active_experts: None,
1172            active_parameters: None,
1173            release_date: None,
1174            gguf_sources: vec![],
1175            capabilities: vec![],
1176            format: ModelFormat::default(),
1177            num_attention_heads: None,
1178            num_key_value_heads: None,
1179            num_hidden_layers: None,
1180            head_dim: None,
1181            attention_layout: None,
1182            hidden_size: None,
1183            moe_intermediate_size: None,
1184            vocab_size: None,
1185            shared_expert_intermediate_size: None,
1186            license: None,
1187        };
1188
1189        // Large budget should return mlx-8bit (best in MLX hierarchy)
1190        let result = model.best_quant_for_budget_with(10.0, 4096, MLX_QUANT_HIERARCHY);
1191        assert!(result.is_some());
1192        let (quant, _) = result.unwrap();
1193        assert_eq!(quant, "mlx-8bit");
1194
1195        // Tighter budget should fall to mlx-4bit
1196        let result = model.best_quant_for_budget_with(5.0, 4096, MLX_QUANT_HIERARCHY);
1197        assert!(result.is_some());
1198        let (quant, _) = result.unwrap();
1199        assert_eq!(quant, "mlx-4bit");
1200    }
1201
1202    #[test]
1203    fn test_quant_bpp() {
1204        assert_eq!(quant_bpp("F32"), 4.0);
1205        assert_eq!(quant_bpp("F16"), 2.0);
1206        assert_eq!(quant_bpp("Q8_0"), 1.05);
1207        assert_eq!(quant_bpp("Q4_K_M"), 0.58);
1208        assert_eq!(quant_bpp("Q2_K"), 0.37);
1209        // Unknown quant defaults to Q4_K_M
1210        assert_eq!(quant_bpp("UNKNOWN"), 0.58);
1211    }
1212
1213    #[test]
1214    fn test_quant_speed_multiplier() {
1215        assert_eq!(quant_speed_multiplier("F16"), 0.6);
1216        assert_eq!(quant_speed_multiplier("Q5_K_M"), 1.0);
1217        assert_eq!(quant_speed_multiplier("Q4_K_M"), 1.15);
1218        assert_eq!(quant_speed_multiplier("Q2_K"), 1.35);
1219        // Lower quant = faster inference
1220        assert!(quant_speed_multiplier("Q2_K") > quant_speed_multiplier("Q8_0"));
1221    }
1222
1223    #[test]
1224    fn test_quant_quality_penalty() {
1225        assert_eq!(quant_quality_penalty("F16"), 0.0);
1226        assert_eq!(quant_quality_penalty("Q8_0"), 0.0);
1227        assert_eq!(quant_quality_penalty("Q4_K_M"), -5.0);
1228        assert_eq!(quant_quality_penalty("Q2_K"), -12.0);
1229        // Lower quant = higher quality penalty
1230        assert!(quant_quality_penalty("Q2_K") < quant_quality_penalty("Q8_0"));
1231    }
1232
1233    // ────────────────────────────────────────────────────────────────────
1234    // LlmModel tests
1235    // ────────────────────────────────────────────────────────────────────
1236
1237    #[test]
1238    fn test_params_b_from_raw() {
1239        let model = LlmModel {
1240            name: "Test Model".to_string(),
1241            provider: "Test".to_string(),
1242            parameter_count: "7B".to_string(),
1243            parameters_raw: Some(7_000_000_000),
1244            min_ram_gb: 4.0,
1245            recommended_ram_gb: 8.0,
1246            min_vram_gb: Some(4.0),
1247            quantization: "Q4_K_M".to_string(),
1248            context_length: 4096,
1249            use_case: "General".to_string(),
1250            is_moe: false,
1251            num_experts: None,
1252            active_experts: None,
1253            active_parameters: None,
1254            release_date: None,
1255            gguf_sources: vec![],
1256            capabilities: vec![],
1257            format: ModelFormat::default(),
1258            num_attention_heads: None,
1259            num_key_value_heads: None,
1260            num_hidden_layers: None,
1261            head_dim: None,
1262            attention_layout: None,
1263            hidden_size: None,
1264            moe_intermediate_size: None,
1265            vocab_size: None,
1266            shared_expert_intermediate_size: None,
1267            license: None,
1268        };
1269        assert_eq!(model.params_b(), 7.0);
1270    }
1271
1272    #[test]
1273    fn test_params_b_from_string() {
1274        let model = LlmModel {
1275            name: "Test Model".to_string(),
1276            provider: "Test".to_string(),
1277            parameter_count: "13B".to_string(),
1278            parameters_raw: None,
1279            min_ram_gb: 8.0,
1280            recommended_ram_gb: 16.0,
1281            min_vram_gb: Some(8.0),
1282            quantization: "Q4_K_M".to_string(),
1283            context_length: 4096,
1284            use_case: "General".to_string(),
1285            is_moe: false,
1286            num_experts: None,
1287            active_experts: None,
1288            active_parameters: None,
1289            release_date: None,
1290            gguf_sources: vec![],
1291            capabilities: vec![],
1292            format: ModelFormat::default(),
1293            num_attention_heads: None,
1294            num_key_value_heads: None,
1295            num_hidden_layers: None,
1296            head_dim: None,
1297            attention_layout: None,
1298            hidden_size: None,
1299            moe_intermediate_size: None,
1300            vocab_size: None,
1301            shared_expert_intermediate_size: None,
1302            license: None,
1303        };
1304        assert_eq!(model.params_b(), 13.0);
1305    }
1306
1307    #[test]
1308    fn test_params_b_from_millions() {
1309        let model = LlmModel {
1310            name: "Test Model".to_string(),
1311            provider: "Test".to_string(),
1312            parameter_count: "500M".to_string(),
1313            parameters_raw: None,
1314            min_ram_gb: 1.0,
1315            recommended_ram_gb: 2.0,
1316            min_vram_gb: Some(1.0),
1317            quantization: "Q4_K_M".to_string(),
1318            context_length: 2048,
1319            use_case: "General".to_string(),
1320            is_moe: false,
1321            num_experts: None,
1322            active_experts: None,
1323            active_parameters: None,
1324            release_date: None,
1325            gguf_sources: vec![],
1326            capabilities: vec![],
1327            format: ModelFormat::default(),
1328            num_attention_heads: None,
1329            num_key_value_heads: None,
1330            num_hidden_layers: None,
1331            head_dim: None,
1332            attention_layout: None,
1333            hidden_size: None,
1334            moe_intermediate_size: None,
1335            vocab_size: None,
1336            shared_expert_intermediate_size: None,
1337            license: None,
1338        };
1339        assert_eq!(model.params_b(), 0.5);
1340    }
1341
1342    #[test]
1343    fn test_estimate_memory_gb() {
1344        let model = LlmModel {
1345            name: "Test Model".to_string(),
1346            provider: "Test".to_string(),
1347            parameter_count: "7B".to_string(),
1348            parameters_raw: Some(7_000_000_000),
1349            min_ram_gb: 4.0,
1350            recommended_ram_gb: 8.0,
1351            min_vram_gb: Some(4.0),
1352            quantization: "Q4_K_M".to_string(),
1353            context_length: 4096,
1354            use_case: "General".to_string(),
1355            is_moe: false,
1356            num_experts: None,
1357            active_experts: None,
1358            active_parameters: None,
1359            release_date: None,
1360            gguf_sources: vec![],
1361            capabilities: vec![],
1362            format: ModelFormat::default(),
1363            num_attention_heads: None,
1364            num_key_value_heads: None,
1365            num_hidden_layers: None,
1366            head_dim: None,
1367            attention_layout: None,
1368            hidden_size: None,
1369            moe_intermediate_size: None,
1370            vocab_size: None,
1371            shared_expert_intermediate_size: None,
1372            license: None,
1373        };
1374
1375        let mem = model.estimate_memory_gb("Q4_K_M", 4096);
1376        // 7B params * 0.58 bytes = 4.06 GB + KV cache + overhead
1377        assert!(mem > 4.0);
1378        assert!(mem < 6.0);
1379
1380        // Q8_0 should require more memory
1381        let mem_q8 = model.estimate_memory_gb("Q8_0", 4096);
1382        assert!(mem_q8 > mem);
1383    }
1384
1385    #[test]
1386    fn test_best_quant_for_budget() {
1387        let model = LlmModel {
1388            name: "Test Model".to_string(),
1389            provider: "Test".to_string(),
1390            parameter_count: "7B".to_string(),
1391            parameters_raw: Some(7_000_000_000),
1392            min_ram_gb: 4.0,
1393            recommended_ram_gb: 8.0,
1394            min_vram_gb: Some(4.0),
1395            quantization: "Q4_K_M".to_string(),
1396            context_length: 4096,
1397            use_case: "General".to_string(),
1398            is_moe: false,
1399            num_experts: None,
1400            active_experts: None,
1401            active_parameters: None,
1402            release_date: None,
1403            gguf_sources: vec![],
1404            capabilities: vec![],
1405            format: ModelFormat::default(),
1406            num_attention_heads: None,
1407            num_key_value_heads: None,
1408            num_hidden_layers: None,
1409            head_dim: None,
1410            attention_layout: None,
1411            hidden_size: None,
1412            moe_intermediate_size: None,
1413            vocab_size: None,
1414            shared_expert_intermediate_size: None,
1415            license: None,
1416        };
1417
1418        // Large budget should return best quant
1419        let result = model.best_quant_for_budget(10.0, 4096);
1420        assert!(result.is_some());
1421        let (quant, _) = result.unwrap();
1422        assert_eq!(quant, "Q8_0");
1423
1424        // Medium budget should find acceptable quant
1425        let result = model.best_quant_for_budget(5.0, 4096);
1426        assert!(result.is_some());
1427
1428        // Tiny budget should return None
1429        let result = model.best_quant_for_budget(1.0, 4096);
1430        assert!(result.is_none());
1431    }
1432
1433    #[test]
1434    fn test_moe_active_vram_gb() {
1435        // Dense model should return None
1436        let dense_model = LlmModel {
1437            name: "Dense Model".to_string(),
1438            provider: "Test".to_string(),
1439            parameter_count: "7B".to_string(),
1440            parameters_raw: Some(7_000_000_000),
1441            min_ram_gb: 4.0,
1442            recommended_ram_gb: 8.0,
1443            min_vram_gb: Some(4.0),
1444            quantization: "Q4_K_M".to_string(),
1445            context_length: 4096,
1446            use_case: "General".to_string(),
1447            is_moe: false,
1448            num_experts: None,
1449            active_experts: None,
1450            active_parameters: None,
1451            release_date: None,
1452            gguf_sources: vec![],
1453            capabilities: vec![],
1454            format: ModelFormat::default(),
1455            num_attention_heads: None,
1456            num_key_value_heads: None,
1457            num_hidden_layers: None,
1458            head_dim: None,
1459            attention_layout: None,
1460            hidden_size: None,
1461            moe_intermediate_size: None,
1462            vocab_size: None,
1463            shared_expert_intermediate_size: None,
1464            license: None,
1465        };
1466        assert!(dense_model.moe_active_vram_gb().is_none());
1467
1468        // MoE model should calculate active VRAM
1469        let moe_model = LlmModel {
1470            name: "MoE Model".to_string(),
1471            provider: "Test".to_string(),
1472            parameter_count: "8x7B".to_string(),
1473            parameters_raw: Some(46_700_000_000),
1474            min_ram_gb: 25.0,
1475            recommended_ram_gb: 50.0,
1476            min_vram_gb: Some(25.0),
1477            quantization: "Q4_K_M".to_string(),
1478            context_length: 32768,
1479            use_case: "General".to_string(),
1480            is_moe: true,
1481            num_experts: Some(8),
1482            active_experts: Some(2),
1483            active_parameters: Some(12_900_000_000),
1484            release_date: None,
1485            gguf_sources: vec![],
1486            capabilities: vec![],
1487            format: ModelFormat::default(),
1488            num_attention_heads: None,
1489            num_key_value_heads: None,
1490            num_hidden_layers: None,
1491            head_dim: None,
1492            attention_layout: None,
1493            hidden_size: None,
1494            moe_intermediate_size: None,
1495            vocab_size: None,
1496            shared_expert_intermediate_size: None,
1497            license: None,
1498        };
1499        let vram = moe_model.moe_active_vram_gb();
1500        assert!(vram.is_some());
1501        let vram_val = vram.unwrap();
1502        // Should be significantly less than full model
1503        assert!(vram_val > 0.0);
1504        assert!(vram_val < 15.0);
1505    }
1506
1507    #[test]
1508    fn test_moe_offloaded_ram_gb() {
1509        // Dense model should return None
1510        let dense_model = LlmModel {
1511            name: "Dense Model".to_string(),
1512            provider: "Test".to_string(),
1513            parameter_count: "7B".to_string(),
1514            parameters_raw: Some(7_000_000_000),
1515            min_ram_gb: 4.0,
1516            recommended_ram_gb: 8.0,
1517            min_vram_gb: Some(4.0),
1518            quantization: "Q4_K_M".to_string(),
1519            context_length: 4096,
1520            use_case: "General".to_string(),
1521            is_moe: false,
1522            num_experts: None,
1523            active_experts: None,
1524            active_parameters: None,
1525            release_date: None,
1526            gguf_sources: vec![],
1527            capabilities: vec![],
1528            format: ModelFormat::default(),
1529            num_attention_heads: None,
1530            num_key_value_heads: None,
1531            num_hidden_layers: None,
1532            head_dim: None,
1533            attention_layout: None,
1534            hidden_size: None,
1535            moe_intermediate_size: None,
1536            vocab_size: None,
1537            shared_expert_intermediate_size: None,
1538            license: None,
1539        };
1540        assert!(dense_model.moe_offloaded_ram_gb().is_none());
1541
1542        // MoE model should calculate offloaded RAM
1543        let moe_model = LlmModel {
1544            name: "MoE Model".to_string(),
1545            provider: "Test".to_string(),
1546            parameter_count: "8x7B".to_string(),
1547            parameters_raw: Some(46_700_000_000),
1548            min_ram_gb: 25.0,
1549            recommended_ram_gb: 50.0,
1550            min_vram_gb: Some(25.0),
1551            quantization: "Q4_K_M".to_string(),
1552            context_length: 32768,
1553            use_case: "General".to_string(),
1554            is_moe: true,
1555            num_experts: Some(8),
1556            active_experts: Some(2),
1557            active_parameters: Some(12_900_000_000),
1558            release_date: None,
1559            gguf_sources: vec![],
1560            capabilities: vec![],
1561            format: ModelFormat::default(),
1562            num_attention_heads: None,
1563            num_key_value_heads: None,
1564            num_hidden_layers: None,
1565            head_dim: None,
1566            attention_layout: None,
1567            hidden_size: None,
1568            moe_intermediate_size: None,
1569            vocab_size: None,
1570            shared_expert_intermediate_size: None,
1571            license: None,
1572        };
1573        let offloaded = moe_model.moe_offloaded_ram_gb();
1574        assert!(offloaded.is_some());
1575        let offloaded_val = offloaded.unwrap();
1576        // Should be substantial
1577        assert!(offloaded_val > 10.0);
1578    }
1579
1580    // ────────────────────────────────────────────────────────────────────
1581    // UseCase tests
1582    // ────────────────────────────────────────────────────────────────────
1583
1584    #[test]
1585    fn test_use_case_from_model_coding() {
1586        let model = LlmModel {
1587            name: "codellama-7b".to_string(),
1588            provider: "Meta".to_string(),
1589            parameter_count: "7B".to_string(),
1590            parameters_raw: Some(7_000_000_000),
1591            min_ram_gb: 4.0,
1592            recommended_ram_gb: 8.0,
1593            min_vram_gb: Some(4.0),
1594            quantization: "Q4_K_M".to_string(),
1595            context_length: 4096,
1596            use_case: "Coding".to_string(),
1597            is_moe: false,
1598            num_experts: None,
1599            active_experts: None,
1600            active_parameters: None,
1601            release_date: None,
1602            gguf_sources: vec![],
1603            capabilities: vec![],
1604            format: ModelFormat::default(),
1605            num_attention_heads: None,
1606            num_key_value_heads: None,
1607            num_hidden_layers: None,
1608            head_dim: None,
1609            attention_layout: None,
1610            hidden_size: None,
1611            moe_intermediate_size: None,
1612            vocab_size: None,
1613            shared_expert_intermediate_size: None,
1614            license: None,
1615        };
1616        assert_eq!(UseCase::from_model(&model), UseCase::Coding);
1617    }
1618
1619    #[test]
1620    fn test_use_case_from_model_embedding() {
1621        let model = LlmModel {
1622            name: "bge-large".to_string(),
1623            provider: "BAAI".to_string(),
1624            parameter_count: "335M".to_string(),
1625            parameters_raw: Some(335_000_000),
1626            min_ram_gb: 1.0,
1627            recommended_ram_gb: 2.0,
1628            min_vram_gb: Some(1.0),
1629            quantization: "F16".to_string(),
1630            context_length: 512,
1631            use_case: "Embedding".to_string(),
1632            is_moe: false,
1633            num_experts: None,
1634            active_experts: None,
1635            active_parameters: None,
1636            release_date: None,
1637            gguf_sources: vec![],
1638            capabilities: vec![],
1639            format: ModelFormat::default(),
1640            num_attention_heads: None,
1641            num_key_value_heads: None,
1642            num_hidden_layers: None,
1643            head_dim: None,
1644            attention_layout: None,
1645            hidden_size: None,
1646            moe_intermediate_size: None,
1647            vocab_size: None,
1648            shared_expert_intermediate_size: None,
1649            license: None,
1650        };
1651        assert_eq!(UseCase::from_model(&model), UseCase::Embedding);
1652    }
1653
1654    #[test]
1655    fn test_use_case_from_model_reasoning() {
1656        let model = LlmModel {
1657            name: "deepseek-r1-7b".to_string(),
1658            provider: "DeepSeek".to_string(),
1659            parameter_count: "7B".to_string(),
1660            parameters_raw: Some(7_000_000_000),
1661            min_ram_gb: 4.0,
1662            recommended_ram_gb: 8.0,
1663            min_vram_gb: Some(4.0),
1664            quantization: "Q4_K_M".to_string(),
1665            context_length: 8192,
1666            use_case: "Reasoning".to_string(),
1667            is_moe: false,
1668            num_experts: None,
1669            active_experts: None,
1670            active_parameters: None,
1671            release_date: None,
1672            gguf_sources: vec![],
1673            capabilities: vec![],
1674            format: ModelFormat::default(),
1675            num_attention_heads: None,
1676            num_key_value_heads: None,
1677            num_hidden_layers: None,
1678            head_dim: None,
1679            attention_layout: None,
1680            hidden_size: None,
1681            moe_intermediate_size: None,
1682            vocab_size: None,
1683            shared_expert_intermediate_size: None,
1684            license: None,
1685        };
1686        assert_eq!(UseCase::from_model(&model), UseCase::Reasoning);
1687    }
1688
1689    // ────────────────────────────────────────────────────────────────────
1690    // ModelDatabase tests
1691    // ────────────────────────────────────────────────────────────────────
1692
1693    #[test]
1694    fn test_model_database_new() {
1695        let db = ModelDatabase::new();
1696        let models = db.get_all_models();
1697        // Should have loaded models from embedded JSON
1698        assert!(!models.is_empty());
1699    }
1700
1701    #[test]
1702    fn test_find_model() {
1703        let db = ModelDatabase::new();
1704
1705        // Search by name substring (case insensitive)
1706        let results = db.find_model("llama");
1707        assert!(!results.is_empty());
1708        assert!(
1709            results
1710                .iter()
1711                .any(|m| m.name.to_lowercase().contains("llama"))
1712        );
1713
1714        // Search should be case insensitive
1715        let results_upper = db.find_model("LLAMA");
1716        assert_eq!(results.len(), results_upper.len());
1717    }
1718
1719    #[test]
1720    fn test_models_fitting_system() {
1721        let db = ModelDatabase::new();
1722
1723        // Large system should fit many models
1724        let fitting = db.models_fitting_system(32.0, true, Some(24.0));
1725        assert!(!fitting.is_empty());
1726
1727        // Very small system should fit fewer or no models
1728        let fitting_small = db.models_fitting_system(2.0, false, None);
1729        assert!(fitting_small.len() < fitting.len());
1730
1731        // All fitting models should meet RAM requirements
1732        for model in fitting_small {
1733            assert!(model.min_ram_gb <= 2.0);
1734        }
1735    }
1736
1737    // ────────────────────────────────────────────────────────────────────
1738    // Capability tests
1739    // ────────────────────────────────────────────────────────────────────
1740
1741    #[test]
1742    fn test_capability_infer_vision() {
1743        let model = LlmModel {
1744            name: "meta-llama/Llama-3.2-11B-Vision-Instruct".to_string(),
1745            provider: "Meta".to_string(),
1746            parameter_count: "11B".to_string(),
1747            parameters_raw: Some(11_000_000_000),
1748            min_ram_gb: 6.0,
1749            recommended_ram_gb: 10.0,
1750            min_vram_gb: Some(6.0),
1751            quantization: "Q4_K_M".to_string(),
1752            context_length: 131072,
1753            use_case: "Multimodal, vision and text".to_string(),
1754            is_moe: false,
1755            num_experts: None,
1756            active_experts: None,
1757            active_parameters: None,
1758            release_date: None,
1759            gguf_sources: vec![],
1760            capabilities: vec![],
1761            format: ModelFormat::default(),
1762            num_attention_heads: None,
1763            num_key_value_heads: None,
1764            num_hidden_layers: None,
1765            head_dim: None,
1766            attention_layout: None,
1767            hidden_size: None,
1768            moe_intermediate_size: None,
1769            vocab_size: None,
1770            shared_expert_intermediate_size: None,
1771            license: None,
1772        };
1773        let caps = Capability::infer(&model);
1774        assert!(caps.contains(&Capability::Vision));
1775        // Also gets ToolUse because "llama-3" + "instruct"
1776        assert!(caps.contains(&Capability::ToolUse));
1777    }
1778
1779    #[test]
1780    fn test_capability_infer_tool_use() {
1781        let model = LlmModel {
1782            name: "Qwen/Qwen3-8B".to_string(),
1783            provider: "Qwen".to_string(),
1784            parameter_count: "8B".to_string(),
1785            parameters_raw: Some(8_000_000_000),
1786            min_ram_gb: 4.5,
1787            recommended_ram_gb: 8.0,
1788            min_vram_gb: Some(4.0),
1789            quantization: "Q4_K_M".to_string(),
1790            context_length: 32768,
1791            use_case: "General purpose text generation".to_string(),
1792            is_moe: false,
1793            num_experts: None,
1794            active_experts: None,
1795            active_parameters: None,
1796            release_date: None,
1797            gguf_sources: vec![],
1798            capabilities: vec![],
1799            format: ModelFormat::default(),
1800            num_attention_heads: None,
1801            num_key_value_heads: None,
1802            num_hidden_layers: None,
1803            head_dim: None,
1804            attention_layout: None,
1805            hidden_size: None,
1806            moe_intermediate_size: None,
1807            vocab_size: None,
1808            shared_expert_intermediate_size: None,
1809            license: None,
1810        };
1811        let caps = Capability::infer(&model);
1812        assert!(caps.contains(&Capability::ToolUse));
1813        assert!(!caps.contains(&Capability::Vision));
1814    }
1815
1816    #[test]
1817    fn test_capability_infer_none() {
1818        let model = LlmModel {
1819            name: "BAAI/bge-large-en-v1.5".to_string(),
1820            provider: "BAAI".to_string(),
1821            parameter_count: "335M".to_string(),
1822            parameters_raw: Some(335_000_000),
1823            min_ram_gb: 1.0,
1824            recommended_ram_gb: 2.0,
1825            min_vram_gb: Some(1.0),
1826            quantization: "F16".to_string(),
1827            context_length: 512,
1828            use_case: "Text embeddings for RAG".to_string(),
1829            is_moe: false,
1830            num_experts: None,
1831            active_experts: None,
1832            active_parameters: None,
1833            release_date: None,
1834            gguf_sources: vec![],
1835            capabilities: vec![],
1836            format: ModelFormat::default(),
1837            num_attention_heads: None,
1838            num_key_value_heads: None,
1839            num_hidden_layers: None,
1840            head_dim: None,
1841            attention_layout: None,
1842            hidden_size: None,
1843            moe_intermediate_size: None,
1844            vocab_size: None,
1845            shared_expert_intermediate_size: None,
1846            license: None,
1847        };
1848        let caps = Capability::infer(&model);
1849        assert!(caps.is_empty());
1850    }
1851
1852    #[test]
1853    fn test_capability_preserves_explicit() {
1854        let model = LlmModel {
1855            name: "some-model".to_string(),
1856            provider: "Test".to_string(),
1857            parameter_count: "7B".to_string(),
1858            parameters_raw: Some(7_000_000_000),
1859            min_ram_gb: 4.0,
1860            recommended_ram_gb: 8.0,
1861            min_vram_gb: Some(4.0),
1862            quantization: "Q4_K_M".to_string(),
1863            context_length: 4096,
1864            use_case: "General".to_string(),
1865            is_moe: false,
1866            num_experts: None,
1867            active_experts: None,
1868            active_parameters: None,
1869            release_date: None,
1870            gguf_sources: vec![],
1871            capabilities: vec![Capability::Vision],
1872            format: ModelFormat::default(),
1873            num_attention_heads: None,
1874            num_key_value_heads: None,
1875            num_hidden_layers: None,
1876            head_dim: None,
1877            attention_layout: None,
1878            hidden_size: None,
1879            moe_intermediate_size: None,
1880            vocab_size: None,
1881            shared_expert_intermediate_size: None,
1882            license: None,
1883        };
1884        let caps = Capability::infer(&model);
1885        // Should keep the explicit Vision and not duplicate it
1886        assert_eq!(caps.iter().filter(|c| **c == Capability::Vision).count(), 1);
1887    }
1888
1889    #[test]
1890    fn test_awq_gptq_quant_values() {
1891        // AWQ
1892        assert_eq!(quant_bpp("AWQ-4bit"), 0.5);
1893        assert_eq!(quant_bpp("AWQ-8bit"), 1.0);
1894        assert_eq!(quant_speed_multiplier("AWQ-4bit"), 1.2);
1895        assert_eq!(quant_speed_multiplier("AWQ-8bit"), 0.85);
1896        assert_eq!(quant_quality_penalty("AWQ-4bit"), -3.0);
1897        assert_eq!(quant_quality_penalty("AWQ-8bit"), 0.0);
1898        // GPTQ
1899        assert_eq!(quant_bpp("GPTQ-Int4"), 0.5);
1900        assert_eq!(quant_bpp("GPTQ-Int8"), 1.0);
1901        assert_eq!(quant_speed_multiplier("GPTQ-Int4"), 1.2);
1902        assert_eq!(quant_speed_multiplier("GPTQ-Int8"), 0.85);
1903        assert_eq!(quant_quality_penalty("GPTQ-Int4"), -3.0);
1904        assert_eq!(quant_quality_penalty("GPTQ-Int8"), 0.0);
1905    }
1906
1907    #[test]
1908    fn test_model_format_prequantized() {
1909        assert!(ModelFormat::Awq.is_prequantized());
1910        assert!(ModelFormat::Gptq.is_prequantized());
1911        assert!(!ModelFormat::Gguf.is_prequantized());
1912        assert!(!ModelFormat::Mlx.is_prequantized());
1913        assert!(!ModelFormat::Safetensors.is_prequantized());
1914    }
1915
1916    // ────────────────────────────────────────────────────────────────────
1917    // GGUF source catalog tests
1918    // ────────────────────────────────────────────────────────────────────
1919
1920    #[test]
1921    fn test_gguf_source_deserialization() {
1922        let json = r#"{"repo": "unsloth/Llama-3.1-8B-Instruct-GGUF", "provider": "unsloth"}"#;
1923        let source: GgufSource = serde_json::from_str(json).unwrap();
1924        assert_eq!(source.repo, "unsloth/Llama-3.1-8B-Instruct-GGUF");
1925        assert_eq!(source.provider, "unsloth");
1926    }
1927
1928    #[test]
1929    fn test_gguf_sources_default_to_empty() {
1930        let json = r#"{
1931            "name": "test/model",
1932            "provider": "Test",
1933            "parameter_count": "7B",
1934            "parameters_raw": 7000000000,
1935            "min_ram_gb": 4.0,
1936            "recommended_ram_gb": 8.0,
1937            "quantization": "Q4_K_M",
1938            "context_length": 4096,
1939            "use_case": "General"
1940        }"#;
1941        let entry: HfModelEntry = serde_json::from_str(json).unwrap();
1942        assert!(entry.gguf_sources.is_empty());
1943    }
1944
1945    #[test]
1946    fn test_catalog_popular_models_have_gguf_sources() {
1947        let db = ModelDatabase::new();
1948        // These popular models should have gguf_sources populated in the catalog
1949        let expected_with_gguf = [
1950            "meta-llama/Llama-3.3-70B-Instruct",
1951            "Qwen/Qwen2.5-7B-Instruct",
1952            "Qwen/Qwen2.5-Coder-7B-Instruct",
1953            "meta-llama/Llama-3.1-8B-Instruct",
1954            "mistralai/Mistral-7B-Instruct-v0.3",
1955        ];
1956        for name in &expected_with_gguf {
1957            let model = db.get_all_models().iter().find(|m| m.name == *name);
1958            assert!(model.is_some(), "Model {} should exist in catalog", name);
1959            let model = model.unwrap();
1960            assert!(
1961                !model.gguf_sources.is_empty(),
1962                "Model {} should have gguf_sources but has none",
1963                name
1964            );
1965        }
1966    }
1967
1968    #[test]
1969    fn test_catalog_gguf_sources_have_valid_repos() {
1970        let db = ModelDatabase::new();
1971        for model in db.get_all_models() {
1972            for source in &model.gguf_sources {
1973                assert!(
1974                    source.repo.contains('/'),
1975                    "GGUF source repo '{}' for model '{}' should be owner/repo format",
1976                    source.repo,
1977                    model.name
1978                );
1979                assert!(
1980                    !source.provider.is_empty(),
1981                    "GGUF source provider for model '{}' should not be empty",
1982                    model.name
1983                );
1984                assert!(
1985                    source.repo.to_uppercase().contains("GGUF"),
1986                    "GGUF source repo '{}' for model '{}' should contain 'GGUF'",
1987                    source.repo,
1988                    model.name
1989                );
1990            }
1991        }
1992    }
1993
1994    #[test]
1995    #[ignore] // Requires network access to populate GGUF sources at build time
1996    fn test_catalog_has_significant_gguf_coverage() {
1997        let db = ModelDatabase::new();
1998        let total = db.get_all_models().len();
1999        let with_gguf = db
2000            .get_all_models()
2001            .iter()
2002            .filter(|m| !m.gguf_sources.is_empty())
2003            .count();
2004        // We should have at least 25% coverage after enrichment
2005        let coverage_pct = (with_gguf as f64 / total as f64) * 100.0;
2006        assert!(
2007            coverage_pct >= 25.0,
2008            "GGUF source coverage is only {:.1}% ({}/{}), expected at least 25%",
2009            coverage_pct,
2010            with_gguf,
2011            total
2012        );
2013    }
2014
2015    // ────────────────────────────────────────────────────────────────────
2016    // Tensor parallelism tests
2017    // ────────────────────────────────────────────────────────────────────
2018
2019    fn tp_test_model(
2020        name: &str,
2021        params_b: f64,
2022        attn_heads: Option<u32>,
2023        kv_heads: Option<u32>,
2024    ) -> LlmModel {
2025        LlmModel {
2026            name: name.to_string(),
2027            provider: "Test".to_string(),
2028            parameter_count: format!("{:.0}B", params_b),
2029            parameters_raw: Some((params_b * 1_000_000_000.0) as u64),
2030            min_ram_gb: 4.0,
2031            recommended_ram_gb: 8.0,
2032            min_vram_gb: Some(4.0),
2033            quantization: "Q4_K_M".to_string(),
2034            context_length: 4096,
2035            use_case: "General".to_string(),
2036            is_moe: false,
2037            num_experts: None,
2038            active_experts: None,
2039            active_parameters: None,
2040            release_date: None,
2041            gguf_sources: vec![],
2042            capabilities: vec![],
2043            format: ModelFormat::default(),
2044            num_attention_heads: attn_heads,
2045            num_key_value_heads: kv_heads,
2046            num_hidden_layers: None,
2047            head_dim: None,
2048            attention_layout: None,
2049            hidden_size: None,
2050            moe_intermediate_size: None,
2051            vocab_size: None,
2052            shared_expert_intermediate_size: None,
2053            license: None,
2054        }
2055    }
2056
2057    #[test]
2058    fn test_supports_tp_with_explicit_heads() {
2059        let model = tp_test_model("Test-8B", 8.0, Some(32), Some(8));
2060        assert!(model.supports_tp(1));
2061        assert!(model.supports_tp(2));
2062        assert!(model.supports_tp(4));
2063        assert!(model.supports_tp(8));
2064        assert!(!model.supports_tp(3)); // 32 % 3 != 0
2065        assert!(!model.supports_tp(5));
2066    }
2067
2068    #[test]
2069    fn test_supports_tp_always_true_for_1() {
2070        let model = tp_test_model("Tiny", 1.0, None, None);
2071        assert!(model.supports_tp(1));
2072    }
2073
2074    #[test]
2075    fn test_valid_tp_sizes_32_8() {
2076        let model = tp_test_model("Test", 8.0, Some(32), Some(8));
2077        let sizes = model.valid_tp_sizes();
2078        assert!(sizes.contains(&1));
2079        assert!(sizes.contains(&2));
2080        assert!(sizes.contains(&4));
2081        assert!(sizes.contains(&8));
2082        assert!(!sizes.contains(&3));
2083    }
2084
2085    #[test]
2086    fn test_valid_tp_sizes_48_heads() {
2087        // 48 attn heads, 8 kv heads — TP must divide both
2088        let model = tp_test_model("Llama-32B", 32.0, Some(48), Some(8));
2089        assert!(model.supports_tp(2)); // 48%2==0, 8%2==0
2090        assert!(!model.supports_tp(3)); // 48%3==0 but 8%3!=0
2091        assert!(model.supports_tp(4)); // 48%4==0, 8%4==0
2092        assert!(model.supports_tp(8)); // 48%8==0, 8%8==0
2093    }
2094
2095    #[test]
2096    fn test_infer_heads_from_name_qwen() {
2097        let (attn, kv) = infer_heads_from_name("Qwen2.5-72B-Instruct", 72.0);
2098        assert_eq!(attn, 64);
2099        assert_eq!(kv, 8);
2100    }
2101
2102    #[test]
2103    fn test_infer_heads_from_name_llama() {
2104        let (attn, kv) = infer_heads_from_name("Llama-3.1-8B", 8.0);
2105        assert_eq!(attn, 32);
2106        assert_eq!(kv, 8);
2107    }
2108
2109    #[test]
2110    fn test_infer_heads_from_name_deepseek() {
2111        let (attn, kv) = infer_heads_from_name("DeepSeek-V3", 671.0);
2112        assert_eq!(attn, 128);
2113        assert_eq!(kv, 16);
2114    }
2115
2116    #[test]
2117    fn test_supports_tp_with_inferred_heads() {
2118        // No explicit heads — should infer from name
2119        let model = tp_test_model("Llama-3.1-70B", 70.0, None, None);
2120        assert!(model.supports_tp(2));
2121        assert!(model.supports_tp(4));
2122        assert!(model.supports_tp(8));
2123    }
2124
2125    // ────────────────────────────────────────────────────────────────────
2126    // KV cache formula + KvQuant + AttentionLayout
2127    // ────────────────────────────────────────────────────────────────────
2128
2129    fn kv_test_model(name: &str) -> LlmModel {
2130        // Roughly modelled on Llama-3.1-8B: 32 layers, 32 heads, 8 KV heads,
2131        // head_dim 128.
2132        LlmModel {
2133            name: name.to_string(),
2134            provider: "Test".to_string(),
2135            parameter_count: "8B".to_string(),
2136            parameters_raw: Some(8_000_000_000),
2137            min_ram_gb: 4.0,
2138            recommended_ram_gb: 8.0,
2139            min_vram_gb: Some(4.0),
2140            quantization: "Q4_K_M".to_string(),
2141            context_length: 8192,
2142            use_case: "General".to_string(),
2143            is_moe: false,
2144            num_experts: None,
2145            active_experts: None,
2146            active_parameters: None,
2147            release_date: None,
2148            gguf_sources: vec![],
2149            capabilities: vec![],
2150            format: ModelFormat::default(),
2151            num_attention_heads: Some(32),
2152            num_key_value_heads: Some(8),
2153            num_hidden_layers: Some(32),
2154            head_dim: Some(128),
2155            attention_layout: None,
2156            hidden_size: None,
2157            moe_intermediate_size: None,
2158            vocab_size: None,
2159            shared_expert_intermediate_size: None,
2160            license: None,
2161        }
2162    }
2163
2164    #[test]
2165    fn test_kv_quant_from_str_round_trip() {
2166        for kv in KvQuant::all() {
2167            let parsed = KvQuant::parse(kv.label()).expect("label should parse");
2168            assert_eq!(parsed, *kv);
2169        }
2170        assert_eq!(KvQuant::parse("FP16"), Some(KvQuant::Fp16));
2171        assert_eq!(KvQuant::parse("Q4_0"), Some(KvQuant::Q4_0));
2172        assert_eq!(KvQuant::parse("turboquant"), Some(KvQuant::TurboQuant));
2173        assert_eq!(KvQuant::parse("nope"), None);
2174    }
2175
2176    #[test]
2177    fn test_kv_cache_precise_formula_matches_hand_calc() {
2178        // 32 layers * 2 (K+V) * 8 KV heads * 128 head_dim * 8192 ctx * 2 (fp16)
2179        // = 1_073_741_824 bytes ≈ 1.0 GB
2180        let model = kv_test_model("Llama-3.1-8B");
2181        let kv = model.kv_cache_gb(8192, KvQuant::Fp16);
2182        assert!((kv - 1.0).abs() < 0.05, "expected ~1.0 GB, got {:.4}", kv);
2183    }
2184
2185    #[test]
2186    fn test_kv_cache_scales_with_quant() {
2187        let model = kv_test_model("test");
2188        let fp16 = model.kv_cache_gb(8192, KvQuant::Fp16);
2189        let q8 = model.kv_cache_gb(8192, KvQuant::Q8_0);
2190        let q4 = model.kv_cache_gb(8192, KvQuant::Q4_0);
2191        // q8 should be ~half fp16, q4 should be ~quarter
2192        assert!((q8 / fp16 - 0.5).abs() < 0.01);
2193        assert!((q4 / fp16 - 0.25).abs() < 0.01);
2194    }
2195
2196    #[test]
2197    fn test_kv_cache_fallback_when_metadata_missing() {
2198        // No layer/head_dim metadata: should fall back to the linear approx
2199        // and still scale with KvQuant.
2200        let mut model = kv_test_model("nameless");
2201        model.num_hidden_layers = None;
2202        model.head_dim = None;
2203        let fp16 = model.kv_cache_gb(8192, KvQuant::Fp16);
2204        let q4 = model.kv_cache_gb(8192, KvQuant::Q4_0);
2205        assert!(fp16 > 0.0);
2206        assert!(q4 < fp16);
2207    }
2208
2209    #[test]
2210    fn test_turboquant_full_attention_uses_compressed_rate() {
2211        // Pure dense (no layout): TQ should compress every layer.
2212        let model = kv_test_model("dense");
2213        let fp16 = model.kv_cache_gb(8192, KvQuant::Fp16);
2214        let tq = model.kv_cache_gb(8192, KvQuant::TurboQuant);
2215        let ratio = tq / fp16;
2216        // ~0.34 / 2.0 = 0.17 of fp16
2217        assert!(
2218            (0.10..=0.25).contains(&ratio),
2219            "TQ ratio on dense should be ~0.17, got {:.3}",
2220            ratio
2221        );
2222    }
2223
2224    #[test]
2225    fn test_turboquant_hybrid_only_compresses_full_attention() {
2226        // 10 full + 30 linear layers (Qwen3.5-A3B style).
2227        let mut model = kv_test_model("hybrid");
2228        model.num_hidden_layers = Some(40);
2229        model.attention_layout = Some(AttentionLayout {
2230            full: 10,
2231            linear: 30,
2232        });
2233        let fp16 = model.kv_cache_gb(8192, KvQuant::Fp16);
2234        let tq = model.kv_cache_gb(8192, KvQuant::TurboQuant);
2235        let savings = 1.0 - tq / fp16;
2236        // Honest savings should be ~0.83 * 0.25 ≈ 21% (only the 10/40 slice
2237        // is compressed by ~83%). Allow a wide tolerance because the constants
2238        // are deliberately conservative.
2239        assert!(
2240            (0.10..=0.30).contains(&savings),
2241            "expected ~20% honest savings on hybrid model, got {:.3}",
2242            savings
2243        );
2244        // And it must be far from the dense headline of ~83%.
2245        assert!(savings < 0.5);
2246    }
2247
2248    #[test]
2249    fn test_attention_layout_compressible_fraction() {
2250        let dense = AttentionLayout {
2251            full: 32,
2252            linear: 0,
2253        };
2254        assert!((dense.compressible_fraction() - 1.0).abs() < 0.0001);
2255
2256        let hybrid = AttentionLayout {
2257            full: 10,
2258            linear: 30,
2259        };
2260        assert!((hybrid.compressible_fraction() - 0.25).abs() < 0.0001);
2261
2262        let pure_ssm = AttentionLayout {
2263            full: 0,
2264            linear: 64,
2265        };
2266        assert!((pure_ssm.compressible_fraction() - 0.0).abs() < 0.0001);
2267    }
2268
2269    #[test]
2270    fn test_infer_attention_layout_qwen3_next() {
2271        let layout = infer_attention_layout_from_name("Qwen/Qwen3-Next-80B-A3B");
2272        assert!(layout.is_some());
2273        let layout = layout.unwrap();
2274        assert!(layout.full > 0 && layout.linear > 0);
2275        assert!(layout.compressible_fraction() < 0.5);
2276    }
2277
2278    #[test]
2279    fn test_infer_attention_layout_dense_returns_none() {
2280        assert!(infer_attention_layout_from_name("meta-llama/Llama-3.1-8B").is_none());
2281        assert!(infer_attention_layout_from_name("Qwen/Qwen2.5-7B").is_none());
2282    }
2283
2284    #[test]
2285    fn test_effective_attention_layout_prefers_explicit() {
2286        let mut model = kv_test_model("Qwen/Qwen3-Next-80B");
2287        // Explicit metadata should override the heuristic
2288        model.attention_layout = Some(AttentionLayout {
2289            full: 5,
2290            linear: 35,
2291        });
2292        let resolved = model.effective_attention_layout().unwrap();
2293        assert_eq!(resolved.full, 5);
2294        assert_eq!(resolved.linear, 35);
2295    }
2296
2297    #[test]
2298    fn test_estimate_memory_with_kv_q8_smaller_than_fp16() {
2299        let model = kv_test_model("Llama-3.1-8B");
2300        let fp16_total = model.estimate_memory_gb_with_kv("Q4_K_M", 32_768, KvQuant::Fp16);
2301        let q8_total = model.estimate_memory_gb_with_kv("Q4_K_M", 32_768, KvQuant::Q8_0);
2302        let q4_total = model.estimate_memory_gb_with_kv("Q4_K_M", 32_768, KvQuant::Q4_0);
2303        assert!(q8_total < fp16_total);
2304        assert!(q4_total < q8_total);
2305    }
2306}