Skip to main content

oxicuda_vision/vit/
vit_block.rs

1//! ViT block: pre-norm multi-head self-attention + MLP with GELU.
2//!
3//! Layout for all weight tensors is row-major C-contiguous `Vec<f32>`.
4//!
5//! ## Forward pass (pre-norm variant)
6//! ```text
7//! h  = LN1(x)
8//! h  = MHSA(h) + x           (residual 1)
9//! h2 = LN2(h)
10//! out = MLP(h2) + h           (residual 2)
11//! ```
12
13use crate::{
14    error::{VisionError, VisionResult},
15    handle::LcgRng,
16};
17
18// ─── Config ──────────────────────────────────────────────────────────────────
19
20/// Configuration for a single ViT transformer block.
21#[derive(Debug, Clone, PartialEq)]
22pub struct ViTBlockConfig {
23    /// Token embedding dimension.
24    pub embed_dim: usize,
25    /// Number of attention heads. Must divide `embed_dim`.
26    pub n_heads: usize,
27    /// MLP hidden-dim multiplier: `mlp_dim = mlp_ratio * embed_dim`.
28    pub mlp_ratio: usize,
29}
30
31impl ViTBlockConfig {
32    /// Create and validate a `ViTBlockConfig`.
33    ///
34    /// # Errors
35    /// - `embed_dim == 0` → `InvalidEmbedDim`
36    /// - `n_heads == 0`   → `InvalidNumHeads`
37    /// - `embed_dim % n_heads != 0` → `HeadDimMismatch`
38    pub fn new(embed_dim: usize, n_heads: usize, mlp_ratio: usize) -> VisionResult<Self> {
39        if embed_dim == 0 {
40            return Err(VisionError::InvalidEmbedDim(embed_dim));
41        }
42        if n_heads == 0 {
43            return Err(VisionError::InvalidNumHeads(n_heads));
44        }
45        if embed_dim % n_heads != 0 {
46            return Err(VisionError::HeadDimMismatch { n_heads, embed_dim });
47        }
48        Ok(Self {
49            embed_dim,
50            n_heads,
51            mlp_ratio,
52        })
53    }
54
55    /// Dimension per attention head.
56    #[must_use]
57    #[inline]
58    pub fn head_dim(&self) -> usize {
59        self.embed_dim / self.n_heads
60    }
61
62    /// MLP hidden dimension.
63    #[must_use]
64    #[inline]
65    pub fn mlp_dim(&self) -> usize {
66        self.mlp_ratio * self.embed_dim
67    }
68}
69
70// ─── Weights ─────────────────────────────────────────────────────────────────
71
72/// Learnable weights for one ViT transformer block.
73///
74/// All tensors are flat row-major `Vec<f32>`.
75pub struct ViTBlockWeights {
76    /// QKV projection kernel: `[3 * embed_dim, embed_dim]` stored flat.
77    /// Equivalent to concatenating W_Q, W_K, W_V along the output-dim axis.
78    pub qkv_weight: Vec<f32>,
79    /// QKV projection bias: `[3 * embed_dim]`.
80    pub qkv_bias: Vec<f32>,
81
82    /// Output projection kernel: `[embed_dim, embed_dim]`.
83    pub out_weight: Vec<f32>,
84    /// Output projection bias: `[embed_dim]`.
85    pub out_bias: Vec<f32>,
86
87    /// MLP first linear kernel: `[mlp_dim, embed_dim]`.
88    pub mlp1_weight: Vec<f32>,
89    /// MLP first linear bias: `[mlp_dim]`.
90    pub mlp1_bias: Vec<f32>,
91
92    /// MLP second linear kernel: `[embed_dim, mlp_dim]`.
93    pub mlp2_weight: Vec<f32>,
94    /// MLP second linear bias: `[embed_dim]`.
95    pub mlp2_bias: Vec<f32>,
96
97    /// LayerNorm 1 scale: `[embed_dim]` (init 1).
98    pub ln1_weight: Vec<f32>,
99    /// LayerNorm 1 bias: `[embed_dim]` (init 0).
100    pub ln1_bias: Vec<f32>,
101
102    /// LayerNorm 2 scale: `[embed_dim]` (init 1).
103    pub ln2_weight: Vec<f32>,
104    /// LayerNorm 2 bias: `[embed_dim]` (init 0).
105    pub ln2_bias: Vec<f32>,
106}
107
108impl ViTBlockWeights {
109    /// Xavier-style default initialisation.
110    ///
111    /// - Attention & MLP weights: N(0, 1/√embed_dim)
112    /// - Biases: zeros
113    /// - LayerNorm weights: ones; biases: zeros
114    pub fn default_init(cfg: &ViTBlockConfig, rng: &mut LcgRng) -> Self {
115        let e = cfg.embed_dim;
116        let mlp = cfg.mlp_dim();
117        let scale = 1.0 / (e as f32).sqrt();
118
119        let fill_scaled = |rng: &mut LcgRng, n: usize, sc: f32| -> Vec<f32> {
120            let mut v = vec![0.0f32; n];
121            rng.fill_normal(&mut v);
122            for x in &mut v {
123                *x *= sc;
124            }
125            v
126        };
127
128        let qkv_weight = fill_scaled(rng, 3 * e * e, scale);
129        let qkv_bias = vec![0.0f32; 3 * e];
130        let out_weight = fill_scaled(rng, e * e, scale);
131        let out_bias = vec![0.0f32; e];
132        let mlp1_weight = fill_scaled(rng, mlp * e, scale);
133        let mlp1_bias = vec![0.0f32; mlp];
134        let mlp2_weight = fill_scaled(rng, e * mlp, scale);
135        let mlp2_bias = vec![0.0f32; e];
136        let ln1_weight = vec![1.0f32; e];
137        let ln1_bias = vec![0.0f32; e];
138        let ln2_weight = vec![1.0f32; e];
139        let ln2_bias = vec![0.0f32; e];
140
141        Self {
142            qkv_weight,
143            qkv_bias,
144            out_weight,
145            out_bias,
146            mlp1_weight,
147            mlp1_bias,
148            mlp2_weight,
149            mlp2_bias,
150            ln1_weight,
151            ln1_bias,
152            ln2_weight,
153            ln2_bias,
154        }
155    }
156}
157
158// ─── ViTBlock ─────────────────────────────────────────────────────────────────
159
160/// A single pre-norm ViT transformer block.
161pub struct ViTBlock {
162    pub config: ViTBlockConfig,
163    pub weights: ViTBlockWeights,
164}
165
166impl ViTBlock {
167    /// Construct a new block with Xavier-initialised weights.
168    pub fn new(cfg: ViTBlockConfig, rng: &mut LcgRng) -> Self {
169        let weights = ViTBlockWeights::default_init(&cfg, rng);
170        Self {
171            config: cfg,
172            weights,
173        }
174    }
175
176    /// Forward pass.
177    ///
178    /// `tokens` is flat `[n_tokens, embed_dim]`.
179    /// Returns `[n_tokens * embed_dim]`.
180    ///
181    /// Pipeline:
182    /// ```text
183    /// h   = LN1(tokens)
184    /// h   = MHSA(h)  + tokens   (residual)
185    /// h2  = LN2(h)
186    /// out = MLP(h2)  + h        (residual)
187    /// ```
188    pub fn forward(&self, tokens: &[f32], n_tokens: usize) -> VisionResult<Vec<f32>> {
189        let e = self.config.embed_dim;
190        if tokens.len() != n_tokens * e {
191            return Err(VisionError::DimensionMismatch {
192                expected: n_tokens * e,
193                got: tokens.len(),
194            });
195        }
196        if n_tokens == 0 {
197            return Err(VisionError::EmptyInput("tokens"));
198        }
199
200        let w = &self.weights;
201        let cfg = &self.config;
202
203        // Pre-norm 1
204        let h = layer_norm(tokens, &w.ln1_weight, &w.ln1_bias, n_tokens, e, 1e-5);
205
206        // Multi-head self-attention
207        let attn_out = mhsa(
208            &h,
209            n_tokens,
210            e,
211            cfg.n_heads,
212            cfg.head_dim(),
213            &w.qkv_weight,
214            &w.qkv_bias,
215            &w.out_weight,
216            &w.out_bias,
217        )?;
218
219        // Residual 1: attn_out + tokens
220        let mut h: Vec<f32> = attn_out
221            .iter()
222            .zip(tokens.iter())
223            .map(|(a, b)| a + b)
224            .collect();
225
226        // Pre-norm 2
227        let h2 = layer_norm(&h, &w.ln2_weight, &w.ln2_bias, n_tokens, e, 1e-5);
228
229        // MLP: Linear → GELU → Linear
230        let mlp_dim = cfg.mlp_dim();
231        let mid = linear(&h2, &w.mlp1_weight, &w.mlp1_bias, e, mlp_dim);
232        let mid: Vec<f32> = mid.into_iter().map(gelu_exact).collect();
233        let mlp_out = linear(&mid, &w.mlp2_weight, &w.mlp2_bias, mlp_dim, e);
234
235        // Residual 2: mlp_out + h
236        for (o, m) in h.iter_mut().zip(mlp_out.iter()) {
237            *o += m;
238        }
239
240        Ok(h)
241    }
242}
243
244// ─── Internal helpers ────────────────────────────────────────────────────────
245
246/// Per-row layer normalisation.
247///
248/// For each of `n` rows of length `d`:
249/// ```text
250/// mean = Σx / d
251/// var  = Σ(x - mean)² / d
252/// out  = (x - mean) / sqrt(var + eps) * weight + bias
253/// ```
254pub(crate) fn layer_norm(
255    x: &[f32],
256    weight: &[f32],
257    bias: &[f32],
258    n: usize,
259    d: usize,
260    eps: f32,
261) -> Vec<f32> {
262    let mut out = vec![0.0f32; n * d];
263    for i in 0..n {
264        let row = &x[i * d..(i + 1) * d];
265        let mean: f32 = row.iter().sum::<f32>() / d as f32;
266        let var: f32 = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / d as f32;
267        let inv_std = 1.0 / (var + eps).sqrt();
268        let o = &mut out[i * d..(i + 1) * d];
269        for j in 0..d {
270            o[j] = (row[j] - mean) * inv_std * weight[j] + bias[j];
271        }
272    }
273    out
274}
275
276/// Dense linear transform: `y = x W^T + b`.
277///
278/// `x`:  `[batch, n_in]` flat
279/// `w`:  `[n_out, n_in]` flat (each row is one output filter)
280/// `b`:  `[n_out]`
281/// Returns `[batch, n_out]` flat.
282pub(crate) fn linear(x: &[f32], w: &[f32], b: &[f32], n_in: usize, n_out: usize) -> Vec<f32> {
283    let batch = x.len() / n_in;
284    let mut out = vec![0.0f32; batch * n_out];
285    for bi in 0..batch {
286        let xrow = &x[bi * n_in..(bi + 1) * n_in];
287        let orow = &mut out[bi * n_out..(bi + 1) * n_out];
288        for oi in 0..n_out {
289            let wrow = &w[oi * n_in..(oi + 1) * n_in];
290            let mut acc = b[oi];
291            for k in 0..n_in {
292                acc += xrow[k] * wrow[k];
293            }
294            orow[oi] = acc;
295        }
296    }
297    out
298}
299
300/// Exact GELU activation via the tanh approximation.
301///
302/// ```text
303/// GELU(x) ≈ x * 0.5 * (1 + tanh(0.7978845608 * (x + 0.044715 * x³)))
304/// ```
305#[inline]
306pub(crate) fn gelu_exact(x: f32) -> f32 {
307    const SQRT_2_OVER_PI: f32 = 0.797_884_6;
308    const COEFF: f32 = 0.044_715;
309    let inner = SQRT_2_OVER_PI * (x + COEFF * x * x * x);
310    x * 0.5 * (1.0 + inner.tanh())
311}
312
313/// Multi-head self-attention.
314///
315/// # Steps
316/// 1. Project `tokens` → Q, K, V via a single fused `[3*embed, embed]` matrix.
317/// 2. Split into three `[n_tokens, embed]` tensors.
318/// 3. Reshape each to `[n_heads, n_tokens, head_dim]`.
319/// 4. Scaled dot-product per head: `S = Q @ K^T / sqrt(head_dim)`.
320/// 5. Row-wise softmax of `S`.
321/// 6. Weighted sum: `A = S @ V`.
322/// 7. Reshape `A` → `[n_tokens, embed]`, apply output projection.
323#[allow(clippy::too_many_arguments)]
324pub(crate) fn mhsa(
325    tokens: &[f32],
326    n_tokens: usize,
327    embed_dim: usize,
328    n_heads: usize,
329    head_dim: usize,
330    qkv_weight: &[f32],
331    qkv_bias: &[f32],
332    out_weight: &[f32],
333    out_bias: &[f32],
334) -> VisionResult<Vec<f32>> {
335    // Step 1: fused QKV projection → [n_tokens, 3*embed]
336    let qkv = linear(tokens, qkv_weight, qkv_bias, embed_dim, 3 * embed_dim);
337
338    // Step 2: split into Q, K, V each [n_tokens, embed_dim]
339    let mut q = vec![0.0f32; n_tokens * embed_dim];
340    let mut k = vec![0.0f32; n_tokens * embed_dim];
341    let mut v = vec![0.0f32; n_tokens * embed_dim];
342    for t in 0..n_tokens {
343        let src = &qkv[t * 3 * embed_dim..(t + 1) * 3 * embed_dim];
344        let qd = &mut q[t * embed_dim..(t + 1) * embed_dim];
345        let kd = &mut k[t * embed_dim..(t + 1) * embed_dim];
346        let vd = &mut v[t * embed_dim..(t + 1) * embed_dim];
347        qd.copy_from_slice(&src[..embed_dim]);
348        kd.copy_from_slice(&src[embed_dim..2 * embed_dim]);
349        vd.copy_from_slice(&src[2 * embed_dim..]);
350    }
351
352    // Steps 3-6: per-head scaled dot-product attention
353    // We compute head-by-head to avoid one giant allocation.
354    let scale = 1.0 / (head_dim as f32).sqrt();
355    // Output buffer: [n_tokens, embed_dim]
356    let mut concat = vec![0.0f32; n_tokens * embed_dim];
357
358    // attn_scores: [n_tokens, n_tokens] scratch
359    let mut scores = vec![0.0f32; n_tokens * n_tokens];
360
361    for h in 0..n_heads {
362        let hd_off = h * head_dim; // column offset within embed_dim
363
364        // Compute scores[i, j] = scale * Σ_d Q[i, h*hd + d] * K[j, h*hd + d]
365        for i in 0..n_tokens {
366            for j in 0..n_tokens {
367                let mut dot = 0.0f32;
368                for d in 0..head_dim {
369                    dot += q[i * embed_dim + hd_off + d] * k[j * embed_dim + hd_off + d];
370                }
371                scores[i * n_tokens + j] = dot * scale;
372            }
373        }
374
375        // In-place row-wise softmax
376        softmax_rows(&mut scores, n_tokens, n_tokens);
377
378        // Weighted sum: A[i, d] = Σ_j scores[i,j] * V[j, h*hd + d]
379        for i in 0..n_tokens {
380            for d in 0..head_dim {
381                let mut acc = 0.0f32;
382                for j in 0..n_tokens {
383                    acc += scores[i * n_tokens + j] * v[j * embed_dim + hd_off + d];
384                }
385                concat[i * embed_dim + hd_off + d] = acc;
386            }
387        }
388    }
389
390    // Step 7: output projection [embed, embed]
391    let out = linear(&concat, out_weight, out_bias, embed_dim, embed_dim);
392
393    // Validate all outputs are finite (catch NaN from degenerate inputs)
394    if out.iter().any(|v| !v.is_finite()) {
395        return Err(VisionError::NonFinite("mhsa output"));
396    }
397
398    Ok(out)
399}
400
401/// In-place row-wise softmax with numerical stability (max subtraction).
402pub(crate) fn softmax_rows(logits: &mut [f32], n_rows: usize, n_cols: usize) {
403    for i in 0..n_rows {
404        let row = &mut logits[i * n_cols..(i + 1) * n_cols];
405        // Find max
406        let mx = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
407        // Shift and exp
408        let mut sum = 0.0f32;
409        for v in row.iter_mut() {
410            *v = (*v - mx).exp();
411            sum += *v;
412        }
413        // Normalise
414        let inv = if sum > 0.0 { 1.0 / sum } else { 1.0 };
415        for v in row.iter_mut() {
416            *v *= inv;
417        }
418    }
419}
420
421// ─── Tests ───────────────────────────────────────────────────────────────────
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    fn make_cfg() -> ViTBlockConfig {
428        ViTBlockConfig::new(64, 4, 4).expect("valid config")
429    }
430
431    // ── Config validation ─────────────────────────────────────────────────────
432
433    #[test]
434    fn config_valid() {
435        let cfg = make_cfg();
436        assert_eq!(cfg.head_dim(), 16);
437        assert_eq!(cfg.mlp_dim(), 256);
438    }
439
440    #[test]
441    fn config_invalid_embed_zero() {
442        let r = ViTBlockConfig::new(0, 4, 4);
443        assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
444    }
445
446    #[test]
447    fn config_invalid_heads_zero() {
448        let r = ViTBlockConfig::new(64, 0, 4);
449        assert!(matches!(r, Err(VisionError::InvalidNumHeads(0))));
450    }
451
452    #[test]
453    fn config_head_dim_mismatch() {
454        let r = ViTBlockConfig::new(64, 3, 4); // 64 % 3 != 0
455        assert!(matches!(
456            r,
457            Err(VisionError::HeadDimMismatch {
458                n_heads: 3,
459                embed_dim: 64
460            })
461        ));
462    }
463
464    // ── layer_norm ────────────────────────────────────────────────────────────
465
466    #[test]
467    fn layer_norm_zero_input_with_identity_affine() {
468        // LN of all-zeros with weight=1, bias=0 → output is all 0.
469        // mean=0, var=0 → normalized = 0/sqrt(eps) * 1 + 0 = 0.
470        let x = vec![0.0f32; 8];
471        let w = vec![1.0f32; 8];
472        let b = vec![0.0f32; 8];
473        let out = layer_norm(&x, &w, &b, 1, 8, 1e-5);
474        assert!(
475            out.iter().all(|&v| v.abs() < 1e-4),
476            "expected near-zero: {out:?}"
477        );
478    }
479
480    #[test]
481    fn layer_norm_constant_row_normalises_to_zero() {
482        // All same value → mean = value, var = 0 → normalised = 0.
483        let x = vec![5.0f32; 16];
484        let w = vec![1.0f32; 16];
485        let b = vec![0.0f32; 16];
486        let out = layer_norm(&x, &w, &b, 1, 16, 1e-5);
487        assert!(out.iter().all(|&v| v.abs() < 1e-4));
488    }
489
490    #[test]
491    fn layer_norm_output_shape() {
492        let x = vec![1.0f32; 4 * 64];
493        let w = vec![1.0f32; 64];
494        let b = vec![0.0f32; 64];
495        let out = layer_norm(&x, &w, &b, 4, 64, 1e-5);
496        assert_eq!(out.len(), 4 * 64);
497    }
498
499    #[test]
500    fn layer_norm_standard_normal_output() {
501        // After LN, each row should have ~mean 0 and ~var 1.
502        let mut rng = LcgRng::new(77);
503        let mut x = vec![0.0f32; 128];
504        rng.fill_normal(&mut x);
505        let w = vec![1.0f32; 128];
506        let b = vec![0.0f32; 128];
507        let out = layer_norm(&x, &w, &b, 1, 128, 1e-5);
508        let mean: f32 = out.iter().sum::<f32>() / 128.0;
509        let var: f32 = out.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / 128.0;
510        assert!(mean.abs() < 1e-4, "mean too large: {mean}");
511        assert!((var - 1.0).abs() < 1e-3, "var not ~1: {var}");
512    }
513
514    // ── MHSA ──────────────────────────────────────────────────────────────────
515
516    #[test]
517    fn mhsa_output_shape() {
518        let cfg = make_cfg();
519        let e = cfg.embed_dim;
520        let n_tokens = 17;
521        let mut rng = LcgRng::new(1);
522        let w = ViTBlockWeights::default_init(&cfg, &mut rng);
523        let tokens = vec![0.1f32; n_tokens * e];
524        let out = mhsa(
525            &tokens,
526            n_tokens,
527            e,
528            cfg.n_heads,
529            cfg.head_dim(),
530            &w.qkv_weight,
531            &w.qkv_bias,
532            &w.out_weight,
533            &w.out_bias,
534        )
535        .expect("mhsa ok");
536        assert_eq!(out.len(), n_tokens * e);
537    }
538
539    #[test]
540    fn mhsa_output_finite() {
541        let cfg = make_cfg();
542        let e = cfg.embed_dim;
543        let n_tokens = 10;
544        let mut rng = LcgRng::new(2);
545        let w = ViTBlockWeights::default_init(&cfg, &mut rng);
546        let mut tokens = vec![0.0f32; n_tokens * e];
547        rng.fill_normal(&mut tokens);
548        let out = mhsa(
549            &tokens,
550            n_tokens,
551            e,
552            cfg.n_heads,
553            cfg.head_dim(),
554            &w.qkv_weight,
555            &w.qkv_bias,
556            &w.out_weight,
557            &w.out_bias,
558        )
559        .expect("mhsa ok");
560        assert!(
561            out.iter().all(|v| v.is_finite()),
562            "non-finite in mhsa output"
563        );
564    }
565
566    // ── Forward ───────────────────────────────────────────────────────────────
567
568    #[test]
569    fn forward_output_shape() {
570        let cfg = make_cfg();
571        let e = cfg.embed_dim;
572        let n_tokens = 17; // 16 patches + 1 CLS
573        let mut rng = LcgRng::new(3);
574        let block = ViTBlock::new(cfg, &mut rng);
575        let tokens = vec![0.0f32; n_tokens * e];
576        let out = block.forward(&tokens, n_tokens).expect("forward ok");
577        assert_eq!(out.len(), n_tokens * e);
578    }
579
580    #[test]
581    fn forward_output_finite() {
582        let cfg = make_cfg();
583        let e = cfg.embed_dim;
584        let n_tokens = 17;
585        let mut rng = LcgRng::new(4);
586        let block = ViTBlock::new(cfg, &mut rng);
587        let mut tokens = vec![0.0f32; n_tokens * e];
588        rng.fill_normal(&mut tokens);
589        let out = block.forward(&tokens, n_tokens).expect("forward ok");
590        assert!(
591            out.iter().all(|v| v.is_finite()),
592            "non-finite in block output"
593        );
594    }
595
596    #[test]
597    fn forward_dimension_mismatch_errors() {
598        let cfg = make_cfg();
599        let n_tokens = 5;
600        let mut rng = LcgRng::new(5);
601        let block = ViTBlock::new(cfg, &mut rng);
602        // Deliberately wrong length
603        let tokens = vec![0.0f32; n_tokens * 32]; // embed=64 expected
604        let r = block.forward(&tokens, n_tokens);
605        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
606    }
607
608    #[test]
609    fn forward_residual_not_trivially_zero() {
610        // Even with zero tokens, the block output should be non-zero due to biases
611        // and LN scaling (though with all-zero tokens, LN may collapse to 0 — biases are 0).
612        // Use random tokens to verify the residual path changes values.
613        let cfg = make_cfg();
614        let e = cfg.embed_dim;
615        let n_tokens = 8;
616        let mut rng = LcgRng::new(6);
617        let block = ViTBlock::new(cfg, &mut rng);
618        let mut tokens = vec![0.0f32; n_tokens * e];
619        rng.fill_normal(&mut tokens);
620        let out = block.forward(&tokens, n_tokens).expect("forward ok");
621        // Output should differ from input (the block is not identity-initialised).
622        let diff: f32 = out
623            .iter()
624            .zip(tokens.iter())
625            .map(|(a, b)| (a - b).abs())
626            .sum();
627        assert!(diff > 1e-6, "block did not change tokens (diff={diff})");
628    }
629
630    // ── gelu_exact ────────────────────────────────────────────────────────────
631
632    #[test]
633    fn gelu_zero() {
634        // GELU(0) = 0
635        assert!((gelu_exact(0.0) - 0.0).abs() < 1e-6);
636    }
637
638    #[test]
639    fn gelu_large_positive_approx_identity() {
640        // For large x, GELU(x) ≈ x
641        let x = 10.0f32;
642        assert!(
643            (gelu_exact(x) - x).abs() < 1e-3,
644            "GELU({x}) = {}",
645            gelu_exact(x)
646        );
647    }
648
649    #[test]
650    fn gelu_large_negative_approx_zero() {
651        // For large negative x, GELU(x) ≈ 0
652        let x = -10.0f32;
653        assert!(gelu_exact(x).abs() < 1e-3, "GELU({x}) = {}", gelu_exact(x));
654    }
655}