1use serde::{Deserialize, Serialize};
6
7const LLAMA2_7B_INTERMEDIATE_SIZE: usize = 11008;
9const LLAMA2_13B_HIDDEN_SIZE: usize = 5120;
10const LLAMA2_13B_INTERMEDIATE_SIZE: usize = 13824;
11const LLAMA_VOCAB_SIZE: usize = 32000;
12const MISTRAL_INTERMEDIATE_SIZE: usize = 14336;
13const MISTRAL_MAX_SEQ_LEN: usize = 32768;
14const QWEN2_0_5B_HIDDEN_SIZE: usize = 896;
15const QWEN2_0_5B_INTERMEDIATE_SIZE: usize = 4864;
16const QWEN2_VOCAB_SIZE: usize = 151936;
17const QWEN2_MAX_SEQ_LEN: usize = 32768;
18const QWEN2_ROPE_THETA: f32 = 1_000_000.0;
19const QWEN3_4B_HIDDEN_SIZE: usize = 2560;
20const QWEN3_4B_INTERMEDIATE_SIZE: usize = 9728;
21const QWEN3_5_9B_HIDDEN_SIZE: usize = 4096;
22const QWEN3_5_9B_INTERMEDIATE_SIZE: usize = 12288;
23const QWEN3_5_VOCAB_SIZE: usize = 248320;
24const QWEN3_5_MAX_SEQ_LEN: usize = 262144;
25const DEFAULT_ROPE_THETA: f32 = 10000.0;
26
27const CODEBERT_HIDDEN_SIZE: usize = 768;
29const CODEBERT_INTERMEDIATE_SIZE: usize = 3072;
30const CODEBERT_VOCAB_SIZE: usize = 50265;
31const CODEBERT_MAX_POSITION: usize = 514; #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
37#[serde(rename_all = "snake_case")]
38pub enum ModelArchitecture {
39 #[default]
41 Decoder,
42 Encoder,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct TransformerConfig {
49 pub hidden_size: usize,
51 pub num_attention_heads: usize,
53 pub num_kv_heads: usize,
55 pub intermediate_size: usize,
57 pub num_hidden_layers: usize,
59 pub vocab_size: usize,
61 pub max_position_embeddings: usize,
63 pub rms_norm_eps: f32,
65 pub rope_theta: f32,
67 pub use_bias: bool,
69 #[serde(default)]
72 pub head_dim_override: Option<usize>,
73 #[serde(default)]
76 pub architecture: ModelArchitecture,
77 #[serde(default)]
80 pub hf_architecture: Option<String>,
81 #[serde(default)]
84 pub hf_model_type: Option<String>,
85 #[serde(default)]
88 pub tie_word_embeddings: bool,
89}
90
91impl TransformerConfig {
92 pub fn llama2_7b() -> Self {
94 Self {
95 hidden_size: 4096,
96 num_attention_heads: 32,
97 num_kv_heads: 32,
98 intermediate_size: LLAMA2_7B_INTERMEDIATE_SIZE,
99 num_hidden_layers: 32,
100 vocab_size: LLAMA_VOCAB_SIZE,
101 max_position_embeddings: 4096,
102 rms_norm_eps: 1e-6,
103 rope_theta: DEFAULT_ROPE_THETA,
104 use_bias: false,
105 head_dim_override: None,
106 architecture: ModelArchitecture::Decoder,
107 hf_architecture: None,
108 hf_model_type: None,
109 tie_word_embeddings: false,
110 }
111 }
112
113 pub fn llama2_13b() -> Self {
115 Self {
116 hidden_size: LLAMA2_13B_HIDDEN_SIZE,
117 num_attention_heads: 40,
118 num_kv_heads: 40,
119 intermediate_size: LLAMA2_13B_INTERMEDIATE_SIZE,
120 num_hidden_layers: 40,
121 vocab_size: LLAMA_VOCAB_SIZE,
122 max_position_embeddings: 4096,
123 rms_norm_eps: 1e-6,
124 rope_theta: DEFAULT_ROPE_THETA,
125 use_bias: false,
126 head_dim_override: None,
127 architecture: ModelArchitecture::Decoder,
128 hf_architecture: None,
129 hf_model_type: None,
130 tie_word_embeddings: false,
131 }
132 }
133
134 pub fn mistral_7b() -> Self {
136 Self {
137 hidden_size: 4096,
138 num_attention_heads: 32,
139 num_kv_heads: 8, intermediate_size: MISTRAL_INTERMEDIATE_SIZE,
141 num_hidden_layers: 32,
142 vocab_size: LLAMA_VOCAB_SIZE,
143 max_position_embeddings: MISTRAL_MAX_SEQ_LEN,
144 rms_norm_eps: 1e-5,
145 rope_theta: DEFAULT_ROPE_THETA,
146 use_bias: false,
147 head_dim_override: None,
148 architecture: ModelArchitecture::Decoder,
149 hf_architecture: None,
150 hf_model_type: None,
151 tie_word_embeddings: false,
152 }
153 }
154
155 pub fn qwen2_0_5b() -> Self {
170 Self {
171 hidden_size: QWEN2_0_5B_HIDDEN_SIZE,
172 num_attention_heads: 14,
173 num_kv_heads: 2,
174 intermediate_size: QWEN2_0_5B_INTERMEDIATE_SIZE,
175 num_hidden_layers: 24,
176 vocab_size: QWEN2_VOCAB_SIZE,
177 max_position_embeddings: QWEN2_MAX_SEQ_LEN,
178 rms_norm_eps: 1e-6,
179 rope_theta: QWEN2_ROPE_THETA,
180 use_bias: true,
181 head_dim_override: None,
182 architecture: ModelArchitecture::Decoder,
183 hf_architecture: None,
184 hf_model_type: None,
185 tie_word_embeddings: true,
186 }
187 }
188
189 #[rustfmt::skip]
191 pub fn qwen2_1_5b() -> Self { Self { hidden_size: 1536, num_attention_heads: 12, intermediate_size: 8960, num_hidden_layers: 28, vocab_size: 151936, ..Self::qwen2_0_5b() } }
192
193 pub fn qwen2_7b() -> Self {
198 Self {
199 hidden_size: 3584,
200 num_attention_heads: 28,
201 num_kv_heads: 4,
202 intermediate_size: 18944,
203 num_hidden_layers: 28,
204 vocab_size: 152064,
205 max_position_embeddings: QWEN2_MAX_SEQ_LEN,
206 rms_norm_eps: 1e-6,
207 rope_theta: QWEN2_ROPE_THETA,
208 use_bias: true,
209 head_dim_override: None,
210 architecture: ModelArchitecture::Decoder,
211 hf_architecture: None,
212 hf_model_type: None,
213 tie_word_embeddings: false,
214 }
215 }
216
217 pub fn qwen3_4b() -> Self {
222 Self {
223 hidden_size: QWEN3_4B_HIDDEN_SIZE,
224 num_attention_heads: 32,
225 num_kv_heads: 8,
226 intermediate_size: QWEN3_4B_INTERMEDIATE_SIZE,
227 num_hidden_layers: 36,
228 vocab_size: QWEN2_VOCAB_SIZE, max_position_embeddings: 40960,
230 rms_norm_eps: 1e-6,
231 rope_theta: QWEN2_ROPE_THETA, use_bias: false, head_dim_override: Some(128), architecture: ModelArchitecture::Decoder,
235 hf_architecture: None,
236 hf_model_type: None,
237 tie_word_embeddings: false,
238 }
239 }
240
241 pub fn qwen3_5_9b() -> Self {
247 Self {
248 hidden_size: QWEN3_5_9B_HIDDEN_SIZE,
249 num_attention_heads: 16,
250 num_kv_heads: 4,
251 intermediate_size: QWEN3_5_9B_INTERMEDIATE_SIZE,
252 num_hidden_layers: 32,
253 vocab_size: QWEN3_5_VOCAB_SIZE,
254 max_position_embeddings: QWEN3_5_MAX_SEQ_LEN,
255 rms_norm_eps: 1e-6,
256 rope_theta: QWEN2_ROPE_THETA, use_bias: false, head_dim_override: None, architecture: ModelArchitecture::Decoder,
260 hf_architecture: None,
261 hf_model_type: None,
262 tie_word_embeddings: false,
263 }
264 }
265
266 pub fn from_apr_metadata(
278 hidden_size: Option<usize>,
279 num_heads: Option<usize>,
280 num_kv_heads: Option<usize>,
281 intermediate_size: Option<usize>,
282 num_layers: Option<usize>,
283 vocab_size: Option<usize>,
284 max_position_embeddings: Option<usize>,
285 rms_norm_eps: Option<f32>,
286 rope_theta: Option<f32>,
287 architecture: Option<&str>,
288 ) -> Option<Self> {
289 let hidden = hidden_size?;
290 let heads = num_heads?;
291 let layers = num_layers?;
292 let vocab = vocab_size?;
293 let intermediate = intermediate_size?;
294
295 let (use_bias, head_dim_override) = match architecture {
298 Some(a) if a.starts_with("qwen3") => {
299 let computed = hidden / heads;
301 let override_dim = if computed == 128 { None } else { Some(128) };
302 (false, override_dim)
303 }
304 Some(a) if a.starts_with("qwen2") => (true, None),
305 _ => (false, None),
306 };
307
308 Some(Self {
309 hidden_size: hidden,
310 num_attention_heads: heads,
311 num_kv_heads: num_kv_heads.unwrap_or(heads),
312 intermediate_size: intermediate,
313 num_hidden_layers: layers,
314 vocab_size: vocab,
315 max_position_embeddings: max_position_embeddings.unwrap_or(32768),
316 rms_norm_eps: rms_norm_eps.unwrap_or(1e-6),
317 rope_theta: rope_theta.unwrap_or(DEFAULT_ROPE_THETA),
318 use_bias,
319 head_dim_override,
320 architecture: match architecture {
321 Some(a) if a.contains("bert") || a.contains("roberta") => {
322 ModelArchitecture::Encoder
323 }
324 _ => ModelArchitecture::Decoder,
325 },
326 hf_architecture: None,
327 hf_model_type: None,
328 tie_word_embeddings: false,
329 })
330 }
331
332 pub fn from_size_str(size: &str) -> Result<Self, String> {
338 match size {
339 "codebert" | "codebert-base" | "125M" => Ok(Self::codebert()),
340 "0.5B" | "500M" | "qwen2-0.5b" => Ok(Self::qwen2_0_5b()),
341 "1.5B" | "qwen2.5-1.5b" | "qwen2-1.5b" => Ok(Self::qwen2_1_5b()),
342 "7B" | "qwen2.5-7b" => Ok(Self::qwen2_7b()),
343 "4B" | "qwen3-4b" | "qwen3" => Ok(Self::qwen3_4b()),
344 "9B" | "qwen3.5-9b" | "qwen3_5" | "qwen3.5" => Ok(Self::qwen3_5_9b()),
345 unknown => Err(format!(
346 "Unknown model size '{unknown}'. Known sizes: codebert, 0.5B, 4B, 7B, 9B"
347 )),
348 }
349 }
350
351 pub fn codebert() -> Self {
356 Self {
357 hidden_size: CODEBERT_HIDDEN_SIZE,
358 num_attention_heads: 12,
359 num_kv_heads: 12, intermediate_size: CODEBERT_INTERMEDIATE_SIZE,
361 num_hidden_layers: 12,
362 vocab_size: CODEBERT_VOCAB_SIZE,
363 max_position_embeddings: CODEBERT_MAX_POSITION,
364 rms_norm_eps: 1e-5, rope_theta: 0.0, use_bias: true,
367 head_dim_override: None,
368 architecture: ModelArchitecture::Encoder,
369 hf_architecture: None,
370 hf_model_type: None,
371 tie_word_embeddings: false,
372 }
373 }
374
375 pub fn tiny() -> Self {
377 Self {
378 hidden_size: 64,
379 num_attention_heads: 2,
380 num_kv_heads: 2,
381 intermediate_size: 256,
382 num_hidden_layers: 2,
383 vocab_size: 1000,
384 max_position_embeddings: 512,
385 rms_norm_eps: 1e-6,
386 rope_theta: DEFAULT_ROPE_THETA,
387 use_bias: false,
388 head_dim_override: None,
389 architecture: ModelArchitecture::Decoder,
390 hf_architecture: None,
391 hf_model_type: None,
392 tie_word_embeddings: false,
393 }
394 }
395
396 pub fn is_encoder(&self) -> bool {
398 self.architecture == ModelArchitecture::Encoder
399 }
400
401 pub fn hf_architecture_name(&self) -> &str {
404 if let Some(ref name) = self.hf_architecture {
405 return name;
406 }
407 if self.is_encoder() {
409 "BertModel"
410 } else if self.use_bias && self.vocab_size > 150000 {
411 "Qwen2ForCausalLM"
413 } else {
414 "LlamaForCausalLM"
415 }
416 }
417
418 pub fn hf_model_type_str(&self) -> &str {
420 if let Some(ref mt) = self.hf_model_type {
421 return mt;
422 }
423 if self.is_encoder() {
424 "roberta"
425 } else if self.use_bias && self.vocab_size > 150000 {
426 "qwen2"
427 } else {
428 "llama"
429 }
430 }
431
432 pub fn ties_embeddings(&self) -> bool {
435 if self.tie_word_embeddings {
436 return true;
437 }
438 self.use_bias && self.vocab_size > 150000
440 }
441
442 pub fn head_dim(&self) -> usize {
447 self.head_dim_override.unwrap_or(self.hidden_size / self.num_attention_heads)
448 }
449
450 pub fn q_dim(&self) -> usize {
455 self.num_attention_heads * self.head_dim()
456 }
457
458 fn kv_dim(&self) -> usize {
468 self.num_kv_heads * self.head_dim()
469 }
470
471 pub fn per_layer_weight_elements(&self) -> usize {
475 let h = self.hidden_size;
476 let q = self.q_dim();
477 let kv = self.kv_dim();
478 let i = self.intermediate_size;
479 q * h + kv * h * 2 + h * q + i * h * 3 + h * 2
482 }
483
484 fn per_layer_grad_weight_elements(&self) -> usize {
489 let h = self.hidden_size;
490 let q = self.q_dim();
491 let kv = self.kv_dim();
492 let i = self.intermediate_size;
493 h * 2 + h * i * 3 + q * h + h * q + h * kv * 2
497 }
498
499 fn per_layer_scratch_linear_coeff(&self) -> usize {
503 let h = self.hidden_size;
504 let kv = self.kv_dim();
505 let i = self.intermediate_size;
506 let n = self.num_attention_heads;
507 let hd = self.head_dim();
508 h * 8 + kv * 2 + i * 4 + n * hd * 3
513 }
514
515 fn per_layer_scratch_quadratic_coeff(&self) -> (usize, usize) {
524 let n = self.num_attention_heads;
525 let hd = self.head_dim();
526 (n, n * hd) }
532
533 pub fn total_training_vram_bytes(&self, max_seq_len: usize) -> usize {
537 let l = self.num_hidden_layers;
538 let s = max_seq_len;
539 let hd = self.head_dim();
540
541 let constant_per_layer =
542 self.per_layer_weight_elements() + self.per_layer_grad_weight_elements();
543 let linear_per_layer = self.per_layer_scratch_linear_coeff() * s;
544
545 let (n_quad, n_hd_linear) = self.per_layer_scratch_quadratic_coeff();
546 let quadratic_per_layer =
547 if s >= hd { 2 * n_quad * s * s } else { n_quad * s * s + n_hd_linear * s };
548
549 let elements_per_layer = constant_per_layer + linear_per_layer + quadratic_per_layer;
550 l * elements_per_layer * 4 }
552
553 pub fn total_training_vram_bytes_shared(&self, max_seq_len: usize) -> usize {
560 let l = self.num_hidden_layers;
561 let s = max_seq_len;
562 let hd = self.head_dim();
563
564 let weights_total = l * self.per_layer_weight_elements();
566
567 let grad_weights_shared = self.per_layer_grad_weight_elements();
569
570 let linear_shared = self.per_layer_scratch_linear_coeff() * s;
572 let (n_quad, n_hd_linear) = self.per_layer_scratch_quadratic_coeff();
573 let quadratic_shared =
574 if s >= hd { 2 * n_quad * s * s } else { n_quad * s * s + n_hd_linear * s };
575
576 let total_elements = weights_total + grad_weights_shared + linear_shared + quadratic_shared;
577 total_elements * 4 }
579
580 pub fn max_seq_len_for_vram_shared(&self, vram_bytes: usize) -> Option<usize> {
586 if self.total_training_vram_bytes_shared(1) > vram_bytes {
587 return None;
588 }
589
590 let mut lo: usize = 1;
591 let mut hi: usize = self.max_position_embeddings;
592
593 while lo < hi {
594 let mid = lo + (hi - lo).div_ceil(2);
595 if self.total_training_vram_bytes_shared(mid) <= vram_bytes {
596 lo = mid;
597 } else {
598 hi = mid - 1;
599 }
600 }
601
602 Some(lo)
603 }
604
605 pub fn max_seq_len_for_vram(&self, vram_bytes: usize) -> Option<usize> {
613 if self.total_training_vram_bytes(1) > vram_bytes {
614 return None;
615 }
616
617 let mut lo: usize = 1;
618 let mut hi: usize = self.max_position_embeddings;
619
620 while lo < hi {
621 let mid = lo + (hi - lo).div_ceil(2);
622 if self.total_training_vram_bytes(mid) <= vram_bytes {
623 lo = mid;
624 } else {
625 hi = mid - 1;
626 }
627 }
628
629 Some(lo)
630 }
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636
637 #[test]
638 fn test_transformer_config_llama2() {
639 let config = TransformerConfig::llama2_7b();
640 assert_eq!(config.hidden_size, 4096);
641 assert_eq!(config.num_attention_heads, 32);
642 assert_eq!(config.head_dim(), 128);
643 }
644
645 #[test]
646 fn test_transformer_config_tiny() {
647 let config = TransformerConfig::tiny();
648 assert_eq!(config.hidden_size, 64);
649 assert_eq!(config.num_attention_heads, 2);
650 assert_eq!(config.head_dim(), 32);
651 }
652
653 #[test]
654 fn test_config_serialization() {
655 let config = TransformerConfig::llama2_7b();
656 let json = serde_json::to_string(&config).expect("JSON serialization should succeed");
657 let restored: TransformerConfig =
658 serde_json::from_str(&json).expect("JSON deserialization should succeed");
659 assert_eq!(restored.hidden_size, config.hidden_size);
660 assert_eq!(restored.num_attention_heads, config.num_attention_heads);
661 }
662
663 #[test]
664 fn test_mistral_config() {
665 let config = TransformerConfig::mistral_7b();
666 assert_eq!(config.num_kv_heads, 8); assert_eq!(config.num_attention_heads, 32);
668 }
670
671 #[test]
680 fn qwen2_0_5b_matches_hf_config_2026_05_04() {
681 let config = TransformerConfig::qwen2_0_5b();
682 assert_eq!(config.hidden_size, 896, "hidden_size");
683 assert_eq!(config.num_attention_heads, 14, "num_attention_heads");
684 assert_eq!(config.num_kv_heads, 2, "num_kv_heads (GQA-7:1)");
685 assert_eq!(config.intermediate_size, 4864, "intermediate_size");
686 assert_eq!(config.num_hidden_layers, 24, "num_hidden_layers");
687 assert_eq!(config.vocab_size, 151_936, "vocab_size");
688 assert_eq!(config.max_position_embeddings, 32_768, "max_position_embeddings");
689 assert!(
690 (config.rms_norm_eps - 1e-6).abs() < f32::EPSILON,
691 "rms_norm_eps={}, want 1e-6",
692 config.rms_norm_eps
693 );
694 assert!(
695 (config.rope_theta - 1_000_000.0).abs() < f32::EPSILON,
696 "rope_theta={}, want 1_000_000.0",
697 config.rope_theta
698 );
699 assert!(config.use_bias, "use_bias must be true (Qwen2 quirk)");
700 assert!(
701 config.tie_word_embeddings,
702 "tie_word_embeddings must be true for Qwen2.5 0.5B (HF config 2026-05-04)"
703 );
704 assert_eq!(config.architecture, ModelArchitecture::Decoder);
705 assert_eq!(config.num_attention_heads / config.num_kv_heads, 7);
707 }
708
709 #[test]
713 fn qwen2_1_5b_inherits_tie_word_embeddings_from_0_5b() {
714 let parent = TransformerConfig::qwen2_0_5b();
715 let child = TransformerConfig::qwen2_1_5b();
716 assert_eq!(
717 child.tie_word_embeddings, parent.tie_word_embeddings,
718 "qwen2_1_5b must inherit tie_word_embeddings from qwen2_0_5b — both are HF tie=true"
719 );
720 assert!(
721 child.tie_word_embeddings,
722 "qwen2_1_5b tie_word_embeddings must be true (HF config 2026-05-04)"
723 );
724 }
725
726 #[test]
730 fn qwen2_7b_does_not_tie_embeddings() {
731 let config = TransformerConfig::qwen2_7b();
732 assert!(
733 !config.tie_word_embeddings,
734 "qwen2_7b tie_word_embeddings MUST be false per HF config 2026-05-04 — \
735 larger Qwen variants pay param cost for untied weights"
736 );
737 }
738
739 #[test]
740 fn test_qwen2_config() {
741 let config = TransformerConfig::qwen2_0_5b();
742 assert!(config.use_bias);
743 assert_eq!(config.vocab_size, 151936);
744 }
745
746 #[test]
747 fn test_llama2_13b_config() {
748 let config = TransformerConfig::llama2_13b();
749 assert_eq!(config.hidden_size, 5120);
750 assert_eq!(config.num_attention_heads, 40);
751 assert_eq!(config.num_hidden_layers, 40);
752 assert_eq!(config.head_dim(), 128); }
754
755 #[test]
756 fn test_config_yaml_serialization() {
757 let config = TransformerConfig::tiny();
758 let yaml = serde_yaml::to_string(&config).expect("config should be valid");
759 let restored: TransformerConfig =
760 serde_yaml::from_str(&yaml).expect("config should be valid");
761 assert_eq!(restored.hidden_size, config.hidden_size);
762 assert_eq!(restored.num_hidden_layers, config.num_hidden_layers);
763 }
764
765 #[test]
766 fn test_grouped_query_attention_ratio() {
767 let config = TransformerConfig::mistral_7b();
768 let heads_per_kv = config.num_attention_heads / config.num_kv_heads;
769 assert_eq!(heads_per_kv, 4); }
771
772 #[test]
773 fn test_config_clone() {
774 let config = TransformerConfig::llama2_7b();
775 let cloned = config.clone();
776 assert_eq!(config.hidden_size, cloned.hidden_size);
777 assert_eq!(config.vocab_size, cloned.vocab_size);
778 }
779
780 #[test]
781 fn test_qwen3_5_9b_config() {
782 let config = TransformerConfig::qwen3_5_9b();
783 assert_eq!(config.hidden_size, 4096);
784 assert_eq!(config.num_attention_heads, 16);
785 assert_eq!(config.num_kv_heads, 4);
786 assert_eq!(config.intermediate_size, 12288);
787 assert_eq!(config.num_hidden_layers, 32);
788 assert_eq!(config.vocab_size, 248320);
789 assert_eq!(config.max_position_embeddings, 262144);
790 assert!(!config.use_bias);
791 }
792
793 #[test]
794 fn test_qwen3_5_9b_head_dim() {
795 let config = TransformerConfig::qwen3_5_9b();
796 assert_eq!(config.head_dim(), 256);
798 }
799
800 #[test]
801 fn test_qwen3_5_9b_gqa_ratio() {
802 let config = TransformerConfig::qwen3_5_9b();
803 let heads_per_kv = config.num_attention_heads / config.num_kv_heads;
804 assert_eq!(heads_per_kv, 4); }
806
807 #[test]
812 fn test_from_apr_metadata_qwen3_8b() {
813 let config = TransformerConfig::from_apr_metadata(
815 Some(4096), Some(32), Some(8), Some(12288), Some(36), Some(151936), Some(40960), Some(1e-6), Some(1e6), Some("qwen3"),
825 )
826 .expect("all required fields present");
827
828 assert_eq!(config.hidden_size, 4096);
829 assert_eq!(config.num_attention_heads, 32);
830 assert_eq!(config.num_kv_heads, 8);
831 assert_eq!(config.num_hidden_layers, 36);
832 assert_eq!(config.vocab_size, 151936);
833 assert_eq!(config.head_dim(), 128); assert!(!config.use_bias); }
836
837 #[test]
838 fn test_from_apr_metadata_qwen2_7b() {
839 let config = TransformerConfig::from_apr_metadata(
841 Some(3584),
842 Some(28),
843 Some(4),
844 Some(18944),
845 Some(28),
846 Some(152064),
847 Some(32768),
848 Some(1e-6),
849 Some(1e6),
850 Some("qwen2"),
851 )
852 .expect("all required fields present");
853
854 assert!(config.use_bias); assert_eq!(config.head_dim(), 128); }
857
858 #[test]
859 fn test_from_apr_metadata_missing_required_returns_none() {
860 assert!(TransformerConfig::from_apr_metadata(
862 None,
863 Some(32),
864 Some(8),
865 Some(12288),
866 Some(36),
867 Some(151936),
868 Some(40960),
869 Some(1e-6),
870 Some(1e6),
871 Some("qwen3"),
872 )
873 .is_none());
874
875 assert!(TransformerConfig::from_apr_metadata(
877 Some(4096),
878 Some(32),
879 Some(8),
880 Some(12288),
881 None,
882 Some(151936),
883 Some(40960),
884 Some(1e-6),
885 Some(1e6),
886 Some("qwen3"),
887 )
888 .is_none());
889 }
890
891 #[test]
899 fn falsify_vram_monotonic_in_seq_len() {
900 let config = TransformerConfig::qwen3_4b();
902 let mut prev = config.total_training_vram_bytes(1);
903 for s in [2, 4, 8, 16, 32, 64, 128, 256, 512] {
904 let cur = config.total_training_vram_bytes(s);
905 assert!(
906 cur > prev,
907 "VRAM must increase: seq_len={s} ({cur}) should exceed prev ({prev})"
908 );
909 prev = cur;
910 }
911 }
912
913 #[test]
914 fn falsify_vram_solver_postcondition() {
915 let config = TransformerConfig::qwen3_4b();
917 let budget = 24 * 1024 * 1024 * 1024_usize; if let Some(max_s) = config.max_seq_len_for_vram(budget) {
919 let used = config.total_training_vram_bytes(max_s);
920 assert!(
921 used <= budget,
922 "Solver returned seq_len={max_s} using {used} bytes > budget {budget}"
923 );
924 if max_s < config.max_position_embeddings {
926 let over = config.total_training_vram_bytes(max_s + 1);
927 assert!(
928 over > budget,
929 "Solver not tight: seq_len={} uses {over} <= budget {budget}",
930 max_s + 1
931 );
932 }
933 }
934 }
935
936 #[test]
937 fn falsify_vram_solver_returns_none_when_impossible() {
938 let config = TransformerConfig::qwen3_4b();
940 let tiny_budget = 1024; assert!(
942 config.max_seq_len_for_vram(tiny_budget).is_none(),
943 "Solver should return None when budget is too small"
944 );
945 }
946
947 #[test]
948 fn falsify_qwen3_4b_vram_matches_oom_observation() {
949 let config = TransformerConfig::qwen3_4b();
952 let vram_512 = config.total_training_vram_bytes(512);
953 let usable_vram = 23 * 1024 * 1024 * 1024_usize; let vram_1 = config.total_training_vram_bytes(1);
957 let shared_128 = config.total_training_vram_bytes_shared(128);
958 let shared_512 = config.total_training_vram_bytes_shared(512);
959 let solved = config.max_seq_len_for_vram_shared(24 * 1024 * 1024 * 1024);
960 eprintln!("=== Qwen3-4B VRAM Budget ===");
961 eprintln!(
962 " Per-layer weights: {:.1} MB",
963 config.per_layer_weight_elements() as f64 * 4.0 / 1e6
964 );
965 eprintln!(
966 " Per-layer grad scratch: {:.1} MB",
967 config.per_layer_grad_weight_elements() as f64 * 4.0 / 1e6
968 );
969 eprintln!(" Per-layer (S=512): {:.1} MB", (vram_512 / 36) as f64 / 1e6);
970 eprintln!(" 36 layers S=1 (per-layer scratch): {:.1} GB", vram_1 as f64 / 1e9);
971 eprintln!(" 36 layers S=512 (per-layer scratch): {:.1} GB", vram_512 as f64 / 1e9);
972 eprintln!(" 36 layers S=128 (SHARED scratch): {:.1} GB", shared_128 as f64 / 1e9);
973 eprintln!(" 36 layers S=512 (SHARED scratch): {:.1} GB", shared_512 as f64 / 1e9);
974 eprintln!(" Max seq_len for 24 GB (shared): {solved:?}");
975
976 assert!(
977 vram_512 > usable_vram,
978 "Formula says {:.1} GB for seq_len=512, but we OOM'd on 23 GB — formula is wrong",
979 vram_512 as f64 / 1e9
980 );
981 }
982
983 #[test]
984 fn falsify_qwen2_0_5b_fits_on_4090() {
985 let config = TransformerConfig::qwen2_0_5b();
988 let vram_512 = config.total_training_vram_bytes(512);
989 let total_vram = 24 * 1024 * 1024 * 1024_usize;
990 assert!(
991 vram_512 < total_vram,
992 "Formula says {:.1} GB for Qwen2-0.5B at seq_len=512, but it fit on 4090",
993 vram_512 as f64 / 1e9
994 );
995 }
996
997 #[test]
998 fn falsify_vram_budget_concrete_values() {
999 let config = TransformerConfig::qwen3_4b();
1001
1002 let expected_weights =
1006 4096 * 2560 + 1024 * 2560 * 2 + 2560 * 4096 + 9728 * 2560 * 3 + 2560 * 2;
1007 assert_eq!(config.per_layer_weight_elements(), expected_weights);
1008
1009 let budget_24gb = 24 * 1024 * 1024 * 1024_usize;
1013 assert!(
1014 config.max_seq_len_for_vram(budget_24gb).is_none(),
1015 "Qwen3-4B per-layer scratch CANNOT fit 24 GB — proves shared scratch needed"
1016 );
1017
1018 let shared_budget = config.total_training_vram_bytes_shared(128);
1022 assert!(
1023 shared_budget < budget_24gb,
1024 "Qwen3-4B shared scratch at seq_len=128 should fit 24 GB, got {:.1} GB",
1025 shared_budget as f64 / 1e9
1026 );
1027 }
1028
1029 #[test]
1032 fn test_model_architecture_default() {
1033 let arch: ModelArchitecture = Default::default();
1034 assert_eq!(arch, ModelArchitecture::Decoder);
1035 }
1036
1037 #[test]
1038 fn test_model_architecture_serialization() {
1039 let encoder = ModelArchitecture::Encoder;
1040 let json = serde_json::to_string(&encoder).expect("serialize");
1041 assert_eq!(json, "\"encoder\"");
1042 let decoder = ModelArchitecture::Decoder;
1043 let json = serde_json::to_string(&decoder).expect("serialize");
1044 assert_eq!(json, "\"decoder\"");
1045
1046 let restored: ModelArchitecture = serde_json::from_str("\"encoder\"").expect("deserialize");
1047 assert_eq!(restored, ModelArchitecture::Encoder);
1048 }
1049
1050 #[test]
1051 fn test_codebert_config() {
1052 let config = TransformerConfig::codebert();
1053 assert_eq!(config.hidden_size, 768);
1054 assert_eq!(config.num_attention_heads, 12);
1055 assert_eq!(config.num_kv_heads, 12);
1056 assert_eq!(config.intermediate_size, 3072);
1057 assert_eq!(config.num_hidden_layers, 12);
1058 assert_eq!(config.vocab_size, 50265);
1059 assert_eq!(config.max_position_embeddings, 514);
1060 assert!(config.use_bias);
1061 assert_eq!(config.architecture, ModelArchitecture::Encoder);
1062 assert!(config.is_encoder());
1063 assert_eq!(config.head_dim(), 64); }
1065
1066 #[test]
1067 fn test_is_encoder() {
1068 assert!(TransformerConfig::codebert().is_encoder());
1069 assert!(!TransformerConfig::llama2_7b().is_encoder());
1070 assert!(!TransformerConfig::tiny().is_encoder());
1071 assert!(!TransformerConfig::qwen2_0_5b().is_encoder());
1072 }
1073
1074 #[test]
1075 fn test_hf_architecture_name_inferred() {
1076 assert_eq!(TransformerConfig::codebert().hf_architecture_name(), "BertModel");
1078 assert_eq!(TransformerConfig::qwen2_0_5b().hf_architecture_name(), "Qwen2ForCausalLM");
1080 assert_eq!(TransformerConfig::llama2_7b().hf_architecture_name(), "LlamaForCausalLM");
1082 }
1083
1084 #[test]
1085 fn test_hf_architecture_name_override() {
1086 let mut config = TransformerConfig::tiny();
1087 config.hf_architecture = Some("CustomModel".to_string());
1088 assert_eq!(config.hf_architecture_name(), "CustomModel");
1089 }
1090
1091 #[test]
1092 fn test_hf_model_type_str_inferred() {
1093 assert_eq!(TransformerConfig::codebert().hf_model_type_str(), "roberta");
1094 assert_eq!(TransformerConfig::qwen2_0_5b().hf_model_type_str(), "qwen2");
1095 assert_eq!(TransformerConfig::llama2_7b().hf_model_type_str(), "llama");
1096 }
1097
1098 #[test]
1099 fn test_hf_model_type_str_override() {
1100 let mut config = TransformerConfig::tiny();
1101 config.hf_model_type = Some("custom_type".to_string());
1102 assert_eq!(config.hf_model_type_str(), "custom_type");
1103 }
1104
1105 #[test]
1106 fn test_ties_embeddings() {
1107 assert!(TransformerConfig::qwen2_0_5b().ties_embeddings());
1109 assert!(!TransformerConfig::llama2_7b().ties_embeddings());
1111 let mut config = TransformerConfig::llama2_7b();
1113 config.tie_word_embeddings = true;
1114 assert!(config.ties_embeddings());
1115 }
1116
1117 #[test]
1118 fn test_head_dim_override() {
1119 let config = TransformerConfig::qwen3_4b();
1120 assert_eq!(config.head_dim_override, Some(128));
1121 assert_eq!(config.head_dim(), 128);
1122 assert_ne!(config.hidden_size / config.num_attention_heads, 128);
1124 }
1125
1126 #[test]
1127 fn test_head_dim_no_override() {
1128 let config = TransformerConfig::llama2_7b();
1129 assert!(config.head_dim_override.is_none());
1130 assert_eq!(config.head_dim(), 128); }
1132
1133 #[test]
1134 fn test_q_dim() {
1135 let config = TransformerConfig::qwen3_4b();
1136 assert_eq!(config.q_dim(), 4096);
1138
1139 let config = TransformerConfig::llama2_7b();
1140 assert_eq!(config.q_dim(), 4096);
1142 }
1143
1144 #[test]
1145 fn test_q_dim_differs_from_hidden() {
1146 let config = TransformerConfig::qwen3_4b();
1147 assert_ne!(config.q_dim(), config.hidden_size);
1149 }
1150
1151 #[test]
1162 fn test_qwen3_4b_projection_shapes() {
1163 let config = TransformerConfig::qwen3_4b();
1164
1165 assert_eq!(config.hidden_size, 2560);
1167 assert_eq!(config.num_attention_heads, 32);
1168 assert_eq!(config.num_kv_heads, 8);
1169 assert_eq!(config.head_dim(), 128);
1170 assert_eq!(config.head_dim_override, Some(128));
1171
1172 let q_dim = config.q_dim();
1174 let kv_dim = config.kv_dim();
1175 assert_eq!(q_dim, 4096); assert_eq!(kv_dim, 1024); let hidden = config.hidden_size;
1180 assert_eq!(q_dim * hidden, 10_485_760); assert_eq!(kv_dim * hidden, 2_621_440); assert_eq!(kv_dim * hidden, 2_621_440); assert_eq!(hidden * q_dim, 10_485_760); }
1185
1186 #[test]
1188 fn test_qwen3_4b_grad_weight_elements_uses_q_dim() {
1189 let config = TransformerConfig::qwen3_4b();
1190 let h = config.hidden_size; let q = config.q_dim(); let kv = config.kv_dim(); let i = config.intermediate_size; let expected = h * 2 + h * i * 3 + q * h + h * q + h * kv * 2; assert_eq!(config.per_layer_grad_weight_elements(), expected);
1202
1203 assert!(q * h > h * h, "q_dim*hidden > hidden*hidden for Qwen3-4B");
1205 }
1206
1207 #[test]
1208 fn test_from_size_str_known_sizes() {
1209 assert!(TransformerConfig::from_size_str("codebert").is_ok());
1210 assert!(TransformerConfig::from_size_str("codebert-base").is_ok());
1211 assert!(TransformerConfig::from_size_str("125M").is_ok());
1212 assert!(TransformerConfig::from_size_str("0.5B").is_ok());
1213 assert!(TransformerConfig::from_size_str("500M").is_ok());
1214 assert!(TransformerConfig::from_size_str("qwen2-0.5b").is_ok());
1215 assert!(TransformerConfig::from_size_str("7B").is_ok());
1216 assert!(TransformerConfig::from_size_str("qwen2.5-7b").is_ok());
1217 assert!(TransformerConfig::from_size_str("4B").is_ok());
1218 assert!(TransformerConfig::from_size_str("qwen3-4b").is_ok());
1219 assert!(TransformerConfig::from_size_str("qwen3").is_ok());
1220 assert!(TransformerConfig::from_size_str("9B").is_ok());
1221 assert!(TransformerConfig::from_size_str("qwen3.5-9b").is_ok());
1222 assert!(TransformerConfig::from_size_str("qwen3_5").is_ok());
1223 assert!(TransformerConfig::from_size_str("qwen3.5").is_ok());
1224 }
1225
1226 #[test]
1227 fn test_from_size_str_unknown() {
1228 let err = TransformerConfig::from_size_str("99B").unwrap_err();
1229 assert!(err.contains("Unknown model size"));
1230 assert!(err.contains("99B"));
1231 }
1232
1233 #[test]
1234 fn test_from_size_str_configs_correct() {
1235 let codebert = TransformerConfig::from_size_str("codebert").unwrap();
1236 assert_eq!(codebert.hidden_size, 768);
1237 assert!(codebert.is_encoder());
1238
1239 let qwen2 = TransformerConfig::from_size_str("0.5B").unwrap();
1240 assert_eq!(qwen2.hidden_size, 896);
1241 assert!(qwen2.use_bias);
1242
1243 let qwen3 = TransformerConfig::from_size_str("4B").unwrap();
1244 assert_eq!(qwen3.hidden_size, 2560);
1245 assert!(!qwen3.use_bias);
1246 }
1247
1248 #[test]
1249 fn test_from_apr_metadata_missing_num_heads() {
1250 assert!(TransformerConfig::from_apr_metadata(
1251 Some(4096),
1252 None, Some(8),
1254 Some(12288),
1255 Some(36),
1256 Some(151936),
1257 None,
1258 None,
1259 None,
1260 None,
1261 )
1262 .is_none());
1263 }
1264
1265 #[test]
1266 fn test_from_apr_metadata_missing_vocab_size() {
1267 assert!(TransformerConfig::from_apr_metadata(
1268 Some(4096),
1269 Some(32),
1270 Some(8),
1271 Some(12288),
1272 Some(36),
1273 None, None,
1275 None,
1276 None,
1277 None,
1278 )
1279 .is_none());
1280 }
1281
1282 #[test]
1283 fn test_from_apr_metadata_missing_intermediate_size() {
1284 assert!(TransformerConfig::from_apr_metadata(
1285 Some(4096),
1286 Some(32),
1287 Some(8),
1288 None, Some(36),
1290 Some(151936),
1291 None,
1292 None,
1293 None,
1294 None,
1295 )
1296 .is_none());
1297 }
1298
1299 #[test]
1300 fn test_from_apr_metadata_defaults() {
1301 let config = TransformerConfig::from_apr_metadata(
1302 Some(512),
1303 Some(8),
1304 None, Some(2048),
1306 Some(6),
1307 Some(32000),
1308 None, None, None, None, )
1313 .unwrap();
1314
1315 assert_eq!(config.num_kv_heads, 8); assert_eq!(config.max_position_embeddings, 32768);
1317 assert!((config.rms_norm_eps - 1e-6).abs() < 1e-10);
1318 assert!((config.rope_theta - 10000.0).abs() < 0.1);
1319 assert_eq!(config.architecture, ModelArchitecture::Decoder);
1320 assert!(!config.use_bias);
1321 }
1322
1323 #[test]
1324 fn test_from_apr_metadata_encoder_architecture() {
1325 let config = TransformerConfig::from_apr_metadata(
1326 Some(768),
1327 Some(12),
1328 Some(12),
1329 Some(3072),
1330 Some(12),
1331 Some(50265),
1332 Some(514),
1333 Some(1e-5),
1334 Some(0.0),
1335 Some("codebert"),
1336 )
1337 .unwrap();
1338 assert_eq!(config.architecture, ModelArchitecture::Encoder);
1339 }
1340
1341 #[test]
1342 fn test_from_apr_metadata_roberta_architecture() {
1343 let config = TransformerConfig::from_apr_metadata(
1344 Some(768),
1345 Some(12),
1346 Some(12),
1347 Some(3072),
1348 Some(12),
1349 Some(50265),
1350 None,
1351 None,
1352 None,
1353 Some("roberta"),
1354 )
1355 .unwrap();
1356 assert_eq!(config.architecture, ModelArchitecture::Encoder);
1357 }
1358
1359 #[test]
1360 fn test_from_apr_metadata_qwen3_head_dim_override() {
1361 let config = TransformerConfig::from_apr_metadata(
1363 Some(2560),
1364 Some(32),
1365 Some(8),
1366 Some(9728),
1367 Some(36),
1368 Some(151936),
1369 Some(40960),
1370 Some(1e-6),
1371 Some(1e6),
1372 Some("qwen3-4b"),
1373 )
1374 .unwrap();
1375 assert_eq!(config.head_dim_override, Some(128));
1376 assert_eq!(config.head_dim(), 128);
1377 assert!(!config.use_bias);
1378 }
1379
1380 #[test]
1381 fn test_from_apr_metadata_qwen3_no_override_needed() {
1382 let config = TransformerConfig::from_apr_metadata(
1384 Some(4096),
1385 Some(32),
1386 Some(8),
1387 Some(12288),
1388 Some(36),
1389 Some(151936),
1390 None,
1391 None,
1392 None,
1393 Some("qwen3-8b"),
1394 )
1395 .unwrap();
1396 assert!(config.head_dim_override.is_none());
1397 assert_eq!(config.head_dim(), 128);
1398 }
1399
1400 #[test]
1401 fn test_qwen2_7b_config() {
1402 let config = TransformerConfig::qwen2_7b();
1403 assert_eq!(config.hidden_size, 3584);
1404 assert_eq!(config.num_attention_heads, 28);
1405 assert_eq!(config.num_kv_heads, 4);
1406 assert_eq!(config.intermediate_size, 18944);
1407 assert_eq!(config.num_hidden_layers, 28);
1408 assert_eq!(config.vocab_size, 152064);
1409 assert!(config.use_bias);
1410 assert_eq!(config.head_dim(), 128); }
1412
1413 #[test]
1414 fn test_qwen3_4b_config() {
1415 let config = TransformerConfig::qwen3_4b();
1416 assert_eq!(config.hidden_size, 2560);
1417 assert_eq!(config.num_attention_heads, 32);
1418 assert_eq!(config.num_kv_heads, 8);
1419 assert_eq!(config.intermediate_size, 9728);
1420 assert_eq!(config.num_hidden_layers, 36);
1421 assert!(!config.use_bias);
1422 assert_eq!(config.head_dim(), 128);
1423 }
1424
1425 #[test]
1426 fn test_per_layer_weight_elements_positive() {
1427 for config in [
1428 TransformerConfig::tiny(),
1429 TransformerConfig::codebert(),
1430 TransformerConfig::qwen2_0_5b(),
1431 TransformerConfig::qwen3_4b(),
1432 ] {
1433 assert!(config.per_layer_weight_elements() > 0);
1434 }
1435 }
1436
1437 #[test]
1438 fn test_vram_shared_less_than_per_layer() {
1439 let config = TransformerConfig::qwen2_0_5b();
1440 let per_layer = config.total_training_vram_bytes(128);
1441 let shared = config.total_training_vram_bytes_shared(128);
1442 assert!(
1444 shared < per_layer,
1445 "Shared ({shared}) should be less than per-layer ({per_layer})"
1446 );
1447 }
1448
1449 #[test]
1450 fn test_vram_shared_monotonic() {
1451 let config = TransformerConfig::qwen2_0_5b();
1452 let mut prev = config.total_training_vram_bytes_shared(1);
1453 for s in [2, 4, 8, 16, 32, 64, 128] {
1454 let cur = config.total_training_vram_bytes_shared(s);
1455 assert!(cur > prev, "Shared VRAM must increase: seq_len={s}");
1456 prev = cur;
1457 }
1458 }
1459
1460 #[test]
1461 fn test_max_seq_len_for_vram_shared() {
1462 let config = TransformerConfig::qwen2_0_5b();
1463 let budget = 8 * 1024 * 1024 * 1024_usize; let max_s = config.max_seq_len_for_vram_shared(budget);
1465 assert!(max_s.is_some());
1466 let s = max_s.unwrap();
1467 assert!(config.total_training_vram_bytes_shared(s) <= budget);
1468 }
1469
1470 #[test]
1471 fn test_max_seq_len_for_vram_shared_impossible() {
1472 let config = TransformerConfig::qwen3_4b();
1473 let tiny_budget = 1024; assert!(config.max_seq_len_for_vram_shared(tiny_budget).is_none());
1475 }
1476
1477 #[test]
1478 fn test_max_seq_len_for_vram_shared_tightness() {
1479 let config = TransformerConfig::tiny();
1480 let budget = 10 * 1024 * 1024_usize; if let Some(s) = config.max_seq_len_for_vram_shared(budget) {
1482 assert!(config.total_training_vram_bytes_shared(s) <= budget);
1483 if s < config.max_position_embeddings {
1484 assert!(config.total_training_vram_bytes_shared(s + 1) > budget);
1485 }
1486 }
1487 }
1488
1489 #[test]
1490 fn test_kv_dim() {
1491 assert_eq!(TransformerConfig::qwen3_4b().kv_dim(), 1024);
1492 assert_eq!(TransformerConfig::llama2_7b().kv_dim(), 4096);
1493 }
1494
1495 #[test]
1496 fn test_per_layer_scratch_coefficients() {
1497 let config = TransformerConfig::tiny();
1498 assert!(config.per_layer_scratch_linear_coeff() > 0);
1499 let (n_quad, n_hd_linear) = config.per_layer_scratch_quadratic_coeff();
1500 assert!(n_quad > 0 && n_hd_linear > 0);
1501 assert!(config.per_layer_grad_weight_elements() > 0);
1502 }
1503}