Skip to main content

llmfit_core/
models.rs

1use serde::{Deserialize, Serialize};
2
3/// Quantization levels ordered from best quality to most compressed.
4/// Used for dynamic quantization selection: try the best that fits.
5pub const QUANT_HIERARCHY: &[&str] = &["Q8_0", "Q6_K", "Q5_K_M", "Q4_K_M", "Q3_K_M", "Q2_K"];
6
7/// MLX-native quantization hierarchy (best quality to most compressed).
8pub const MLX_QUANT_HIERARCHY: &[&str] = &["mlx-8bit", "mlx-4bit"];
9
10/// Bytes per parameter for each quantization level.
11pub fn quant_bpp(quant: &str) -> f64 {
12    match quant {
13        "F32" => 4.0,
14        "F16" | "BF16" => 2.0,
15        "Q8_0" => 1.05,
16        "Q6_K" => 0.80,
17        "Q5_K_M" => 0.68,
18        "Q4_K_M" | "Q4_0" => 0.58,
19        "Q3_K_M" => 0.48,
20        "Q2_K" => 0.37,
21        "mlx-4bit" => 0.55,
22        "mlx-8bit" => 1.0,
23        "AWQ-4bit" => 0.5,
24        "AWQ-8bit" => 1.0,
25        "GPTQ-Int4" => 0.5,
26        "GPTQ-Int8" => 1.0,
27        _ => 0.58,
28    }
29}
30
31/// Speed multiplier for quantization (lower quant = faster inference).
32pub fn quant_speed_multiplier(quant: &str) -> f64 {
33    match quant {
34        "F16" | "BF16" => 0.6,
35        "Q8_0" => 0.8,
36        "Q6_K" => 0.95,
37        "Q5_K_M" => 1.0,
38        "Q4_K_M" | "Q4_0" => 1.15,
39        "Q3_K_M" => 1.25,
40        "Q2_K" => 1.35,
41        "mlx-4bit" => 1.15,
42        "mlx-8bit" => 0.85,
43        "AWQ-4bit" | "GPTQ-Int4" => 1.2,
44        "AWQ-8bit" | "GPTQ-Int8" => 0.85,
45        _ => 1.0,
46    }
47}
48
49/// Bytes per parameter for a given quantization format.
50/// Used by the bandwidth-based tok/s estimator to compute model size in GB.
51pub fn quant_bytes_per_param(quant: &str) -> f64 {
52    match quant {
53        "F16" | "BF16" => 2.0,
54        "Q8_0" => 1.0,
55        "Q6_K" => 0.75,
56        "Q5_K_M" => 0.625,
57        "Q4_K_M" | "Q4_0" => 0.5,
58        "Q3_K_M" => 0.375,
59        "Q2_K" => 0.25,
60        "mlx-4bit" => 0.5,
61        "mlx-8bit" => 1.0,
62        "AWQ-4bit" | "GPTQ-Int4" => 0.5,
63        "AWQ-8bit" | "GPTQ-Int8" => 1.0,
64        _ => 0.5, // default to ~4-bit
65    }
66}
67
68/// Quality penalty for quantization (lower quant = lower quality).
69pub fn quant_quality_penalty(quant: &str) -> f64 {
70    match quant {
71        "F16" | "BF16" => 0.0,
72        "Q8_0" => 0.0,
73        "Q6_K" => -1.0,
74        "Q5_K_M" => -2.0,
75        "Q4_K_M" | "Q4_0" => -5.0,
76        "Q3_K_M" => -8.0,
77        "Q2_K" => -12.0,
78        "mlx-4bit" => -4.0,
79        "mlx-8bit" => 0.0,
80        "AWQ-4bit" => -3.0,
81        "AWQ-8bit" => 0.0,
82        "GPTQ-Int4" => -3.0,
83        "GPTQ-Int8" => 0.0,
84        _ => -5.0,
85    }
86}
87
88/// Model capability flags (orthogonal to UseCase).
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
90#[serde(rename_all = "snake_case")]
91pub enum Capability {
92    Vision,
93    ToolUse,
94}
95
96impl Capability {
97    pub fn label(&self) -> &'static str {
98        match self {
99            Capability::Vision => "Vision",
100            Capability::ToolUse => "Tool Use",
101        }
102    }
103
104    pub fn all() -> &'static [Capability] {
105        &[Capability::Vision, Capability::ToolUse]
106    }
107
108    /// Infer capabilities from model metadata when not explicitly set in JSON.
109    pub fn infer(model: &LlmModel) -> Vec<Capability> {
110        let mut caps = model.capabilities.clone();
111        let name = model.name.to_lowercase();
112        let use_case = model.use_case.to_lowercase();
113
114        // Vision detection
115        if !caps.contains(&Capability::Vision)
116            && (name.contains("vision")
117                || name.contains("-vl-")
118                || name.ends_with("-vl")
119                || name.contains("llava")
120                || name.contains("onevision")
121                || name.contains("pixtral")
122                || use_case.contains("vision")
123                || use_case.contains("multimodal"))
124        {
125            caps.push(Capability::Vision);
126        }
127
128        // Tool use detection (known model families)
129        if !caps.contains(&Capability::ToolUse)
130            && (use_case.contains("tool")
131                || use_case.contains("function call")
132                || name.contains("qwen3")
133                || name.contains("qwen2.5")
134                || name.contains("command-r")
135                || (name.contains("llama-3") && name.contains("instruct"))
136                || (name.contains("mistral") && name.contains("instruct"))
137                || name.contains("hermes")
138                || (name.contains("gemma-3") && name.ends_with("-it"))
139                || (name.contains("gemma-4") && name.ends_with("-it")))
140        {
141            caps.push(Capability::ToolUse);
142        }
143
144        caps
145    }
146}
147
148/// Model weight format — determines which inference runtime to use.
149#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
150#[serde(rename_all = "lowercase")]
151#[derive(Default)]
152pub enum ModelFormat {
153    #[default]
154    Gguf,
155    Awq,
156    Gptq,
157    Mlx,
158    Safetensors,
159}
160
161impl ModelFormat {
162    /// Returns true for formats that are pre-quantized at a fixed bit width
163    /// and cannot be dynamically re-quantized (AWQ, GPTQ).
164    pub fn is_prequantized(&self) -> bool {
165        matches!(self, ModelFormat::Awq | ModelFormat::Gptq)
166    }
167}
168
169/// Use-case category for scoring weights.
170#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
171pub enum UseCase {
172    General,
173    Coding,
174    Reasoning,
175    Chat,
176    Multimodal,
177    Embedding,
178}
179
180impl UseCase {
181    pub fn label(&self) -> &'static str {
182        match self {
183            UseCase::General => "General",
184            UseCase::Coding => "Coding",
185            UseCase::Reasoning => "Reasoning",
186            UseCase::Chat => "Chat",
187            UseCase::Multimodal => "Multimodal",
188            UseCase::Embedding => "Embedding",
189        }
190    }
191
192    /// Infer use-case from the model's use_case field and name.
193    pub fn from_model(model: &LlmModel) -> Self {
194        let name = model.name.to_lowercase();
195        let use_case = model.use_case.to_lowercase();
196
197        if use_case.contains("embedding") || name.contains("embed") || name.contains("bge") {
198            UseCase::Embedding
199        } else if name.contains("code") || use_case.contains("code") {
200            UseCase::Coding
201        } else if use_case.contains("vision") || use_case.contains("multimodal") {
202            UseCase::Multimodal
203        } else if use_case.contains("reason")
204            || use_case.contains("chain-of-thought")
205            || name.contains("deepseek-r1")
206        {
207            UseCase::Reasoning
208        } else if use_case.contains("chat") || use_case.contains("instruction") {
209            UseCase::Chat
210        } else {
211            UseCase::General
212        }
213    }
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct LlmModel {
218    pub name: String,
219    pub provider: String,
220    pub parameter_count: String,
221    #[serde(default)]
222    pub parameters_raw: Option<u64>,
223    pub min_ram_gb: f64,
224    pub recommended_ram_gb: f64,
225    pub min_vram_gb: Option<f64>,
226    pub quantization: String,
227    pub context_length: u32,
228    pub use_case: String,
229    #[serde(default)]
230    pub is_moe: bool,
231    #[serde(default)]
232    pub num_experts: Option<u32>,
233    #[serde(default)]
234    pub active_experts: Option<u32>,
235    #[serde(default)]
236    pub active_parameters: Option<u64>,
237    #[serde(default)]
238    pub release_date: Option<String>,
239    /// Known GGUF download sources (e.g. unsloth, bartowski repos on HuggingFace)
240    #[serde(default)]
241    pub gguf_sources: Vec<GgufSource>,
242    /// Model capabilities (vision, tool use, etc.)
243    #[serde(default)]
244    pub capabilities: Vec<Capability>,
245    /// Model weight format (gguf, awq, gptq, mlx, safetensors)
246    #[serde(default)]
247    pub format: ModelFormat,
248    /// Number of attention heads (for tensor-parallelism compatibility checks).
249    #[serde(default)]
250    pub num_attention_heads: Option<u32>,
251    /// Number of key-value heads for GQA (defaults to num_attention_heads if None).
252    #[serde(default)]
253    pub num_key_value_heads: Option<u32>,
254    /// Total number of transformer layers. Used by the precise KV cache formula.
255    #[serde(default)]
256    pub num_hidden_layers: Option<u32>,
257    /// Per-head dimension. Used by the precise KV cache formula. When absent,
258    /// derived as `hidden_size / num_attention_heads` if both are known, or
259    /// a name based heuristic otherwise.
260    #[serde(default)]
261    pub head_dim: Option<u32>,
262    /// Attention layer composition for hybrid models (full attention + linear /
263    /// Mamba style layers). When None, all layers are assumed to be full
264    /// attention. Used by KV cache compression schemes (e.g. TurboQuant) that
265    /// only apply to full attention layers.
266    #[serde(default)]
267    pub attention_layout: Option<AttentionLayout>,
268    /// Model license (e.g. "apache-2.0", "mit", "llama3.1")
269    #[serde(default)]
270    pub license: Option<String>,
271}
272
273/// Composition of attention layers in a hybrid model.
274///
275/// Some recent architectures (Qwen3-Next, Jamba, Mamba style hybrids) mix
276/// full attention layers with cheaper linear / state space layers. KV cache
277/// compression schemes like TurboQuant only apply to the full attention
278/// fraction, so we track the split here to compute honest savings.
279#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
280pub struct AttentionLayout {
281    /// Number of full self attention layers (compressible).
282    pub full: u32,
283    /// Number of linear / state space layers (not compressible by KV quant).
284    pub linear: u32,
285}
286
287impl AttentionLayout {
288    pub fn total(&self) -> u32 {
289        self.full + self.linear
290    }
291
292    /// Fraction of layers that are full attention (and therefore compressible
293    /// by KV quant schemes). Returns 1.0 for an all-full model.
294    pub fn compressible_fraction(&self) -> f64 {
295        let total = self.total();
296        if total == 0 {
297            1.0
298        } else {
299            self.full as f64 / total as f64
300        }
301    }
302}
303
304/// KV cache element representation. Controls bytes per element for the
305/// precise KV cache formula and (for TurboQuant) gates on runtime support.
306#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
307pub enum KvQuant {
308    /// fp16 / bf16, the inference default for most runtimes.
309    #[default]
310    #[serde(rename = "fp16")]
311    Fp16,
312    /// fp8 KV cache (vLLM, llama.cpp via --cache-type-k fp8 on supported builds).
313    #[serde(rename = "fp8")]
314    Fp8,
315    /// 8 bit integer KV cache (llama.cpp `q8_0`, vLLM int8).
316    #[serde(rename = "q8_0")]
317    Q8_0,
318    /// 4 bit integer KV cache (llama.cpp `q4_0`, vLLM int4).
319    #[serde(rename = "q4_0")]
320    Q4_0,
321    /// TurboQuant (3 bit keys + 2 bit values + Pi/S overhead). Research
322    /// integration, vLLM + CUDA only, not in upstream vLLM yet. Compression
323    /// only applies to full attention layers, so hybrid models see less.
324    /// See https://github.com/0xSero/turboquant
325    #[serde(rename = "tq")]
326    TurboQuant,
327}
328
329impl KvQuant {
330    pub fn label(&self) -> &'static str {
331        match self {
332            KvQuant::Fp16 => "fp16",
333            KvQuant::Fp8 => "fp8",
334            KvQuant::Q8_0 => "q8_0",
335            KvQuant::Q4_0 => "q4_0",
336            KvQuant::TurboQuant => "tq",
337        }
338    }
339
340    /// Bytes per KV element for non-TurboQuant variants. TurboQuant is handled
341    /// per layer because it only affects the full attention slice.
342    pub fn bytes_per_element(&self) -> f64 {
343        match self {
344            KvQuant::Fp16 => 2.0,
345            KvQuant::Fp8 => 1.0,
346            KvQuant::Q8_0 => 1.0,
347            KvQuant::Q4_0 => 0.5,
348            // For the bookkeeping path that doesn't know about layout, assume
349            // ~2.7 bits per element on the compressible slice. The real
350            // computation in `precise_kv_cache_gb` handles the layout split.
351            KvQuant::TurboQuant => 0.34,
352        }
353    }
354
355    pub fn parse(s: &str) -> Option<Self> {
356        match s.trim().to_lowercase().as_str() {
357            "fp16" | "f16" | "bf16" | "default" => Some(KvQuant::Fp16),
358            "fp8" | "f8" => Some(KvQuant::Fp8),
359            "q8" | "q8_0" | "int8" => Some(KvQuant::Q8_0),
360            "q4" | "q4_0" | "int4" => Some(KvQuant::Q4_0),
361            "tq" | "turboquant" => Some(KvQuant::TurboQuant),
362            _ => None,
363        }
364    }
365
366    /// All KV quant options llmfit knows how to estimate. Order is best
367    /// quality (fp16) to most compressed.
368    pub fn all() -> &'static [KvQuant] {
369        &[
370            KvQuant::Fp16,
371            KvQuant::Fp8,
372            KvQuant::Q8_0,
373            KvQuant::Q4_0,
374            KvQuant::TurboQuant,
375        ]
376    }
377}
378
379impl std::fmt::Display for KvQuant {
380    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381        f.write_str(self.label())
382    }
383}
384
385/// Returns true if a model's license matches any in the comma-separated filter string.
386/// Models without a license never match.
387pub fn matches_license_filter(license: &Option<String>, filter: &str) -> bool {
388    let allowed: Vec<String> = filter.split(',').map(|s| s.trim().to_lowercase()).collect();
389    license
390        .as_ref()
391        .map(|l| allowed.contains(&l.to_lowercase()))
392        .unwrap_or(false)
393}
394
395/// A known GGUF download source for a model on HuggingFace.
396#[derive(Debug, Clone, Serialize, Deserialize)]
397pub struct GgufSource {
398    /// HuggingFace repo ID (e.g. "unsloth/Llama-3.1-8B-Instruct-GGUF")
399    pub repo: String,
400    /// Provider who published the GGUF (e.g. "unsloth", "bartowski")
401    pub provider: String,
402}
403
404impl LlmModel {
405    /// MLX models are Apple-only — they won't run on NVIDIA/AMD/Intel hardware.
406    /// We detect them by the `-MLX-` suffix that's standard on HuggingFace
407    /// (e.g. `Qwen3-8B-MLX-4bit`, `LFM2-1.2B-MLX-8bit`).
408    pub fn is_mlx_model(&self) -> bool {
409        let name_lower = self.name.to_lowercase();
410        name_lower.contains("-mlx-") || name_lower.ends_with("-mlx")
411    }
412
413    /// Returns true if this model uses a pre-quantized format (AWQ/GPTQ)
414    /// that cannot be dynamically re-quantized.
415    pub fn is_prequantized(&self) -> bool {
416        self.format.is_prequantized()
417    }
418
419    /// Returns true if the model's attention/KV heads are evenly divisible
420    /// by `tp_size`, meaning it can be split across that many devices.
421    /// TP=1 always returns true.
422    pub fn supports_tp(&self, tp_size: u32) -> bool {
423        if tp_size <= 1 {
424            return true;
425        }
426        let (attn, kv) = self.infer_head_counts();
427        attn % tp_size == 0 && kv % tp_size == 0
428    }
429
430    /// Returns all valid TP degrees in [1..=8] for this model.
431    pub fn valid_tp_sizes(&self) -> Vec<u32> {
432        (1..=8).filter(|&tp| self.supports_tp(tp)).collect()
433    }
434
435    /// Infer attention and KV head counts from metadata or model name heuristics.
436    fn infer_head_counts(&self) -> (u32, u32) {
437        if let (Some(attn), Some(kv)) = (self.num_attention_heads, self.num_key_value_heads) {
438            return (attn, kv);
439        }
440        if let Some(attn) = self.num_attention_heads {
441            return (attn, attn);
442        }
443        // Heuristic: infer from model name
444        infer_heads_from_name(&self.name, self.params_b())
445    }
446
447    /// Bytes-per-parameter for the model's quantization level.
448    fn quant_bpp(&self) -> f64 {
449        quant_bpp(&self.quantization)
450    }
451
452    /// Parameter count in billions, extracted from parameters_raw or parameter_count.
453    pub fn params_b(&self) -> f64 {
454        if let Some(raw) = self.parameters_raw {
455            raw as f64 / 1_000_000_000.0
456        } else {
457            // Parse from string like "7B", "1.1B", "137M"
458            let s = self.parameter_count.trim().to_uppercase();
459            if let Some(num_str) = s.strip_suffix('B') {
460                num_str.parse::<f64>().unwrap_or(7.0)
461            } else if let Some(num_str) = s.strip_suffix('M') {
462                num_str.parse::<f64>().unwrap_or(0.0) / 1000.0
463            } else {
464                7.0
465            }
466        }
467    }
468
469    /// Approximate on-disk size (GB) for a given quantization level.
470    /// This is just the model weights: params_b * bytes_per_param.
471    pub fn estimate_disk_gb(&self, quant: &str) -> f64 {
472        self.params_b() * quant_bpp(quant)
473    }
474
475    /// Estimate memory required (GB) at a given quantization and context length.
476    /// Defaults to fp16 KV cache. Use `estimate_memory_gb_with_kv` to override.
477    pub fn estimate_memory_gb(&self, quant: &str, ctx: u32) -> f64 {
478        self.estimate_memory_gb_with_kv(quant, ctx, KvQuant::Fp16)
479    }
480
481    /// Estimate memory required (GB) with an explicit KV cache quantization.
482    /// Formula: model_weights + KV_cache + runtime_overhead
483    pub fn estimate_memory_gb_with_kv(&self, quant: &str, ctx: u32, kv: KvQuant) -> f64 {
484        let bpp = quant_bpp(quant);
485        let params = self.params_b();
486        let model_mem = params * bpp;
487        let kv_cache = self.kv_cache_gb(ctx, kv);
488        // Runtime overhead (CUDA/Metal context, buffers)
489        let overhead = 0.5;
490        model_mem + kv_cache + overhead
491    }
492
493    /// KV cache size in GB at the given context length and KV quant.
494    ///
495    /// Uses the precise per layer formula when `num_hidden_layers`,
496    /// `num_key_value_heads`, and `head_dim` are known:
497    ///
498    /// `kv_bytes = 2 * n_layers * n_kv_heads * head_dim * ctx * dtype_bytes`
499    ///
500    /// Falls back to a coarse `params * ctx` approximation when the metadata
501    /// is missing so older catalog entries don't regress.
502    ///
503    /// For TurboQuant, only the full attention slice (per `attention_layout`)
504    /// is compressed. Linear / state space layers stay at fp16.
505    pub fn kv_cache_gb(&self, ctx: u32, kv: KvQuant) -> f64 {
506        let params = self.params_b();
507        let layout = self.effective_attention_layout();
508
509        // Precise path: requires layer count, KV head count, head dim.
510        if let (Some(n_layers), Some(head_dim)) = (self.num_hidden_layers, self.head_dim) {
511            let n_kv_heads = self
512                .num_key_value_heads
513                .or(self.num_attention_heads)
514                .unwrap_or(8);
515
516            let bytes_per_layer =
517                |bpe: f64| -> f64 { 2.0 * n_kv_heads as f64 * head_dim as f64 * ctx as f64 * bpe };
518
519            let total_bytes = match kv {
520                KvQuant::TurboQuant => {
521                    // Compressed slice (full attention) at TQ rate, rest stay fp16.
522                    let full_layers = match layout {
523                        Some(l) => l.full.min(n_layers),
524                        None => n_layers,
525                    };
526                    let linear_layers = n_layers.saturating_sub(full_layers);
527                    bytes_per_layer(KvQuant::TurboQuant.bytes_per_element()) * full_layers as f64
528                        + bytes_per_layer(KvQuant::Fp16.bytes_per_element()) * linear_layers as f64
529                }
530                _ => bytes_per_layer(kv.bytes_per_element()) * n_layers as f64,
531            };
532
533            return total_bytes / 1_073_741_824.0;
534        }
535
536        // Fallback: coarse linear approximation, scaled by KV quant ratio.
537        // Historical formula was 0.000008 * params_b * ctx (assumes fp16).
538        let baseline_fp16 = 0.000008 * params * ctx as f64;
539        let scale = match kv {
540            KvQuant::Fp16 => 1.0,
541            KvQuant::Fp8 | KvQuant::Q8_0 => 0.5,
542            KvQuant::Q4_0 => 0.25,
543            KvQuant::TurboQuant => {
544                // Without layer counts we can't separate full vs linear, so
545                // weight the savings by the layout if available, otherwise
546                // assume an all-full dense transformer.
547                let frac = layout.map(|l| l.compressible_fraction()).unwrap_or(1.0);
548                let tq_ratio = KvQuant::TurboQuant.bytes_per_element() / 2.0;
549                frac * tq_ratio + (1.0 - frac)
550            }
551        };
552        baseline_fp16 * scale
553    }
554
555    /// Select the best quantization level that fits within a memory budget.
556    /// Returns the quant name and estimated memory in GB, or None if nothing fits.
557    pub fn best_quant_for_budget(&self, budget_gb: f64, ctx: u32) -> Option<(&'static str, f64)> {
558        self.best_quant_for_budget_with(budget_gb, ctx, QUANT_HIERARCHY)
559    }
560
561    /// Select the best quantization from a custom hierarchy that fits within a memory budget.
562    pub fn best_quant_for_budget_with(
563        &self,
564        budget_gb: f64,
565        ctx: u32,
566        hierarchy: &[&'static str],
567    ) -> Option<(&'static str, f64)> {
568        // Try best quality first
569        for &q in hierarchy {
570            let mem = self.estimate_memory_gb(q, ctx);
571            if mem <= budget_gb {
572                return Some((q, mem));
573            }
574        }
575        // Try halving context once
576        let half_ctx = ctx / 2;
577        if half_ctx >= 1024 {
578            for &q in hierarchy {
579                let mem = self.estimate_memory_gb(q, half_ctx);
580                if mem <= budget_gb {
581                    return Some((q, mem));
582                }
583            }
584        }
585        None
586    }
587
588    /// Resolved attention layout: explicit metadata if present, otherwise a
589    /// best effort heuristic based on the model name. Returns `None` for
590    /// plain dense transformers (which the KV estimator should treat as
591    /// "all layers compressible").
592    pub fn effective_attention_layout(&self) -> Option<AttentionLayout> {
593        self.attention_layout
594            .or_else(|| infer_attention_layout_from_name(&self.name))
595    }
596
597    /// For MoE models, compute estimated VRAM for active experts only.
598    /// Returns None for dense models.
599    pub fn moe_active_vram_gb(&self) -> Option<f64> {
600        if !self.is_moe {
601            return None;
602        }
603        let active_params = self.active_parameters? as f64;
604        let bpp = self.quant_bpp();
605        let size_gb = (active_params * bpp) / (1024.0 * 1024.0 * 1024.0);
606        Some((size_gb * 1.1).max(0.5))
607    }
608
609    /// Returns true if this model is MLX-specific (Apple Silicon only).
610    /// MLX models are identified by having "-MLX" in their name.
611    pub fn is_mlx_only(&self) -> bool {
612        self.name.to_uppercase().contains("-MLX")
613    }
614
615    /// For MoE models, compute RAM needed for offloaded (inactive) experts.
616    /// Returns None for dense models.
617    pub fn moe_offloaded_ram_gb(&self) -> Option<f64> {
618        if !self.is_moe {
619            return None;
620        }
621        let active = self.active_parameters? as f64;
622        let total = self.parameters_raw? as f64;
623        let inactive = total - active;
624        if inactive <= 0.0 {
625            return Some(0.0);
626        }
627        let bpp = self.quant_bpp();
628        Some((inactive * bpp) / (1024.0 * 1024.0 * 1024.0))
629    }
630}
631
632/// Intermediate struct matching the JSON schema from the scraper.
633/// Extra fields are ignored when mapping to LlmModel.
634#[derive(Debug, Clone, Deserialize)]
635struct HfModelEntry {
636    name: String,
637    provider: String,
638    parameter_count: String,
639    #[serde(default)]
640    parameters_raw: Option<u64>,
641    min_ram_gb: f64,
642    recommended_ram_gb: f64,
643    min_vram_gb: Option<f64>,
644    quantization: String,
645    context_length: u32,
646    use_case: String,
647    #[serde(default)]
648    is_moe: bool,
649    #[serde(default)]
650    num_experts: Option<u32>,
651    #[serde(default)]
652    active_experts: Option<u32>,
653    #[serde(default)]
654    active_parameters: Option<u64>,
655    #[serde(default)]
656    release_date: Option<String>,
657    #[serde(default)]
658    gguf_sources: Vec<GgufSource>,
659    #[serde(default)]
660    capabilities: Vec<Capability>,
661    #[serde(default)]
662    format: ModelFormat,
663    #[serde(default)]
664    hf_downloads: u64,
665    #[serde(default)]
666    hf_likes: u64,
667    #[serde(default)]
668    num_attention_heads: Option<u32>,
669    #[serde(default)]
670    num_key_value_heads: Option<u32>,
671    #[serde(default)]
672    num_hidden_layers: Option<u32>,
673    #[serde(default)]
674    head_dim: Option<u32>,
675    #[serde(default)]
676    license: Option<String>,
677}
678
679const HF_MODELS_JSON: &str = include_str!("../data/hf_models.json");
680
681pub struct ModelDatabase {
682    models: Vec<LlmModel>,
683}
684
685impl Default for ModelDatabase {
686    fn default() -> Self {
687        Self::new()
688    }
689}
690
691/// Normalize a model name/ID to a canonical slug for deduplication.
692///
693/// Strips the `org/` prefix, lowercases, and collapses `-`/`_`/`.` so that
694/// `meta-llama/Llama-3.1-8B` and `meta-llama/llama-3.1-8b` compare equal.
695pub(crate) fn canonical_slug(name: &str) -> String {
696    let slug = name.split('/').next_back().unwrap_or(name);
697    slug.to_lowercase().replace(['-', '_', '.'], "")
698}
699
700/// Parse the compile-time embedded JSON into a flat `Vec<LlmModel>`.
701fn load_embedded() -> Vec<LlmModel> {
702    let entries: Vec<HfModelEntry> =
703        serde_json::from_str(HF_MODELS_JSON).expect("Failed to parse embedded hf_models.json");
704    entries
705        .into_iter()
706        .map(|e| {
707            let mut model = LlmModel {
708                name: e.name,
709                provider: e.provider,
710                parameter_count: e.parameter_count,
711                parameters_raw: e.parameters_raw,
712                min_ram_gb: e.min_ram_gb,
713                recommended_ram_gb: e.recommended_ram_gb,
714                min_vram_gb: e.min_vram_gb,
715                quantization: e.quantization,
716                context_length: e.context_length,
717                use_case: e.use_case,
718                is_moe: e.is_moe,
719                num_experts: e.num_experts,
720                active_experts: e.active_experts,
721                active_parameters: e.active_parameters,
722                release_date: e.release_date,
723                gguf_sources: e.gguf_sources,
724                capabilities: e.capabilities,
725                format: e.format,
726                num_attention_heads: e.num_attention_heads,
727                num_key_value_heads: e.num_key_value_heads,
728                num_hidden_layers: e.num_hidden_layers,
729                head_dim: e.head_dim,
730                attention_layout: None,
731                license: e.license,
732            };
733            model.capabilities = Capability::infer(&model);
734            // Auto-populate attention_layout from name heuristic for known
735            // hybrid families. Explicit metadata still wins (model.attention_layout
736            // stays None until the scraper is taught to read it from config.json).
737            if model.attention_layout.is_none() {
738                model.attention_layout = infer_attention_layout_from_name(&model.name);
739            }
740            model
741        })
742        .collect()
743}
744
745impl ModelDatabase {
746    /// Load only the compile-time embedded model list (no cache).
747    /// Used internally by the updater to determine which models are already known.
748    pub fn embedded() -> Self {
749        ModelDatabase {
750            models: load_embedded(),
751        }
752    }
753
754    /// Load the embedded model list **and** merge any locally cached models.
755    ///
756    /// Cached models are appended after the embedded ones; if an ID already
757    /// exists in the embedded list it is skipped to avoid duplication.
758    /// Silently ignores a missing or corrupt cache file.
759    pub fn new() -> Self {
760        let mut models = load_embedded();
761
762        // Merge cached models (from `llmfit update`) without duplicating.
763        // canonical_slug normalizes org/ prefix, case, and separators so that
764        // e.g. `meta-llama/Llama-3.1-8B` and `meta-llama/llama-3.1-8b` are
765        // treated as the same model.
766        let embedded_keys: std::collections::HashSet<String> =
767            models.iter().map(|m| canonical_slug(&m.name)).collect();
768
769        for cached in crate::update::load_cache() {
770            if !embedded_keys.contains(&canonical_slug(&cached.name)) {
771                models.push(cached);
772            }
773        }
774
775        ModelDatabase { models }
776    }
777
778    pub fn get_all_models(&self) -> &Vec<LlmModel> {
779        &self.models
780    }
781
782    pub fn find_model(&self, query: &str) -> Vec<&LlmModel> {
783        let query_lower = query.to_lowercase();
784        self.models
785            .iter()
786            .filter(|m| {
787                m.name.to_lowercase().contains(&query_lower)
788                    || m.provider.to_lowercase().contains(&query_lower)
789                    || m.parameter_count.to_lowercase().contains(&query_lower)
790            })
791            .collect()
792    }
793
794    pub fn models_fitting_system(
795        &self,
796        available_ram_gb: f64,
797        has_gpu: bool,
798        vram_gb: Option<f64>,
799    ) -> Vec<&LlmModel> {
800        self.models
801            .iter()
802            .filter(|m| {
803                // Check RAM requirement
804                let ram_ok = m.min_ram_gb <= available_ram_gb;
805
806                // If model requires GPU and system has GPU, check VRAM
807                if let Some(min_vram) = m.min_vram_gb {
808                    if has_gpu {
809                        if let Some(system_vram) = vram_gb {
810                            ram_ok && min_vram <= system_vram
811                        } else {
812                            // GPU detected but VRAM unknown, allow but warn
813                            ram_ok
814                        }
815                    } else {
816                        // Model prefers GPU but can run on CPU with enough RAM
817                        ram_ok && available_ram_gb >= m.recommended_ram_gb
818                    }
819                } else {
820                    ram_ok
821                }
822            })
823            .collect()
824    }
825}
826
827/// Infer an attention layout from the model name for known hybrid families.
828/// Returns `None` for plain dense / all-full transformers (which is the safe
829/// default for the KV cache estimator: assume all layers are compressible).
830///
831/// The numbers here come from the published configs of each family as of
832/// 2026 Q1. They're a best effort starting point and should be replaced
833/// with values scraped from `config.json` whenever the metadata is available.
834pub fn infer_attention_layout_from_name(name: &str) -> Option<AttentionLayout> {
835    let lower = name.to_lowercase();
836
837    // Qwen3-Next series: roughly 1 full attention layer per 4 layers,
838    // remainder are linear / gated DeltaNet style. The A3B (35B total)
839    // variant ships with 10 full out of 40 according to the TurboQuant
840    // benchmark in 0xSero/turboquant.
841    if lower.contains("qwen3-next") || lower.contains("qwen3.5-next") {
842        return Some(AttentionLayout {
843            full: 10,
844            linear: 30,
845        });
846    }
847
848    // Jamba (Mamba + Transformer hybrid). Jamba 1.5 Mini and Large both
849    // use a 1:7 attention to mamba ratio in their 32 layer blocks.
850    if lower.contains("jamba") {
851        return Some(AttentionLayout {
852            full: 4,
853            linear: 28,
854        });
855    }
856
857    // Zamba2 (Mamba2 + shared attention). Zamba2-7B has 2 shared attention
858    // blocks and 54 mamba layers per the model card.
859    if lower.contains("zamba") {
860        return Some(AttentionLayout {
861            full: 2,
862            linear: 54,
863        });
864    }
865
866    // RWKV / Mamba pure SSM models: no full attention at all. We still
867    // report them so the KV estimator can short circuit. Compressible
868    // fraction is 0, so KV quant savings will correctly show as zero.
869    if lower.contains("mamba") || lower.contains("rwkv") {
870        return Some(AttentionLayout { full: 0, linear: 1 });
871    }
872
873    None
874}
875
876/// Infer attention and KV head counts from the model name and parameter count.
877/// Used as a fallback when explicit head counts are not available in the model metadata.
878fn infer_heads_from_name(name: &str, params_b: f64) -> (u32, u32) {
879    let name_lower = name.to_lowercase();
880
881    // Qwen family
882    if name_lower.contains("qwen") {
883        if params_b > 100.0 {
884            return (128, 16);
885        } else if params_b > 50.0 {
886            return (64, 8);
887        } else if params_b > 25.0 {
888            return (40, 8);
889        } else if params_b > 10.0 {
890            return (40, 8);
891        } else if params_b > 5.0 {
892            return (32, 8);
893        } else {
894            return (16, 4);
895        }
896    }
897
898    // Llama family
899    if name_lower.contains("llama") {
900        if name_lower.contains("scout") || name_lower.contains("maverick") {
901            return (64, 8);
902        } else if params_b > 60.0 {
903            return (64, 8);
904        } else if params_b > 20.0 {
905            return (48, 8);
906        } else if params_b > 5.0 {
907            return (32, 8);
908        } else {
909            return (16, 8);
910        }
911    }
912
913    // DeepSeek family
914    if name_lower.contains("deepseek") {
915        if params_b > 200.0 {
916            return (128, 16);
917        } else if params_b > 50.0 {
918            return (64, 8);
919        } else if params_b > 25.0 {
920            return (40, 8);
921        } else if params_b > 10.0 {
922            return (40, 8);
923        } else {
924            return (32, 8);
925        }
926    }
927
928    // Mistral/Mixtral
929    if name_lower.contains("mistral") || name_lower.contains("mixtral") {
930        if params_b > 100.0 {
931            return (96, 8);
932        } else if params_b > 20.0 {
933            return (32, 8);
934        } else {
935            return (32, 8);
936        }
937    }
938
939    // Gemma
940    if name_lower.contains("gemma") {
941        if params_b > 20.0 {
942            return (32, 16);
943        } else if params_b > 5.0 {
944            return (16, 8);
945        } else {
946            return (8, 4);
947        }
948    }
949
950    // Phi
951    if name_lower.contains("phi") {
952        if params_b > 10.0 {
953            return (40, 10);
954        } else {
955            return (32, 8);
956        }
957    }
958
959    // MiniMax
960    if name_lower.contains("minimax") {
961        return (48, 8);
962    }
963
964    // Default: common pattern based on param count
965    if params_b > 100.0 {
966        (128, 16)
967    } else if params_b > 50.0 {
968        (64, 8)
969    } else if params_b > 20.0 {
970        (32, 8)
971    } else if params_b > 5.0 {
972        (32, 8)
973    } else {
974        (16, 4)
975    }
976}
977
978#[cfg(test)]
979mod tests {
980    use super::*;
981
982    // ────────────────────────────────────────────────────────────────────
983    // Quantization function tests
984    // ────────────────────────────────────────────────────────────────────
985
986    #[test]
987    fn test_mlx_quant_bpp_values() {
988        assert_eq!(quant_bpp("mlx-4bit"), 0.55);
989        assert_eq!(quant_bpp("mlx-8bit"), 1.0);
990        assert_eq!(quant_speed_multiplier("mlx-4bit"), 1.15);
991        assert_eq!(quant_speed_multiplier("mlx-8bit"), 0.85);
992        assert_eq!(quant_quality_penalty("mlx-4bit"), -4.0);
993        assert_eq!(quant_quality_penalty("mlx-8bit"), 0.0);
994    }
995
996    #[test]
997    fn test_best_quant_with_mlx_hierarchy() {
998        let model = LlmModel {
999            name: "Test Model".to_string(),
1000            provider: "Test".to_string(),
1001            parameter_count: "7B".to_string(),
1002            parameters_raw: Some(7_000_000_000),
1003            min_ram_gb: 4.0,
1004            recommended_ram_gb: 8.0,
1005            min_vram_gb: Some(4.0),
1006            quantization: "Q4_K_M".to_string(),
1007            context_length: 4096,
1008            use_case: "General".to_string(),
1009            is_moe: false,
1010            num_experts: None,
1011            active_experts: None,
1012            active_parameters: None,
1013            release_date: None,
1014            gguf_sources: vec![],
1015            capabilities: vec![],
1016            format: ModelFormat::default(),
1017            num_attention_heads: None,
1018            num_key_value_heads: None,
1019            num_hidden_layers: None,
1020            head_dim: None,
1021            attention_layout: None,
1022            license: None,
1023        };
1024
1025        // Large budget should return mlx-8bit (best in MLX hierarchy)
1026        let result = model.best_quant_for_budget_with(10.0, 4096, MLX_QUANT_HIERARCHY);
1027        assert!(result.is_some());
1028        let (quant, _) = result.unwrap();
1029        assert_eq!(quant, "mlx-8bit");
1030
1031        // Tighter budget should fall to mlx-4bit
1032        let result = model.best_quant_for_budget_with(5.0, 4096, MLX_QUANT_HIERARCHY);
1033        assert!(result.is_some());
1034        let (quant, _) = result.unwrap();
1035        assert_eq!(quant, "mlx-4bit");
1036    }
1037
1038    #[test]
1039    fn test_quant_bpp() {
1040        assert_eq!(quant_bpp("F32"), 4.0);
1041        assert_eq!(quant_bpp("F16"), 2.0);
1042        assert_eq!(quant_bpp("Q8_0"), 1.05);
1043        assert_eq!(quant_bpp("Q4_K_M"), 0.58);
1044        assert_eq!(quant_bpp("Q2_K"), 0.37);
1045        // Unknown quant defaults to Q4_K_M
1046        assert_eq!(quant_bpp("UNKNOWN"), 0.58);
1047    }
1048
1049    #[test]
1050    fn test_quant_speed_multiplier() {
1051        assert_eq!(quant_speed_multiplier("F16"), 0.6);
1052        assert_eq!(quant_speed_multiplier("Q5_K_M"), 1.0);
1053        assert_eq!(quant_speed_multiplier("Q4_K_M"), 1.15);
1054        assert_eq!(quant_speed_multiplier("Q2_K"), 1.35);
1055        // Lower quant = faster inference
1056        assert!(quant_speed_multiplier("Q2_K") > quant_speed_multiplier("Q8_0"));
1057    }
1058
1059    #[test]
1060    fn test_quant_quality_penalty() {
1061        assert_eq!(quant_quality_penalty("F16"), 0.0);
1062        assert_eq!(quant_quality_penalty("Q8_0"), 0.0);
1063        assert_eq!(quant_quality_penalty("Q4_K_M"), -5.0);
1064        assert_eq!(quant_quality_penalty("Q2_K"), -12.0);
1065        // Lower quant = higher quality penalty
1066        assert!(quant_quality_penalty("Q2_K") < quant_quality_penalty("Q8_0"));
1067    }
1068
1069    // ────────────────────────────────────────────────────────────────────
1070    // LlmModel tests
1071    // ────────────────────────────────────────────────────────────────────
1072
1073    #[test]
1074    fn test_params_b_from_raw() {
1075        let model = LlmModel {
1076            name: "Test Model".to_string(),
1077            provider: "Test".to_string(),
1078            parameter_count: "7B".to_string(),
1079            parameters_raw: Some(7_000_000_000),
1080            min_ram_gb: 4.0,
1081            recommended_ram_gb: 8.0,
1082            min_vram_gb: Some(4.0),
1083            quantization: "Q4_K_M".to_string(),
1084            context_length: 4096,
1085            use_case: "General".to_string(),
1086            is_moe: false,
1087            num_experts: None,
1088            active_experts: None,
1089            active_parameters: None,
1090            release_date: None,
1091            gguf_sources: vec![],
1092            capabilities: vec![],
1093            format: ModelFormat::default(),
1094            num_attention_heads: None,
1095            num_key_value_heads: None,
1096            num_hidden_layers: None,
1097            head_dim: None,
1098            attention_layout: None,
1099            license: None,
1100        };
1101        assert_eq!(model.params_b(), 7.0);
1102    }
1103
1104    #[test]
1105    fn test_params_b_from_string() {
1106        let model = LlmModel {
1107            name: "Test Model".to_string(),
1108            provider: "Test".to_string(),
1109            parameter_count: "13B".to_string(),
1110            parameters_raw: None,
1111            min_ram_gb: 8.0,
1112            recommended_ram_gb: 16.0,
1113            min_vram_gb: Some(8.0),
1114            quantization: "Q4_K_M".to_string(),
1115            context_length: 4096,
1116            use_case: "General".to_string(),
1117            is_moe: false,
1118            num_experts: None,
1119            active_experts: None,
1120            active_parameters: None,
1121            release_date: None,
1122            gguf_sources: vec![],
1123            capabilities: vec![],
1124            format: ModelFormat::default(),
1125            num_attention_heads: None,
1126            num_key_value_heads: None,
1127            num_hidden_layers: None,
1128            head_dim: None,
1129            attention_layout: None,
1130            license: None,
1131        };
1132        assert_eq!(model.params_b(), 13.0);
1133    }
1134
1135    #[test]
1136    fn test_params_b_from_millions() {
1137        let model = LlmModel {
1138            name: "Test Model".to_string(),
1139            provider: "Test".to_string(),
1140            parameter_count: "500M".to_string(),
1141            parameters_raw: None,
1142            min_ram_gb: 1.0,
1143            recommended_ram_gb: 2.0,
1144            min_vram_gb: Some(1.0),
1145            quantization: "Q4_K_M".to_string(),
1146            context_length: 2048,
1147            use_case: "General".to_string(),
1148            is_moe: false,
1149            num_experts: None,
1150            active_experts: None,
1151            active_parameters: None,
1152            release_date: None,
1153            gguf_sources: vec![],
1154            capabilities: vec![],
1155            format: ModelFormat::default(),
1156            num_attention_heads: None,
1157            num_key_value_heads: None,
1158            num_hidden_layers: None,
1159            head_dim: None,
1160            attention_layout: None,
1161            license: None,
1162        };
1163        assert_eq!(model.params_b(), 0.5);
1164    }
1165
1166    #[test]
1167    fn test_estimate_memory_gb() {
1168        let model = LlmModel {
1169            name: "Test Model".to_string(),
1170            provider: "Test".to_string(),
1171            parameter_count: "7B".to_string(),
1172            parameters_raw: Some(7_000_000_000),
1173            min_ram_gb: 4.0,
1174            recommended_ram_gb: 8.0,
1175            min_vram_gb: Some(4.0),
1176            quantization: "Q4_K_M".to_string(),
1177            context_length: 4096,
1178            use_case: "General".to_string(),
1179            is_moe: false,
1180            num_experts: None,
1181            active_experts: None,
1182            active_parameters: None,
1183            release_date: None,
1184            gguf_sources: vec![],
1185            capabilities: vec![],
1186            format: ModelFormat::default(),
1187            num_attention_heads: None,
1188            num_key_value_heads: None,
1189            num_hidden_layers: None,
1190            head_dim: None,
1191            attention_layout: None,
1192            license: None,
1193        };
1194
1195        let mem = model.estimate_memory_gb("Q4_K_M", 4096);
1196        // 7B params * 0.58 bytes = 4.06 GB + KV cache + overhead
1197        assert!(mem > 4.0);
1198        assert!(mem < 6.0);
1199
1200        // Q8_0 should require more memory
1201        let mem_q8 = model.estimate_memory_gb("Q8_0", 4096);
1202        assert!(mem_q8 > mem);
1203    }
1204
1205    #[test]
1206    fn test_best_quant_for_budget() {
1207        let model = LlmModel {
1208            name: "Test Model".to_string(),
1209            provider: "Test".to_string(),
1210            parameter_count: "7B".to_string(),
1211            parameters_raw: Some(7_000_000_000),
1212            min_ram_gb: 4.0,
1213            recommended_ram_gb: 8.0,
1214            min_vram_gb: Some(4.0),
1215            quantization: "Q4_K_M".to_string(),
1216            context_length: 4096,
1217            use_case: "General".to_string(),
1218            is_moe: false,
1219            num_experts: None,
1220            active_experts: None,
1221            active_parameters: None,
1222            release_date: None,
1223            gguf_sources: vec![],
1224            capabilities: vec![],
1225            format: ModelFormat::default(),
1226            num_attention_heads: None,
1227            num_key_value_heads: None,
1228            num_hidden_layers: None,
1229            head_dim: None,
1230            attention_layout: None,
1231            license: None,
1232        };
1233
1234        // Large budget should return best quant
1235        let result = model.best_quant_for_budget(10.0, 4096);
1236        assert!(result.is_some());
1237        let (quant, _) = result.unwrap();
1238        assert_eq!(quant, "Q8_0");
1239
1240        // Medium budget should find acceptable quant
1241        let result = model.best_quant_for_budget(5.0, 4096);
1242        assert!(result.is_some());
1243
1244        // Tiny budget should return None
1245        let result = model.best_quant_for_budget(1.0, 4096);
1246        assert!(result.is_none());
1247    }
1248
1249    #[test]
1250    fn test_moe_active_vram_gb() {
1251        // Dense model should return None
1252        let dense_model = LlmModel {
1253            name: "Dense Model".to_string(),
1254            provider: "Test".to_string(),
1255            parameter_count: "7B".to_string(),
1256            parameters_raw: Some(7_000_000_000),
1257            min_ram_gb: 4.0,
1258            recommended_ram_gb: 8.0,
1259            min_vram_gb: Some(4.0),
1260            quantization: "Q4_K_M".to_string(),
1261            context_length: 4096,
1262            use_case: "General".to_string(),
1263            is_moe: false,
1264            num_experts: None,
1265            active_experts: None,
1266            active_parameters: None,
1267            release_date: None,
1268            gguf_sources: vec![],
1269            capabilities: vec![],
1270            format: ModelFormat::default(),
1271            num_attention_heads: None,
1272            num_key_value_heads: None,
1273            num_hidden_layers: None,
1274            head_dim: None,
1275            attention_layout: None,
1276            license: None,
1277        };
1278        assert!(dense_model.moe_active_vram_gb().is_none());
1279
1280        // MoE model should calculate active VRAM
1281        let moe_model = LlmModel {
1282            name: "MoE Model".to_string(),
1283            provider: "Test".to_string(),
1284            parameter_count: "8x7B".to_string(),
1285            parameters_raw: Some(46_700_000_000),
1286            min_ram_gb: 25.0,
1287            recommended_ram_gb: 50.0,
1288            min_vram_gb: Some(25.0),
1289            quantization: "Q4_K_M".to_string(),
1290            context_length: 32768,
1291            use_case: "General".to_string(),
1292            is_moe: true,
1293            num_experts: Some(8),
1294            active_experts: Some(2),
1295            active_parameters: Some(12_900_000_000),
1296            release_date: None,
1297            gguf_sources: vec![],
1298            capabilities: vec![],
1299            format: ModelFormat::default(),
1300            num_attention_heads: None,
1301            num_key_value_heads: None,
1302            num_hidden_layers: None,
1303            head_dim: None,
1304            attention_layout: None,
1305            license: None,
1306        };
1307        let vram = moe_model.moe_active_vram_gb();
1308        assert!(vram.is_some());
1309        let vram_val = vram.unwrap();
1310        // Should be significantly less than full model
1311        assert!(vram_val > 0.0);
1312        assert!(vram_val < 15.0);
1313    }
1314
1315    #[test]
1316    fn test_moe_offloaded_ram_gb() {
1317        // Dense model should return None
1318        let dense_model = LlmModel {
1319            name: "Dense Model".to_string(),
1320            provider: "Test".to_string(),
1321            parameter_count: "7B".to_string(),
1322            parameters_raw: Some(7_000_000_000),
1323            min_ram_gb: 4.0,
1324            recommended_ram_gb: 8.0,
1325            min_vram_gb: Some(4.0),
1326            quantization: "Q4_K_M".to_string(),
1327            context_length: 4096,
1328            use_case: "General".to_string(),
1329            is_moe: false,
1330            num_experts: None,
1331            active_experts: None,
1332            active_parameters: None,
1333            release_date: None,
1334            gguf_sources: vec![],
1335            capabilities: vec![],
1336            format: ModelFormat::default(),
1337            num_attention_heads: None,
1338            num_key_value_heads: None,
1339            num_hidden_layers: None,
1340            head_dim: None,
1341            attention_layout: None,
1342            license: None,
1343        };
1344        assert!(dense_model.moe_offloaded_ram_gb().is_none());
1345
1346        // MoE model should calculate offloaded RAM
1347        let moe_model = LlmModel {
1348            name: "MoE Model".to_string(),
1349            provider: "Test".to_string(),
1350            parameter_count: "8x7B".to_string(),
1351            parameters_raw: Some(46_700_000_000),
1352            min_ram_gb: 25.0,
1353            recommended_ram_gb: 50.0,
1354            min_vram_gb: Some(25.0),
1355            quantization: "Q4_K_M".to_string(),
1356            context_length: 32768,
1357            use_case: "General".to_string(),
1358            is_moe: true,
1359            num_experts: Some(8),
1360            active_experts: Some(2),
1361            active_parameters: Some(12_900_000_000),
1362            release_date: None,
1363            gguf_sources: vec![],
1364            capabilities: vec![],
1365            format: ModelFormat::default(),
1366            num_attention_heads: None,
1367            num_key_value_heads: None,
1368            num_hidden_layers: None,
1369            head_dim: None,
1370            attention_layout: None,
1371            license: None,
1372        };
1373        let offloaded = moe_model.moe_offloaded_ram_gb();
1374        assert!(offloaded.is_some());
1375        let offloaded_val = offloaded.unwrap();
1376        // Should be substantial
1377        assert!(offloaded_val > 10.0);
1378    }
1379
1380    // ────────────────────────────────────────────────────────────────────
1381    // UseCase tests
1382    // ────────────────────────────────────────────────────────────────────
1383
1384    #[test]
1385    fn test_use_case_from_model_coding() {
1386        let model = LlmModel {
1387            name: "codellama-7b".to_string(),
1388            provider: "Meta".to_string(),
1389            parameter_count: "7B".to_string(),
1390            parameters_raw: Some(7_000_000_000),
1391            min_ram_gb: 4.0,
1392            recommended_ram_gb: 8.0,
1393            min_vram_gb: Some(4.0),
1394            quantization: "Q4_K_M".to_string(),
1395            context_length: 4096,
1396            use_case: "Coding".to_string(),
1397            is_moe: false,
1398            num_experts: None,
1399            active_experts: None,
1400            active_parameters: None,
1401            release_date: None,
1402            gguf_sources: vec![],
1403            capabilities: vec![],
1404            format: ModelFormat::default(),
1405            num_attention_heads: None,
1406            num_key_value_heads: None,
1407            num_hidden_layers: None,
1408            head_dim: None,
1409            attention_layout: None,
1410            license: None,
1411        };
1412        assert_eq!(UseCase::from_model(&model), UseCase::Coding);
1413    }
1414
1415    #[test]
1416    fn test_use_case_from_model_embedding() {
1417        let model = LlmModel {
1418            name: "bge-large".to_string(),
1419            provider: "BAAI".to_string(),
1420            parameter_count: "335M".to_string(),
1421            parameters_raw: Some(335_000_000),
1422            min_ram_gb: 1.0,
1423            recommended_ram_gb: 2.0,
1424            min_vram_gb: Some(1.0),
1425            quantization: "F16".to_string(),
1426            context_length: 512,
1427            use_case: "Embedding".to_string(),
1428            is_moe: false,
1429            num_experts: None,
1430            active_experts: None,
1431            active_parameters: None,
1432            release_date: None,
1433            gguf_sources: vec![],
1434            capabilities: vec![],
1435            format: ModelFormat::default(),
1436            num_attention_heads: None,
1437            num_key_value_heads: None,
1438            num_hidden_layers: None,
1439            head_dim: None,
1440            attention_layout: None,
1441            license: None,
1442        };
1443        assert_eq!(UseCase::from_model(&model), UseCase::Embedding);
1444    }
1445
1446    #[test]
1447    fn test_use_case_from_model_reasoning() {
1448        let model = LlmModel {
1449            name: "deepseek-r1-7b".to_string(),
1450            provider: "DeepSeek".to_string(),
1451            parameter_count: "7B".to_string(),
1452            parameters_raw: Some(7_000_000_000),
1453            min_ram_gb: 4.0,
1454            recommended_ram_gb: 8.0,
1455            min_vram_gb: Some(4.0),
1456            quantization: "Q4_K_M".to_string(),
1457            context_length: 8192,
1458            use_case: "Reasoning".to_string(),
1459            is_moe: false,
1460            num_experts: None,
1461            active_experts: None,
1462            active_parameters: None,
1463            release_date: None,
1464            gguf_sources: vec![],
1465            capabilities: vec![],
1466            format: ModelFormat::default(),
1467            num_attention_heads: None,
1468            num_key_value_heads: None,
1469            num_hidden_layers: None,
1470            head_dim: None,
1471            attention_layout: None,
1472            license: None,
1473        };
1474        assert_eq!(UseCase::from_model(&model), UseCase::Reasoning);
1475    }
1476
1477    // ────────────────────────────────────────────────────────────────────
1478    // ModelDatabase tests
1479    // ────────────────────────────────────────────────────────────────────
1480
1481    #[test]
1482    fn test_model_database_new() {
1483        let db = ModelDatabase::new();
1484        let models = db.get_all_models();
1485        // Should have loaded models from embedded JSON
1486        assert!(!models.is_empty());
1487    }
1488
1489    #[test]
1490    fn test_find_model() {
1491        let db = ModelDatabase::new();
1492
1493        // Search by name substring (case insensitive)
1494        let results = db.find_model("llama");
1495        assert!(!results.is_empty());
1496        assert!(
1497            results
1498                .iter()
1499                .any(|m| m.name.to_lowercase().contains("llama"))
1500        );
1501
1502        // Search should be case insensitive
1503        let results_upper = db.find_model("LLAMA");
1504        assert_eq!(results.len(), results_upper.len());
1505    }
1506
1507    #[test]
1508    fn test_models_fitting_system() {
1509        let db = ModelDatabase::new();
1510
1511        // Large system should fit many models
1512        let fitting = db.models_fitting_system(32.0, true, Some(24.0));
1513        assert!(!fitting.is_empty());
1514
1515        // Very small system should fit fewer or no models
1516        let fitting_small = db.models_fitting_system(2.0, false, None);
1517        assert!(fitting_small.len() < fitting.len());
1518
1519        // All fitting models should meet RAM requirements
1520        for model in fitting_small {
1521            assert!(model.min_ram_gb <= 2.0);
1522        }
1523    }
1524
1525    // ────────────────────────────────────────────────────────────────────
1526    // Capability tests
1527    // ────────────────────────────────────────────────────────────────────
1528
1529    #[test]
1530    fn test_capability_infer_vision() {
1531        let model = LlmModel {
1532            name: "meta-llama/Llama-3.2-11B-Vision-Instruct".to_string(),
1533            provider: "Meta".to_string(),
1534            parameter_count: "11B".to_string(),
1535            parameters_raw: Some(11_000_000_000),
1536            min_ram_gb: 6.0,
1537            recommended_ram_gb: 10.0,
1538            min_vram_gb: Some(6.0),
1539            quantization: "Q4_K_M".to_string(),
1540            context_length: 131072,
1541            use_case: "Multimodal, vision and text".to_string(),
1542            is_moe: false,
1543            num_experts: None,
1544            active_experts: None,
1545            active_parameters: None,
1546            release_date: None,
1547            gguf_sources: vec![],
1548            capabilities: vec![],
1549            format: ModelFormat::default(),
1550            num_attention_heads: None,
1551            num_key_value_heads: None,
1552            num_hidden_layers: None,
1553            head_dim: None,
1554            attention_layout: None,
1555            license: None,
1556        };
1557        let caps = Capability::infer(&model);
1558        assert!(caps.contains(&Capability::Vision));
1559        // Also gets ToolUse because "llama-3" + "instruct"
1560        assert!(caps.contains(&Capability::ToolUse));
1561    }
1562
1563    #[test]
1564    fn test_capability_infer_tool_use() {
1565        let model = LlmModel {
1566            name: "Qwen/Qwen3-8B".to_string(),
1567            provider: "Qwen".to_string(),
1568            parameter_count: "8B".to_string(),
1569            parameters_raw: Some(8_000_000_000),
1570            min_ram_gb: 4.5,
1571            recommended_ram_gb: 8.0,
1572            min_vram_gb: Some(4.0),
1573            quantization: "Q4_K_M".to_string(),
1574            context_length: 32768,
1575            use_case: "General purpose text generation".to_string(),
1576            is_moe: false,
1577            num_experts: None,
1578            active_experts: None,
1579            active_parameters: None,
1580            release_date: None,
1581            gguf_sources: vec![],
1582            capabilities: vec![],
1583            format: ModelFormat::default(),
1584            num_attention_heads: None,
1585            num_key_value_heads: None,
1586            num_hidden_layers: None,
1587            head_dim: None,
1588            attention_layout: None,
1589            license: None,
1590        };
1591        let caps = Capability::infer(&model);
1592        assert!(caps.contains(&Capability::ToolUse));
1593        assert!(!caps.contains(&Capability::Vision));
1594    }
1595
1596    #[test]
1597    fn test_capability_infer_none() {
1598        let model = LlmModel {
1599            name: "BAAI/bge-large-en-v1.5".to_string(),
1600            provider: "BAAI".to_string(),
1601            parameter_count: "335M".to_string(),
1602            parameters_raw: Some(335_000_000),
1603            min_ram_gb: 1.0,
1604            recommended_ram_gb: 2.0,
1605            min_vram_gb: Some(1.0),
1606            quantization: "F16".to_string(),
1607            context_length: 512,
1608            use_case: "Text embeddings for RAG".to_string(),
1609            is_moe: false,
1610            num_experts: None,
1611            active_experts: None,
1612            active_parameters: None,
1613            release_date: None,
1614            gguf_sources: vec![],
1615            capabilities: vec![],
1616            format: ModelFormat::default(),
1617            num_attention_heads: None,
1618            num_key_value_heads: None,
1619            num_hidden_layers: None,
1620            head_dim: None,
1621            attention_layout: None,
1622            license: None,
1623        };
1624        let caps = Capability::infer(&model);
1625        assert!(caps.is_empty());
1626    }
1627
1628    #[test]
1629    fn test_capability_preserves_explicit() {
1630        let model = LlmModel {
1631            name: "some-model".to_string(),
1632            provider: "Test".to_string(),
1633            parameter_count: "7B".to_string(),
1634            parameters_raw: Some(7_000_000_000),
1635            min_ram_gb: 4.0,
1636            recommended_ram_gb: 8.0,
1637            min_vram_gb: Some(4.0),
1638            quantization: "Q4_K_M".to_string(),
1639            context_length: 4096,
1640            use_case: "General".to_string(),
1641            is_moe: false,
1642            num_experts: None,
1643            active_experts: None,
1644            active_parameters: None,
1645            release_date: None,
1646            gguf_sources: vec![],
1647            capabilities: vec![Capability::Vision],
1648            format: ModelFormat::default(),
1649            num_attention_heads: None,
1650            num_key_value_heads: None,
1651            num_hidden_layers: None,
1652            head_dim: None,
1653            attention_layout: None,
1654            license: None,
1655        };
1656        let caps = Capability::infer(&model);
1657        // Should keep the explicit Vision and not duplicate it
1658        assert_eq!(caps.iter().filter(|c| **c == Capability::Vision).count(), 1);
1659    }
1660
1661    #[test]
1662    fn test_awq_gptq_quant_values() {
1663        // AWQ
1664        assert_eq!(quant_bpp("AWQ-4bit"), 0.5);
1665        assert_eq!(quant_bpp("AWQ-8bit"), 1.0);
1666        assert_eq!(quant_speed_multiplier("AWQ-4bit"), 1.2);
1667        assert_eq!(quant_speed_multiplier("AWQ-8bit"), 0.85);
1668        assert_eq!(quant_quality_penalty("AWQ-4bit"), -3.0);
1669        assert_eq!(quant_quality_penalty("AWQ-8bit"), 0.0);
1670        // GPTQ
1671        assert_eq!(quant_bpp("GPTQ-Int4"), 0.5);
1672        assert_eq!(quant_bpp("GPTQ-Int8"), 1.0);
1673        assert_eq!(quant_speed_multiplier("GPTQ-Int4"), 1.2);
1674        assert_eq!(quant_speed_multiplier("GPTQ-Int8"), 0.85);
1675        assert_eq!(quant_quality_penalty("GPTQ-Int4"), -3.0);
1676        assert_eq!(quant_quality_penalty("GPTQ-Int8"), 0.0);
1677    }
1678
1679    #[test]
1680    fn test_model_format_prequantized() {
1681        assert!(ModelFormat::Awq.is_prequantized());
1682        assert!(ModelFormat::Gptq.is_prequantized());
1683        assert!(!ModelFormat::Gguf.is_prequantized());
1684        assert!(!ModelFormat::Mlx.is_prequantized());
1685        assert!(!ModelFormat::Safetensors.is_prequantized());
1686    }
1687
1688    // ────────────────────────────────────────────────────────────────────
1689    // GGUF source catalog tests
1690    // ────────────────────────────────────────────────────────────────────
1691
1692    #[test]
1693    fn test_gguf_source_deserialization() {
1694        let json = r#"{"repo": "unsloth/Llama-3.1-8B-Instruct-GGUF", "provider": "unsloth"}"#;
1695        let source: GgufSource = serde_json::from_str(json).unwrap();
1696        assert_eq!(source.repo, "unsloth/Llama-3.1-8B-Instruct-GGUF");
1697        assert_eq!(source.provider, "unsloth");
1698    }
1699
1700    #[test]
1701    fn test_gguf_sources_default_to_empty() {
1702        let json = r#"{
1703            "name": "test/model",
1704            "provider": "Test",
1705            "parameter_count": "7B",
1706            "parameters_raw": 7000000000,
1707            "min_ram_gb": 4.0,
1708            "recommended_ram_gb": 8.0,
1709            "quantization": "Q4_K_M",
1710            "context_length": 4096,
1711            "use_case": "General"
1712        }"#;
1713        let entry: HfModelEntry = serde_json::from_str(json).unwrap();
1714        assert!(entry.gguf_sources.is_empty());
1715    }
1716
1717    #[test]
1718    fn test_catalog_popular_models_have_gguf_sources() {
1719        let db = ModelDatabase::new();
1720        // These popular models should have gguf_sources populated in the catalog
1721        let expected_with_gguf = [
1722            "meta-llama/Llama-3.3-70B-Instruct",
1723            "Qwen/Qwen2.5-7B-Instruct",
1724            "Qwen/Qwen2.5-Coder-7B-Instruct",
1725            "meta-llama/Meta-Llama-3-8B-Instruct",
1726            "mistralai/Mistral-7B-Instruct-v0.3",
1727        ];
1728        for name in &expected_with_gguf {
1729            let model = db.get_all_models().iter().find(|m| m.name == *name);
1730            assert!(model.is_some(), "Model {} should exist in catalog", name);
1731            let model = model.unwrap();
1732            assert!(
1733                !model.gguf_sources.is_empty(),
1734                "Model {} should have gguf_sources but has none",
1735                name
1736            );
1737        }
1738    }
1739
1740    #[test]
1741    fn test_catalog_gguf_sources_have_valid_repos() {
1742        let db = ModelDatabase::new();
1743        for model in db.get_all_models() {
1744            for source in &model.gguf_sources {
1745                assert!(
1746                    source.repo.contains('/'),
1747                    "GGUF source repo '{}' for model '{}' should be owner/repo format",
1748                    source.repo,
1749                    model.name
1750                );
1751                assert!(
1752                    !source.provider.is_empty(),
1753                    "GGUF source provider for model '{}' should not be empty",
1754                    model.name
1755                );
1756                assert!(
1757                    source.repo.to_uppercase().contains("GGUF"),
1758                    "GGUF source repo '{}' for model '{}' should contain 'GGUF'",
1759                    source.repo,
1760                    model.name
1761                );
1762            }
1763        }
1764    }
1765
1766    #[test]
1767    #[ignore] // Requires network access to populate GGUF sources at build time
1768    fn test_catalog_has_significant_gguf_coverage() {
1769        let db = ModelDatabase::new();
1770        let total = db.get_all_models().len();
1771        let with_gguf = db
1772            .get_all_models()
1773            .iter()
1774            .filter(|m| !m.gguf_sources.is_empty())
1775            .count();
1776        // We should have at least 25% coverage after enrichment
1777        let coverage_pct = (with_gguf as f64 / total as f64) * 100.0;
1778        assert!(
1779            coverage_pct >= 25.0,
1780            "GGUF source coverage is only {:.1}% ({}/{}), expected at least 25%",
1781            coverage_pct,
1782            with_gguf,
1783            total
1784        );
1785    }
1786
1787    // ────────────────────────────────────────────────────────────────────
1788    // Tensor parallelism tests
1789    // ────────────────────────────────────────────────────────────────────
1790
1791    fn tp_test_model(
1792        name: &str,
1793        params_b: f64,
1794        attn_heads: Option<u32>,
1795        kv_heads: Option<u32>,
1796    ) -> LlmModel {
1797        LlmModel {
1798            name: name.to_string(),
1799            provider: "Test".to_string(),
1800            parameter_count: format!("{:.0}B", params_b),
1801            parameters_raw: Some((params_b * 1_000_000_000.0) as u64),
1802            min_ram_gb: 4.0,
1803            recommended_ram_gb: 8.0,
1804            min_vram_gb: Some(4.0),
1805            quantization: "Q4_K_M".to_string(),
1806            context_length: 4096,
1807            use_case: "General".to_string(),
1808            is_moe: false,
1809            num_experts: None,
1810            active_experts: None,
1811            active_parameters: None,
1812            release_date: None,
1813            gguf_sources: vec![],
1814            capabilities: vec![],
1815            format: ModelFormat::default(),
1816            num_attention_heads: attn_heads,
1817            num_key_value_heads: kv_heads,
1818            num_hidden_layers: None,
1819            head_dim: None,
1820            attention_layout: None,
1821            license: None,
1822        }
1823    }
1824
1825    #[test]
1826    fn test_supports_tp_with_explicit_heads() {
1827        let model = tp_test_model("Test-8B", 8.0, Some(32), Some(8));
1828        assert!(model.supports_tp(1));
1829        assert!(model.supports_tp(2));
1830        assert!(model.supports_tp(4));
1831        assert!(model.supports_tp(8));
1832        assert!(!model.supports_tp(3)); // 32 % 3 != 0
1833        assert!(!model.supports_tp(5));
1834    }
1835
1836    #[test]
1837    fn test_supports_tp_always_true_for_1() {
1838        let model = tp_test_model("Tiny", 1.0, None, None);
1839        assert!(model.supports_tp(1));
1840    }
1841
1842    #[test]
1843    fn test_valid_tp_sizes_32_8() {
1844        let model = tp_test_model("Test", 8.0, Some(32), Some(8));
1845        let sizes = model.valid_tp_sizes();
1846        assert!(sizes.contains(&1));
1847        assert!(sizes.contains(&2));
1848        assert!(sizes.contains(&4));
1849        assert!(sizes.contains(&8));
1850        assert!(!sizes.contains(&3));
1851    }
1852
1853    #[test]
1854    fn test_valid_tp_sizes_48_heads() {
1855        // 48 attn heads, 8 kv heads — TP must divide both
1856        let model = tp_test_model("Llama-32B", 32.0, Some(48), Some(8));
1857        assert!(model.supports_tp(2)); // 48%2==0, 8%2==0
1858        assert!(!model.supports_tp(3)); // 48%3==0 but 8%3!=0
1859        assert!(model.supports_tp(4)); // 48%4==0, 8%4==0
1860        assert!(model.supports_tp(8)); // 48%8==0, 8%8==0
1861    }
1862
1863    #[test]
1864    fn test_infer_heads_from_name_qwen() {
1865        let (attn, kv) = infer_heads_from_name("Qwen2.5-72B-Instruct", 72.0);
1866        assert_eq!(attn, 64);
1867        assert_eq!(kv, 8);
1868    }
1869
1870    #[test]
1871    fn test_infer_heads_from_name_llama() {
1872        let (attn, kv) = infer_heads_from_name("Llama-3.1-8B", 8.0);
1873        assert_eq!(attn, 32);
1874        assert_eq!(kv, 8);
1875    }
1876
1877    #[test]
1878    fn test_infer_heads_from_name_deepseek() {
1879        let (attn, kv) = infer_heads_from_name("DeepSeek-V3", 671.0);
1880        assert_eq!(attn, 128);
1881        assert_eq!(kv, 16);
1882    }
1883
1884    #[test]
1885    fn test_supports_tp_with_inferred_heads() {
1886        // No explicit heads — should infer from name
1887        let model = tp_test_model("Llama-3.1-70B", 70.0, None, None);
1888        assert!(model.supports_tp(2));
1889        assert!(model.supports_tp(4));
1890        assert!(model.supports_tp(8));
1891    }
1892
1893    // ────────────────────────────────────────────────────────────────────
1894    // KV cache formula + KvQuant + AttentionLayout
1895    // ────────────────────────────────────────────────────────────────────
1896
1897    fn kv_test_model(name: &str) -> LlmModel {
1898        // Roughly modelled on Llama-3.1-8B: 32 layers, 32 heads, 8 KV heads,
1899        // head_dim 128.
1900        LlmModel {
1901            name: name.to_string(),
1902            provider: "Test".to_string(),
1903            parameter_count: "8B".to_string(),
1904            parameters_raw: Some(8_000_000_000),
1905            min_ram_gb: 4.0,
1906            recommended_ram_gb: 8.0,
1907            min_vram_gb: Some(4.0),
1908            quantization: "Q4_K_M".to_string(),
1909            context_length: 8192,
1910            use_case: "General".to_string(),
1911            is_moe: false,
1912            num_experts: None,
1913            active_experts: None,
1914            active_parameters: None,
1915            release_date: None,
1916            gguf_sources: vec![],
1917            capabilities: vec![],
1918            format: ModelFormat::default(),
1919            num_attention_heads: Some(32),
1920            num_key_value_heads: Some(8),
1921            num_hidden_layers: Some(32),
1922            head_dim: Some(128),
1923            attention_layout: None,
1924            license: None,
1925        }
1926    }
1927
1928    #[test]
1929    fn test_kv_quant_from_str_round_trip() {
1930        for kv in KvQuant::all() {
1931            let parsed = KvQuant::parse(kv.label()).expect("label should parse");
1932            assert_eq!(parsed, *kv);
1933        }
1934        assert_eq!(KvQuant::parse("FP16"), Some(KvQuant::Fp16));
1935        assert_eq!(KvQuant::parse("Q4_0"), Some(KvQuant::Q4_0));
1936        assert_eq!(KvQuant::parse("turboquant"), Some(KvQuant::TurboQuant));
1937        assert_eq!(KvQuant::parse("nope"), None);
1938    }
1939
1940    #[test]
1941    fn test_kv_cache_precise_formula_matches_hand_calc() {
1942        // 32 layers * 2 (K+V) * 8 KV heads * 128 head_dim * 8192 ctx * 2 (fp16)
1943        // = 1_073_741_824 bytes ≈ 1.0 GB
1944        let model = kv_test_model("Llama-3.1-8B");
1945        let kv = model.kv_cache_gb(8192, KvQuant::Fp16);
1946        assert!((kv - 1.0).abs() < 0.05, "expected ~1.0 GB, got {:.4}", kv);
1947    }
1948
1949    #[test]
1950    fn test_kv_cache_scales_with_quant() {
1951        let model = kv_test_model("test");
1952        let fp16 = model.kv_cache_gb(8192, KvQuant::Fp16);
1953        let q8 = model.kv_cache_gb(8192, KvQuant::Q8_0);
1954        let q4 = model.kv_cache_gb(8192, KvQuant::Q4_0);
1955        // q8 should be ~half fp16, q4 should be ~quarter
1956        assert!((q8 / fp16 - 0.5).abs() < 0.01);
1957        assert!((q4 / fp16 - 0.25).abs() < 0.01);
1958    }
1959
1960    #[test]
1961    fn test_kv_cache_fallback_when_metadata_missing() {
1962        // No layer/head_dim metadata: should fall back to the linear approx
1963        // and still scale with KvQuant.
1964        let mut model = kv_test_model("nameless");
1965        model.num_hidden_layers = None;
1966        model.head_dim = None;
1967        let fp16 = model.kv_cache_gb(8192, KvQuant::Fp16);
1968        let q4 = model.kv_cache_gb(8192, KvQuant::Q4_0);
1969        assert!(fp16 > 0.0);
1970        assert!(q4 < fp16);
1971    }
1972
1973    #[test]
1974    fn test_turboquant_full_attention_uses_compressed_rate() {
1975        // Pure dense (no layout): TQ should compress every layer.
1976        let model = kv_test_model("dense");
1977        let fp16 = model.kv_cache_gb(8192, KvQuant::Fp16);
1978        let tq = model.kv_cache_gb(8192, KvQuant::TurboQuant);
1979        let ratio = tq / fp16;
1980        // ~0.34 / 2.0 = 0.17 of fp16
1981        assert!(
1982            (0.10..=0.25).contains(&ratio),
1983            "TQ ratio on dense should be ~0.17, got {:.3}",
1984            ratio
1985        );
1986    }
1987
1988    #[test]
1989    fn test_turboquant_hybrid_only_compresses_full_attention() {
1990        // 10 full + 30 linear layers (Qwen3.5-A3B style).
1991        let mut model = kv_test_model("hybrid");
1992        model.num_hidden_layers = Some(40);
1993        model.attention_layout = Some(AttentionLayout {
1994            full: 10,
1995            linear: 30,
1996        });
1997        let fp16 = model.kv_cache_gb(8192, KvQuant::Fp16);
1998        let tq = model.kv_cache_gb(8192, KvQuant::TurboQuant);
1999        let savings = 1.0 - tq / fp16;
2000        // Honest savings should be ~0.83 * 0.25 ≈ 21% (only the 10/40 slice
2001        // is compressed by ~83%). Allow a wide tolerance because the constants
2002        // are deliberately conservative.
2003        assert!(
2004            (0.10..=0.30).contains(&savings),
2005            "expected ~20% honest savings on hybrid model, got {:.3}",
2006            savings
2007        );
2008        // And it must be far from the dense headline of ~83%.
2009        assert!(savings < 0.5);
2010    }
2011
2012    #[test]
2013    fn test_attention_layout_compressible_fraction() {
2014        let dense = AttentionLayout {
2015            full: 32,
2016            linear: 0,
2017        };
2018        assert!((dense.compressible_fraction() - 1.0).abs() < 0.0001);
2019
2020        let hybrid = AttentionLayout {
2021            full: 10,
2022            linear: 30,
2023        };
2024        assert!((hybrid.compressible_fraction() - 0.25).abs() < 0.0001);
2025
2026        let pure_ssm = AttentionLayout {
2027            full: 0,
2028            linear: 64,
2029        };
2030        assert!((pure_ssm.compressible_fraction() - 0.0).abs() < 0.0001);
2031    }
2032
2033    #[test]
2034    fn test_infer_attention_layout_qwen3_next() {
2035        let layout = infer_attention_layout_from_name("Qwen/Qwen3-Next-80B-A3B");
2036        assert!(layout.is_some());
2037        let layout = layout.unwrap();
2038        assert!(layout.full > 0 && layout.linear > 0);
2039        assert!(layout.compressible_fraction() < 0.5);
2040    }
2041
2042    #[test]
2043    fn test_infer_attention_layout_dense_returns_none() {
2044        assert!(infer_attention_layout_from_name("meta-llama/Llama-3.1-8B").is_none());
2045        assert!(infer_attention_layout_from_name("Qwen/Qwen2.5-7B").is_none());
2046    }
2047
2048    #[test]
2049    fn test_effective_attention_layout_prefers_explicit() {
2050        let mut model = kv_test_model("Qwen/Qwen3-Next-80B");
2051        // Explicit metadata should override the heuristic
2052        model.attention_layout = Some(AttentionLayout {
2053            full: 5,
2054            linear: 35,
2055        });
2056        let resolved = model.effective_attention_layout().unwrap();
2057        assert_eq!(resolved.full, 5);
2058        assert_eq!(resolved.linear, 35);
2059    }
2060
2061    #[test]
2062    fn test_estimate_memory_with_kv_q8_smaller_than_fp16() {
2063        let model = kv_test_model("Llama-3.1-8B");
2064        let fp16_total = model.estimate_memory_gb_with_kv("Q4_K_M", 32_768, KvQuant::Fp16);
2065        let q8_total = model.estimate_memory_gb_with_kv("Q4_K_M", 32_768, KvQuant::Q8_0);
2066        let q4_total = model.estimate_memory_gb_with_kv("Q4_K_M", 32_768, KvQuant::Q4_0);
2067        assert!(q8_total < fp16_total);
2068        assert!(q4_total < q8_total);
2069    }
2070}