1use oxibonsai_core::config::Qwen3Config;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum ModelVariant {
12 Bonsai8B,
14 Bonsai4B,
16 Bonsai1_7B,
18 TernaryBonsai8B,
20 TernaryBonsai4B,
22 TernaryBonsai1_7B,
24 FP8Bonsai8B,
26 FP8Bonsai4B,
28 FP8Bonsai1_7B,
30 Custom,
32}
33
34impl ModelVariant {
35 pub fn from_config(config: &Qwen3Config) -> Self {
40 match (config.num_layers, config.hidden_size) {
41 (36, 4096) => ModelVariant::Bonsai8B,
42 (24, 2560) => ModelVariant::Bonsai4B,
43 (16, 1536) => ModelVariant::Bonsai1_7B,
44 _ => ModelVariant::Custom,
45 }
46 }
47
48 pub fn from_config_and_sample_tensor_type(
53 config: &Qwen3Config,
54 sample_tensor_type: oxibonsai_core::GgufTensorType,
55 ) -> Self {
56 let base = Self::from_config(config);
57 if sample_tensor_type.is_ternary() {
58 match base {
59 Self::Bonsai8B => Self::TernaryBonsai8B,
60 Self::Bonsai4B => Self::TernaryBonsai4B,
61 Self::Bonsai1_7B => Self::TernaryBonsai1_7B,
62 other => other, }
64 } else if sample_tensor_type.is_fp8() {
65 match base {
66 Self::Bonsai8B => Self::FP8Bonsai8B,
67 Self::Bonsai4B => Self::FP8Bonsai4B,
68 Self::Bonsai1_7B => Self::FP8Bonsai1_7B,
69 other => other, }
71 } else {
72 base
73 }
74 }
75
76 pub fn default_config(&self) -> Qwen3Config {
81 match self {
82 ModelVariant::Bonsai8B => Qwen3Config::bonsai_8b(),
83 ModelVariant::Bonsai4B => Qwen3Config::bonsai_4b(),
84 ModelVariant::Bonsai1_7B => Qwen3Config::bonsai_1_7b(),
85 ModelVariant::TernaryBonsai8B => Qwen3Config::ternary_bonsai_8b(),
86 ModelVariant::TernaryBonsai4B => Qwen3Config::ternary_bonsai_4b(),
87 ModelVariant::TernaryBonsai1_7B => Qwen3Config::ternary_bonsai_1_7b(),
88 ModelVariant::FP8Bonsai8B => Qwen3Config::bonsai_8b(),
90 ModelVariant::FP8Bonsai4B => Qwen3Config::bonsai_4b(),
91 ModelVariant::FP8Bonsai1_7B => Qwen3Config::bonsai_1_7b(),
92 ModelVariant::Custom => Qwen3Config::bonsai_8b(),
93 }
94 }
95
96 pub fn name(&self) -> &'static str {
98 match self {
99 ModelVariant::Bonsai8B => "Bonsai-8B",
100 ModelVariant::Bonsai4B => "Bonsai-4B",
101 ModelVariant::Bonsai1_7B => "Bonsai-1.7B",
102 ModelVariant::TernaryBonsai8B => "Ternary-Bonsai-8B",
103 ModelVariant::TernaryBonsai4B => "Ternary-Bonsai-4B",
104 ModelVariant::TernaryBonsai1_7B => "Ternary-Bonsai-1.7B",
105 ModelVariant::FP8Bonsai8B => "FP8-Bonsai-8B",
106 ModelVariant::FP8Bonsai4B => "FP8-Bonsai-4B",
107 ModelVariant::FP8Bonsai1_7B => "FP8-Bonsai-1.7B",
108 ModelVariant::Custom => "Custom",
109 }
110 }
111
112 pub fn param_count(&self) -> u64 {
119 match self {
120 ModelVariant::Bonsai8B | ModelVariant::TernaryBonsai8B | ModelVariant::FP8Bonsai8B => {
121 8_030_000_000
130 }
131 ModelVariant::Bonsai4B | ModelVariant::TernaryBonsai4B | ModelVariant::FP8Bonsai4B => {
132 4_020_000_000
137 }
138 ModelVariant::Bonsai1_7B
139 | ModelVariant::TernaryBonsai1_7B
140 | ModelVariant::FP8Bonsai1_7B => {
141 1_720_000_000
143 }
144 ModelVariant::Custom => 0,
145 }
146 }
147
148 pub fn expected_model_size_bytes(&self) -> u64 {
154 match self {
155 ModelVariant::Bonsai8B => {
156 2_200_000_000
162 }
163 ModelVariant::Bonsai4B => {
164 1_300_000_000
168 }
169 ModelVariant::Bonsai1_7B => {
170 700_000_000
174 }
175 ModelVariant::TernaryBonsai8B => {
176 1_750_000_000
183 }
184 ModelVariant::TernaryBonsai4B => {
185 900_000_000
188 }
189 ModelVariant::TernaryBonsai1_7B => {
190 390_000_000
193 }
194 ModelVariant::FP8Bonsai8B => {
195 8_500_000_000
200 }
201 ModelVariant::FP8Bonsai4B => {
202 5_000_000_000
204 }
205 ModelVariant::FP8Bonsai1_7B => {
206 2_300_000_000
208 }
209 ModelVariant::Custom => 0,
210 }
211 }
212
213 pub fn known_variants() -> &'static [ModelVariant] {
215 &[
216 ModelVariant::Bonsai8B,
217 ModelVariant::Bonsai4B,
218 ModelVariant::Bonsai1_7B,
219 ModelVariant::TernaryBonsai8B,
220 ModelVariant::TernaryBonsai4B,
221 ModelVariant::TernaryBonsai1_7B,
222 ModelVariant::FP8Bonsai8B,
223 ModelVariant::FP8Bonsai4B,
224 ModelVariant::FP8Bonsai1_7B,
225 ]
226 }
227
228 pub fn is_known(&self) -> bool {
230 !matches!(self, ModelVariant::Custom)
231 }
232}
233
234impl std::fmt::Display for ModelVariant {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 write!(f, "{}", self.name())
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn detect_bonsai_8b() {
246 let config = Qwen3Config::bonsai_8b();
247 assert_eq!(ModelVariant::from_config(&config), ModelVariant::Bonsai8B);
248 assert_eq!(ModelVariant::Bonsai8B.name(), "Bonsai-8B");
249 assert!(ModelVariant::Bonsai8B.is_known());
250 }
251
252 #[test]
253 fn detect_bonsai_4b() {
254 let config = Qwen3Config::bonsai_4b();
255 assert_eq!(ModelVariant::from_config(&config), ModelVariant::Bonsai4B);
256 assert_eq!(ModelVariant::Bonsai4B.name(), "Bonsai-4B");
257 assert!(ModelVariant::Bonsai4B.is_known());
258 }
259
260 #[test]
261 fn detect_bonsai_1_7b() {
262 let config = Qwen3Config::bonsai_1_7b();
263 assert_eq!(ModelVariant::from_config(&config), ModelVariant::Bonsai1_7B);
264 assert_eq!(ModelVariant::Bonsai1_7B.name(), "Bonsai-1.7B");
265 assert!(ModelVariant::Bonsai1_7B.is_known());
266 }
267
268 #[test]
269 fn detect_custom() {
270 let mut config = Qwen3Config::bonsai_8b();
271 config.num_layers = 48;
272 config.hidden_size = 8192;
273 assert_eq!(ModelVariant::from_config(&config), ModelVariant::Custom);
274 assert_eq!(ModelVariant::Custom.name(), "Custom");
275 assert!(!ModelVariant::Custom.is_known());
276 }
277
278 #[test]
279 fn default_configs_roundtrip() {
280 let one_bit_variants = [
285 ModelVariant::Bonsai8B,
286 ModelVariant::Bonsai4B,
287 ModelVariant::Bonsai1_7B,
288 ];
289 for variant in &one_bit_variants {
290 let config = variant.default_config();
291 let detected = ModelVariant::from_config(&config);
292 assert_eq!(
293 *variant, detected,
294 "variant {:?} config should round-trip",
295 variant
296 );
297 }
298 }
299
300 #[test]
301 fn param_counts_are_reasonable() {
302 assert!(ModelVariant::Bonsai8B.param_count() > 7_000_000_000);
303 assert!(ModelVariant::Bonsai8B.param_count() < 10_000_000_000);
304
305 assert!(ModelVariant::Bonsai4B.param_count() > 3_000_000_000);
306 assert!(ModelVariant::Bonsai4B.param_count() < 5_000_000_000);
307
308 assert!(ModelVariant::Bonsai1_7B.param_count() > 1_000_000_000);
309 assert!(ModelVariant::Bonsai1_7B.param_count() < 2_500_000_000);
310
311 assert_eq!(ModelVariant::Custom.param_count(), 0);
312 }
313
314 #[test]
315 fn model_sizes_decrease_with_variant() {
316 let size_8b = ModelVariant::Bonsai8B.expected_model_size_bytes();
317 let size_4b = ModelVariant::Bonsai4B.expected_model_size_bytes();
318 let size_1_7b = ModelVariant::Bonsai1_7B.expected_model_size_bytes();
319
320 assert!(size_8b > size_4b, "8B should be larger than 4B");
321 assert!(size_4b > size_1_7b, "4B should be larger than 1.7B");
322 assert!(size_1_7b > 0, "1.7B should have nonzero size");
323 }
324
325 #[test]
326 fn display_trait() {
327 assert_eq!(format!("{}", ModelVariant::Bonsai8B), "Bonsai-8B");
328 assert_eq!(format!("{}", ModelVariant::Custom), "Custom");
329 }
330
331 #[test]
332 fn known_variants_list() {
333 let variants = ModelVariant::known_variants();
334 assert_eq!(variants.len(), 9);
335 assert!(variants.contains(&ModelVariant::Bonsai8B));
336 assert!(variants.contains(&ModelVariant::Bonsai4B));
337 assert!(variants.contains(&ModelVariant::Bonsai1_7B));
338 assert!(variants.contains(&ModelVariant::TernaryBonsai8B));
339 assert!(variants.contains(&ModelVariant::TernaryBonsai4B));
340 assert!(variants.contains(&ModelVariant::TernaryBonsai1_7B));
341 assert!(variants.contains(&ModelVariant::FP8Bonsai8B));
342 assert!(variants.contains(&ModelVariant::FP8Bonsai4B));
343 assert!(variants.contains(&ModelVariant::FP8Bonsai1_7B));
344 }
345
346 #[test]
347 fn detect_ternary_8b_by_tensor_type() {
348 let cfg = Qwen3Config::ternary_bonsai_8b();
349 let variant = ModelVariant::from_config_and_sample_tensor_type(
350 &cfg,
351 oxibonsai_core::GgufTensorType::TQ2_0_g128,
352 );
353 assert_eq!(variant, ModelVariant::TernaryBonsai8B);
354 }
355
356 #[test]
357 fn detect_bonsai_8b_stays_1bit() {
358 let cfg = Qwen3Config::bonsai_8b();
359 let variant = ModelVariant::from_config_and_sample_tensor_type(
360 &cfg,
361 oxibonsai_core::GgufTensorType::Q1_0_g128,
362 );
363 assert_eq!(variant, ModelVariant::Bonsai8B);
364 }
365
366 #[test]
367 fn ternary_variant_param_counts_match_bonsai() {
368 assert_eq!(
369 ModelVariant::TernaryBonsai8B.param_count(),
370 ModelVariant::Bonsai8B.param_count()
371 );
372 assert_eq!(
373 ModelVariant::TernaryBonsai4B.param_count(),
374 ModelVariant::Bonsai4B.param_count()
375 );
376 assert_eq!(
377 ModelVariant::TernaryBonsai1_7B.param_count(),
378 ModelVariant::Bonsai1_7B.param_count()
379 );
380 }
381
382 #[test]
383 fn ternary_variant_expected_size_less_than_fp16() {
384 let ternary_size = ModelVariant::TernaryBonsai8B.expected_model_size_bytes();
386 assert!(
387 ternary_size < 2_000_000_000,
388 "8B ternary expected < 2 GB, got {}",
389 ternary_size
390 );
391 assert!(
392 ternary_size > 1_000_000_000,
393 "8B ternary expected > 1 GB, got {}",
394 ternary_size
395 );
396 }
397
398 #[test]
399 fn ternary_variants_are_known() {
400 assert!(ModelVariant::TernaryBonsai8B.is_known());
401 assert!(ModelVariant::TernaryBonsai4B.is_known());
402 assert!(ModelVariant::TernaryBonsai1_7B.is_known());
403 }
404
405 #[test]
406 fn ternary_variant_names() {
407 assert_eq!(ModelVariant::TernaryBonsai8B.name(), "Ternary-Bonsai-8B");
408 assert_eq!(ModelVariant::TernaryBonsai4B.name(), "Ternary-Bonsai-4B");
409 assert_eq!(
410 ModelVariant::TernaryBonsai1_7B.name(),
411 "Ternary-Bonsai-1.7B"
412 );
413 }
414
415 #[test]
416 fn ternary_display_trait() {
417 assert_eq!(
418 format!("{}", ModelVariant::TernaryBonsai8B),
419 "Ternary-Bonsai-8B"
420 );
421 assert_eq!(
422 format!("{}", ModelVariant::TernaryBonsai4B),
423 "Ternary-Bonsai-4B"
424 );
425 assert_eq!(
426 format!("{}", ModelVariant::TernaryBonsai1_7B),
427 "Ternary-Bonsai-1.7B"
428 );
429 }
430
431 #[test]
432 fn ternary_default_configs_roundtrip() {
433 let cfg_8b = ModelVariant::TernaryBonsai8B.default_config();
437 assert_eq!(cfg_8b.num_layers, 36);
438 assert_eq!(cfg_8b.hidden_size, 4096);
439
440 let cfg_4b = ModelVariant::TernaryBonsai4B.default_config();
441 assert_eq!(cfg_4b.num_layers, 24);
442 assert_eq!(cfg_4b.hidden_size, 2560);
443
444 let cfg_1_7b = ModelVariant::TernaryBonsai1_7B.default_config();
445 assert_eq!(cfg_1_7b.num_layers, 16);
446 assert_eq!(cfg_1_7b.hidden_size, 1536);
447 }
448
449 #[test]
450 fn detect_ternary_4b_and_1_7b_by_tensor_type() {
451 let cfg_4b = Qwen3Config::ternary_bonsai_4b();
452 let variant_4b = ModelVariant::from_config_and_sample_tensor_type(
453 &cfg_4b,
454 oxibonsai_core::GgufTensorType::TQ2_0_g128,
455 );
456 assert_eq!(variant_4b, ModelVariant::TernaryBonsai4B);
457
458 let cfg_1_7b = Qwen3Config::ternary_bonsai_1_7b();
459 let variant_1_7b = ModelVariant::from_config_and_sample_tensor_type(
460 &cfg_1_7b,
461 oxibonsai_core::GgufTensorType::TQ2_0_g128,
462 );
463 assert_eq!(variant_1_7b, ModelVariant::TernaryBonsai1_7B);
464 }
465
466 #[test]
467 fn custom_stays_custom_with_ternary_type() {
468 let mut cfg = Qwen3Config::bonsai_8b();
469 cfg.num_layers = 48;
470 cfg.hidden_size = 8192;
471 let variant = ModelVariant::from_config_and_sample_tensor_type(
472 &cfg,
473 oxibonsai_core::GgufTensorType::TQ2_0_g128,
474 );
475 assert_eq!(variant, ModelVariant::Custom);
476 }
477
478 #[test]
479 fn detect_fp8_e4m3_8b_by_tensor_type() {
480 let cfg = Qwen3Config::bonsai_8b();
481 let variant = ModelVariant::from_config_and_sample_tensor_type(
482 &cfg,
483 oxibonsai_core::GgufTensorType::F8_E4M3,
484 );
485 assert_eq!(variant, ModelVariant::FP8Bonsai8B);
486 }
487
488 #[test]
489 fn detect_fp8_e5m2_1_7b_by_tensor_type() {
490 let cfg = Qwen3Config::bonsai_1_7b();
491 let variant = ModelVariant::from_config_and_sample_tensor_type(
492 &cfg,
493 oxibonsai_core::GgufTensorType::F8_E5M2,
494 );
495 assert_eq!(variant, ModelVariant::FP8Bonsai1_7B);
496 }
497
498 #[test]
499 fn fp8_variant_param_counts_match_bonsai() {
500 assert_eq!(
501 ModelVariant::FP8Bonsai8B.param_count(),
502 ModelVariant::Bonsai8B.param_count()
503 );
504 assert_eq!(
505 ModelVariant::FP8Bonsai4B.param_count(),
506 ModelVariant::Bonsai4B.param_count()
507 );
508 assert_eq!(
509 ModelVariant::FP8Bonsai1_7B.param_count(),
510 ModelVariant::Bonsai1_7B.param_count()
511 );
512 }
513
514 #[test]
515 fn fp8_variant_names() {
516 assert_eq!(ModelVariant::FP8Bonsai8B.name(), "FP8-Bonsai-8B");
517 assert_eq!(ModelVariant::FP8Bonsai4B.name(), "FP8-Bonsai-4B");
518 assert_eq!(ModelVariant::FP8Bonsai1_7B.name(), "FP8-Bonsai-1.7B");
519 }
520
521 #[test]
522 fn fp8_variants_are_known() {
523 assert!(ModelVariant::FP8Bonsai8B.is_known());
524 assert!(ModelVariant::FP8Bonsai4B.is_known());
525 assert!(ModelVariant::FP8Bonsai1_7B.is_known());
526 }
527}