Skip to main content

entrenar/transformer/
config.rs

1//! Transformer configuration module
2//!
3//! This module provides configuration structures for transformer models.
4
5use serde::{Deserialize, Serialize};
6
7// Well-known model architecture constants
8const LLAMA2_7B_INTERMEDIATE_SIZE: usize = 11008;
9const LLAMA2_13B_HIDDEN_SIZE: usize = 5120;
10const LLAMA2_13B_INTERMEDIATE_SIZE: usize = 13824;
11const LLAMA_VOCAB_SIZE: usize = 32000;
12const MISTRAL_INTERMEDIATE_SIZE: usize = 14336;
13const MISTRAL_MAX_SEQ_LEN: usize = 32768;
14const QWEN2_0_5B_HIDDEN_SIZE: usize = 896;
15const QWEN2_0_5B_INTERMEDIATE_SIZE: usize = 4864;
16const QWEN2_VOCAB_SIZE: usize = 151936;
17const QWEN2_MAX_SEQ_LEN: usize = 32768;
18const QWEN2_ROPE_THETA: f32 = 1_000_000.0;
19const QWEN3_4B_HIDDEN_SIZE: usize = 2560;
20const QWEN3_4B_INTERMEDIATE_SIZE: usize = 9728;
21const QWEN3_5_9B_HIDDEN_SIZE: usize = 4096;
22const QWEN3_5_9B_INTERMEDIATE_SIZE: usize = 12288;
23const QWEN3_5_VOCAB_SIZE: usize = 248320;
24const QWEN3_5_MAX_SEQ_LEN: usize = 262144;
25const DEFAULT_ROPE_THETA: f32 = 10000.0;
26
27// CodeBERT / RoBERTa constants
28const CODEBERT_HIDDEN_SIZE: usize = 768;
29const CODEBERT_INTERMEDIATE_SIZE: usize = 3072;
30const CODEBERT_VOCAB_SIZE: usize = 50265;
31const CODEBERT_MAX_POSITION: usize = 514; // 512 + 2 special tokens
32
33/// Model architecture family.
34///
35/// Determines position encoding, normalization, FFN activation, and pooling.
36#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
37#[serde(rename_all = "snake_case")]
38pub enum ModelArchitecture {
39    /// Decoder-only (LLaMA, Qwen, Mistral): RoPE, RMSNorm, SwiGLU, last-token pooling
40    #[default]
41    Decoder,
42    /// Encoder-only (BERT, RoBERTa, CodeBERT): learned positions, LayerNorm, GELU, CLS pooling
43    Encoder,
44}
45
46/// Configuration for transformer models
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct TransformerConfig {
49    /// Hidden dimension (embedding size)
50    pub hidden_size: usize,
51    /// Number of attention heads
52    pub num_attention_heads: usize,
53    /// Number of key-value heads (for grouped-query attention)
54    pub num_kv_heads: usize,
55    /// Feed-forward network intermediate dimension
56    pub intermediate_size: usize,
57    /// Number of transformer layers
58    pub num_hidden_layers: usize,
59    /// Vocabulary size
60    pub vocab_size: usize,
61    /// Maximum sequence length
62    pub max_position_embeddings: usize,
63    /// RMS normalization epsilon
64    pub rms_norm_eps: f32,
65    /// RoPE theta base
66    pub rope_theta: f32,
67    /// Whether to use bias in linear layers
68    pub use_bias: bool,
69    /// Explicit per-head dimension (overrides hidden_size / num_heads).
70    /// Required for Qwen3 where head_dim=128 but hidden_size/num_heads=80.
71    #[serde(default)]
72    pub head_dim_override: Option<usize>,
73    /// Architecture family: encoder (BERT/RoBERTa) or decoder (LLaMA/Qwen).
74    /// Determines position encoding, normalization, activation, and pooling strategy.
75    #[serde(default)]
76    pub architecture: ModelArchitecture,
77    /// HuggingFace architecture class name (e.g., "Qwen2ForCausalLM", "LlamaForCausalLM").
78    /// Used for checkpoint config.json compatibility.
79    #[serde(default)]
80    pub hf_architecture: Option<String>,
81    /// HuggingFace model type (e.g., "qwen2", "llama").
82    /// Used for checkpoint config.json compatibility.
83    #[serde(default)]
84    pub hf_model_type: Option<String>,
85    /// Whether to tie input/output embeddings (embed_tokens and lm_head).
86    /// Qwen2: true, LLaMA: false.
87    #[serde(default)]
88    pub tie_word_embeddings: bool,
89}
90
91impl TransformerConfig {
92    /// LLaMA 2 7B configuration
93    pub fn llama2_7b() -> Self {
94        Self {
95            hidden_size: 4096,
96            num_attention_heads: 32,
97            num_kv_heads: 32,
98            intermediate_size: LLAMA2_7B_INTERMEDIATE_SIZE,
99            num_hidden_layers: 32,
100            vocab_size: LLAMA_VOCAB_SIZE,
101            max_position_embeddings: 4096,
102            rms_norm_eps: 1e-6,
103            rope_theta: DEFAULT_ROPE_THETA,
104            use_bias: false,
105            head_dim_override: None,
106            architecture: ModelArchitecture::Decoder,
107            hf_architecture: None,
108            hf_model_type: None,
109            tie_word_embeddings: false,
110        }
111    }
112
113    /// LLaMA 2 13B configuration
114    pub fn llama2_13b() -> Self {
115        Self {
116            hidden_size: LLAMA2_13B_HIDDEN_SIZE,
117            num_attention_heads: 40,
118            num_kv_heads: 40,
119            intermediate_size: LLAMA2_13B_INTERMEDIATE_SIZE,
120            num_hidden_layers: 40,
121            vocab_size: LLAMA_VOCAB_SIZE,
122            max_position_embeddings: 4096,
123            rms_norm_eps: 1e-6,
124            rope_theta: DEFAULT_ROPE_THETA,
125            use_bias: false,
126            head_dim_override: None,
127            architecture: ModelArchitecture::Decoder,
128            hf_architecture: None,
129            hf_model_type: None,
130            tie_word_embeddings: false,
131        }
132    }
133
134    /// Mistral 7B configuration
135    pub fn mistral_7b() -> Self {
136        Self {
137            hidden_size: 4096,
138            num_attention_heads: 32,
139            num_kv_heads: 8, // Grouped-query attention
140            intermediate_size: MISTRAL_INTERMEDIATE_SIZE,
141            num_hidden_layers: 32,
142            vocab_size: LLAMA_VOCAB_SIZE,
143            max_position_embeddings: MISTRAL_MAX_SEQ_LEN,
144            rms_norm_eps: 1e-5,
145            rope_theta: DEFAULT_ROPE_THETA,
146            use_bias: false,
147            head_dim_override: None,
148            architecture: ModelArchitecture::Decoder,
149            hf_architecture: None,
150            hf_model_type: None,
151            tie_word_embeddings: false,
152        }
153    }
154
155    /// Qwen2 0.5B configuration (good for testing).
156    ///
157    /// Empirically verified against
158    /// `~/.cache/huggingface/hub/models--Qwen--Qwen2.5-Coder-0.5B-Instruct/.../config.json`
159    /// 2026-05-04. Pinned by
160    /// `contracts/apr-pretrain-arch-polymorphic-v1.yaml` FALSIFY-001.
161    ///
162    /// Note: `tie_word_embeddings: true` is the Qwen2.5 0.5B/1.5B convention
163    /// (the 7B variant turns this OFF; see `qwen2_7b()`). This is a Qwen
164    /// scaling-law quirk — small Qwen models reuse embedding+lm_head weights
165    /// to save params, but the larger variants pay the param cost for
166    /// untied weights. Drift-prevention: keeping this `true` is required
167    /// for SHIP-TWO-001 §49 MODEL-2 fine-tune from a Qwen2.5-Coder-0.5B
168    /// checkpoint.
169    pub fn qwen2_0_5b() -> Self {
170        Self {
171            hidden_size: QWEN2_0_5B_HIDDEN_SIZE,
172            num_attention_heads: 14,
173            num_kv_heads: 2,
174            intermediate_size: QWEN2_0_5B_INTERMEDIATE_SIZE,
175            num_hidden_layers: 24,
176            vocab_size: QWEN2_VOCAB_SIZE,
177            max_position_embeddings: QWEN2_MAX_SEQ_LEN,
178            rms_norm_eps: 1e-6,
179            rope_theta: QWEN2_ROPE_THETA,
180            use_bias: true,
181            head_dim_override: None,
182            architecture: ModelArchitecture::Decoder,
183            hf_architecture: None,
184            hf_model_type: None,
185            tie_word_embeddings: true,
186        }
187    }
188
189    /// Qwen2.5-Coder-1.5B-Instruct: 28 layers, 12 heads, 2 KV heads, hidden=1536
190    #[rustfmt::skip]
191    pub fn qwen2_1_5b() -> Self { Self { hidden_size: 1536, num_attention_heads: 12, intermediate_size: 8960, num_hidden_layers: 28, vocab_size: 151936, ..Self::qwen2_0_5b() } }
192
193    /// Qwen2.5-Coder 7B configuration (GH-371)
194    ///
195    /// Qwen2.5-Coder-7B-Instruct: 28 layers, 28 heads, 4 KV heads, hidden=3584
196    /// Contract: contracts/model-families/qwen2.yaml
197    pub fn qwen2_7b() -> Self {
198        Self {
199            hidden_size: 3584,
200            num_attention_heads: 28,
201            num_kv_heads: 4,
202            intermediate_size: 18944,
203            num_hidden_layers: 28,
204            vocab_size: 152064,
205            max_position_embeddings: QWEN2_MAX_SEQ_LEN,
206            rms_norm_eps: 1e-6,
207            rope_theta: QWEN2_ROPE_THETA,
208            use_bias: true,
209            head_dim_override: None,
210            architecture: ModelArchitecture::Decoder,
211            hf_architecture: None,
212            hf_model_type: None,
213            tie_word_embeddings: false,
214        }
215    }
216
217    /// Qwen3 4B configuration
218    ///
219    /// Qwen3-4B: 36 layers, 32 heads, 8 KV heads, hidden=2560, head_dim=128.
220    /// Same vocab_size as Qwen2 (151936). No attention bias (Qwen3 family).
221    pub fn qwen3_4b() -> Self {
222        Self {
223            hidden_size: QWEN3_4B_HIDDEN_SIZE,
224            num_attention_heads: 32,
225            num_kv_heads: 8,
226            intermediate_size: QWEN3_4B_INTERMEDIATE_SIZE,
227            num_hidden_layers: 36,
228            vocab_size: QWEN2_VOCAB_SIZE, // 151936, same as Qwen2
229            max_position_embeddings: 40960,
230            rms_norm_eps: 1e-6,
231            rope_theta: QWEN2_ROPE_THETA, // 1M theta
232            use_bias: false,              // Qwen3: no attention bias
233            head_dim_override: Some(128), // Contract: qwen3.yaml §4b.head_dim=128
234            architecture: ModelArchitecture::Decoder,
235            hf_architecture: None,
236            hf_model_type: None,
237            tie_word_embeddings: false,
238        }
239    }
240
241    /// Qwen3.5 9B configuration
242    ///
243    /// Key differences from Qwen2: no attention bias, head_dim=256 (explicit),
244    /// vocab_size=248320, hybrid attention (standard + linear layers).
245    /// Contract: contracts/model-families/qwen3_5.yaml
246    pub fn qwen3_5_9b() -> Self {
247        Self {
248            hidden_size: QWEN3_5_9B_HIDDEN_SIZE,
249            num_attention_heads: 16,
250            num_kv_heads: 4,
251            intermediate_size: QWEN3_5_9B_INTERMEDIATE_SIZE,
252            num_hidden_layers: 32,
253            vocab_size: QWEN3_5_VOCAB_SIZE,
254            max_position_embeddings: QWEN3_5_MAX_SEQ_LEN,
255            rms_norm_eps: 1e-6,
256            rope_theta: QWEN2_ROPE_THETA, // Same 1M theta as Qwen2
257            use_bias: false,              // KEY: no attention bias (unlike Qwen2)
258            head_dim_override: None,      // 4096/16=256, no override needed
259            architecture: ModelArchitecture::Decoder,
260            hf_architecture: None,
261            hf_model_type: None,
262            tie_word_embeddings: false,
263        }
264    }
265
266    /// Construct from APR v2 metadata fields.
267    ///
268    /// CONTRACT: The `.apr` file is the single source of truth for model
269    /// architecture. These fields were validated at import time by the
270    /// `tensor-layout-v1` contract. This function propagates that contract
271    /// to the training pipeline — no hardcoded lookups, no silent fallbacks.
272    ///
273    /// Returns None if any required field is missing, forcing the caller to
274    /// handle the error explicitly rather than silently degrading to tiny().
275    ///
276    /// GH-376: Fixes instruct pipeline ignoring .apr architecture metadata.
277    pub fn from_apr_metadata(
278        hidden_size: Option<usize>,
279        num_heads: Option<usize>,
280        num_kv_heads: Option<usize>,
281        intermediate_size: Option<usize>,
282        num_layers: Option<usize>,
283        vocab_size: Option<usize>,
284        max_position_embeddings: Option<usize>,
285        rms_norm_eps: Option<f32>,
286        rope_theta: Option<f32>,
287        architecture: Option<&str>,
288    ) -> Option<Self> {
289        let hidden = hidden_size?;
290        let heads = num_heads?;
291        let layers = num_layers?;
292        let vocab = vocab_size?;
293        let intermediate = intermediate_size?;
294
295        // Qwen3 family: head_dim=128 is explicit, not hidden/heads
296        // Qwen2 family: use_bias=true
297        let (use_bias, head_dim_override) = match architecture {
298            Some(a) if a.starts_with("qwen3") => {
299                // Qwen3: no bias, explicit head_dim=128 when hidden/heads != 128
300                let computed = hidden / heads;
301                let override_dim = if computed == 128 { None } else { Some(128) };
302                (false, override_dim)
303            }
304            Some(a) if a.starts_with("qwen2") => (true, None),
305            _ => (false, None),
306        };
307
308        Some(Self {
309            hidden_size: hidden,
310            num_attention_heads: heads,
311            num_kv_heads: num_kv_heads.unwrap_or(heads),
312            intermediate_size: intermediate,
313            num_hidden_layers: layers,
314            vocab_size: vocab,
315            max_position_embeddings: max_position_embeddings.unwrap_or(32768),
316            rms_norm_eps: rms_norm_eps.unwrap_or(1e-6),
317            rope_theta: rope_theta.unwrap_or(DEFAULT_ROPE_THETA),
318            use_bias,
319            head_dim_override,
320            architecture: match architecture {
321                Some(a) if a.contains("bert") || a.contains("roberta") => {
322                    ModelArchitecture::Encoder
323                }
324                _ => ModelArchitecture::Decoder,
325            },
326            hf_architecture: None,
327            hf_model_type: None,
328            tie_word_embeddings: false,
329        })
330    }
331
332    /// Resolve config from a model size string. Errors on unknown sizes.
333    ///
334    /// GH-377: Replaces `_ => TransformerConfig::tiny()` catch-all pattern.
335    /// This is the single canonical mapping from size strings to configs.
336    /// Every callsite that previously had its own match table should use this.
337    pub fn from_size_str(size: &str) -> Result<Self, String> {
338        match size {
339            "codebert" | "codebert-base" | "125M" => Ok(Self::codebert()),
340            "0.5B" | "500M" | "qwen2-0.5b" => Ok(Self::qwen2_0_5b()),
341            "1.5B" | "qwen2.5-1.5b" | "qwen2-1.5b" => Ok(Self::qwen2_1_5b()),
342            "7B" | "qwen2.5-7b" => Ok(Self::qwen2_7b()),
343            "4B" | "qwen3-4b" | "qwen3" => Ok(Self::qwen3_4b()),
344            "9B" | "qwen3.5-9b" | "qwen3_5" | "qwen3.5" => Ok(Self::qwen3_5_9b()),
345            unknown => Err(format!(
346                "Unknown model size '{unknown}'. Known sizes: codebert, 0.5B, 4B, 7B, 9B"
347            )),
348        }
349    }
350
351    /// CodeBERT (microsoft/codebert-base) encoder configuration.
352    ///
353    /// RoBERTa architecture: 12 layers, 768 hidden, 12 heads, GELU, LayerNorm, learned positions.
354    /// SSC v11 Section 4: 125M params, ~20ms CPU inference, WASM-deployable.
355    pub fn codebert() -> Self {
356        Self {
357            hidden_size: CODEBERT_HIDDEN_SIZE,
358            num_attention_heads: 12,
359            num_kv_heads: 12, // No GQA in BERT
360            intermediate_size: CODEBERT_INTERMEDIATE_SIZE,
361            num_hidden_layers: 12,
362            vocab_size: CODEBERT_VOCAB_SIZE,
363            max_position_embeddings: CODEBERT_MAX_POSITION,
364            rms_norm_eps: 1e-5, // LayerNorm eps for RoBERTa
365            rope_theta: 0.0,    // Not used (learned positions)
366            use_bias: true,
367            head_dim_override: None,
368            architecture: ModelArchitecture::Encoder,
369            hf_architecture: None,
370            hf_model_type: None,
371            tie_word_embeddings: false,
372        }
373    }
374
375    /// Tiny configuration for testing
376    pub fn tiny() -> Self {
377        Self {
378            hidden_size: 64,
379            num_attention_heads: 2,
380            num_kv_heads: 2,
381            intermediate_size: 256,
382            num_hidden_layers: 2,
383            vocab_size: 1000,
384            max_position_embeddings: 512,
385            rms_norm_eps: 1e-6,
386            rope_theta: DEFAULT_ROPE_THETA,
387            use_bias: false,
388            head_dim_override: None,
389            architecture: ModelArchitecture::Decoder,
390            hf_architecture: None,
391            hf_model_type: None,
392            tie_word_embeddings: false,
393        }
394    }
395
396    /// Whether this config describes an encoder (BERT/RoBERTa) architecture.
397    pub fn is_encoder(&self) -> bool {
398        self.architecture == ModelArchitecture::Encoder
399    }
400
401    /// HuggingFace architecture class name for checkpoint config.json.
402    /// Uses explicit override if set, otherwise infers from config.
403    pub fn hf_architecture_name(&self) -> &str {
404        if let Some(ref name) = self.hf_architecture {
405            return name;
406        }
407        // Infer from model characteristics
408        if self.is_encoder() {
409            "BertModel"
410        } else if self.use_bias && self.vocab_size > 150000 {
411            // Qwen2 family: has attention biases + large vocab
412            "Qwen2ForCausalLM"
413        } else {
414            "LlamaForCausalLM"
415        }
416    }
417
418    /// HuggingFace model_type string for checkpoint config.json.
419    pub fn hf_model_type_str(&self) -> &str {
420        if let Some(ref mt) = self.hf_model_type {
421            return mt;
422        }
423        if self.is_encoder() {
424            "roberta"
425        } else if self.use_bias && self.vocab_size > 150000 {
426            "qwen2"
427        } else {
428            "llama"
429        }
430    }
431
432    /// Whether embeddings are tied (embed_tokens == lm_head).
433    /// Uses explicit flag if set, otherwise infers from architecture.
434    pub fn ties_embeddings(&self) -> bool {
435        if self.tie_word_embeddings {
436            return true;
437        }
438        // Qwen2 ties embeddings by default
439        self.use_bias && self.vocab_size > 150000
440    }
441
442    /// Per-head dimension.
443    ///
444    /// Uses explicit override when set (Qwen3: head_dim=128 with hidden=2560, 32 heads).
445    /// Falls back to hidden_size / num_heads for standard architectures.
446    pub fn head_dim(&self) -> usize {
447        self.head_dim_override.unwrap_or(self.hidden_size / self.num_attention_heads)
448    }
449
450    /// Total Q/O projection dimension = num_heads * head_dim.
451    ///
452    /// Equals hidden_size for standard architectures but differs when head_dim
453    /// is explicitly overridden (e.g. Qwen3-4B: 32 * 128 = 4096 != 2560).
454    pub fn q_dim(&self) -> usize {
455        self.num_attention_heads * self.head_dim()
456    }
457
458    // =========================================================================
459    // VRAM Budget Solver (Provable Design-by-Contract)
460    //
461    // Contract: contracts/model-families/qwen3.yaml §CUDA TRAINING RESOURCE BUDGET
462    // Meyer (1992) "No Hidden Clauses": every term maps 1:1 to a GpuBuffer::new()
463    // call in cuda_block.rs. No magic numbers — all derived from model dims.
464    // =========================================================================
465
466    /// KV hidden dimension = num_kv_heads * head_dim.
467    fn kv_dim(&self) -> usize {
468        self.num_kv_heads * self.head_dim()
469    }
470
471    /// Per-layer weight VRAM in f32 elements (constant, independent of seq_len).
472    ///
473    /// Maps to cuda_block.rs lines 212-220: `GpuBuffer::from_host()` uploads.
474    pub fn per_layer_weight_elements(&self) -> usize {
475        let h = self.hidden_size;
476        let q = self.q_dim();
477        let kv = self.kv_dim();
478        let i = self.intermediate_size;
479        // w_q: q*h, w_k: kv*h, w_v: kv*h, w_o: h*q, w_gate: i*h, w_up: i*h, w_down: h*i
480        // input_norm: h, post_attn_norm: h
481        q * h + kv * h * 2 + h * q + i * h * 3 + h * 2
482    }
483
484    /// Per-layer gradient weight VRAM in f32 elements (constant, independent of seq_len).
485    ///
486    /// Maps to cuda_block.rs CudaGradWorkspace: constant-size gradient buffers.
487    /// grad_w_q uses q_dim*hidden, grad_w_o uses hidden*q_dim (#262).
488    fn per_layer_grad_weight_elements(&self) -> usize {
489        let h = self.hidden_size;
490        let q = self.q_dim();
491        let kv = self.kv_dim();
492        let i = self.intermediate_size;
493        // grad_input_norm: h, grad_post_attn_norm: h
494        // grad_gate: h*i, grad_up: h*i, grad_down: i*h
495        // grad_w_q: q*h, grad_w_k: h*kv, grad_w_v: h*kv, grad_w_o: h*q
496        h * 2 + h * i * 3 + q * h + h * q + h * kv * 2
497    }
498
499    /// Per-layer scratch elements that scale linearly with seq_len.
500    ///
501    /// Maps to cuda_block.rs lines 224-236, 243-248: `GpuBuffer::new(_, S * dim)`.
502    fn per_layer_scratch_linear_coeff(&self) -> usize {
503        let h = self.hidden_size;
504        let kv = self.kv_dim();
505        let i = self.intermediate_size;
506        let n = self.num_attention_heads;
507        let hd = self.head_dim();
508        // Forward: norm1(h) + q(h) + k(kv) + v(kv) + attn_out(h) + o_proj(h)
509        //          + residual1(h) + norm2(h) + gate(i) + up(i) + swiglu(i) + ffn(h)
510        // Backward: grad_hidden(h) + grad_swiglu(i)
511        // Attention reshape: q_batched(N*hd) + kv_temp(N*hd) + kv_temp2(N*hd)
512        h * 8 + kv * 2 + i * 4 + n * hd * 3
513    }
514
515    /// Per-layer scratch elements that scale quadratically with seq_len.
516    ///
517    /// Returns (quadratic_coeff, linear_fallback_coeff) because:
518    ///   attn_scores = N * S * S
519    ///   grad_attn_scores = N * S * max(S, hd)
520    ///
521    /// When S >= hd: total = 2 * N * S^2 (pure quadratic)
522    /// When S < hd:  total = N * S^2 + N * S * hd (mixed)
523    fn per_layer_scratch_quadratic_coeff(&self) -> (usize, usize) {
524        let n = self.num_attention_heads;
525        let hd = self.head_dim();
526        // attn_scores: N * S * S (always quadratic)
527        // grad_attn_scores: N * S * max(S, hd)
528        //   When S >= hd → N * S * S (quadratic)
529        //   When S < hd  → N * S * hd (linear)
530        (n, n * hd) // (quadratic_coeff, linear_fallback for grad when S < hd)
531    }
532
533    /// Total VRAM in bytes for all layers at a given max_seq_len.
534    ///
535    /// Postcondition: result is exact for the current cuda_block.rs buffer layout.
536    pub fn total_training_vram_bytes(&self, max_seq_len: usize) -> usize {
537        let l = self.num_hidden_layers;
538        let s = max_seq_len;
539        let hd = self.head_dim();
540
541        let constant_per_layer =
542            self.per_layer_weight_elements() + self.per_layer_grad_weight_elements();
543        let linear_per_layer = self.per_layer_scratch_linear_coeff() * s;
544
545        let (n_quad, n_hd_linear) = self.per_layer_scratch_quadratic_coeff();
546        let quadratic_per_layer =
547            if s >= hd { 2 * n_quad * s * s } else { n_quad * s * s + n_hd_linear * s };
548
549        let elements_per_layer = constant_per_layer + linear_per_layer + quadratic_per_layer;
550        l * elements_per_layer * 4 // f32 = 4 bytes
551    }
552
553    /// Total VRAM in bytes with SHARED scratch workspace (1 per model, not per layer).
554    ///
555    /// This is the correct budget formula when gradient buffers are shared across
556    /// layers (canonical in PyTorch/JAX). Only weights are truly per-layer.
557    ///
558    /// Postcondition: result < total_training_vram_bytes(s) for L > 1
559    pub fn total_training_vram_bytes_shared(&self, max_seq_len: usize) -> usize {
560        let l = self.num_hidden_layers;
561        let s = max_seq_len;
562        let hd = self.head_dim();
563
564        // Weights are per-layer (unavoidable — must all be resident on GPU)
565        let weights_total = l * self.per_layer_weight_elements();
566
567        // Gradient weight buffers: SHARED (one set, reused across layers)
568        let grad_weights_shared = self.per_layer_grad_weight_elements();
569
570        // Seq-len-dependent scratch: SHARED (one set)
571        let linear_shared = self.per_layer_scratch_linear_coeff() * s;
572        let (n_quad, n_hd_linear) = self.per_layer_scratch_quadratic_coeff();
573        let quadratic_shared =
574            if s >= hd { 2 * n_quad * s * s } else { n_quad * s * s + n_hd_linear * s };
575
576        let total_elements = weights_total + grad_weights_shared + linear_shared + quadratic_shared;
577        total_elements * 4 // f32 = 4 bytes
578    }
579
580    /// Solve for the maximum seq_len that fits in the given VRAM budget (bytes),
581    /// using shared scratch workspace.
582    ///
583    /// This is the solver to use with the shared-scratch architecture.
584    /// Returns None if even seq_len=1 exceeds the budget.
585    pub fn max_seq_len_for_vram_shared(&self, vram_bytes: usize) -> Option<usize> {
586        if self.total_training_vram_bytes_shared(1) > vram_bytes {
587            return None;
588        }
589
590        let mut lo: usize = 1;
591        let mut hi: usize = self.max_position_embeddings;
592
593        while lo < hi {
594            let mid = lo + (hi - lo).div_ceil(2);
595            if self.total_training_vram_bytes_shared(mid) <= vram_bytes {
596                lo = mid;
597            } else {
598                hi = mid - 1;
599            }
600        }
601
602        Some(lo)
603    }
604
605    /// Solve for the maximum seq_len that fits in the given VRAM budget (bytes).
606    ///
607    /// Binary search over [1, max_position_embeddings].
608    /// Returns None if even seq_len=1 exceeds the budget.
609    ///
610    /// Precondition: vram_bytes > 0
611    /// Postcondition: total_training_vram_bytes(result) <= vram_bytes
612    pub fn max_seq_len_for_vram(&self, vram_bytes: usize) -> Option<usize> {
613        if self.total_training_vram_bytes(1) > vram_bytes {
614            return None;
615        }
616
617        let mut lo: usize = 1;
618        let mut hi: usize = self.max_position_embeddings;
619
620        while lo < hi {
621            let mid = lo + (hi - lo).div_ceil(2);
622            if self.total_training_vram_bytes(mid) <= vram_bytes {
623                lo = mid;
624            } else {
625                hi = mid - 1;
626            }
627        }
628
629        Some(lo)
630    }
631}
632
633#[cfg(test)]
634mod tests {
635    use super::*;
636
637    #[test]
638    fn test_transformer_config_llama2() {
639        let config = TransformerConfig::llama2_7b();
640        assert_eq!(config.hidden_size, 4096);
641        assert_eq!(config.num_attention_heads, 32);
642        assert_eq!(config.head_dim(), 128);
643    }
644
645    #[test]
646    fn test_transformer_config_tiny() {
647        let config = TransformerConfig::tiny();
648        assert_eq!(config.hidden_size, 64);
649        assert_eq!(config.num_attention_heads, 2);
650        assert_eq!(config.head_dim(), 32);
651    }
652
653    #[test]
654    fn test_config_serialization() {
655        let config = TransformerConfig::llama2_7b();
656        let json = serde_json::to_string(&config).expect("JSON serialization should succeed");
657        let restored: TransformerConfig =
658            serde_json::from_str(&json).expect("JSON deserialization should succeed");
659        assert_eq!(restored.hidden_size, config.hidden_size);
660        assert_eq!(restored.num_attention_heads, config.num_attention_heads);
661    }
662
663    #[test]
664    fn test_mistral_config() {
665        let config = TransformerConfig::mistral_7b();
666        assert_eq!(config.num_kv_heads, 8); // Grouped-query attention
667        assert_eq!(config.num_attention_heads, 32);
668        // 32 / 8 = 4 query heads per KV head
669    }
670
671    /// FALSIFY-APR-PRETRAIN-ARCH-001 — `qwen2_0_5b()` constructor matches
672    /// HF config byte-for-byte.
673    ///
674    /// Empirically verified against
675    /// `~/.cache/huggingface/hub/models--Qwen--Qwen2.5-Coder-0.5B-Instruct/.../config.json`
676    /// on 2026-05-04 for SHIP-TWO-001 §50.4 step 5b. If any of these 11
677    /// fields drift from HF, the §49 fine-tune path's init weights will
678    /// load into the wrong-shape optimizer and produce silent gibberish.
679    #[test]
680    fn qwen2_0_5b_matches_hf_config_2026_05_04() {
681        let config = TransformerConfig::qwen2_0_5b();
682        assert_eq!(config.hidden_size, 896, "hidden_size");
683        assert_eq!(config.num_attention_heads, 14, "num_attention_heads");
684        assert_eq!(config.num_kv_heads, 2, "num_kv_heads (GQA-7:1)");
685        assert_eq!(config.intermediate_size, 4864, "intermediate_size");
686        assert_eq!(config.num_hidden_layers, 24, "num_hidden_layers");
687        assert_eq!(config.vocab_size, 151_936, "vocab_size");
688        assert_eq!(config.max_position_embeddings, 32_768, "max_position_embeddings");
689        assert!(
690            (config.rms_norm_eps - 1e-6).abs() < f32::EPSILON,
691            "rms_norm_eps={}, want 1e-6",
692            config.rms_norm_eps
693        );
694        assert!(
695            (config.rope_theta - 1_000_000.0).abs() < f32::EPSILON,
696            "rope_theta={}, want 1_000_000.0",
697            config.rope_theta
698        );
699        assert!(config.use_bias, "use_bias must be true (Qwen2 quirk)");
700        assert!(
701            config.tie_word_embeddings,
702            "tie_word_embeddings must be true for Qwen2.5 0.5B (HF config 2026-05-04)"
703        );
704        assert_eq!(config.architecture, ModelArchitecture::Decoder);
705        // GQA ratio = 14/2 = 7, the canonical Qwen2.5-0.5B GQA-7:1.
706        assert_eq!(config.num_attention_heads / config.num_kv_heads, 7);
707    }
708
709    /// Drift-prevention: `qwen2_1_5b()` inherits `tie_word_embeddings` from
710    /// `qwen2_0_5b()` via `..Self::qwen2_0_5b()` spread. If someone splits
711    /// the inheritance, this test catches the silent flip.
712    #[test]
713    fn qwen2_1_5b_inherits_tie_word_embeddings_from_0_5b() {
714        let parent = TransformerConfig::qwen2_0_5b();
715        let child = TransformerConfig::qwen2_1_5b();
716        assert_eq!(
717            child.tie_word_embeddings, parent.tie_word_embeddings,
718            "qwen2_1_5b must inherit tie_word_embeddings from qwen2_0_5b — both are HF tie=true"
719        );
720        assert!(
721            child.tie_word_embeddings,
722            "qwen2_1_5b tie_word_embeddings must be true (HF config 2026-05-04)"
723        );
724    }
725
726    /// Pin the Qwen scaling-law quirk: 0.5B + 1.5B tie embeddings, 7B does not.
727    /// If the 7B is ever changed to inherit from 0.5B, this test catches it
728    /// before an operator silently fine-tunes a 7B with the wrong head shape.
729    #[test]
730    fn qwen2_7b_does_not_tie_embeddings() {
731        let config = TransformerConfig::qwen2_7b();
732        assert!(
733            !config.tie_word_embeddings,
734            "qwen2_7b tie_word_embeddings MUST be false per HF config 2026-05-04 — \
735             larger Qwen variants pay param cost for untied weights"
736        );
737    }
738
739    #[test]
740    fn test_qwen2_config() {
741        let config = TransformerConfig::qwen2_0_5b();
742        assert!(config.use_bias);
743        assert_eq!(config.vocab_size, 151936);
744    }
745
746    #[test]
747    fn test_llama2_13b_config() {
748        let config = TransformerConfig::llama2_13b();
749        assert_eq!(config.hidden_size, 5120);
750        assert_eq!(config.num_attention_heads, 40);
751        assert_eq!(config.num_hidden_layers, 40);
752        assert_eq!(config.head_dim(), 128); // 5120 / 40 = 128
753    }
754
755    #[test]
756    fn test_config_yaml_serialization() {
757        let config = TransformerConfig::tiny();
758        let yaml = serde_yaml::to_string(&config).expect("config should be valid");
759        let restored: TransformerConfig =
760            serde_yaml::from_str(&yaml).expect("config should be valid");
761        assert_eq!(restored.hidden_size, config.hidden_size);
762        assert_eq!(restored.num_hidden_layers, config.num_hidden_layers);
763    }
764
765    #[test]
766    fn test_grouped_query_attention_ratio() {
767        let config = TransformerConfig::mistral_7b();
768        let heads_per_kv = config.num_attention_heads / config.num_kv_heads;
769        assert_eq!(heads_per_kv, 4); // 32 / 8 = 4
770    }
771
772    #[test]
773    fn test_config_clone() {
774        let config = TransformerConfig::llama2_7b();
775        let cloned = config.clone();
776        assert_eq!(config.hidden_size, cloned.hidden_size);
777        assert_eq!(config.vocab_size, cloned.vocab_size);
778    }
779
780    #[test]
781    fn test_qwen3_5_9b_config() {
782        let config = TransformerConfig::qwen3_5_9b();
783        assert_eq!(config.hidden_size, 4096);
784        assert_eq!(config.num_attention_heads, 16);
785        assert_eq!(config.num_kv_heads, 4);
786        assert_eq!(config.intermediate_size, 12288);
787        assert_eq!(config.num_hidden_layers, 32);
788        assert_eq!(config.vocab_size, 248320);
789        assert_eq!(config.max_position_embeddings, 262144);
790        assert!(!config.use_bias);
791    }
792
793    #[test]
794    fn test_qwen3_5_9b_head_dim() {
795        let config = TransformerConfig::qwen3_5_9b();
796        // 4096 / 16 = 256 (explicit head_dim, not derived from hidden/heads ratio)
797        assert_eq!(config.head_dim(), 256);
798    }
799
800    #[test]
801    fn test_qwen3_5_9b_gqa_ratio() {
802        let config = TransformerConfig::qwen3_5_9b();
803        let heads_per_kv = config.num_attention_heads / config.num_kv_heads;
804        assert_eq!(heads_per_kv, 4); // 16 / 4 = 4 Q heads per KV head
805    }
806
807    // =========================================================================
808    // from_apr_metadata contract tests (GH-376)
809    // =========================================================================
810
811    #[test]
812    fn test_from_apr_metadata_qwen3_8b() {
813        // Qwen3-8B: 36 layers, 32 heads, 8 KV heads, hidden=4096, head_dim=128
814        let config = TransformerConfig::from_apr_metadata(
815            Some(4096),   // hidden_size
816            Some(32),     // num_heads
817            Some(8),      // num_kv_heads
818            Some(12288),  // intermediate_size
819            Some(36),     // num_layers
820            Some(151936), // vocab_size
821            Some(40960),  // max_position_embeddings
822            Some(1e-6),   // rms_norm_eps
823            Some(1e6),    // rope_theta
824            Some("qwen3"),
825        )
826        .expect("all required fields present");
827
828        assert_eq!(config.hidden_size, 4096);
829        assert_eq!(config.num_attention_heads, 32);
830        assert_eq!(config.num_kv_heads, 8);
831        assert_eq!(config.num_hidden_layers, 36);
832        assert_eq!(config.vocab_size, 151936);
833        assert_eq!(config.head_dim(), 128); // 4096/32=128, no override needed
834        assert!(!config.use_bias); // Qwen3: no bias
835    }
836
837    #[test]
838    fn test_from_apr_metadata_qwen2_7b() {
839        // Qwen2.5 should get use_bias=true
840        let config = TransformerConfig::from_apr_metadata(
841            Some(3584),
842            Some(28),
843            Some(4),
844            Some(18944),
845            Some(28),
846            Some(152064),
847            Some(32768),
848            Some(1e-6),
849            Some(1e6),
850            Some("qwen2"),
851        )
852        .expect("all required fields present");
853
854        assert!(config.use_bias); // Qwen2: has bias
855        assert_eq!(config.head_dim(), 128); // 3584/28=128
856    }
857
858    #[test]
859    fn test_from_apr_metadata_missing_required_returns_none() {
860        // Missing hidden_size — must return None, not silently degrade
861        assert!(TransformerConfig::from_apr_metadata(
862            None,
863            Some(32),
864            Some(8),
865            Some(12288),
866            Some(36),
867            Some(151936),
868            Some(40960),
869            Some(1e-6),
870            Some(1e6),
871            Some("qwen3"),
872        )
873        .is_none());
874
875        // Missing num_layers
876        assert!(TransformerConfig::from_apr_metadata(
877            Some(4096),
878            Some(32),
879            Some(8),
880            Some(12288),
881            None,
882            Some(151936),
883            Some(40960),
884            Some(1e-6),
885            Some(1e6),
886            Some("qwen3"),
887        )
888        .is_none());
889    }
890
891    // =========================================================================
892    // VRAM Budget Solver Falsification Tests
893    //
894    // Popperian: each test attempts to BREAK a mathematical invariant.
895    // If any test fails, the budget formula disagrees with cuda_block.rs.
896    // =========================================================================
897
898    #[test]
899    fn falsify_vram_monotonic_in_seq_len() {
900        // Prediction: VRAM is strictly monotonically increasing in seq_len
901        let config = TransformerConfig::qwen3_4b();
902        let mut prev = config.total_training_vram_bytes(1);
903        for s in [2, 4, 8, 16, 32, 64, 128, 256, 512] {
904            let cur = config.total_training_vram_bytes(s);
905            assert!(
906                cur > prev,
907                "VRAM must increase: seq_len={s} ({cur}) should exceed prev ({prev})"
908            );
909            prev = cur;
910        }
911    }
912
913    #[test]
914    fn falsify_vram_solver_postcondition() {
915        // Prediction: solver result satisfies total_vram <= budget
916        let config = TransformerConfig::qwen3_4b();
917        let budget = 24 * 1024 * 1024 * 1024_usize; // 24 GB (RTX 4090)
918        if let Some(max_s) = config.max_seq_len_for_vram(budget) {
919            let used = config.total_training_vram_bytes(max_s);
920            assert!(
921                used <= budget,
922                "Solver returned seq_len={max_s} using {used} bytes > budget {budget}"
923            );
924            // And seq_len+1 should exceed budget (tightness)
925            if max_s < config.max_position_embeddings {
926                let over = config.total_training_vram_bytes(max_s + 1);
927                assert!(
928                    over > budget,
929                    "Solver not tight: seq_len={} uses {over} <= budget {budget}",
930                    max_s + 1
931                );
932            }
933        }
934    }
935
936    #[test]
937    fn falsify_vram_solver_returns_none_when_impossible() {
938        // Prediction: if even seq_len=1 exceeds budget, solver returns None
939        let config = TransformerConfig::qwen3_4b();
940        let tiny_budget = 1024; // 1 KB — impossible for any model
941        assert!(
942            config.max_seq_len_for_vram(tiny_budget).is_none(),
943            "Solver should return None when budget is too small"
944        );
945    }
946
947    #[test]
948    fn falsify_qwen3_4b_vram_matches_oom_observation() {
949        // Observation: Qwen3-4B OOM'd on 24 GB 4090 at seq_len=512.
950        // The formula MUST agree: seq_len=512 should exceed ~23 GB usable VRAM.
951        let config = TransformerConfig::qwen3_4b();
952        let vram_512 = config.total_training_vram_bytes(512);
953        let usable_vram = 23 * 1024 * 1024 * 1024_usize; // ~23 GB after CUDA runtime
954
955        // Diagnostic: print the budget breakdown
956        let vram_1 = config.total_training_vram_bytes(1);
957        let shared_128 = config.total_training_vram_bytes_shared(128);
958        let shared_512 = config.total_training_vram_bytes_shared(512);
959        let solved = config.max_seq_len_for_vram_shared(24 * 1024 * 1024 * 1024);
960        eprintln!("=== Qwen3-4B VRAM Budget ===");
961        eprintln!(
962            "  Per-layer weights:    {:.1} MB",
963            config.per_layer_weight_elements() as f64 * 4.0 / 1e6
964        );
965        eprintln!(
966            "  Per-layer grad scratch: {:.1} MB",
967            config.per_layer_grad_weight_elements() as f64 * 4.0 / 1e6
968        );
969        eprintln!("  Per-layer (S=512): {:.1} MB", (vram_512 / 36) as f64 / 1e6);
970        eprintln!("  36 layers S=1 (per-layer scratch): {:.1} GB", vram_1 as f64 / 1e9);
971        eprintln!("  36 layers S=512 (per-layer scratch): {:.1} GB", vram_512 as f64 / 1e9);
972        eprintln!("  36 layers S=128 (SHARED scratch):    {:.1} GB", shared_128 as f64 / 1e9);
973        eprintln!("  36 layers S=512 (SHARED scratch):    {:.1} GB", shared_512 as f64 / 1e9);
974        eprintln!("  Max seq_len for 24 GB (shared):      {solved:?}");
975
976        assert!(
977            vram_512 > usable_vram,
978            "Formula says {:.1} GB for seq_len=512, but we OOM'd on 23 GB — formula is wrong",
979            vram_512 as f64 / 1e9
980        );
981    }
982
983    #[test]
984    fn falsify_qwen2_0_5b_fits_on_4090() {
985        // Observation: Qwen2-0.5B trained successfully on 4090 at seq_len=512.
986        // The formula MUST agree: it should fit in 24 GB.
987        let config = TransformerConfig::qwen2_0_5b();
988        let vram_512 = config.total_training_vram_bytes(512);
989        let total_vram = 24 * 1024 * 1024 * 1024_usize;
990        assert!(
991            vram_512 < total_vram,
992            "Formula says {:.1} GB for Qwen2-0.5B at seq_len=512, but it fit on 4090",
993            vram_512 as f64 / 1e9
994        );
995    }
996
997    #[test]
998    fn falsify_vram_budget_concrete_values() {
999        // Verify concrete VRAM numbers for Qwen3-4B to catch formula drift.
1000        let config = TransformerConfig::qwen3_4b();
1001
1002        // Per-layer weights: q(4096*2560) + k(1024*2560) + v(1024*2560)
1003        //   + o(2560*4096) + gate(9728*2560) + up(9728*2560) + down(2560*9728)
1004        //   + norms(2560*2)
1005        let expected_weights =
1006            4096 * 2560 + 1024 * 2560 * 2 + 2560 * 4096 + 9728 * 2560 * 3 + 2560 * 2;
1007        assert_eq!(config.per_layer_weight_elements(), expected_weights);
1008
1009        // With PER-LAYER gradient scratch (current cuda_block.rs layout),
1010        // Qwen3-4B's constant overhead alone exceeds 24 GB:
1011        // 36 layers × 776 MB = 27.9 GB. Solver correctly returns None.
1012        let budget_24gb = 24 * 1024 * 1024 * 1024_usize;
1013        assert!(
1014            config.max_seq_len_for_vram(budget_24gb).is_none(),
1015            "Qwen3-4B per-layer scratch CANNOT fit 24 GB — proves shared scratch needed"
1016        );
1017
1018        // With SHARED scratch (weight-only per-layer), budget check uses
1019        // total_training_vram_bytes_shared(). Qwen3-4B weights-only = 14.5 GB,
1020        // leaves ~9 GB for one shared scratch set + seq_len-dependent buffers.
1021        let shared_budget = config.total_training_vram_bytes_shared(128);
1022        assert!(
1023            shared_budget < budget_24gb,
1024            "Qwen3-4B shared scratch at seq_len=128 should fit 24 GB, got {:.1} GB",
1025            shared_budget as f64 / 1e9
1026        );
1027    }
1028
1029    // ── Additional coverage tests ─────────────────────────────────
1030
1031    #[test]
1032    fn test_model_architecture_default() {
1033        let arch: ModelArchitecture = Default::default();
1034        assert_eq!(arch, ModelArchitecture::Decoder);
1035    }
1036
1037    #[test]
1038    fn test_model_architecture_serialization() {
1039        let encoder = ModelArchitecture::Encoder;
1040        let json = serde_json::to_string(&encoder).expect("serialize");
1041        assert_eq!(json, "\"encoder\"");
1042        let decoder = ModelArchitecture::Decoder;
1043        let json = serde_json::to_string(&decoder).expect("serialize");
1044        assert_eq!(json, "\"decoder\"");
1045
1046        let restored: ModelArchitecture = serde_json::from_str("\"encoder\"").expect("deserialize");
1047        assert_eq!(restored, ModelArchitecture::Encoder);
1048    }
1049
1050    #[test]
1051    fn test_codebert_config() {
1052        let config = TransformerConfig::codebert();
1053        assert_eq!(config.hidden_size, 768);
1054        assert_eq!(config.num_attention_heads, 12);
1055        assert_eq!(config.num_kv_heads, 12);
1056        assert_eq!(config.intermediate_size, 3072);
1057        assert_eq!(config.num_hidden_layers, 12);
1058        assert_eq!(config.vocab_size, 50265);
1059        assert_eq!(config.max_position_embeddings, 514);
1060        assert!(config.use_bias);
1061        assert_eq!(config.architecture, ModelArchitecture::Encoder);
1062        assert!(config.is_encoder());
1063        assert_eq!(config.head_dim(), 64); // 768 / 12
1064    }
1065
1066    #[test]
1067    fn test_is_encoder() {
1068        assert!(TransformerConfig::codebert().is_encoder());
1069        assert!(!TransformerConfig::llama2_7b().is_encoder());
1070        assert!(!TransformerConfig::tiny().is_encoder());
1071        assert!(!TransformerConfig::qwen2_0_5b().is_encoder());
1072    }
1073
1074    #[test]
1075    fn test_hf_architecture_name_inferred() {
1076        // Encoder
1077        assert_eq!(TransformerConfig::codebert().hf_architecture_name(), "BertModel");
1078        // Qwen2 (bias + large vocab)
1079        assert_eq!(TransformerConfig::qwen2_0_5b().hf_architecture_name(), "Qwen2ForCausalLM");
1080        // LLaMA (no bias)
1081        assert_eq!(TransformerConfig::llama2_7b().hf_architecture_name(), "LlamaForCausalLM");
1082    }
1083
1084    #[test]
1085    fn test_hf_architecture_name_override() {
1086        let mut config = TransformerConfig::tiny();
1087        config.hf_architecture = Some("CustomModel".to_string());
1088        assert_eq!(config.hf_architecture_name(), "CustomModel");
1089    }
1090
1091    #[test]
1092    fn test_hf_model_type_str_inferred() {
1093        assert_eq!(TransformerConfig::codebert().hf_model_type_str(), "roberta");
1094        assert_eq!(TransformerConfig::qwen2_0_5b().hf_model_type_str(), "qwen2");
1095        assert_eq!(TransformerConfig::llama2_7b().hf_model_type_str(), "llama");
1096    }
1097
1098    #[test]
1099    fn test_hf_model_type_str_override() {
1100        let mut config = TransformerConfig::tiny();
1101        config.hf_model_type = Some("custom_type".to_string());
1102        assert_eq!(config.hf_model_type_str(), "custom_type");
1103    }
1104
1105    #[test]
1106    fn test_ties_embeddings() {
1107        // Qwen2 ties embeddings (bias + large vocab)
1108        assert!(TransformerConfig::qwen2_0_5b().ties_embeddings());
1109        // LLaMA does not
1110        assert!(!TransformerConfig::llama2_7b().ties_embeddings());
1111        // Explicit flag override
1112        let mut config = TransformerConfig::llama2_7b();
1113        config.tie_word_embeddings = true;
1114        assert!(config.ties_embeddings());
1115    }
1116
1117    #[test]
1118    fn test_head_dim_override() {
1119        let config = TransformerConfig::qwen3_4b();
1120        assert_eq!(config.head_dim_override, Some(128));
1121        assert_eq!(config.head_dim(), 128);
1122        // Without override: 2560 / 32 = 80 (but override gives 128)
1123        assert_ne!(config.hidden_size / config.num_attention_heads, 128);
1124    }
1125
1126    #[test]
1127    fn test_head_dim_no_override() {
1128        let config = TransformerConfig::llama2_7b();
1129        assert!(config.head_dim_override.is_none());
1130        assert_eq!(config.head_dim(), 128); // 4096 / 32
1131    }
1132
1133    #[test]
1134    fn test_q_dim() {
1135        let config = TransformerConfig::qwen3_4b();
1136        // 32 heads * 128 head_dim = 4096
1137        assert_eq!(config.q_dim(), 4096);
1138
1139        let config = TransformerConfig::llama2_7b();
1140        // 32 heads * 128 = 4096 = hidden_size
1141        assert_eq!(config.q_dim(), 4096);
1142    }
1143
1144    #[test]
1145    fn test_q_dim_differs_from_hidden() {
1146        let config = TransformerConfig::qwen3_4b();
1147        // Qwen3-4B: q_dim = 4096 but hidden_size = 2560
1148        assert_ne!(config.q_dim(), config.hidden_size);
1149    }
1150
1151    /// GH-262: Verify Qwen3-4B projection weight shapes match HuggingFace config.json.
1152    ///
1153    /// config.json fields: hidden_size=2560, num_attention_heads=32,
1154    /// num_key_value_heads=8, head_dim=128.
1155    ///
1156    /// Expected shapes (element counts):
1157    ///   q_proj: [4096, 2560] = 10,485,760
1158    ///   k_proj: [1024, 2560] = 2,621,440
1159    ///   v_proj: [1024, 2560] = 2,621,440
1160    ///   o_proj: [2560, 4096] = 10,485,760
1161    #[test]
1162    fn test_qwen3_4b_projection_shapes() {
1163        let config = TransformerConfig::qwen3_4b();
1164
1165        // Verify base dimensions
1166        assert_eq!(config.hidden_size, 2560);
1167        assert_eq!(config.num_attention_heads, 32);
1168        assert_eq!(config.num_kv_heads, 8);
1169        assert_eq!(config.head_dim(), 128);
1170        assert_eq!(config.head_dim_override, Some(128));
1171
1172        // Derived projection dimensions
1173        let q_dim = config.q_dim();
1174        let kv_dim = config.kv_dim();
1175        assert_eq!(q_dim, 4096); // 32 * 128
1176        assert_eq!(kv_dim, 1024); // 8 * 128
1177
1178        // Weight element counts (what validate_weight_shapes checks)
1179        let hidden = config.hidden_size;
1180        assert_eq!(q_dim * hidden, 10_485_760); // q_proj [4096, 2560]
1181        assert_eq!(kv_dim * hidden, 2_621_440); // k_proj [1024, 2560]
1182        assert_eq!(kv_dim * hidden, 2_621_440); // v_proj [1024, 2560]
1183        assert_eq!(hidden * q_dim, 10_485_760); // o_proj [2560, 4096]
1184    }
1185
1186    /// GH-262: Verify VRAM estimation uses q_dim (not hidden_size) for Q/O grads.
1187    #[test]
1188    fn test_qwen3_4b_grad_weight_elements_uses_q_dim() {
1189        let config = TransformerConfig::qwen3_4b();
1190        let h = config.hidden_size; // 2560
1191        let q = config.q_dim(); // 4096
1192        let kv = config.kv_dim(); // 1024
1193        let i = config.intermediate_size; // 9728
1194
1195        // per_layer_grad_weight_elements must use q*h for Q/O, not h*h
1196        let expected = h * 2          // norms
1197            + h * i * 3              // gate, up, down
1198            + q * h                  // grad_w_q
1199            + h * q                  // grad_w_o
1200            + h * kv * 2; // grad_w_k, grad_w_v
1201        assert_eq!(config.per_layer_grad_weight_elements(), expected);
1202
1203        // Sanity: h*h would be smaller (2560^2 = 6.5M) vs q*h (4096*2560 = 10.5M)
1204        assert!(q * h > h * h, "q_dim*hidden > hidden*hidden for Qwen3-4B");
1205    }
1206
1207    #[test]
1208    fn test_from_size_str_known_sizes() {
1209        assert!(TransformerConfig::from_size_str("codebert").is_ok());
1210        assert!(TransformerConfig::from_size_str("codebert-base").is_ok());
1211        assert!(TransformerConfig::from_size_str("125M").is_ok());
1212        assert!(TransformerConfig::from_size_str("0.5B").is_ok());
1213        assert!(TransformerConfig::from_size_str("500M").is_ok());
1214        assert!(TransformerConfig::from_size_str("qwen2-0.5b").is_ok());
1215        assert!(TransformerConfig::from_size_str("7B").is_ok());
1216        assert!(TransformerConfig::from_size_str("qwen2.5-7b").is_ok());
1217        assert!(TransformerConfig::from_size_str("4B").is_ok());
1218        assert!(TransformerConfig::from_size_str("qwen3-4b").is_ok());
1219        assert!(TransformerConfig::from_size_str("qwen3").is_ok());
1220        assert!(TransformerConfig::from_size_str("9B").is_ok());
1221        assert!(TransformerConfig::from_size_str("qwen3.5-9b").is_ok());
1222        assert!(TransformerConfig::from_size_str("qwen3_5").is_ok());
1223        assert!(TransformerConfig::from_size_str("qwen3.5").is_ok());
1224    }
1225
1226    #[test]
1227    fn test_from_size_str_unknown() {
1228        let err = TransformerConfig::from_size_str("99B").unwrap_err();
1229        assert!(err.contains("Unknown model size"));
1230        assert!(err.contains("99B"));
1231    }
1232
1233    #[test]
1234    fn test_from_size_str_configs_correct() {
1235        let codebert = TransformerConfig::from_size_str("codebert").unwrap();
1236        assert_eq!(codebert.hidden_size, 768);
1237        assert!(codebert.is_encoder());
1238
1239        let qwen2 = TransformerConfig::from_size_str("0.5B").unwrap();
1240        assert_eq!(qwen2.hidden_size, 896);
1241        assert!(qwen2.use_bias);
1242
1243        let qwen3 = TransformerConfig::from_size_str("4B").unwrap();
1244        assert_eq!(qwen3.hidden_size, 2560);
1245        assert!(!qwen3.use_bias);
1246    }
1247
1248    #[test]
1249    fn test_from_apr_metadata_missing_num_heads() {
1250        assert!(TransformerConfig::from_apr_metadata(
1251            Some(4096),
1252            None, // missing heads
1253            Some(8),
1254            Some(12288),
1255            Some(36),
1256            Some(151936),
1257            None,
1258            None,
1259            None,
1260            None,
1261        )
1262        .is_none());
1263    }
1264
1265    #[test]
1266    fn test_from_apr_metadata_missing_vocab_size() {
1267        assert!(TransformerConfig::from_apr_metadata(
1268            Some(4096),
1269            Some(32),
1270            Some(8),
1271            Some(12288),
1272            Some(36),
1273            None, // missing vocab
1274            None,
1275            None,
1276            None,
1277            None,
1278        )
1279        .is_none());
1280    }
1281
1282    #[test]
1283    fn test_from_apr_metadata_missing_intermediate_size() {
1284        assert!(TransformerConfig::from_apr_metadata(
1285            Some(4096),
1286            Some(32),
1287            Some(8),
1288            None, // missing intermediate
1289            Some(36),
1290            Some(151936),
1291            None,
1292            None,
1293            None,
1294            None,
1295        )
1296        .is_none());
1297    }
1298
1299    #[test]
1300    fn test_from_apr_metadata_defaults() {
1301        let config = TransformerConfig::from_apr_metadata(
1302            Some(512),
1303            Some(8),
1304            None, // defaults to num_heads
1305            Some(2048),
1306            Some(6),
1307            Some(32000),
1308            None, // defaults to 32768
1309            None, // defaults to 1e-6
1310            None, // defaults to 10000.0
1311            None, // defaults to Decoder
1312        )
1313        .unwrap();
1314
1315        assert_eq!(config.num_kv_heads, 8); // defaults to num_heads
1316        assert_eq!(config.max_position_embeddings, 32768);
1317        assert!((config.rms_norm_eps - 1e-6).abs() < 1e-10);
1318        assert!((config.rope_theta - 10000.0).abs() < 0.1);
1319        assert_eq!(config.architecture, ModelArchitecture::Decoder);
1320        assert!(!config.use_bias);
1321    }
1322
1323    #[test]
1324    fn test_from_apr_metadata_encoder_architecture() {
1325        let config = TransformerConfig::from_apr_metadata(
1326            Some(768),
1327            Some(12),
1328            Some(12),
1329            Some(3072),
1330            Some(12),
1331            Some(50265),
1332            Some(514),
1333            Some(1e-5),
1334            Some(0.0),
1335            Some("codebert"),
1336        )
1337        .unwrap();
1338        assert_eq!(config.architecture, ModelArchitecture::Encoder);
1339    }
1340
1341    #[test]
1342    fn test_from_apr_metadata_roberta_architecture() {
1343        let config = TransformerConfig::from_apr_metadata(
1344            Some(768),
1345            Some(12),
1346            Some(12),
1347            Some(3072),
1348            Some(12),
1349            Some(50265),
1350            None,
1351            None,
1352            None,
1353            Some("roberta"),
1354        )
1355        .unwrap();
1356        assert_eq!(config.architecture, ModelArchitecture::Encoder);
1357    }
1358
1359    #[test]
1360    fn test_from_apr_metadata_qwen3_head_dim_override() {
1361        // Qwen3-4B: hidden=2560, 32 heads → 2560/32=80 != 128, so override needed
1362        let config = TransformerConfig::from_apr_metadata(
1363            Some(2560),
1364            Some(32),
1365            Some(8),
1366            Some(9728),
1367            Some(36),
1368            Some(151936),
1369            Some(40960),
1370            Some(1e-6),
1371            Some(1e6),
1372            Some("qwen3-4b"),
1373        )
1374        .unwrap();
1375        assert_eq!(config.head_dim_override, Some(128));
1376        assert_eq!(config.head_dim(), 128);
1377        assert!(!config.use_bias);
1378    }
1379
1380    #[test]
1381    fn test_from_apr_metadata_qwen3_no_override_needed() {
1382        // If hidden/heads = 128, no override needed
1383        let config = TransformerConfig::from_apr_metadata(
1384            Some(4096),
1385            Some(32),
1386            Some(8),
1387            Some(12288),
1388            Some(36),
1389            Some(151936),
1390            None,
1391            None,
1392            None,
1393            Some("qwen3-8b"),
1394        )
1395        .unwrap();
1396        assert!(config.head_dim_override.is_none());
1397        assert_eq!(config.head_dim(), 128);
1398    }
1399
1400    #[test]
1401    fn test_qwen2_7b_config() {
1402        let config = TransformerConfig::qwen2_7b();
1403        assert_eq!(config.hidden_size, 3584);
1404        assert_eq!(config.num_attention_heads, 28);
1405        assert_eq!(config.num_kv_heads, 4);
1406        assert_eq!(config.intermediate_size, 18944);
1407        assert_eq!(config.num_hidden_layers, 28);
1408        assert_eq!(config.vocab_size, 152064);
1409        assert!(config.use_bias);
1410        assert_eq!(config.head_dim(), 128); // 3584 / 28
1411    }
1412
1413    #[test]
1414    fn test_qwen3_4b_config() {
1415        let config = TransformerConfig::qwen3_4b();
1416        assert_eq!(config.hidden_size, 2560);
1417        assert_eq!(config.num_attention_heads, 32);
1418        assert_eq!(config.num_kv_heads, 8);
1419        assert_eq!(config.intermediate_size, 9728);
1420        assert_eq!(config.num_hidden_layers, 36);
1421        assert!(!config.use_bias);
1422        assert_eq!(config.head_dim(), 128);
1423    }
1424
1425    #[test]
1426    fn test_per_layer_weight_elements_positive() {
1427        for config in [
1428            TransformerConfig::tiny(),
1429            TransformerConfig::codebert(),
1430            TransformerConfig::qwen2_0_5b(),
1431            TransformerConfig::qwen3_4b(),
1432        ] {
1433            assert!(config.per_layer_weight_elements() > 0);
1434        }
1435    }
1436
1437    #[test]
1438    fn test_vram_shared_less_than_per_layer() {
1439        let config = TransformerConfig::qwen2_0_5b();
1440        let per_layer = config.total_training_vram_bytes(128);
1441        let shared = config.total_training_vram_bytes_shared(128);
1442        // Shared should be less for multi-layer models
1443        assert!(
1444            shared < per_layer,
1445            "Shared ({shared}) should be less than per-layer ({per_layer})"
1446        );
1447    }
1448
1449    #[test]
1450    fn test_vram_shared_monotonic() {
1451        let config = TransformerConfig::qwen2_0_5b();
1452        let mut prev = config.total_training_vram_bytes_shared(1);
1453        for s in [2, 4, 8, 16, 32, 64, 128] {
1454            let cur = config.total_training_vram_bytes_shared(s);
1455            assert!(cur > prev, "Shared VRAM must increase: seq_len={s}");
1456            prev = cur;
1457        }
1458    }
1459
1460    #[test]
1461    fn test_max_seq_len_for_vram_shared() {
1462        let config = TransformerConfig::qwen2_0_5b();
1463        let budget = 8 * 1024 * 1024 * 1024_usize; // 8 GB
1464        let max_s = config.max_seq_len_for_vram_shared(budget);
1465        assert!(max_s.is_some());
1466        let s = max_s.unwrap();
1467        assert!(config.total_training_vram_bytes_shared(s) <= budget);
1468    }
1469
1470    #[test]
1471    fn test_max_seq_len_for_vram_shared_impossible() {
1472        let config = TransformerConfig::qwen3_4b();
1473        let tiny_budget = 1024; // 1 KB
1474        assert!(config.max_seq_len_for_vram_shared(tiny_budget).is_none());
1475    }
1476
1477    #[test]
1478    fn test_max_seq_len_for_vram_shared_tightness() {
1479        let config = TransformerConfig::tiny();
1480        let budget = 10 * 1024 * 1024_usize; // 10 MB
1481        if let Some(s) = config.max_seq_len_for_vram_shared(budget) {
1482            assert!(config.total_training_vram_bytes_shared(s) <= budget);
1483            if s < config.max_position_embeddings {
1484                assert!(config.total_training_vram_bytes_shared(s + 1) > budget);
1485            }
1486        }
1487    }
1488
1489    #[test]
1490    fn test_kv_dim() {
1491        assert_eq!(TransformerConfig::qwen3_4b().kv_dim(), 1024);
1492        assert_eq!(TransformerConfig::llama2_7b().kv_dim(), 4096);
1493    }
1494
1495    #[test]
1496    fn test_per_layer_scratch_coefficients() {
1497        let config = TransformerConfig::tiny();
1498        assert!(config.per_layer_scratch_linear_coeff() > 0);
1499        let (n_quad, n_hd_linear) = config.per_layer_scratch_quadratic_coeff();
1500        assert!(n_quad > 0 && n_hd_linear > 0);
1501        assert!(config.per_layer_grad_weight_elements() > 0);
1502    }
1503}