Skip to main content

oxicuda_vision/convnext/
block.rs

1//! ConvNeXt block: depthwise 7×7 convolution + channel LayerNorm +
2//! inverted bottleneck (1×1 expansion → GELU → 1×1 projection) + layer
3//! scale + residual.
4//!
5//! Reference: Liu et al. 2022 CVPR, *"A ConvNet for the 2020s"*.
6//!
7//! ## Tensor layout
8//! Activations use a channel-major row-major layout: `x` is
9//! `(channels, height, width)` flattened so that channel `c`'s pixel `(h, w)`
10//! lives at `c·H·W + h·W + w`.
11//!
12//! ## Forward pass
13//! ```text
14//! y = depthwise_conv(x)              (C,H,W) same-pad 7×7, per-channel
15//! y = channel_layernorm(y)           LN across channels at each (h,w)
16//! y = pointwise1(y)                  1×1 conv  C → expansion·C
17//! y = GELU(y)
18//! y = pointwise2(y)                  1×1 conv  expansion·C → C
19//! y = layer_scale ⊙ y                per-channel scale γ
20//! out = x + y                        residual
21//! ```
22
23use crate::{
24    error::{VisionError, VisionResult},
25    handle::LcgRng,
26};
27
28/// LayerNorm epsilon for the channel-wise normalisation.
29const LN_EPS: f32 = 1e-6;
30
31// ─── Config ──────────────────────────────────────────────────────────────────
32
33/// Configuration for a single ConvNeXt block.
34#[derive(Debug, Clone, PartialEq)]
35pub struct ConvNextConfig {
36    /// Number of feature channels `C` (input == output).
37    pub channels: usize,
38    /// Spatial height `H` in pixels.
39    pub height: usize,
40    /// Spatial width `W` in pixels.
41    pub width: usize,
42    /// Depthwise convolution kernel edge length (must be odd, e.g. 7).
43    pub kernel: usize,
44    /// Inverted-bottleneck expansion factor (hidden = `expansion · C`).
45    pub expansion: usize,
46    /// Initial value for the layer-scale gamma (e.g. `1e-6`).
47    pub layer_scale_init: f32,
48}
49
50impl ConvNextConfig {
51    /// Create and validate a `ConvNextConfig`.
52    ///
53    /// # Errors
54    /// - `channels == 0` / `height == 0` / `width == 0` → `InvalidImageSize`
55    /// - `kernel == 0` or `kernel` even → `InvalidPatchSize`
56    /// - `expansion == 0` → `Internal`
57    pub fn new(
58        channels: usize,
59        height: usize,
60        width: usize,
61        kernel: usize,
62        expansion: usize,
63        layer_scale_init: f32,
64    ) -> VisionResult<Self> {
65        if channels == 0 || height == 0 || width == 0 {
66            return Err(VisionError::InvalidImageSize {
67                height,
68                width,
69                channels,
70            });
71        }
72        if kernel == 0 || kernel % 2 == 0 {
73            return Err(VisionError::InvalidPatchSize {
74                patch_size: kernel,
75                img_size: height,
76            });
77        }
78        if expansion == 0 {
79            return Err(VisionError::Internal("expansion must be >= 1".to_string()));
80        }
81        Ok(Self {
82            channels,
83            height,
84            width,
85            kernel,
86            expansion,
87            layer_scale_init,
88        })
89    }
90
91    /// Number of spatial pixels `H·W`.
92    #[must_use]
93    #[inline]
94    pub fn spatial(&self) -> usize {
95        self.height * self.width
96    }
97
98    /// Inverted-bottleneck hidden width `expansion · C`.
99    #[must_use]
100    #[inline]
101    pub fn hidden(&self) -> usize {
102        self.expansion * self.channels
103    }
104
105    /// Symmetric "same" zero padding `(kernel - 1) / 2`.
106    #[must_use]
107    #[inline]
108    pub fn pad(&self) -> usize {
109        (self.kernel - 1) / 2
110    }
111}
112
113// ─── ConvNextBlock ─────────────────────────────────────────────────────────────
114
115/// A single ConvNeXt block.
116///
117/// All weight tensors are flat row-major `Vec<f32>`.
118pub struct ConvNextBlock {
119    /// Depthwise kernel `[channels, kernel, kernel]`.
120    dw_kernel: Vec<f32>,
121    /// Depthwise bias `[channels]`.
122    dw_bias: Vec<f32>,
123    /// Channel-LayerNorm scale `[channels]` (init 1).
124    ln_gamma: Vec<f32>,
125    /// Channel-LayerNorm bias `[channels]` (init 0).
126    ln_beta: Vec<f32>,
127    /// Pointwise-1 kernel `[expansion·C, C]` (1×1 conv).
128    pw1_weight: Vec<f32>,
129    /// Pointwise-1 bias `[expansion·C]`.
130    pw1_bias: Vec<f32>,
131    /// Pointwise-2 kernel `[C, expansion·C]` (1×1 conv).
132    pw2_weight: Vec<f32>,
133    /// Pointwise-2 bias `[C]`.
134    pw2_bias: Vec<f32>,
135    /// Layer-scale gamma `[channels]` (init `layer_scale_init`).
136    layer_scale: Vec<f32>,
137    /// Block configuration.
138    cfg: ConvNextConfig,
139}
140
141impl ConvNextBlock {
142    /// Construct a new block with random depthwise / pointwise weights, zeroed
143    /// depthwise bias, identity LayerNorm affine, and layer-scale gamma set to
144    /// `cfg.layer_scale_init`.
145    ///
146    /// Kernels and pointwise weights use a Kaiming-ish scaled normal
147    /// initialisation drawn from the deterministic LCG RNG.
148    ///
149    /// # Errors
150    /// Propagates configuration validation from [`ConvNextConfig::new`].
151    pub fn new(cfg: ConvNextConfig, rng: &mut LcgRng) -> VisionResult<Self> {
152        let cfg = ConvNextConfig::new(
153            cfg.channels,
154            cfg.height,
155            cfg.width,
156            cfg.kernel,
157            cfg.expansion,
158            cfg.layer_scale_init,
159        )?;
160        let c = cfg.channels;
161        let hidden = cfg.hidden();
162
163        let fill_scaled = |rng: &mut LcgRng, n: usize, sc: f32| -> Vec<f32> {
164            let mut v = vec![0.0f32; n];
165            rng.fill_normal(&mut v);
166            for x in &mut v {
167                *x *= sc;
168            }
169            v
170        };
171
172        // Depthwise: fan_in = kernel² (one channel per filter).
173        let dw_fan_in = cfg.kernel * cfg.kernel;
174        let dw_scale = (2.0 / dw_fan_in as f32).sqrt();
175        let dw_kernel = fill_scaled(rng, c * dw_fan_in, dw_scale);
176        let dw_bias = vec![0.0f32; c];
177
178        let ln_gamma = vec![1.0f32; c];
179        let ln_beta = vec![0.0f32; c];
180
181        // Pointwise-1: fan_in = C.
182        let pw1_scale = (2.0 / c as f32).sqrt();
183        let pw1_weight = fill_scaled(rng, hidden * c, pw1_scale);
184        let pw1_bias = vec![0.0f32; hidden];
185
186        // Pointwise-2: fan_in = hidden.
187        let pw2_scale = (2.0 / hidden as f32).sqrt();
188        let pw2_weight = fill_scaled(rng, c * hidden, pw2_scale);
189        let pw2_bias = vec![0.0f32; c];
190
191        let layer_scale = vec![cfg.layer_scale_init; c];
192
193        Ok(Self {
194            dw_kernel,
195            dw_bias,
196            ln_gamma,
197            ln_beta,
198            pw1_weight,
199            pw1_bias,
200            pw2_weight,
201            pw2_bias,
202            layer_scale,
203            cfg,
204        })
205    }
206
207    /// Read-only access to the block configuration.
208    #[must_use]
209    #[inline]
210    pub fn config(&self) -> &ConvNextConfig {
211        &self.cfg
212    }
213
214    /// Mutable access to the depthwise kernel (used for tests that install an
215    /// identity / delta kernel).
216    #[inline]
217    pub fn dw_kernel_mut(&mut self) -> &mut [f32] {
218        &mut self.dw_kernel
219    }
220
221    /// Mutable access to the depthwise bias.
222    #[inline]
223    pub fn dw_bias_mut(&mut self) -> &mut [f32] {
224        &mut self.dw_bias
225    }
226
227    /// Validate that `x` has the expected `C·H·W` flat length.
228    fn check_input_len(&self, x: &[f32]) -> VisionResult<()> {
229        let expected = self.cfg.channels * self.cfg.spatial();
230        if x.len() != expected {
231            return Err(VisionError::DimensionMismatch {
232                expected,
233                got: x.len(),
234            });
235        }
236        Ok(())
237    }
238
239    /// Depthwise convolution: each channel `c` is convolved with its own
240    /// `kernel × kernel` filter, zero-padded "same" (`pad = (k-1)/2`),
241    /// stride 1, plus per-channel bias.
242    ///
243    /// Input / output are `(C, H, W)` flat.
244    ///
245    /// # Errors
246    /// `DimensionMismatch` if `x.len() != C·H·W`.
247    pub fn depthwise_conv(&self, x: &[f32]) -> VisionResult<Vec<f32>> {
248        self.check_input_len(x)?;
249        let c = self.cfg.channels;
250        let h = self.cfg.height;
251        let w = self.cfg.width;
252        let k = self.cfg.kernel;
253        let pad = self.cfg.pad();
254        let hw = h * w;
255        let k2 = k * k;
256
257        let mut out = vec![0.0f32; c * hw];
258        for ch in 0..c {
259            let in_base = ch * hw;
260            let ker_base = ch * k2;
261            let bias = self.dw_bias[ch];
262            for oh in 0..h {
263                for ow in 0..w {
264                    let mut acc = bias;
265                    for ki in 0..k {
266                        // Signed input row for this kernel tap.
267                        let ih = oh as isize + ki as isize - pad as isize;
268                        if ih < 0 || ih >= h as isize {
269                            continue;
270                        }
271                        let ih = ih as usize;
272                        for kj in 0..k {
273                            let iw = ow as isize + kj as isize - pad as isize;
274                            if iw < 0 || iw >= w as isize {
275                                continue;
276                            }
277                            let iw = iw as usize;
278                            acc +=
279                                self.dw_kernel[ker_base + ki * k + kj] * x[in_base + ih * w + iw];
280                        }
281                    }
282                    out[in_base + oh * w + ow] = acc;
283                }
284            }
285        }
286        Ok(out)
287    }
288
289    /// Channel LayerNorm: at each spatial position `(h, w)`, gather the `C`
290    /// channel values, normalise to zero-mean / unit-variance (ε = `1e-6`),
291    /// then apply per-channel affine `γ · x̂ + β`.
292    ///
293    /// Input / output are `(C, H, W)` flat.
294    ///
295    /// # Errors
296    /// `DimensionMismatch` if `x.len() != C·H·W`.
297    pub fn channel_layernorm(&self, x: &[f32]) -> VisionResult<Vec<f32>> {
298        self.check_input_len(x)?;
299        let c = self.cfg.channels;
300        let hw = self.cfg.spatial();
301
302        let mut out = vec![0.0f32; c * hw];
303        for p in 0..hw {
304            // Mean across channels at pixel p.
305            let mut mean = 0.0f32;
306            for ch in 0..c {
307                mean += x[ch * hw + p];
308            }
309            mean /= c as f32;
310            // Variance across channels.
311            let mut var = 0.0f32;
312            for ch in 0..c {
313                let d = x[ch * hw + p] - mean;
314                var += d * d;
315            }
316            var /= c as f32;
317            let inv_std = 1.0 / (var + LN_EPS).sqrt();
318            for ch in 0..c {
319                let norm = (x[ch * hw + p] - mean) * inv_std;
320                out[ch * hw + p] = norm * self.ln_gamma[ch] + self.ln_beta[ch];
321            }
322        }
323        Ok(out)
324    }
325
326    /// Apply a 1×1 convolution `W·x + b` independently at every spatial
327    /// position. `weight` is `[out_c, in_c]` row-major (one filter per output
328    /// channel). Input is `(in_c, H, W)`, output is `(out_c, H, W)`.
329    fn pointwise(
330        &self,
331        x: &[f32],
332        weight: &[f32],
333        bias: &[f32],
334        in_c: usize,
335        out_c: usize,
336    ) -> Vec<f32> {
337        let hw = self.cfg.spatial();
338        let mut out = vec![0.0f32; out_c * hw];
339        for p in 0..hw {
340            for oc in 0..out_c {
341                let wrow = &weight[oc * in_c..(oc + 1) * in_c];
342                let mut acc = bias[oc];
343                for ic in 0..in_c {
344                    acc += wrow[ic] * x[ic * hw + p];
345                }
346                out[oc * hw + p] = acc;
347            }
348        }
349        out
350    }
351
352    /// Forward pass: `(C·H·W) → (C·H·W)`.
353    ///
354    /// See the module docs for the full pipeline. When `layer_scale_init` is
355    /// `0.0`, the residual branch is multiplied by zero so `forward(x) == x`.
356    ///
357    /// # Errors
358    /// - `DimensionMismatch` if `x.len() != C·H·W`.
359    /// - `NonFinite` if the output contains non-finite values.
360    pub fn forward(&self, x: &[f32]) -> VisionResult<Vec<f32>> {
361        self.check_input_len(x)?;
362        let c = self.cfg.channels;
363        let hidden = self.cfg.hidden();
364        let hw = self.cfg.spatial();
365
366        // Depthwise conv → channel LayerNorm.
367        let y = self.depthwise_conv(x)?;
368        let y = self.channel_layernorm(&y)?;
369
370        // Inverted bottleneck: 1×1 expand → GELU → 1×1 project.
371        let y = self.pointwise(&y, &self.pw1_weight, &self.pw1_bias, c, hidden);
372        let y: Vec<f32> = y.into_iter().map(gelu).collect();
373        let mut y = self.pointwise(&y, &self.pw2_weight, &self.pw2_bias, hidden, c);
374
375        // Per-channel layer scale.
376        for ch in 0..c {
377            let gamma = self.layer_scale[ch];
378            for p in 0..hw {
379                y[ch * hw + p] *= gamma;
380            }
381        }
382
383        // Residual.
384        let out: Vec<f32> = x.iter().zip(y.iter()).map(|(a, b)| a + b).collect();
385        if out.iter().any(|v| !v.is_finite()) {
386            return Err(VisionError::NonFinite("convnext block output"));
387        }
388        Ok(out)
389    }
390
391    /// Total number of learnable parameters in this block.
392    ///
393    /// ```text
394    /// dw_kernel  : C · k²
395    /// dw_bias    : C
396    /// ln_gamma   : C
397    /// ln_beta    : C
398    /// pw1_weight : hidden · C
399    /// pw1_bias   : hidden
400    /// pw2_weight : C · hidden
401    /// pw2_bias   : C
402    /// layer_scale: C
403    /// ```
404    #[must_use]
405    pub fn n_params(&self) -> usize {
406        let c = self.cfg.channels;
407        let hidden = self.cfg.hidden();
408        let k2 = self.cfg.kernel * self.cfg.kernel;
409        c * k2  // dw_kernel
410            + c // dw_bias
411            + c // ln_gamma
412            + c // ln_beta
413            + hidden * c // pw1_weight
414            + hidden // pw1_bias
415            + c * hidden // pw2_weight
416            + c // pw2_bias
417            + c // layer_scale
418    }
419}
420
421// ─── GELU ──────────────────────────────────────────────────────────────────────
422
423/// GELU activation via the tanh approximation:
424/// `0.5·v·(1 + tanh(√(2/π)·(v + 0.044715·v³)))`.
425#[inline]
426fn gelu(v: f32) -> f32 {
427    const SQRT_2_OVER_PI: f32 = 0.797_884_6;
428    const COEFF: f32 = 0.044_715;
429    let inner = SQRT_2_OVER_PI * (v + COEFF * v * v * v);
430    0.5 * v * (1.0 + inner.tanh())
431}
432
433// ─── Tests ───────────────────────────────────────────────────────────────────
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    fn make_cfg() -> ConvNextConfig {
440        // C=8, 6×6 spatial, 3×3 depthwise, 4× expansion, layer scale 1e-6.
441        ConvNextConfig::new(8, 6, 6, 3, 4, 1e-6).expect("valid config")
442    }
443
444    fn random_input(cfg: &ConvNextConfig, seed: u64) -> Vec<f32> {
445        let mut rng = LcgRng::new(seed);
446        let mut x = vec![0.0f32; cfg.channels * cfg.spatial()];
447        rng.fill_normal(&mut x);
448        x
449    }
450
451    #[test]
452    fn config_derived_quantities() {
453        let cfg = make_cfg();
454        assert_eq!(cfg.spatial(), 36);
455        assert_eq!(cfg.hidden(), 32);
456        assert_eq!(cfg.pad(), 1); // (3-1)/2
457    }
458
459    #[test]
460    fn depthwise_conv_output_length() {
461        let cfg = make_cfg();
462        let mut rng = LcgRng::new(1);
463        let block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
464        let x = random_input(&cfg, 2);
465        let y = block.depthwise_conv(&x).expect("dw");
466        assert_eq!(y.len(), cfg.channels * cfg.spatial());
467    }
468
469    #[test]
470    fn depthwise_identity_kernel_is_input() {
471        // Center-1 delta kernel + zero bias ⇒ depthwise conv is the identity.
472        let cfg = make_cfg();
473        let mut rng = LcgRng::new(3);
474        let mut block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
475        let k = cfg.kernel;
476        let k2 = k * k;
477        let center = (k / 2) * k + (k / 2);
478        {
479            let ker = block.dw_kernel_mut();
480            for v in ker.iter_mut() {
481                *v = 0.0;
482            }
483            for ch in 0..cfg.channels {
484                ker[ch * k2 + center] = 1.0;
485            }
486        }
487        for v in block.dw_bias_mut().iter_mut() {
488            *v = 0.0;
489        }
490        let x = random_input(&cfg, 4);
491        let y = block.depthwise_conv(&x).expect("dw");
492        for (a, b) in y.iter().zip(x.iter()) {
493            assert!((a - b).abs() < 1e-5, "identity kernel mismatch: {a} vs {b}");
494        }
495    }
496
497    #[test]
498    fn channel_layernorm_zero_mean_unit_var() {
499        // With γ=1, β=0 (default init), per-pixel channel stats are ~N(0,1).
500        let cfg = make_cfg();
501        let mut rng = LcgRng::new(5);
502        let block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
503        let x = random_input(&cfg, 6);
504        let y = block.channel_layernorm(&x).expect("ln");
505        let c = cfg.channels;
506        let hw = cfg.spatial();
507        for &p in &[0usize, 7, hw - 1] {
508            let mut mean = 0.0f32;
509            for ch in 0..c {
510                mean += y[ch * hw + p];
511            }
512            mean /= c as f32;
513            let mut var = 0.0f32;
514            for ch in 0..c {
515                let d = y[ch * hw + p] - mean;
516                var += d * d;
517            }
518            var /= c as f32;
519            assert!(mean.abs() < 1e-4, "pixel {p} mean not ~0: {mean}");
520            assert!((var - 1.0).abs() < 1e-2, "pixel {p} var not ~1: {var}");
521        }
522    }
523
524    #[test]
525    fn forward_output_length() {
526        let cfg = make_cfg();
527        let mut rng = LcgRng::new(7);
528        let block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
529        let x = random_input(&cfg, 8);
530        let out = block.forward(&x).expect("forward");
531        assert_eq!(out.len(), cfg.channels * cfg.spatial());
532    }
533
534    #[test]
535    fn forward_finite() {
536        let cfg = make_cfg();
537        let mut rng = LcgRng::new(9);
538        let block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
539        let x = random_input(&cfg, 10);
540        let out = block.forward(&x).expect("forward");
541        assert!(out.iter().all(|v| v.is_finite()), "non-finite output");
542    }
543
544    #[test]
545    fn layer_scale_zero_makes_identity() {
546        // layer_scale_init = 0 ⇒ residual branch zeroed ⇒ forward(x) == x.
547        let cfg = ConvNextConfig::new(8, 6, 6, 3, 4, 0.0).expect("cfg");
548        let mut rng = LcgRng::new(11);
549        let block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
550        let x = random_input(&cfg, 12);
551        let out = block.forward(&x).expect("forward");
552        for (a, b) in out.iter().zip(x.iter()) {
553            assert_eq!(a, b, "zero layer scale must be exact identity");
554        }
555    }
556
557    #[test]
558    fn n_params_formula_matches() {
559        let cfg = make_cfg();
560        let mut rng = LcgRng::new(13);
561        let block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
562        let c = cfg.channels;
563        let hidden = cfg.hidden();
564        let k2 = cfg.kernel * cfg.kernel;
565        let expected = c * k2 + c + c + c + hidden * c + hidden + c * hidden + c + c;
566        assert_eq!(block.n_params(), expected);
567    }
568
569    #[test]
570    fn kernel_one_works() {
571        // 1×1 depthwise (kernel=1, pad=0) is valid and behaves as a per-channel
572        // scale-plus-bias.
573        let cfg = ConvNextConfig::new(4, 5, 5, 1, 2, 1e-6).expect("cfg");
574        assert_eq!(cfg.pad(), 0);
575        let mut rng = LcgRng::new(14);
576        let block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
577        let x = random_input(&cfg, 15);
578        let y = block.depthwise_conv(&x).expect("dw");
579        assert_eq!(y.len(), cfg.channels * cfg.spatial());
580        let out = block.forward(&x).expect("forward");
581        assert_eq!(out.len(), cfg.channels * cfg.spatial());
582    }
583
584    #[test]
585    fn expansion_grows_param_count() {
586        let mut rng = LcgRng::new(16);
587        let cfg2 = ConvNextConfig::new(8, 4, 4, 3, 2, 1e-6).expect("cfg");
588        let cfg4 = ConvNextConfig::new(8, 4, 4, 3, 4, 1e-6).expect("cfg");
589        let b2 = ConvNextBlock::new(cfg2, &mut rng).expect("block");
590        let b4 = ConvNextBlock::new(cfg4, &mut rng).expect("block");
591        assert!(
592            b4.n_params() > b2.n_params(),
593            "more expansion must mean more params"
594        );
595    }
596
597    #[test]
598    fn gelu_zero_is_zero() {
599        assert!(gelu(0.0).abs() < 1e-6);
600    }
601
602    #[test]
603    fn gelu_large_positive_approx_identity() {
604        let v = 10.0f32;
605        assert!((gelu(v) - v).abs() < 1e-3, "GELU({v}) = {}", gelu(v));
606    }
607
608    #[test]
609    fn gelu_large_negative_approx_zero() {
610        let v = -10.0f32;
611        assert!(gelu(v).abs() < 1e-3, "GELU({v}) = {}", gelu(v));
612    }
613
614    #[test]
615    fn err_channels_zero() {
616        let r = ConvNextConfig::new(0, 6, 6, 3, 4, 1e-6);
617        assert!(matches!(r, Err(VisionError::InvalidImageSize { .. })));
618    }
619
620    #[test]
621    fn err_kernel_even() {
622        let r = ConvNextConfig::new(8, 6, 6, 4, 4, 1e-6);
623        assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
624    }
625
626    #[test]
627    fn err_kernel_zero() {
628        let r = ConvNextConfig::new(8, 6, 6, 0, 4, 1e-6);
629        assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
630    }
631
632    #[test]
633    fn err_expansion_zero() {
634        let r = ConvNextConfig::new(8, 6, 6, 3, 0, 1e-6);
635        assert!(matches!(r, Err(VisionError::Internal(_))));
636    }
637
638    #[test]
639    fn err_height_zero() {
640        let r = ConvNextConfig::new(8, 0, 6, 3, 4, 1e-6);
641        assert!(matches!(r, Err(VisionError::InvalidImageSize { .. })));
642    }
643
644    #[test]
645    fn err_forward_wrong_length() {
646        let cfg = make_cfg();
647        let mut rng = LcgRng::new(17);
648        let block = ConvNextBlock::new(cfg, &mut rng).expect("block");
649        let x = vec![0.0f32; 5]; // wrong
650        let r = block.forward(&x);
651        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
652    }
653
654    #[test]
655    fn err_depthwise_wrong_length() {
656        let cfg = make_cfg();
657        let mut rng = LcgRng::new(18);
658        let block = ConvNextBlock::new(cfg, &mut rng).expect("block");
659        let x = vec![0.0f32; 3]; // wrong
660        let r = block.depthwise_conv(&x);
661        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
662    }
663
664    #[test]
665    fn deterministic_given_seed() {
666        let cfg = make_cfg();
667        let mut rng_a = LcgRng::new(77);
668        let mut rng_b = LcgRng::new(77);
669        let block_a = ConvNextBlock::new(cfg.clone(), &mut rng_a).expect("block");
670        let block_b = ConvNextBlock::new(cfg.clone(), &mut rng_b).expect("block");
671        let x = random_input(&cfg, 78);
672        let out_a = block_a.forward(&x).expect("forward");
673        let out_b = block_b.forward(&x).expect("forward");
674        assert_eq!(out_a, out_b, "same seed must give identical output");
675    }
676}