Skip to main content

oxibonsai_model/
model_registry.rs

1//! Multi-model support: auto-detect Bonsai model variant from GGUF metadata.
2//!
3//! The model registry provides automatic detection of model architecture
4//! variants (8B, 4B, 1.7B) based on configuration parameters like
5//! layer count and hidden dimension size.
6
7use oxibonsai_core::config::Qwen3Config;
8
9/// Known Bonsai model variants.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum ModelVariant {
12    /// Bonsai-8B (Qwen3-8B architecture): 36 layers, hidden=4096
13    Bonsai8B,
14    /// Bonsai-4B: 24 layers, hidden=2560
15    Bonsai4B,
16    /// Bonsai-1.7B: 16 layers, hidden=1536
17    Bonsai1_7B,
18    /// Ternary-Bonsai-8B: same Qwen3-8B architecture, {-1,0,+1} weights (TQ2_0_g128).
19    TernaryBonsai8B,
20    /// Ternary-Bonsai-4B: same Qwen3-4B architecture, {-1,0,+1} weights (TQ2_0_g128).
21    TernaryBonsai4B,
22    /// Ternary-Bonsai-1.7B: same Qwen3-1.7B architecture, {-1,0,+1} weights (TQ2_0_g128).
23    TernaryBonsai1_7B,
24    /// FP8-Bonsai-8B: same Qwen3-8B architecture, FP8 weights (F8_E4M3 or F8_E5M2).
25    FP8Bonsai8B,
26    /// FP8-Bonsai-4B: same Qwen3-4B architecture, FP8 weights.
27    FP8Bonsai4B,
28    /// FP8-Bonsai-1.7B: same Qwen3-1.7B architecture, FP8 weights.
29    FP8Bonsai1_7B,
30    /// Custom or unrecognized architecture
31    Custom,
32}
33
34impl ModelVariant {
35    /// Auto-detect variant from model configuration.
36    ///
37    /// Matches on the combination of `num_layers` and `hidden_size`
38    /// to identify known architectures.
39    pub fn from_config(config: &Qwen3Config) -> Self {
40        match (config.num_layers, config.hidden_size) {
41            (36, 4096) => ModelVariant::Bonsai8B,
42            (24, 2560) => ModelVariant::Bonsai4B,
43            (16, 1536) => ModelVariant::Bonsai1_7B,
44            _ => ModelVariant::Custom,
45        }
46    }
47
48    /// Detect model variant from config + sample tensor type (for ternary vs 1-bit disambiguation).
49    ///
50    /// Architecture match is identical to `from_config`, but if `sample_tensor_type.is_ternary()`,
51    /// the result is upgraded to the ternary sibling variant.
52    pub fn from_config_and_sample_tensor_type(
53        config: &Qwen3Config,
54        sample_tensor_type: oxibonsai_core::GgufTensorType,
55    ) -> Self {
56        let base = Self::from_config(config);
57        if sample_tensor_type.is_ternary() {
58            match base {
59                Self::Bonsai8B => Self::TernaryBonsai8B,
60                Self::Bonsai4B => Self::TernaryBonsai4B,
61                Self::Bonsai1_7B => Self::TernaryBonsai1_7B,
62                other => other, // Custom or already-ternary → unchanged
63            }
64        } else if sample_tensor_type.is_fp8() {
65            match base {
66                Self::Bonsai8B => Self::FP8Bonsai8B,
67                Self::Bonsai4B => Self::FP8Bonsai4B,
68                Self::Bonsai1_7B => Self::FP8Bonsai1_7B,
69                other => other, // Custom or already-fp8 → unchanged
70            }
71        } else {
72            base
73        }
74    }
75
76    /// Get the default configuration for this variant.
77    ///
78    /// Returns the standard configuration for known variants.
79    /// For `Custom`, returns the 8B configuration as a fallback.
80    pub fn default_config(&self) -> Qwen3Config {
81        match self {
82            ModelVariant::Bonsai8B => Qwen3Config::bonsai_8b(),
83            ModelVariant::Bonsai4B => Qwen3Config::bonsai_4b(),
84            ModelVariant::Bonsai1_7B => Qwen3Config::bonsai_1_7b(),
85            ModelVariant::TernaryBonsai8B => Qwen3Config::ternary_bonsai_8b(),
86            ModelVariant::TernaryBonsai4B => Qwen3Config::ternary_bonsai_4b(),
87            ModelVariant::TernaryBonsai1_7B => Qwen3Config::ternary_bonsai_1_7b(),
88            // FP8 variants share the same Qwen3 architecture as their 1-bit siblings.
89            ModelVariant::FP8Bonsai8B => Qwen3Config::bonsai_8b(),
90            ModelVariant::FP8Bonsai4B => Qwen3Config::bonsai_4b(),
91            ModelVariant::FP8Bonsai1_7B => Qwen3Config::bonsai_1_7b(),
92            ModelVariant::Custom => Qwen3Config::bonsai_8b(),
93        }
94    }
95
96    /// Human-readable display name for this variant.
97    pub fn name(&self) -> &'static str {
98        match self {
99            ModelVariant::Bonsai8B => "Bonsai-8B",
100            ModelVariant::Bonsai4B => "Bonsai-4B",
101            ModelVariant::Bonsai1_7B => "Bonsai-1.7B",
102            ModelVariant::TernaryBonsai8B => "Ternary-Bonsai-8B",
103            ModelVariant::TernaryBonsai4B => "Ternary-Bonsai-4B",
104            ModelVariant::TernaryBonsai1_7B => "Ternary-Bonsai-1.7B",
105            ModelVariant::FP8Bonsai8B => "FP8-Bonsai-8B",
106            ModelVariant::FP8Bonsai4B => "FP8-Bonsai-4B",
107            ModelVariant::FP8Bonsai1_7B => "FP8-Bonsai-1.7B",
108            ModelVariant::Custom => "Custom",
109        }
110    }
111
112    /// Approximate parameter count for this variant.
113    ///
114    /// Computed as: embedding + attention + ffn + norms + output head.
115    /// For 1-bit models, each "parameter" is 1 bit + per-group scale.
116    /// Ternary variants share the same architecture (and thus the same parameter count)
117    /// as their 1-bit siblings; only the storage format differs.
118    pub fn param_count(&self) -> u64 {
119        match self {
120            ModelVariant::Bonsai8B | ModelVariant::TernaryBonsai8B | ModelVariant::FP8Bonsai8B => {
121                // Qwen3-8B: ~8.03B parameters
122                // Embedding: 151936 * 4096 = 622M
123                // Per layer: Q(4096*4096) + K(4096*1024) + V(4096*1024) + O(4096*4096)
124                //          + gate(4096*14336) + up(4096*14336) + down(14336*4096)
125                //          + 2 norms(4096 each)
126                // = 16M + 4M + 4M + 16M + 58.7M + 58.7M + 58.7M + 8K = ~216M per layer
127                // 36 layers = ~7.78B
128                // + embedding(622M) + output(622M) + final norm(4K)
129                8_030_000_000
130            }
131            ModelVariant::Bonsai4B | ModelVariant::TernaryBonsai4B | ModelVariant::FP8Bonsai4B => {
132                // 24 layers, hidden=2560, intermediate=6912
133                // Per layer: Q(2560*2560) + K(2560*512) + V(2560*512) + O(2560*2560)
134                //          + gate(2560*6912) + up(2560*6912) + down(6912*2560) + norms
135                // Embedding: 151936 * 2560
136                4_020_000_000
137            }
138            ModelVariant::Bonsai1_7B
139            | ModelVariant::TernaryBonsai1_7B
140            | ModelVariant::FP8Bonsai1_7B => {
141                // 16 layers, hidden=1536, intermediate=4096
142                1_720_000_000
143            }
144            ModelVariant::Custom => 0,
145        }
146    }
147
148    /// Expected model file size in bytes for the quantized GGUF file.
149    ///
150    /// For 1-bit variants: ~1 bit per param + scale factors + FP16 embeddings.
151    /// For ternary variants: TQ2_0_g128 uses 34 bytes per 128 weights ≈ 0.266 bytes/param.
152    /// Embeddings and norms are typically stored in FP16 or FP32.
153    pub fn expected_model_size_bytes(&self) -> u64 {
154        match self {
155            ModelVariant::Bonsai8B => {
156                // ~8B params at 1 bit = ~1 GB for weights
157                // + embeddings in FP16: 151936 * 4096 * 2 = ~1.2 GB
158                // + norms in FP32: ~0.01 GB
159                // + metadata overhead
160                // Total: ~2.2 GB
161                2_200_000_000
162            }
163            ModelVariant::Bonsai4B => {
164                // ~4B params at 1 bit = ~0.5 GB
165                // + embeddings in FP16: 151936 * 2560 * 2 = ~0.78 GB
166                // Total: ~1.3 GB
167                1_300_000_000
168            }
169            ModelVariant::Bonsai1_7B => {
170                // ~1.7B params at 1 bit = ~0.21 GB
171                // + embeddings in FP16: 151936 * 1536 * 2 = ~0.47 GB
172                // Total: ~0.7 GB
173                700_000_000
174            }
175            ModelVariant::TernaryBonsai8B => {
176                // TQ2_0_g128: 34 bytes per 128 weights ≈ 0.266 bytes/param
177                // ~8.03B params × 0.266 ≈ ~2.13 GB minus embeddings sharing
178                // Embeddings (FP16): 151936 * 4096 * 2 ≈ 1.24 GB — same as 1-bit
179                // Transformer weights only (excl. embedding/output ~1.24B params):
180                //   ~6.8B × 0.266 ≈ 1.81 GB + embedding 1.24 GB → ~1.75 GB total
181                // (embeddings/output stored in FP16 dominate less at ternary density)
182                1_750_000_000
183            }
184            ModelVariant::TernaryBonsai4B => {
185                // ~4.02B params, transformer weights ~3.63B × 0.266 ≈ 0.97 GB
186                // + embeddings (FP16): 151936 * 2560 * 2 ≈ 0.78 GB → ~0.90 GB total
187                900_000_000
188            }
189            ModelVariant::TernaryBonsai1_7B => {
190                // ~1.72B params, transformer weights ~1.49B × 0.266 ≈ 0.40 GB
191                // + embeddings (FP16): 151936 * 1536 * 2 ≈ 0.47 GB → ~0.39 GB total
192                390_000_000
193            }
194            ModelVariant::FP8Bonsai8B => {
195                // FP8: 1 byte/weight + FP16 scale per 32-weight block ≈ 1.0625 bytes/weight
196                // Transformer weights: ~7.88B × 1.0625 ≈ 8.37 GB — but embeddings in FP16
197                // Embeddings (FP16): 151936 × 4096 × 2 ≈ 1.24 GB
198                // Rough total: ~8.5 GB (FP8 is closer to FP16 in size)
199                8_500_000_000
200            }
201            ModelVariant::FP8Bonsai4B => {
202                // Transformer: ~3.63B × 1.0625 ≈ 3.86 GB + embeddings 0.78 GB → ~5.0 GB
203                5_000_000_000
204            }
205            ModelVariant::FP8Bonsai1_7B => {
206                // Transformer: ~1.49B × 1.0625 ≈ 1.58 GB + embeddings 0.47 GB → ~2.3 GB
207                2_300_000_000
208            }
209            ModelVariant::Custom => 0,
210        }
211    }
212
213    /// Return all known (non-Custom) variants.
214    pub fn known_variants() -> &'static [ModelVariant] {
215        &[
216            ModelVariant::Bonsai8B,
217            ModelVariant::Bonsai4B,
218            ModelVariant::Bonsai1_7B,
219            ModelVariant::TernaryBonsai8B,
220            ModelVariant::TernaryBonsai4B,
221            ModelVariant::TernaryBonsai1_7B,
222            ModelVariant::FP8Bonsai8B,
223            ModelVariant::FP8Bonsai4B,
224            ModelVariant::FP8Bonsai1_7B,
225        ]
226    }
227
228    /// Whether this variant is a known (non-custom) architecture.
229    pub fn is_known(&self) -> bool {
230        !matches!(self, ModelVariant::Custom)
231    }
232}
233
234impl std::fmt::Display for ModelVariant {
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        write!(f, "{}", self.name())
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn detect_bonsai_8b() {
246        let config = Qwen3Config::bonsai_8b();
247        assert_eq!(ModelVariant::from_config(&config), ModelVariant::Bonsai8B);
248        assert_eq!(ModelVariant::Bonsai8B.name(), "Bonsai-8B");
249        assert!(ModelVariant::Bonsai8B.is_known());
250    }
251
252    #[test]
253    fn detect_bonsai_4b() {
254        let config = Qwen3Config::bonsai_4b();
255        assert_eq!(ModelVariant::from_config(&config), ModelVariant::Bonsai4B);
256        assert_eq!(ModelVariant::Bonsai4B.name(), "Bonsai-4B");
257        assert!(ModelVariant::Bonsai4B.is_known());
258    }
259
260    #[test]
261    fn detect_bonsai_1_7b() {
262        let config = Qwen3Config::bonsai_1_7b();
263        assert_eq!(ModelVariant::from_config(&config), ModelVariant::Bonsai1_7B);
264        assert_eq!(ModelVariant::Bonsai1_7B.name(), "Bonsai-1.7B");
265        assert!(ModelVariant::Bonsai1_7B.is_known());
266    }
267
268    #[test]
269    fn detect_custom() {
270        let mut config = Qwen3Config::bonsai_8b();
271        config.num_layers = 48;
272        config.hidden_size = 8192;
273        assert_eq!(ModelVariant::from_config(&config), ModelVariant::Custom);
274        assert_eq!(ModelVariant::Custom.name(), "Custom");
275        assert!(!ModelVariant::Custom.is_known());
276    }
277
278    #[test]
279    fn default_configs_roundtrip() {
280        // Only the 1-bit variants can round-trip through from_config() alone.
281        // Ternary variants share the same architecture as their 1-bit siblings,
282        // so from_config() returns the 1-bit sibling — that is expected and correct.
283        // Ternary detection requires from_config_and_sample_tensor_type().
284        let one_bit_variants = [
285            ModelVariant::Bonsai8B,
286            ModelVariant::Bonsai4B,
287            ModelVariant::Bonsai1_7B,
288        ];
289        for variant in &one_bit_variants {
290            let config = variant.default_config();
291            let detected = ModelVariant::from_config(&config);
292            assert_eq!(
293                *variant, detected,
294                "variant {:?} config should round-trip",
295                variant
296            );
297        }
298    }
299
300    #[test]
301    fn param_counts_are_reasonable() {
302        assert!(ModelVariant::Bonsai8B.param_count() > 7_000_000_000);
303        assert!(ModelVariant::Bonsai8B.param_count() < 10_000_000_000);
304
305        assert!(ModelVariant::Bonsai4B.param_count() > 3_000_000_000);
306        assert!(ModelVariant::Bonsai4B.param_count() < 5_000_000_000);
307
308        assert!(ModelVariant::Bonsai1_7B.param_count() > 1_000_000_000);
309        assert!(ModelVariant::Bonsai1_7B.param_count() < 2_500_000_000);
310
311        assert_eq!(ModelVariant::Custom.param_count(), 0);
312    }
313
314    #[test]
315    fn model_sizes_decrease_with_variant() {
316        let size_8b = ModelVariant::Bonsai8B.expected_model_size_bytes();
317        let size_4b = ModelVariant::Bonsai4B.expected_model_size_bytes();
318        let size_1_7b = ModelVariant::Bonsai1_7B.expected_model_size_bytes();
319
320        assert!(size_8b > size_4b, "8B should be larger than 4B");
321        assert!(size_4b > size_1_7b, "4B should be larger than 1.7B");
322        assert!(size_1_7b > 0, "1.7B should have nonzero size");
323    }
324
325    #[test]
326    fn display_trait() {
327        assert_eq!(format!("{}", ModelVariant::Bonsai8B), "Bonsai-8B");
328        assert_eq!(format!("{}", ModelVariant::Custom), "Custom");
329    }
330
331    #[test]
332    fn known_variants_list() {
333        let variants = ModelVariant::known_variants();
334        assert_eq!(variants.len(), 9);
335        assert!(variants.contains(&ModelVariant::Bonsai8B));
336        assert!(variants.contains(&ModelVariant::Bonsai4B));
337        assert!(variants.contains(&ModelVariant::Bonsai1_7B));
338        assert!(variants.contains(&ModelVariant::TernaryBonsai8B));
339        assert!(variants.contains(&ModelVariant::TernaryBonsai4B));
340        assert!(variants.contains(&ModelVariant::TernaryBonsai1_7B));
341        assert!(variants.contains(&ModelVariant::FP8Bonsai8B));
342        assert!(variants.contains(&ModelVariant::FP8Bonsai4B));
343        assert!(variants.contains(&ModelVariant::FP8Bonsai1_7B));
344    }
345
346    #[test]
347    fn detect_ternary_8b_by_tensor_type() {
348        let cfg = Qwen3Config::ternary_bonsai_8b();
349        let variant = ModelVariant::from_config_and_sample_tensor_type(
350            &cfg,
351            oxibonsai_core::GgufTensorType::TQ2_0_g128,
352        );
353        assert_eq!(variant, ModelVariant::TernaryBonsai8B);
354    }
355
356    #[test]
357    fn detect_bonsai_8b_stays_1bit() {
358        let cfg = Qwen3Config::bonsai_8b();
359        let variant = ModelVariant::from_config_and_sample_tensor_type(
360            &cfg,
361            oxibonsai_core::GgufTensorType::Q1_0_g128,
362        );
363        assert_eq!(variant, ModelVariant::Bonsai8B);
364    }
365
366    #[test]
367    fn ternary_variant_param_counts_match_bonsai() {
368        assert_eq!(
369            ModelVariant::TernaryBonsai8B.param_count(),
370            ModelVariant::Bonsai8B.param_count()
371        );
372        assert_eq!(
373            ModelVariant::TernaryBonsai4B.param_count(),
374            ModelVariant::Bonsai4B.param_count()
375        );
376        assert_eq!(
377            ModelVariant::TernaryBonsai1_7B.param_count(),
378            ModelVariant::Bonsai1_7B.param_count()
379        );
380    }
381
382    #[test]
383    fn ternary_variant_expected_size_less_than_fp16() {
384        // Ternary 8B at ~1.75 GB should be way less than FP16 8B at ~16 GB
385        let ternary_size = ModelVariant::TernaryBonsai8B.expected_model_size_bytes();
386        assert!(
387            ternary_size < 2_000_000_000,
388            "8B ternary expected < 2 GB, got {}",
389            ternary_size
390        );
391        assert!(
392            ternary_size > 1_000_000_000,
393            "8B ternary expected > 1 GB, got {}",
394            ternary_size
395        );
396    }
397
398    #[test]
399    fn ternary_variants_are_known() {
400        assert!(ModelVariant::TernaryBonsai8B.is_known());
401        assert!(ModelVariant::TernaryBonsai4B.is_known());
402        assert!(ModelVariant::TernaryBonsai1_7B.is_known());
403    }
404
405    #[test]
406    fn ternary_variant_names() {
407        assert_eq!(ModelVariant::TernaryBonsai8B.name(), "Ternary-Bonsai-8B");
408        assert_eq!(ModelVariant::TernaryBonsai4B.name(), "Ternary-Bonsai-4B");
409        assert_eq!(
410            ModelVariant::TernaryBonsai1_7B.name(),
411            "Ternary-Bonsai-1.7B"
412        );
413    }
414
415    #[test]
416    fn ternary_display_trait() {
417        assert_eq!(
418            format!("{}", ModelVariant::TernaryBonsai8B),
419            "Ternary-Bonsai-8B"
420        );
421        assert_eq!(
422            format!("{}", ModelVariant::TernaryBonsai4B),
423            "Ternary-Bonsai-4B"
424        );
425        assert_eq!(
426            format!("{}", ModelVariant::TernaryBonsai1_7B),
427            "Ternary-Bonsai-1.7B"
428        );
429    }
430
431    #[test]
432    fn ternary_default_configs_roundtrip() {
433        // Ternary variants have identical architecture to their 1-bit siblings,
434        // so from_config() returns the 1-bit variant — that is expected and correct.
435        // Verify the default_config() returns sensible configs with matching architecture.
436        let cfg_8b = ModelVariant::TernaryBonsai8B.default_config();
437        assert_eq!(cfg_8b.num_layers, 36);
438        assert_eq!(cfg_8b.hidden_size, 4096);
439
440        let cfg_4b = ModelVariant::TernaryBonsai4B.default_config();
441        assert_eq!(cfg_4b.num_layers, 24);
442        assert_eq!(cfg_4b.hidden_size, 2560);
443
444        let cfg_1_7b = ModelVariant::TernaryBonsai1_7B.default_config();
445        assert_eq!(cfg_1_7b.num_layers, 16);
446        assert_eq!(cfg_1_7b.hidden_size, 1536);
447    }
448
449    #[test]
450    fn detect_ternary_4b_and_1_7b_by_tensor_type() {
451        let cfg_4b = Qwen3Config::ternary_bonsai_4b();
452        let variant_4b = ModelVariant::from_config_and_sample_tensor_type(
453            &cfg_4b,
454            oxibonsai_core::GgufTensorType::TQ2_0_g128,
455        );
456        assert_eq!(variant_4b, ModelVariant::TernaryBonsai4B);
457
458        let cfg_1_7b = Qwen3Config::ternary_bonsai_1_7b();
459        let variant_1_7b = ModelVariant::from_config_and_sample_tensor_type(
460            &cfg_1_7b,
461            oxibonsai_core::GgufTensorType::TQ2_0_g128,
462        );
463        assert_eq!(variant_1_7b, ModelVariant::TernaryBonsai1_7B);
464    }
465
466    #[test]
467    fn custom_stays_custom_with_ternary_type() {
468        let mut cfg = Qwen3Config::bonsai_8b();
469        cfg.num_layers = 48;
470        cfg.hidden_size = 8192;
471        let variant = ModelVariant::from_config_and_sample_tensor_type(
472            &cfg,
473            oxibonsai_core::GgufTensorType::TQ2_0_g128,
474        );
475        assert_eq!(variant, ModelVariant::Custom);
476    }
477
478    #[test]
479    fn detect_fp8_e4m3_8b_by_tensor_type() {
480        let cfg = Qwen3Config::bonsai_8b();
481        let variant = ModelVariant::from_config_and_sample_tensor_type(
482            &cfg,
483            oxibonsai_core::GgufTensorType::F8_E4M3,
484        );
485        assert_eq!(variant, ModelVariant::FP8Bonsai8B);
486    }
487
488    #[test]
489    fn detect_fp8_e5m2_1_7b_by_tensor_type() {
490        let cfg = Qwen3Config::bonsai_1_7b();
491        let variant = ModelVariant::from_config_and_sample_tensor_type(
492            &cfg,
493            oxibonsai_core::GgufTensorType::F8_E5M2,
494        );
495        assert_eq!(variant, ModelVariant::FP8Bonsai1_7B);
496    }
497
498    #[test]
499    fn fp8_variant_param_counts_match_bonsai() {
500        assert_eq!(
501            ModelVariant::FP8Bonsai8B.param_count(),
502            ModelVariant::Bonsai8B.param_count()
503        );
504        assert_eq!(
505            ModelVariant::FP8Bonsai4B.param_count(),
506            ModelVariant::Bonsai4B.param_count()
507        );
508        assert_eq!(
509            ModelVariant::FP8Bonsai1_7B.param_count(),
510            ModelVariant::Bonsai1_7B.param_count()
511        );
512    }
513
514    #[test]
515    fn fp8_variant_names() {
516        assert_eq!(ModelVariant::FP8Bonsai8B.name(), "FP8-Bonsai-8B");
517        assert_eq!(ModelVariant::FP8Bonsai4B.name(), "FP8-Bonsai-4B");
518        assert_eq!(ModelVariant::FP8Bonsai1_7B.name(), "FP8-Bonsai-1.7B");
519    }
520
521    #[test]
522    fn fp8_variants_are_known() {
523        assert!(ModelVariant::FP8Bonsai8B.is_known());
524        assert!(ModelVariant::FP8Bonsai4B.is_known());
525        assert!(ModelVariant::FP8Bonsai1_7B.is_known());
526    }
527}