Skip to main content

entrenar/transformer/
config.rs

1//! Transformer configuration module
2//!
3//! This module provides configuration structures for transformer models.
4
5use serde::{Deserialize, Serialize};
6
7// Well-known model architecture constants
8const LLAMA2_7B_INTERMEDIATE_SIZE: usize = 11008;
9const LLAMA2_13B_HIDDEN_SIZE: usize = 5120;
10const LLAMA2_13B_INTERMEDIATE_SIZE: usize = 13824;
11const LLAMA_VOCAB_SIZE: usize = 32000;
12const MISTRAL_INTERMEDIATE_SIZE: usize = 14336;
13const MISTRAL_MAX_SEQ_LEN: usize = 32768;
14const QWEN2_0_5B_HIDDEN_SIZE: usize = 896;
15const QWEN2_0_5B_INTERMEDIATE_SIZE: usize = 4864;
16const QWEN2_VOCAB_SIZE: usize = 151936;
17const QWEN2_MAX_SEQ_LEN: usize = 32768;
18const QWEN2_ROPE_THETA: f32 = 1_000_000.0;
19const QWEN3_4B_HIDDEN_SIZE: usize = 2560;
20const QWEN3_4B_INTERMEDIATE_SIZE: usize = 9728;
21const QWEN3_5_9B_HIDDEN_SIZE: usize = 4096;
22const QWEN3_5_9B_INTERMEDIATE_SIZE: usize = 12288;
23const QWEN3_5_VOCAB_SIZE: usize = 248320;
24const QWEN3_5_MAX_SEQ_LEN: usize = 262144;
25const DEFAULT_ROPE_THETA: f32 = 10000.0;
26
27// CodeBERT / RoBERTa constants
28const CODEBERT_HIDDEN_SIZE: usize = 768;
29const CODEBERT_INTERMEDIATE_SIZE: usize = 3072;
30const CODEBERT_VOCAB_SIZE: usize = 50265;
31const CODEBERT_MAX_POSITION: usize = 514; // 512 + 2 special tokens
32
33/// Model architecture family.
34///
35/// Determines position encoding, normalization, FFN activation, and pooling.
36#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
37#[serde(rename_all = "snake_case")]
38pub enum ModelArchitecture {
39    /// Decoder-only (LLaMA, Qwen, Mistral): RoPE, RMSNorm, SwiGLU, last-token pooling
40    #[default]
41    Decoder,
42    /// Encoder-only (BERT, RoBERTa, CodeBERT): learned positions, LayerNorm, GELU, CLS pooling
43    Encoder,
44}
45
46/// Configuration for transformer models
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct TransformerConfig {
49    /// Hidden dimension (embedding size)
50    pub hidden_size: usize,
51    /// Number of attention heads
52    pub num_attention_heads: usize,
53    /// Number of key-value heads (for grouped-query attention)
54    pub num_kv_heads: usize,
55    /// Feed-forward network intermediate dimension
56    pub intermediate_size: usize,
57    /// Number of transformer layers
58    pub num_hidden_layers: usize,
59    /// Vocabulary size
60    pub vocab_size: usize,
61    /// Maximum sequence length
62    pub max_position_embeddings: usize,
63    /// RMS normalization epsilon
64    pub rms_norm_eps: f32,
65    /// RoPE theta base
66    pub rope_theta: f32,
67    /// Whether to use bias in linear layers
68    pub use_bias: bool,
69    /// Explicit per-head dimension (overrides hidden_size / num_heads).
70    /// Required for Qwen3 where head_dim=128 but hidden_size/num_heads=80.
71    #[serde(default)]
72    pub head_dim_override: Option<usize>,
73    /// Architecture family: encoder (BERT/RoBERTa) or decoder (LLaMA/Qwen).
74    /// Determines position encoding, normalization, activation, and pooling strategy.
75    #[serde(default)]
76    pub architecture: ModelArchitecture,
77    /// HuggingFace architecture class name (e.g., "Qwen2ForCausalLM", "LlamaForCausalLM").
78    /// Used for checkpoint config.json compatibility.
79    #[serde(default)]
80    pub hf_architecture: Option<String>,
81    /// HuggingFace model type (e.g., "qwen2", "llama").
82    /// Used for checkpoint config.json compatibility.
83    #[serde(default)]
84    pub hf_model_type: Option<String>,
85    /// Whether to tie input/output embeddings (embed_tokens and lm_head).
86    /// Qwen2: true, LLaMA: false.
87    #[serde(default)]
88    pub tie_word_embeddings: bool,
89}
90
91impl TransformerConfig {
92    /// LLaMA 2 7B configuration
93    pub fn llama2_7b() -> Self {
94        Self {
95            hidden_size: 4096,
96            num_attention_heads: 32,
97            num_kv_heads: 32,
98            intermediate_size: LLAMA2_7B_INTERMEDIATE_SIZE,
99            num_hidden_layers: 32,
100            vocab_size: LLAMA_VOCAB_SIZE,
101            max_position_embeddings: 4096,
102            rms_norm_eps: 1e-6,
103            rope_theta: DEFAULT_ROPE_THETA,
104            use_bias: false,
105            head_dim_override: None,
106            architecture: ModelArchitecture::Decoder,
107            hf_architecture: None,
108            hf_model_type: None,
109            tie_word_embeddings: false,
110        }
111    }
112
113    /// LLaMA 2 13B configuration
114    pub fn llama2_13b() -> Self {
115        Self {
116            hidden_size: LLAMA2_13B_HIDDEN_SIZE,
117            num_attention_heads: 40,
118            num_kv_heads: 40,
119            intermediate_size: LLAMA2_13B_INTERMEDIATE_SIZE,
120            num_hidden_layers: 40,
121            vocab_size: LLAMA_VOCAB_SIZE,
122            max_position_embeddings: 4096,
123            rms_norm_eps: 1e-6,
124            rope_theta: DEFAULT_ROPE_THETA,
125            use_bias: false,
126            head_dim_override: None,
127            architecture: ModelArchitecture::Decoder,
128            hf_architecture: None,
129            hf_model_type: None,
130            tie_word_embeddings: false,
131        }
132    }
133
134    /// Mistral 7B configuration
135    pub fn mistral_7b() -> Self {
136        Self {
137            hidden_size: 4096,
138            num_attention_heads: 32,
139            num_kv_heads: 8, // Grouped-query attention
140            intermediate_size: MISTRAL_INTERMEDIATE_SIZE,
141            num_hidden_layers: 32,
142            vocab_size: LLAMA_VOCAB_SIZE,
143            max_position_embeddings: MISTRAL_MAX_SEQ_LEN,
144            rms_norm_eps: 1e-5,
145            rope_theta: DEFAULT_ROPE_THETA,
146            use_bias: false,
147            head_dim_override: None,
148            architecture: ModelArchitecture::Decoder,
149            hf_architecture: None,
150            hf_model_type: None,
151            tie_word_embeddings: false,
152        }
153    }
154
155    /// Qwen2 0.5B configuration (good for testing).
156    ///
157    /// Empirically verified against
158    /// `~/.cache/huggingface/hub/models--Qwen--Qwen2.5-Coder-0.5B-Instruct/.../config.json`
159    /// 2026-05-04. Pinned by
160    /// `contracts/apr-pretrain-arch-polymorphic-v1.yaml` FALSIFY-001.
161    ///
162    /// Note: `tie_word_embeddings: true` is the Qwen2.5 0.5B/1.5B convention
163    /// (the 7B variant turns this OFF; see `qwen2_7b()`). This is a Qwen
164    /// scaling-law quirk — small Qwen models reuse embedding+lm_head weights
165    /// to save params, but the larger variants pay the param cost for
166    /// untied weights. Drift-prevention: keeping this `true` is required
167    /// for SHIP-TWO-001 §49 MODEL-2 fine-tune from a Qwen2.5-Coder-0.5B
168    /// checkpoint.
169    pub fn qwen2_0_5b() -> Self {
170        Self {
171            hidden_size: QWEN2_0_5B_HIDDEN_SIZE,
172            num_attention_heads: 14,
173            num_kv_heads: 2,
174            intermediate_size: QWEN2_0_5B_INTERMEDIATE_SIZE,
175            num_hidden_layers: 24,
176            vocab_size: QWEN2_VOCAB_SIZE,
177            max_position_embeddings: QWEN2_MAX_SEQ_LEN,
178            rms_norm_eps: 1e-6,
179            rope_theta: QWEN2_ROPE_THETA,
180            use_bias: true,
181            head_dim_override: None,
182            architecture: ModelArchitecture::Decoder,
183            hf_architecture: None,
184            hf_model_type: None,
185            tie_word_embeddings: true,
186        }
187    }
188
189    /// Qwen2.5-Coder-1.5B-Instruct: 28 layers, 12 heads, 2 KV heads, hidden=1536
190    #[rustfmt::skip]
191    pub fn qwen2_1_5b() -> Self { Self { hidden_size: 1536, num_attention_heads: 12, intermediate_size: 8960, num_hidden_layers: 28, vocab_size: 151936, ..Self::qwen2_0_5b() } }
192
193    /// Qwen2.5-Coder 7B configuration (GH-371)
194    ///
195    /// Qwen2.5-Coder-7B-Instruct: 28 layers, 28 heads, 4 KV heads, hidden=3584
196    /// Contract: contracts/model-families/qwen2.yaml
197    pub fn qwen2_7b() -> Self {
198        Self {
199            hidden_size: 3584,
200            num_attention_heads: 28,
201            num_kv_heads: 4,
202            intermediate_size: 18944,
203            num_hidden_layers: 28,
204            vocab_size: 152064,
205            max_position_embeddings: QWEN2_MAX_SEQ_LEN,
206            rms_norm_eps: 1e-6,
207            rope_theta: QWEN2_ROPE_THETA,
208            use_bias: true,
209            head_dim_override: None,
210            architecture: ModelArchitecture::Decoder,
211            hf_architecture: None,
212            hf_model_type: None,
213            tie_word_embeddings: false,
214        }
215    }
216
217    /// Qwen3 4B configuration
218    ///
219    /// Qwen3-4B: 36 layers, 32 heads, 8 KV heads, hidden=2560, head_dim=128.
220    /// Same vocab_size as Qwen2 (151936). No attention bias (Qwen3 family).
221    pub fn qwen3_4b() -> Self {
222        Self {
223            hidden_size: QWEN3_4B_HIDDEN_SIZE,
224            num_attention_heads: 32,
225            num_kv_heads: 8,
226            intermediate_size: QWEN3_4B_INTERMEDIATE_SIZE,
227            num_hidden_layers: 36,
228            vocab_size: QWEN2_VOCAB_SIZE, // 151936, same as Qwen2
229            max_position_embeddings: 40960,
230            rms_norm_eps: 1e-6,
231            rope_theta: QWEN2_ROPE_THETA, // 1M theta
232            use_bias: false,              // Qwen3: no attention bias
233            head_dim_override: Some(128), // Contract: qwen3.yaml §4b.head_dim=128
234            architecture: ModelArchitecture::Decoder,
235            hf_architecture: None,
236            hf_model_type: None,
237            tie_word_embeddings: false,
238        }
239    }
240
241    /// Qwen3.5 9B configuration
242    ///
243    /// Key differences from Qwen2: no attention bias, head_dim=256 (explicit),
244    /// vocab_size=248320, hybrid attention (standard + linear layers).
245    /// Contract: contracts/model-families/qwen3_5.yaml
246    pub fn qwen3_5_9b() -> Self {
247        Self {
248            hidden_size: QWEN3_5_9B_HIDDEN_SIZE,
249            num_attention_heads: 16,
250            num_kv_heads: 4,
251            intermediate_size: QWEN3_5_9B_INTERMEDIATE_SIZE,
252            num_hidden_layers: 32,
253            vocab_size: QWEN3_5_VOCAB_SIZE,
254            max_position_embeddings: QWEN3_5_MAX_SEQ_LEN,
255            rms_norm_eps: 1e-6,
256            rope_theta: QWEN2_ROPE_THETA, // Same 1M theta as Qwen2
257            use_bias: false,              // KEY: no attention bias (unlike Qwen2)
258            head_dim_override: None,      // 4096/16=256, no override needed
259            architecture: ModelArchitecture::Decoder,
260            hf_architecture: None,
261            hf_model_type: None,
262            tie_word_embeddings: false,
263        }
264    }
265
266    /// Construct from APR v2 metadata fields.
267    ///
268    /// CONTRACT: The `.apr` file is the single source of truth for model
269    /// architecture. These fields were validated at import time by the
270    /// `tensor-layout-v1` contract. This function propagates that contract
271    /// to the training pipeline — no hardcoded lookups, no silent fallbacks.
272    ///
273    /// Returns None if any required field is missing, forcing the caller to
274    /// handle the error explicitly rather than silently degrading to tiny().
275    ///
276    /// GH-376: Fixes instruct pipeline ignoring .apr architecture metadata.
277    pub fn from_apr_metadata(
278        hidden_size: Option<usize>,
279        num_heads: Option<usize>,
280        num_kv_heads: Option<usize>,
281        intermediate_size: Option<usize>,
282        num_layers: Option<usize>,
283        vocab_size: Option<usize>,
284        max_position_embeddings: Option<usize>,
285        rms_norm_eps: Option<f32>,
286        rope_theta: Option<f32>,
287        architecture: Option<&str>,
288    ) -> Option<Self> {
289        let hidden = hidden_size?;
290        let heads = num_heads?;
291        let layers = num_layers?;
292        let vocab = vocab_size?;
293        let intermediate = intermediate_size?;
294
295        // Qwen3 family: head_dim=128 is explicit, not hidden/heads
296        // Qwen2 family: use_bias=true
297        let (use_bias, head_dim_override) = match architecture {
298            Some(a) if a.starts_with("qwen3") => {
299                // Qwen3: no bias, explicit head_dim=128 when hidden/heads != 128
300                let computed = hidden / heads;
301                let override_dim = if computed == 128 { None } else { Some(128) };
302                (false, override_dim)
303            }
304            Some(a) if a.starts_with("qwen2") => (true, None),
305            _ => (false, None),
306        };
307
308        Some(Self {
309            hidden_size: hidden,
310            num_attention_heads: heads,
311            num_kv_heads: num_kv_heads.unwrap_or(heads),
312            intermediate_size: intermediate,
313            num_hidden_layers: layers,
314            vocab_size: vocab,
315            max_position_embeddings: max_position_embeddings.unwrap_or(32768),
316            rms_norm_eps: rms_norm_eps.unwrap_or(1e-6),
317            rope_theta: rope_theta.unwrap_or(DEFAULT_ROPE_THETA),
318            use_bias,
319            head_dim_override,
320            architecture: match architecture {
321                Some(a) if a.contains("bert") || a.contains("roberta") => {
322                    ModelArchitecture::Encoder
323                }
324                _ => ModelArchitecture::Decoder,
325            },
326            hf_architecture: None,
327            hf_model_type: None,
328            tie_word_embeddings: false,
329        })
330    }
331
332    /// Resolve config from a model size string. Errors on unknown sizes.
333    ///
334    /// GH-377: Replaces `_ => TransformerConfig::tiny()` catch-all pattern.
335    /// This is the single canonical mapping from size strings to configs.
336    /// Every callsite that previously had its own match table should use this.
337    pub fn from_size_str(size: &str) -> Result<Self, String> {
338        match size {
339            "codebert" | "codebert-base" | "125M" => Ok(Self::codebert()),
340            "0.5B" | "500M" | "qwen2-0.5b" => Ok(Self::qwen2_0_5b()),
341            "1.5B" | "qwen2.5-1.5b" | "qwen2-1.5b" => Ok(Self::qwen2_1_5b()),
342            "7B" | "qwen2.5-7b" => Ok(Self::qwen2_7b()),
343            "4B" | "qwen3-4b" | "qwen3" => Ok(Self::qwen3_4b()),
344            "9B" | "qwen3.5-9b" | "qwen3_5" | "qwen3.5" => Ok(Self::qwen3_5_9b()),
345            unknown => Err(format!(
346                "Unknown model size '{unknown}'. Known sizes: codebert, 0.5B, 4B, 7B, 9B"
347            )),
348        }
349    }
350
351    /// CodeBERT (microsoft/codebert-base) encoder configuration.
352    ///
353    /// RoBERTa architecture: 12 layers, 768 hidden, 12 heads, GELU, LayerNorm, learned positions.
354    /// SSC v11 Section 4: 125M params, ~20ms CPU inference, WASM-deployable.
355    pub fn codebert() -> Self {
356        Self {
357            hidden_size: CODEBERT_HIDDEN_SIZE,
358            num_attention_heads: 12,
359            num_kv_heads: 12, // No GQA in BERT
360            intermediate_size: CODEBERT_INTERMEDIATE_SIZE,
361            num_hidden_layers: 12,
362            vocab_size: CODEBERT_VOCAB_SIZE,
363            max_position_embeddings: CODEBERT_MAX_POSITION,
364            rms_norm_eps: 1e-5, // LayerNorm eps for RoBERTa
365            rope_theta: 0.0,    // Not used (learned positions)
366            use_bias: true,
367            head_dim_override: None,
368            architecture: ModelArchitecture::Encoder,
369            hf_architecture: None,
370            hf_model_type: None,
371            tie_word_embeddings: false,
372        }
373    }
374
375    /// Tiny configuration for testing
376    pub fn tiny() -> Self {
377        Self {
378            hidden_size: 64,
379            num_attention_heads: 2,
380            num_kv_heads: 2,
381            intermediate_size: 256,
382            num_hidden_layers: 2,
383            vocab_size: 1000,
384            max_position_embeddings: 512,
385            rms_norm_eps: 1e-6,
386            rope_theta: DEFAULT_ROPE_THETA,
387            use_bias: false,
388            head_dim_override: None,
389            architecture: ModelArchitecture::Decoder,
390            hf_architecture: None,
391            hf_model_type: None,
392            tie_word_embeddings: false,
393        }
394    }
395
396    /// Whether this config describes an encoder (BERT/RoBERTa) architecture.
397    pub fn is_encoder(&self) -> bool {
398        self.architecture == ModelArchitecture::Encoder
399    }
400
401    /// HuggingFace architecture class name for checkpoint config.json.
402    /// Uses explicit override if set, otherwise infers from config.
403    pub fn hf_architecture_name(&self) -> &str {
404        if let Some(ref name) = self.hf_architecture {
405            return name;
406        }
407        // Infer from model characteristics
408        if self.is_encoder() {
409            "BertModel"
410        } else if self.use_bias && self.vocab_size > 150000 {
411            // Qwen2 family: has attention biases + large vocab
412            "Qwen2ForCausalLM"
413        } else {
414            "LlamaForCausalLM"
415        }
416    }
417
418    /// HuggingFace model_type string for checkpoint config.json.
419    pub fn hf_model_type_str(&self) -> &str {
420        if let Some(ref mt) = self.hf_model_type {
421            return mt;
422        }
423        if self.is_encoder() {
424            "roberta"
425        } else if self.use_bias && self.vocab_size > 150000 {
426            "qwen2"
427        } else {
428            "llama"
429        }
430    }
431
432    /// Whether embeddings are tied (embed_tokens == lm_head).
433    /// Uses explicit flag if set, otherwise infers from architecture.
434    pub fn ties_embeddings(&self) -> bool {
435        if self.tie_word_embeddings {
436            return true;
437        }
438        // Qwen2 ties embeddings by default
439        self.use_bias && self.vocab_size > 150000
440    }
441
442    /// Per-head dimension.
443    ///
444    /// Uses explicit override when set (Qwen3: head_dim=128 with hidden=2560, 32 heads).
445    /// Falls back to hidden_size / num_heads for standard architectures.
446    pub fn head_dim(&self) -> usize {
447        self.head_dim_override.unwrap_or(self.hidden_size / self.num_attention_heads)
448    }
449
450    /// Total Q/O projection dimension = num_heads * head_dim.
451    ///
452    /// Equals hidden_size for standard architectures but differs when head_dim
453    /// is explicitly overridden (e.g. Qwen3-4B: 32 * 128 = 4096 != 2560).
454    pub fn q_dim(&self) -> usize {
455        self.num_attention_heads * self.head_dim()
456    }
457
458    // =========================================================================
459    // VRAM Budget Solver (Provable Design-by-Contract)
460    //
461    // Contract: contracts/model-families/qwen3.yaml §CUDA TRAINING RESOURCE BUDGET
462    // Meyer (1992) "No Hidden Clauses": every term maps 1:1 to a GpuBuffer::new()
463    // call in cuda_block.rs. No magic numbers — all derived from model dims.
464    // =========================================================================
465
466    /// KV hidden dimension = num_kv_heads * head_dim.
467    fn kv_dim(&self) -> usize {
468        self.num_kv_heads * self.head_dim()
469    }
470
471    /// Per-layer weight VRAM in f32 elements (constant, independent of seq_len).
472    ///
473    /// Maps to cuda_block.rs lines 212-220: `GpuBuffer::from_host()` uploads.
474    pub fn per_layer_weight_elements(&self) -> usize {
475        let h = self.hidden_size;
476        let q = self.q_dim();
477        let kv = self.kv_dim();
478        let i = self.intermediate_size;
479        // w_q: q*h, w_k: kv*h, w_v: kv*h, w_o: h*q, w_gate: i*h, w_up: i*h, w_down: h*i
480        // input_norm: h, post_attn_norm: h
481        q * h + kv * h * 2 + h * q + i * h * 3 + h * 2
482    }
483
484    /// Per-layer gradient weight VRAM in f32 elements (constant, independent of seq_len).
485    ///
486    /// Maps to cuda_block.rs CudaGradWorkspace: constant-size gradient buffers.
487    /// grad_w_q uses q_dim*hidden, grad_w_o uses hidden*q_dim (#262).
488    fn per_layer_grad_weight_elements(&self) -> usize {
489        let h = self.hidden_size;
490        let q = self.q_dim();
491        let kv = self.kv_dim();
492        let i = self.intermediate_size;
493        // grad_input_norm: h, grad_post_attn_norm: h
494        // grad_gate: h*i, grad_up: h*i, grad_down: i*h
495        // grad_w_q: q*h, grad_w_k: h*kv, grad_w_v: h*kv, grad_w_o: h*q
496        h * 2 + h * i * 3 + q * h + h * q + h * kv * 2
497    }
498
499    /// Per-layer scratch elements that scale linearly with seq_len.
500    ///
501    /// Maps to cuda_block.rs lines 224-236, 243-248: `GpuBuffer::new(_, S * dim)`.
502    fn per_layer_scratch_linear_coeff(&self) -> usize {
503        let h = self.hidden_size;
504        let kv = self.kv_dim();
505        let i = self.intermediate_size;
506        let n = self.num_attention_heads;
507        let hd = self.head_dim();
508        // Forward: norm1(h) + q(h) + k(kv) + v(kv) + attn_out(h) + o_proj(h)
509        //          + residual1(h) + norm2(h) + gate(i) + up(i) + swiglu(i) + ffn(h)
510        // Backward: grad_hidden(h) + grad_swiglu(i)
511        // Attention reshape: q_batched(N*hd) + kv_temp(N*hd) + kv_temp2(N*hd)
512        h * 8 + kv * 2 + i * 4 + n * hd * 3
513    }
514
515    /// Per-layer scratch elements that scale quadratically with seq_len.
516    ///
517    /// Returns (quadratic_coeff, linear_fallback_coeff) because:
518    ///   attn_scores = N * S * S
519    ///   grad_attn_scores = N * S * max(S, hd)
520    ///
521    /// When S >= hd: total = 2 * N * S^2 (pure quadratic)
522    /// When S < hd:  total = N * S^2 + N * S * hd (mixed)
523    fn per_layer_scratch_quadratic_coeff(&self) -> (usize, usize) {
524        let n = self.num_attention_heads;
525        let hd = self.head_dim();
526        // attn_scores: N * S * S (always quadratic)
527        // grad_attn_scores: N * S * max(S, hd)
528        //   When S >= hd → N * S * S (quadratic)
529        //   When S < hd  → N * S * hd (linear)
530        (n, n * hd) // (quadratic_coeff, linear_fallback for grad when S < hd)
531    }
532
533    /// Total VRAM in bytes for all layers at a given max_seq_len.
534    ///
535    /// Postcondition: result is exact for the current cuda_block.rs buffer layout.
536    pub fn total_training_vram_bytes(&self, max_seq_len: usize) -> usize {
537        let l = self.num_hidden_layers;
538        let s = max_seq_len;
539        let hd = self.head_dim();
540
541        let constant_per_layer =
542            self.per_layer_weight_elements() + self.per_layer_grad_weight_elements();
543        let linear_per_layer = self.per_layer_scratch_linear_coeff() * s;
544
545        let (n_quad, n_hd_linear) = self.per_layer_scratch_quadratic_coeff();
546        let quadratic_per_layer =
547            if s >= hd { 2 * n_quad * s * s } else { n_quad * s * s + n_hd_linear * s };
548
549        let elements_per_layer = constant_per_layer + linear_per_layer + quadratic_per_layer;
550        l * elements_per_layer * 4 // f32 = 4 bytes
551    }
552
553    /// Total VRAM in bytes with SHARED scratch workspace (1 per model, not per layer).
554    ///
555    /// This is the correct budget formula when gradient buffers are shared across
556    /// layers (canonical in PyTorch/JAX). Only weights are truly per-layer.
557    ///
558    /// Postcondition: result < total_training_vram_bytes(s) for L > 1
559    pub fn total_training_vram_bytes_shared(&self, max_seq_len: usize) -> usize {
560        let l = self.num_hidden_layers;
561        let s = max_seq_len;
562        let hd = self.head_dim();
563
564        // Weights are per-layer (unavoidable — must all be resident on GPU)
565        let weights_total = l * self.per_layer_weight_elements();
566
567        // Gradient weight buffers: SHARED (one set, reused across layers)
568        let grad_weights_shared = self.per_layer_grad_weight_elements();
569
570        // Seq-len-dependent scratch: SHARED (one set)
571        let linear_shared = self.per_layer_scratch_linear_coeff() * s;
572        let (n_quad, n_hd_linear) = self.per_layer_scratch_quadratic_coeff();
573        let quadratic_shared =
574            if s >= hd { 2 * n_quad * s * s } else { n_quad * s * s + n_hd_linear * s };
575
576        let total_elements = weights_total + grad_weights_shared + linear_shared + quadratic_shared;
577        total_elements * 4 // f32 = 4 bytes
578    }
579
580    /// Solve for the maximum seq_len that fits in the given VRAM budget (bytes),
581    /// using shared scratch workspace.
582    ///
583    /// This is the solver to use with the shared-scratch architecture.
584    /// Returns None if even seq_len=1 exceeds the budget.
585    pub fn max_seq_len_for_vram_shared(&self, vram_bytes: usize) -> Option<usize> {
586        if self.total_training_vram_bytes_shared(1) > vram_bytes {
587            return None;
588        }
589
590        let mut lo: usize = 1;
591        let mut hi: usize = self.max_position_embeddings;
592
593        while lo < hi {
594            let mid = lo + (hi - lo).div_ceil(2);
595            if self.total_training_vram_bytes_shared(mid) <= vram_bytes {
596                lo = mid;
597            } else {
598                hi = mid - 1;
599            }
600        }
601
602        Some(lo)
603    }
604
605    /// Solve for the maximum seq_len that fits in the given VRAM budget (bytes).
606    ///
607    /// Binary search over [1, max_position_embeddings].
608    /// Returns None if even seq_len=1 exceeds the budget.
609    ///
610    /// Precondition: vram_bytes > 0
611    /// Postcondition: total_training_vram_bytes(result) <= vram_bytes
612    pub fn max_seq_len_for_vram(&self, vram_bytes: usize) -> Option<usize> {
613        if self.total_training_vram_bytes(1) > vram_bytes {
614            return None;
615        }
616
617        let mut lo: usize = 1;
618        let mut hi: usize = self.max_position_embeddings;
619
620        while lo < hi {
621            let mid = lo + (hi - lo).div_ceil(2);
622            if self.total_training_vram_bytes(mid) <= vram_bytes {
623                lo = mid;
624            } else {
625                hi = mid - 1;
626            }
627        }
628
629        Some(lo)
630    }
631}
632
633#[cfg(test)]
634mod tests {
635    use super::*;
636
637    #[test]
638    fn test_transformer_config_llama2() {
639        let config = TransformerConfig::llama2_7b();
640        assert_eq!(config.hidden_size, 4096);
641        assert_eq!(config.num_attention_heads, 32);
642        assert_eq!(config.head_dim(), 128);
643    }
644
645    #[test]
646    fn test_transformer_config_tiny() {
647        let config = TransformerConfig::tiny();
648        assert_eq!(config.hidden_size, 64);
649        assert_eq!(config.num_attention_heads, 2);
650        assert_eq!(config.head_dim(), 32);
651    }
652
653    #[test]
654    fn test_config_serialization() {
655        let config = TransformerConfig::llama2_7b();
656        let json = serde_json::to_string(&config).expect("JSON serialization should succeed");
657        let restored: TransformerConfig =
658            serde_json::from_str(&json).expect("JSON deserialization should succeed");
659        assert_eq!(restored.hidden_size, config.hidden_size);
660        assert_eq!(restored.num_attention_heads, config.num_attention_heads);
661    }
662
663    #[test]
664    fn test_mistral_config() {
665        let config = TransformerConfig::mistral_7b();
666        assert_eq!(config.num_kv_heads, 8); // Grouped-query attention
667        assert_eq!(config.num_attention_heads, 32);
668        // 32 / 8 = 4 query heads per KV head
669    }
670
671    /// FALSIFY-APR-PRETRAIN-ARCH-001 — `qwen2_0_5b()` constructor matches
672    /// HF config byte-for-byte.
673    ///
674    /// Empirically verified against
675    /// `~/.cache/huggingface/hub/models--Qwen--Qwen2.5-Coder-0.5B-Instruct/.../config.json`
676    /// on 2026-05-04 for SHIP-TWO-001 §50.4 step 5b. If any of these 11
677    /// fields drift from HF, the §49 fine-tune path's init weights will
678    /// load into the wrong-shape optimizer and produce silent gibberish.
679    #[test]
680    fn qwen2_0_5b_matches_hf_config_2026_05_04() {
681        let config = TransformerConfig::qwen2_0_5b();
682        assert_eq!(config.hidden_size, 896, "hidden_size");
683        assert_eq!(config.num_attention_heads, 14, "num_attention_heads");
684        assert_eq!(config.num_kv_heads, 2, "num_kv_heads (GQA-7:1)");
685        assert_eq!(config.intermediate_size, 4864, "intermediate_size");
686        assert_eq!(config.num_hidden_layers, 24, "num_hidden_layers");
687        assert_eq!(config.vocab_size, 151_936, "vocab_size");
688        assert_eq!(config.max_position_embeddings, 32_768, "max_position_embeddings");
689        assert!(
690            (config.rms_norm_eps - 1e-6).abs() < f32::EPSILON,
691            "rms_norm_eps={}, want 1e-6",
692            config.rms_norm_eps
693        );
694        assert!(
695            (config.rope_theta - 1_000_000.0).abs() < f32::EPSILON,
696            "rope_theta={}, want 1_000_000.0",
697            config.rope_theta
698        );
699        assert!(config.use_bias, "use_bias must be true (Qwen2 quirk)");
700        assert!(
701            config.tie_word_embeddings,
702            "tie_word_embeddings must be true for Qwen2.5 0.5B (HF config 2026-05-04)"
703        );
704        assert_eq!(config.architecture, ModelArchitecture::Decoder);
705        // GQA ratio = 14/2 = 7, the canonical Qwen2.5-0.5B GQA-7:1.
706        assert_eq!(config.num_attention_heads / config.num_kv_heads, 7);
707    }
708
709    /// FALSIFY-APR-PRETRAIN-INIT-POPULATE-COVERAGE-001 (RED-then-GREEN):
710    /// `Transformer::new(qwen2_0_5b())` MUST allocate Q/K/V projection
711    /// biases when `config.use_bias == true`. Without this invariant,
712    /// `populate_trainer_from_init_tensors` silently drops 24 layers ×
713    /// 3 biases = 72 init tensors during populate, producing a hybrid
714    /// model whose forward pass is structurally wrong.
715    ///
716    /// Discovered 2026-05-09 via the 5g.2 LIVE smoke producing
717    /// val_loss=0.0008 (implausibly low; see
718    /// `evidence/section-59-5g-2-dispatch-2026-05-09/README.md`).
719    /// Root cause: `MultiHeadAttention::new` hardcoded `b_q: None,
720    /// b_k: None, b_v: None` regardless of `config.use_bias`. The
721    /// existing FALSIFY-001 (`qwen2_0_5b_matches_hf_config_2026_05_04`)
722    /// only checked the CONFIG STRUCT FIELD VALUES; it did not
723    /// observe that `MultiHeadAttention::new(config)` ignored
724    /// `config.use_bias`. This is the gap-between-contracts class
725    /// of defect that provable-contracts can only catch when a
726    /// falsifier observes the gap.
727    ///
728    /// Methodology: we pick the canonical 290-tensor count of
729    /// Qwen2.5-Coder-0.5B-Instruct (per HF config) and assert
730    /// `Transformer::new(qwen2_0_5b()).named_parameters().len() == 290`.
731    ///   2 (embed_tokens.weight + model.norm.weight)
732    /// + 24 layers × 12 params/layer (2 norms + 4 attn weights +
733    ///                                 3 attn biases + 3 mlp weights)
734    /// = 2 + 288 = 290. Tied lm_head shares with embed_tokens, so
735    /// it does NOT appear as an extra named parameter.
736    ///
737    /// Spec: SPEC-SHIP-TWO-001 §59 (forthcoming) val_loss anomaly
738    /// → §50.4 step 5f.6 (populate-coverage cascade).
739    #[test]
740    fn falsify_qwen2_0_5b_named_parameters_count_matches_hf() {
741        use super::super::Transformer;
742        let config = TransformerConfig::qwen2_0_5b();
743        let model = Transformer::new(&config);
744        let params = model.named_parameters();
745        let actual = params.len();
746        let expected = 2 + 24 * 12; // embed + norm + 24 layers × 12 params
747        assert_eq!(
748            actual, expected,
749            "FALSIFY-APR-PRETRAIN-INIT-POPULATE-COVERAGE-001: \
750             Transformer::new(qwen2_0_5b()).named_parameters().len() = {actual}, \
751             expected {expected}. Missing params likely include Q/K/V \
752             projection biases (24 layers × 3 = 72 expected biases) — \
753             MultiHeadAttention::new must allocate them when \
754             config.use_bias == true. See evidence/section-59-5g-2-\
755             dispatch-2026-05-09/README.md for the val_loss=0.0008 \
756             anomaly that surfaced this gap.",
757        );
758    }
759
760    /// FALSIFY-APR-PRETRAIN-INIT-POPULATE-COVERAGE-002 (paired with -001):
761    /// Every layer in `Transformer::new(qwen2_0_5b())` MUST expose
762    /// `q_proj.bias`, `k_proj.bias`, `v_proj.bias` in its
763    /// `named_parameters()` output when `config.use_bias == true`.
764    /// This is a stricter form of -001 — it not only counts but
765    /// names the missing tensors so the populate path's BTreeMap
766    /// lookup hits real init keys.
767    #[test]
768    fn falsify_qwen2_0_5b_layers_expose_qkv_biases_when_use_bias_true() {
769        use super::super::Transformer;
770        let config = TransformerConfig::qwen2_0_5b();
771        assert!(config.use_bias, "qwen2_0_5b config must declare use_bias=true");
772        let model = Transformer::new(&config);
773        let params = model.named_parameters();
774        let names: std::collections::BTreeSet<&str> =
775            params.iter().map(|(name, _)| name.as_str()).collect();
776
777        for layer_idx in 0..24 {
778            for proj in &["q_proj", "k_proj", "v_proj"] {
779                let key = format!("model.layers.{layer_idx}.self_attn.{proj}.bias");
780                assert!(
781                    names.contains(key.as_str()),
782                    "FALSIFY-APR-PRETRAIN-INIT-POPULATE-COVERAGE-002: \
783                     missing named parameter `{key}` despite use_bias=true. \
784                     MultiHeadAttention::new MUST allocate b_{} when \
785                     config.use_bias is true; today it hardcodes None.",
786                    proj.split('_').next().unwrap_or(proj)
787                );
788            }
789        }
790    }
791
792    /// Drift-prevention: `qwen2_1_5b()` inherits `tie_word_embeddings` from
793    /// `qwen2_0_5b()` via `..Self::qwen2_0_5b()` spread. If someone splits
794    /// the inheritance, this test catches the silent flip.
795    #[test]
796    fn qwen2_1_5b_inherits_tie_word_embeddings_from_0_5b() {
797        let parent = TransformerConfig::qwen2_0_5b();
798        let child = TransformerConfig::qwen2_1_5b();
799        assert_eq!(
800            child.tie_word_embeddings, parent.tie_word_embeddings,
801            "qwen2_1_5b must inherit tie_word_embeddings from qwen2_0_5b — both are HF tie=true"
802        );
803        assert!(
804            child.tie_word_embeddings,
805            "qwen2_1_5b tie_word_embeddings must be true (HF config 2026-05-04)"
806        );
807    }
808
809    /// Pin the Qwen scaling-law quirk: 0.5B + 1.5B tie embeddings, 7B does not.
810    /// If the 7B is ever changed to inherit from 0.5B, this test catches it
811    /// before an operator silently fine-tunes a 7B with the wrong head shape.
812    #[test]
813    fn qwen2_7b_does_not_tie_embeddings() {
814        let config = TransformerConfig::qwen2_7b();
815        assert!(
816            !config.tie_word_embeddings,
817            "qwen2_7b tie_word_embeddings MUST be false per HF config 2026-05-04 — \
818             larger Qwen variants pay param cost for untied weights"
819        );
820    }
821
822    #[test]
823    fn test_qwen2_config() {
824        let config = TransformerConfig::qwen2_0_5b();
825        assert!(config.use_bias);
826        assert_eq!(config.vocab_size, 151936);
827    }
828
829    #[test]
830    fn test_llama2_13b_config() {
831        let config = TransformerConfig::llama2_13b();
832        assert_eq!(config.hidden_size, 5120);
833        assert_eq!(config.num_attention_heads, 40);
834        assert_eq!(config.num_hidden_layers, 40);
835        assert_eq!(config.head_dim(), 128); // 5120 / 40 = 128
836    }
837
838    #[test]
839    fn test_config_yaml_serialization() {
840        let config = TransformerConfig::tiny();
841        let yaml = serde_yaml::to_string(&config).expect("config should be valid");
842        let restored: TransformerConfig =
843            serde_yaml::from_str(&yaml).expect("config should be valid");
844        assert_eq!(restored.hidden_size, config.hidden_size);
845        assert_eq!(restored.num_hidden_layers, config.num_hidden_layers);
846    }
847
848    #[test]
849    fn test_grouped_query_attention_ratio() {
850        let config = TransformerConfig::mistral_7b();
851        let heads_per_kv = config.num_attention_heads / config.num_kv_heads;
852        assert_eq!(heads_per_kv, 4); // 32 / 8 = 4
853    }
854
855    #[test]
856    fn test_config_clone() {
857        let config = TransformerConfig::llama2_7b();
858        let cloned = config.clone();
859        assert_eq!(config.hidden_size, cloned.hidden_size);
860        assert_eq!(config.vocab_size, cloned.vocab_size);
861    }
862
863    #[test]
864    fn test_qwen3_5_9b_config() {
865        let config = TransformerConfig::qwen3_5_9b();
866        assert_eq!(config.hidden_size, 4096);
867        assert_eq!(config.num_attention_heads, 16);
868        assert_eq!(config.num_kv_heads, 4);
869        assert_eq!(config.intermediate_size, 12288);
870        assert_eq!(config.num_hidden_layers, 32);
871        assert_eq!(config.vocab_size, 248320);
872        assert_eq!(config.max_position_embeddings, 262144);
873        assert!(!config.use_bias);
874    }
875
876    #[test]
877    fn test_qwen3_5_9b_head_dim() {
878        let config = TransformerConfig::qwen3_5_9b();
879        // 4096 / 16 = 256 (explicit head_dim, not derived from hidden/heads ratio)
880        assert_eq!(config.head_dim(), 256);
881    }
882
883    #[test]
884    fn test_qwen3_5_9b_gqa_ratio() {
885        let config = TransformerConfig::qwen3_5_9b();
886        let heads_per_kv = config.num_attention_heads / config.num_kv_heads;
887        assert_eq!(heads_per_kv, 4); // 16 / 4 = 4 Q heads per KV head
888    }
889
890    // =========================================================================
891    // from_apr_metadata contract tests (GH-376)
892    // =========================================================================
893
894    #[test]
895    fn test_from_apr_metadata_qwen3_8b() {
896        // Qwen3-8B: 36 layers, 32 heads, 8 KV heads, hidden=4096, head_dim=128
897        let config = TransformerConfig::from_apr_metadata(
898            Some(4096),   // hidden_size
899            Some(32),     // num_heads
900            Some(8),      // num_kv_heads
901            Some(12288),  // intermediate_size
902            Some(36),     // num_layers
903            Some(151936), // vocab_size
904            Some(40960),  // max_position_embeddings
905            Some(1e-6),   // rms_norm_eps
906            Some(1e6),    // rope_theta
907            Some("qwen3"),
908        )
909        .expect("all required fields present");
910
911        assert_eq!(config.hidden_size, 4096);
912        assert_eq!(config.num_attention_heads, 32);
913        assert_eq!(config.num_kv_heads, 8);
914        assert_eq!(config.num_hidden_layers, 36);
915        assert_eq!(config.vocab_size, 151936);
916        assert_eq!(config.head_dim(), 128); // 4096/32=128, no override needed
917        assert!(!config.use_bias); // Qwen3: no bias
918    }
919
920    #[test]
921    fn test_from_apr_metadata_qwen2_7b() {
922        // Qwen2.5 should get use_bias=true
923        let config = TransformerConfig::from_apr_metadata(
924            Some(3584),
925            Some(28),
926            Some(4),
927            Some(18944),
928            Some(28),
929            Some(152064),
930            Some(32768),
931            Some(1e-6),
932            Some(1e6),
933            Some("qwen2"),
934        )
935        .expect("all required fields present");
936
937        assert!(config.use_bias); // Qwen2: has bias
938        assert_eq!(config.head_dim(), 128); // 3584/28=128
939    }
940
941    #[test]
942    fn test_from_apr_metadata_missing_required_returns_none() {
943        // Missing hidden_size — must return None, not silently degrade
944        assert!(TransformerConfig::from_apr_metadata(
945            None,
946            Some(32),
947            Some(8),
948            Some(12288),
949            Some(36),
950            Some(151936),
951            Some(40960),
952            Some(1e-6),
953            Some(1e6),
954            Some("qwen3"),
955        )
956        .is_none());
957
958        // Missing num_layers
959        assert!(TransformerConfig::from_apr_metadata(
960            Some(4096),
961            Some(32),
962            Some(8),
963            Some(12288),
964            None,
965            Some(151936),
966            Some(40960),
967            Some(1e-6),
968            Some(1e6),
969            Some("qwen3"),
970        )
971        .is_none());
972    }
973
974    // =========================================================================
975    // VRAM Budget Solver Falsification Tests
976    //
977    // Popperian: each test attempts to BREAK a mathematical invariant.
978    // If any test fails, the budget formula disagrees with cuda_block.rs.
979    // =========================================================================
980
981    #[test]
982    fn falsify_vram_monotonic_in_seq_len() {
983        // Prediction: VRAM is strictly monotonically increasing in seq_len
984        let config = TransformerConfig::qwen3_4b();
985        let mut prev = config.total_training_vram_bytes(1);
986        for s in [2, 4, 8, 16, 32, 64, 128, 256, 512] {
987            let cur = config.total_training_vram_bytes(s);
988            assert!(
989                cur > prev,
990                "VRAM must increase: seq_len={s} ({cur}) should exceed prev ({prev})"
991            );
992            prev = cur;
993        }
994    }
995
996    #[test]
997    fn falsify_vram_solver_postcondition() {
998        // Prediction: solver result satisfies total_vram <= budget
999        let config = TransformerConfig::qwen3_4b();
1000        let budget = 24 * 1024 * 1024 * 1024_usize; // 24 GB (RTX 4090)
1001        if let Some(max_s) = config.max_seq_len_for_vram(budget) {
1002            let used = config.total_training_vram_bytes(max_s);
1003            assert!(
1004                used <= budget,
1005                "Solver returned seq_len={max_s} using {used} bytes > budget {budget}"
1006            );
1007            // And seq_len+1 should exceed budget (tightness)
1008            if max_s < config.max_position_embeddings {
1009                let over = config.total_training_vram_bytes(max_s + 1);
1010                assert!(
1011                    over > budget,
1012                    "Solver not tight: seq_len={} uses {over} <= budget {budget}",
1013                    max_s + 1
1014                );
1015            }
1016        }
1017    }
1018
1019    #[test]
1020    fn falsify_vram_solver_returns_none_when_impossible() {
1021        // Prediction: if even seq_len=1 exceeds budget, solver returns None
1022        let config = TransformerConfig::qwen3_4b();
1023        let tiny_budget = 1024; // 1 KB — impossible for any model
1024        assert!(
1025            config.max_seq_len_for_vram(tiny_budget).is_none(),
1026            "Solver should return None when budget is too small"
1027        );
1028    }
1029
1030    #[test]
1031    fn falsify_qwen3_4b_vram_matches_oom_observation() {
1032        // Observation: Qwen3-4B OOM'd on 24 GB 4090 at seq_len=512.
1033        // The formula MUST agree: seq_len=512 should exceed ~23 GB usable VRAM.
1034        let config = TransformerConfig::qwen3_4b();
1035        let vram_512 = config.total_training_vram_bytes(512);
1036        let usable_vram = 23 * 1024 * 1024 * 1024_usize; // ~23 GB after CUDA runtime
1037
1038        // Diagnostic: print the budget breakdown
1039        let vram_1 = config.total_training_vram_bytes(1);
1040        let shared_128 = config.total_training_vram_bytes_shared(128);
1041        let shared_512 = config.total_training_vram_bytes_shared(512);
1042        let solved = config.max_seq_len_for_vram_shared(24 * 1024 * 1024 * 1024);
1043        eprintln!("=== Qwen3-4B VRAM Budget ===");
1044        eprintln!(
1045            "  Per-layer weights:    {:.1} MB",
1046            config.per_layer_weight_elements() as f64 * 4.0 / 1e6
1047        );
1048        eprintln!(
1049            "  Per-layer grad scratch: {:.1} MB",
1050            config.per_layer_grad_weight_elements() as f64 * 4.0 / 1e6
1051        );
1052        eprintln!("  Per-layer (S=512): {:.1} MB", (vram_512 / 36) as f64 / 1e6);
1053        eprintln!("  36 layers S=1 (per-layer scratch): {:.1} GB", vram_1 as f64 / 1e9);
1054        eprintln!("  36 layers S=512 (per-layer scratch): {:.1} GB", vram_512 as f64 / 1e9);
1055        eprintln!("  36 layers S=128 (SHARED scratch):    {:.1} GB", shared_128 as f64 / 1e9);
1056        eprintln!("  36 layers S=512 (SHARED scratch):    {:.1} GB", shared_512 as f64 / 1e9);
1057        eprintln!("  Max seq_len for 24 GB (shared):      {solved:?}");
1058
1059        assert!(
1060            vram_512 > usable_vram,
1061            "Formula says {:.1} GB for seq_len=512, but we OOM'd on 23 GB — formula is wrong",
1062            vram_512 as f64 / 1e9
1063        );
1064    }
1065
1066    #[test]
1067    fn falsify_qwen2_0_5b_fits_on_4090() {
1068        // Observation: Qwen2-0.5B trained successfully on 4090 at seq_len=512.
1069        // The formula MUST agree: it should fit in 24 GB.
1070        let config = TransformerConfig::qwen2_0_5b();
1071        let vram_512 = config.total_training_vram_bytes(512);
1072        let total_vram = 24 * 1024 * 1024 * 1024_usize;
1073        assert!(
1074            vram_512 < total_vram,
1075            "Formula says {:.1} GB for Qwen2-0.5B at seq_len=512, but it fit on 4090",
1076            vram_512 as f64 / 1e9
1077        );
1078    }
1079
1080    #[test]
1081    fn falsify_vram_budget_concrete_values() {
1082        // Verify concrete VRAM numbers for Qwen3-4B to catch formula drift.
1083        let config = TransformerConfig::qwen3_4b();
1084
1085        // Per-layer weights: q(4096*2560) + k(1024*2560) + v(1024*2560)
1086        //   + o(2560*4096) + gate(9728*2560) + up(9728*2560) + down(2560*9728)
1087        //   + norms(2560*2)
1088        let expected_weights =
1089            4096 * 2560 + 1024 * 2560 * 2 + 2560 * 4096 + 9728 * 2560 * 3 + 2560 * 2;
1090        assert_eq!(config.per_layer_weight_elements(), expected_weights);
1091
1092        // With PER-LAYER gradient scratch (current cuda_block.rs layout),
1093        // Qwen3-4B's constant overhead alone exceeds 24 GB:
1094        // 36 layers × 776 MB = 27.9 GB. Solver correctly returns None.
1095        let budget_24gb = 24 * 1024 * 1024 * 1024_usize;
1096        assert!(
1097            config.max_seq_len_for_vram(budget_24gb).is_none(),
1098            "Qwen3-4B per-layer scratch CANNOT fit 24 GB — proves shared scratch needed"
1099        );
1100
1101        // With SHARED scratch (weight-only per-layer), budget check uses
1102        // total_training_vram_bytes_shared(). Qwen3-4B weights-only = 14.5 GB,
1103        // leaves ~9 GB for one shared scratch set + seq_len-dependent buffers.
1104        let shared_budget = config.total_training_vram_bytes_shared(128);
1105        assert!(
1106            shared_budget < budget_24gb,
1107            "Qwen3-4B shared scratch at seq_len=128 should fit 24 GB, got {:.1} GB",
1108            shared_budget as f64 / 1e9
1109        );
1110    }
1111
1112    // ── Additional coverage tests ─────────────────────────────────
1113
1114    #[test]
1115    fn test_model_architecture_default() {
1116        let arch: ModelArchitecture = Default::default();
1117        assert_eq!(arch, ModelArchitecture::Decoder);
1118    }
1119
1120    #[test]
1121    fn test_model_architecture_serialization() {
1122        let encoder = ModelArchitecture::Encoder;
1123        let json = serde_json::to_string(&encoder).expect("serialize");
1124        assert_eq!(json, "\"encoder\"");
1125        let decoder = ModelArchitecture::Decoder;
1126        let json = serde_json::to_string(&decoder).expect("serialize");
1127        assert_eq!(json, "\"decoder\"");
1128
1129        let restored: ModelArchitecture = serde_json::from_str("\"encoder\"").expect("deserialize");
1130        assert_eq!(restored, ModelArchitecture::Encoder);
1131    }
1132
1133    #[test]
1134    fn test_codebert_config() {
1135        let config = TransformerConfig::codebert();
1136        assert_eq!(config.hidden_size, 768);
1137        assert_eq!(config.num_attention_heads, 12);
1138        assert_eq!(config.num_kv_heads, 12);
1139        assert_eq!(config.intermediate_size, 3072);
1140        assert_eq!(config.num_hidden_layers, 12);
1141        assert_eq!(config.vocab_size, 50265);
1142        assert_eq!(config.max_position_embeddings, 514);
1143        assert!(config.use_bias);
1144        assert_eq!(config.architecture, ModelArchitecture::Encoder);
1145        assert!(config.is_encoder());
1146        assert_eq!(config.head_dim(), 64); // 768 / 12
1147    }
1148
1149    #[test]
1150    fn test_is_encoder() {
1151        assert!(TransformerConfig::codebert().is_encoder());
1152        assert!(!TransformerConfig::llama2_7b().is_encoder());
1153        assert!(!TransformerConfig::tiny().is_encoder());
1154        assert!(!TransformerConfig::qwen2_0_5b().is_encoder());
1155    }
1156
1157    #[test]
1158    fn test_hf_architecture_name_inferred() {
1159        // Encoder
1160        assert_eq!(TransformerConfig::codebert().hf_architecture_name(), "BertModel");
1161        // Qwen2 (bias + large vocab)
1162        assert_eq!(TransformerConfig::qwen2_0_5b().hf_architecture_name(), "Qwen2ForCausalLM");
1163        // LLaMA (no bias)
1164        assert_eq!(TransformerConfig::llama2_7b().hf_architecture_name(), "LlamaForCausalLM");
1165    }
1166
1167    #[test]
1168    fn test_hf_architecture_name_override() {
1169        let mut config = TransformerConfig::tiny();
1170        config.hf_architecture = Some("CustomModel".to_string());
1171        assert_eq!(config.hf_architecture_name(), "CustomModel");
1172    }
1173
1174    #[test]
1175    fn test_hf_model_type_str_inferred() {
1176        assert_eq!(TransformerConfig::codebert().hf_model_type_str(), "roberta");
1177        assert_eq!(TransformerConfig::qwen2_0_5b().hf_model_type_str(), "qwen2");
1178        assert_eq!(TransformerConfig::llama2_7b().hf_model_type_str(), "llama");
1179    }
1180
1181    #[test]
1182    fn test_hf_model_type_str_override() {
1183        let mut config = TransformerConfig::tiny();
1184        config.hf_model_type = Some("custom_type".to_string());
1185        assert_eq!(config.hf_model_type_str(), "custom_type");
1186    }
1187
1188    #[test]
1189    fn test_ties_embeddings() {
1190        // Qwen2 ties embeddings (bias + large vocab)
1191        assert!(TransformerConfig::qwen2_0_5b().ties_embeddings());
1192        // LLaMA does not
1193        assert!(!TransformerConfig::llama2_7b().ties_embeddings());
1194        // Explicit flag override
1195        let mut config = TransformerConfig::llama2_7b();
1196        config.tie_word_embeddings = true;
1197        assert!(config.ties_embeddings());
1198    }
1199
1200    #[test]
1201    fn test_head_dim_override() {
1202        let config = TransformerConfig::qwen3_4b();
1203        assert_eq!(config.head_dim_override, Some(128));
1204        assert_eq!(config.head_dim(), 128);
1205        // Without override: 2560 / 32 = 80 (but override gives 128)
1206        assert_ne!(config.hidden_size / config.num_attention_heads, 128);
1207    }
1208
1209    #[test]
1210    fn test_head_dim_no_override() {
1211        let config = TransformerConfig::llama2_7b();
1212        assert!(config.head_dim_override.is_none());
1213        assert_eq!(config.head_dim(), 128); // 4096 / 32
1214    }
1215
1216    #[test]
1217    fn test_q_dim() {
1218        let config = TransformerConfig::qwen3_4b();
1219        // 32 heads * 128 head_dim = 4096
1220        assert_eq!(config.q_dim(), 4096);
1221
1222        let config = TransformerConfig::llama2_7b();
1223        // 32 heads * 128 = 4096 = hidden_size
1224        assert_eq!(config.q_dim(), 4096);
1225    }
1226
1227    #[test]
1228    fn test_q_dim_differs_from_hidden() {
1229        let config = TransformerConfig::qwen3_4b();
1230        // Qwen3-4B: q_dim = 4096 but hidden_size = 2560
1231        assert_ne!(config.q_dim(), config.hidden_size);
1232    }
1233
1234    /// GH-262: Verify Qwen3-4B projection weight shapes match HuggingFace config.json.
1235    ///
1236    /// config.json fields: hidden_size=2560, num_attention_heads=32,
1237    /// num_key_value_heads=8, head_dim=128.
1238    ///
1239    /// Expected shapes (element counts):
1240    ///   q_proj: [4096, 2560] = 10,485,760
1241    ///   k_proj: [1024, 2560] = 2,621,440
1242    ///   v_proj: [1024, 2560] = 2,621,440
1243    ///   o_proj: [2560, 4096] = 10,485,760
1244    #[test]
1245    fn test_qwen3_4b_projection_shapes() {
1246        let config = TransformerConfig::qwen3_4b();
1247
1248        // Verify base dimensions
1249        assert_eq!(config.hidden_size, 2560);
1250        assert_eq!(config.num_attention_heads, 32);
1251        assert_eq!(config.num_kv_heads, 8);
1252        assert_eq!(config.head_dim(), 128);
1253        assert_eq!(config.head_dim_override, Some(128));
1254
1255        // Derived projection dimensions
1256        let q_dim = config.q_dim();
1257        let kv_dim = config.kv_dim();
1258        assert_eq!(q_dim, 4096); // 32 * 128
1259        assert_eq!(kv_dim, 1024); // 8 * 128
1260
1261        // Weight element counts (what validate_weight_shapes checks)
1262        let hidden = config.hidden_size;
1263        assert_eq!(q_dim * hidden, 10_485_760); // q_proj [4096, 2560]
1264        assert_eq!(kv_dim * hidden, 2_621_440); // k_proj [1024, 2560]
1265        assert_eq!(kv_dim * hidden, 2_621_440); // v_proj [1024, 2560]
1266        assert_eq!(hidden * q_dim, 10_485_760); // o_proj [2560, 4096]
1267    }
1268
1269    /// GH-262: Verify VRAM estimation uses q_dim (not hidden_size) for Q/O grads.
1270    #[test]
1271    fn test_qwen3_4b_grad_weight_elements_uses_q_dim() {
1272        let config = TransformerConfig::qwen3_4b();
1273        let h = config.hidden_size; // 2560
1274        let q = config.q_dim(); // 4096
1275        let kv = config.kv_dim(); // 1024
1276        let i = config.intermediate_size; // 9728
1277
1278        // per_layer_grad_weight_elements must use q*h for Q/O, not h*h
1279        let expected = h * 2          // norms
1280            + h * i * 3              // gate, up, down
1281            + q * h                  // grad_w_q
1282            + h * q                  // grad_w_o
1283            + h * kv * 2; // grad_w_k, grad_w_v
1284        assert_eq!(config.per_layer_grad_weight_elements(), expected);
1285
1286        // Sanity: h*h would be smaller (2560^2 = 6.5M) vs q*h (4096*2560 = 10.5M)
1287        assert!(q * h > h * h, "q_dim*hidden > hidden*hidden for Qwen3-4B");
1288    }
1289
1290    #[test]
1291    fn test_from_size_str_known_sizes() {
1292        assert!(TransformerConfig::from_size_str("codebert").is_ok());
1293        assert!(TransformerConfig::from_size_str("codebert-base").is_ok());
1294        assert!(TransformerConfig::from_size_str("125M").is_ok());
1295        assert!(TransformerConfig::from_size_str("0.5B").is_ok());
1296        assert!(TransformerConfig::from_size_str("500M").is_ok());
1297        assert!(TransformerConfig::from_size_str("qwen2-0.5b").is_ok());
1298        assert!(TransformerConfig::from_size_str("7B").is_ok());
1299        assert!(TransformerConfig::from_size_str("qwen2.5-7b").is_ok());
1300        assert!(TransformerConfig::from_size_str("4B").is_ok());
1301        assert!(TransformerConfig::from_size_str("qwen3-4b").is_ok());
1302        assert!(TransformerConfig::from_size_str("qwen3").is_ok());
1303        assert!(TransformerConfig::from_size_str("9B").is_ok());
1304        assert!(TransformerConfig::from_size_str("qwen3.5-9b").is_ok());
1305        assert!(TransformerConfig::from_size_str("qwen3_5").is_ok());
1306        assert!(TransformerConfig::from_size_str("qwen3.5").is_ok());
1307    }
1308
1309    #[test]
1310    fn test_from_size_str_unknown() {
1311        let err = TransformerConfig::from_size_str("99B").unwrap_err();
1312        assert!(err.contains("Unknown model size"));
1313        assert!(err.contains("99B"));
1314    }
1315
1316    #[test]
1317    fn test_from_size_str_configs_correct() {
1318        let codebert = TransformerConfig::from_size_str("codebert").unwrap();
1319        assert_eq!(codebert.hidden_size, 768);
1320        assert!(codebert.is_encoder());
1321
1322        let qwen2 = TransformerConfig::from_size_str("0.5B").unwrap();
1323        assert_eq!(qwen2.hidden_size, 896);
1324        assert!(qwen2.use_bias);
1325
1326        let qwen3 = TransformerConfig::from_size_str("4B").unwrap();
1327        assert_eq!(qwen3.hidden_size, 2560);
1328        assert!(!qwen3.use_bias);
1329    }
1330
1331    #[test]
1332    fn test_from_apr_metadata_missing_num_heads() {
1333        assert!(TransformerConfig::from_apr_metadata(
1334            Some(4096),
1335            None, // missing heads
1336            Some(8),
1337            Some(12288),
1338            Some(36),
1339            Some(151936),
1340            None,
1341            None,
1342            None,
1343            None,
1344        )
1345        .is_none());
1346    }
1347
1348    #[test]
1349    fn test_from_apr_metadata_missing_vocab_size() {
1350        assert!(TransformerConfig::from_apr_metadata(
1351            Some(4096),
1352            Some(32),
1353            Some(8),
1354            Some(12288),
1355            Some(36),
1356            None, // missing vocab
1357            None,
1358            None,
1359            None,
1360            None,
1361        )
1362        .is_none());
1363    }
1364
1365    #[test]
1366    fn test_from_apr_metadata_missing_intermediate_size() {
1367        assert!(TransformerConfig::from_apr_metadata(
1368            Some(4096),
1369            Some(32),
1370            Some(8),
1371            None, // missing intermediate
1372            Some(36),
1373            Some(151936),
1374            None,
1375            None,
1376            None,
1377            None,
1378        )
1379        .is_none());
1380    }
1381
1382    #[test]
1383    fn test_from_apr_metadata_defaults() {
1384        let config = TransformerConfig::from_apr_metadata(
1385            Some(512),
1386            Some(8),
1387            None, // defaults to num_heads
1388            Some(2048),
1389            Some(6),
1390            Some(32000),
1391            None, // defaults to 32768
1392            None, // defaults to 1e-6
1393            None, // defaults to 10000.0
1394            None, // defaults to Decoder
1395        )
1396        .unwrap();
1397
1398        assert_eq!(config.num_kv_heads, 8); // defaults to num_heads
1399        assert_eq!(config.max_position_embeddings, 32768);
1400        assert!((config.rms_norm_eps - 1e-6).abs() < 1e-10);
1401        assert!((config.rope_theta - 10000.0).abs() < 0.1);
1402        assert_eq!(config.architecture, ModelArchitecture::Decoder);
1403        assert!(!config.use_bias);
1404    }
1405
1406    #[test]
1407    fn test_from_apr_metadata_encoder_architecture() {
1408        let config = TransformerConfig::from_apr_metadata(
1409            Some(768),
1410            Some(12),
1411            Some(12),
1412            Some(3072),
1413            Some(12),
1414            Some(50265),
1415            Some(514),
1416            Some(1e-5),
1417            Some(0.0),
1418            Some("codebert"),
1419        )
1420        .unwrap();
1421        assert_eq!(config.architecture, ModelArchitecture::Encoder);
1422    }
1423
1424    #[test]
1425    fn test_from_apr_metadata_roberta_architecture() {
1426        let config = TransformerConfig::from_apr_metadata(
1427            Some(768),
1428            Some(12),
1429            Some(12),
1430            Some(3072),
1431            Some(12),
1432            Some(50265),
1433            None,
1434            None,
1435            None,
1436            Some("roberta"),
1437        )
1438        .unwrap();
1439        assert_eq!(config.architecture, ModelArchitecture::Encoder);
1440    }
1441
1442    #[test]
1443    fn test_from_apr_metadata_qwen3_head_dim_override() {
1444        // Qwen3-4B: hidden=2560, 32 heads → 2560/32=80 != 128, so override needed
1445        let config = TransformerConfig::from_apr_metadata(
1446            Some(2560),
1447            Some(32),
1448            Some(8),
1449            Some(9728),
1450            Some(36),
1451            Some(151936),
1452            Some(40960),
1453            Some(1e-6),
1454            Some(1e6),
1455            Some("qwen3-4b"),
1456        )
1457        .unwrap();
1458        assert_eq!(config.head_dim_override, Some(128));
1459        assert_eq!(config.head_dim(), 128);
1460        assert!(!config.use_bias);
1461    }
1462
1463    #[test]
1464    fn test_from_apr_metadata_qwen3_no_override_needed() {
1465        // If hidden/heads = 128, no override needed
1466        let config = TransformerConfig::from_apr_metadata(
1467            Some(4096),
1468            Some(32),
1469            Some(8),
1470            Some(12288),
1471            Some(36),
1472            Some(151936),
1473            None,
1474            None,
1475            None,
1476            Some("qwen3-8b"),
1477        )
1478        .unwrap();
1479        assert!(config.head_dim_override.is_none());
1480        assert_eq!(config.head_dim(), 128);
1481    }
1482
1483    #[test]
1484    fn test_qwen2_7b_config() {
1485        let config = TransformerConfig::qwen2_7b();
1486        assert_eq!(config.hidden_size, 3584);
1487        assert_eq!(config.num_attention_heads, 28);
1488        assert_eq!(config.num_kv_heads, 4);
1489        assert_eq!(config.intermediate_size, 18944);
1490        assert_eq!(config.num_hidden_layers, 28);
1491        assert_eq!(config.vocab_size, 152064);
1492        assert!(config.use_bias);
1493        assert_eq!(config.head_dim(), 128); // 3584 / 28
1494    }
1495
1496    #[test]
1497    fn test_qwen3_4b_config() {
1498        let config = TransformerConfig::qwen3_4b();
1499        assert_eq!(config.hidden_size, 2560);
1500        assert_eq!(config.num_attention_heads, 32);
1501        assert_eq!(config.num_kv_heads, 8);
1502        assert_eq!(config.intermediate_size, 9728);
1503        assert_eq!(config.num_hidden_layers, 36);
1504        assert!(!config.use_bias);
1505        assert_eq!(config.head_dim(), 128);
1506    }
1507
1508    #[test]
1509    fn test_per_layer_weight_elements_positive() {
1510        for config in [
1511            TransformerConfig::tiny(),
1512            TransformerConfig::codebert(),
1513            TransformerConfig::qwen2_0_5b(),
1514            TransformerConfig::qwen3_4b(),
1515        ] {
1516            assert!(config.per_layer_weight_elements() > 0);
1517        }
1518    }
1519
1520    #[test]
1521    fn test_vram_shared_less_than_per_layer() {
1522        let config = TransformerConfig::qwen2_0_5b();
1523        let per_layer = config.total_training_vram_bytes(128);
1524        let shared = config.total_training_vram_bytes_shared(128);
1525        // Shared should be less for multi-layer models
1526        assert!(
1527            shared < per_layer,
1528            "Shared ({shared}) should be less than per-layer ({per_layer})"
1529        );
1530    }
1531
1532    #[test]
1533    fn test_vram_shared_monotonic() {
1534        let config = TransformerConfig::qwen2_0_5b();
1535        let mut prev = config.total_training_vram_bytes_shared(1);
1536        for s in [2, 4, 8, 16, 32, 64, 128] {
1537            let cur = config.total_training_vram_bytes_shared(s);
1538            assert!(cur > prev, "Shared VRAM must increase: seq_len={s}");
1539            prev = cur;
1540        }
1541    }
1542
1543    #[test]
1544    fn test_max_seq_len_for_vram_shared() {
1545        let config = TransformerConfig::qwen2_0_5b();
1546        let budget = 8 * 1024 * 1024 * 1024_usize; // 8 GB
1547        let max_s = config.max_seq_len_for_vram_shared(budget);
1548        assert!(max_s.is_some());
1549        let s = max_s.unwrap();
1550        assert!(config.total_training_vram_bytes_shared(s) <= budget);
1551    }
1552
1553    #[test]
1554    fn test_max_seq_len_for_vram_shared_impossible() {
1555        let config = TransformerConfig::qwen3_4b();
1556        let tiny_budget = 1024; // 1 KB
1557        assert!(config.max_seq_len_for_vram_shared(tiny_budget).is_none());
1558    }
1559
1560    #[test]
1561    fn test_max_seq_len_for_vram_shared_tightness() {
1562        let config = TransformerConfig::tiny();
1563        let budget = 10 * 1024 * 1024_usize; // 10 MB
1564        if let Some(s) = config.max_seq_len_for_vram_shared(budget) {
1565            assert!(config.total_training_vram_bytes_shared(s) <= budget);
1566            if s < config.max_position_embeddings {
1567                assert!(config.total_training_vram_bytes_shared(s + 1) > budget);
1568            }
1569        }
1570    }
1571
1572    #[test]
1573    fn test_kv_dim() {
1574        assert_eq!(TransformerConfig::qwen3_4b().kv_dim(), 1024);
1575        assert_eq!(TransformerConfig::llama2_7b().kv_dim(), 4096);
1576    }
1577
1578    #[test]
1579    fn test_per_layer_scratch_coefficients() {
1580        let config = TransformerConfig::tiny();
1581        assert!(config.per_layer_scratch_linear_coeff() > 0);
1582        let (n_quad, n_hd_linear) = config.per_layer_scratch_quadratic_coeff();
1583        assert!(n_quad > 0 && n_hd_linear > 0);
1584        assert!(config.per_layer_grad_weight_elements() > 0);
1585    }
1586}