1use crate::{
24 error::{VisionError, VisionResult},
25 handle::LcgRng,
26};
27
28const LN_EPS: f32 = 1e-6;
30
31#[derive(Debug, Clone, PartialEq)]
35pub struct ConvNextConfig {
36 pub channels: usize,
38 pub height: usize,
40 pub width: usize,
42 pub kernel: usize,
44 pub expansion: usize,
46 pub layer_scale_init: f32,
48}
49
50impl ConvNextConfig {
51 pub fn new(
58 channels: usize,
59 height: usize,
60 width: usize,
61 kernel: usize,
62 expansion: usize,
63 layer_scale_init: f32,
64 ) -> VisionResult<Self> {
65 if channels == 0 || height == 0 || width == 0 {
66 return Err(VisionError::InvalidImageSize {
67 height,
68 width,
69 channels,
70 });
71 }
72 if kernel == 0 || kernel % 2 == 0 {
73 return Err(VisionError::InvalidPatchSize {
74 patch_size: kernel,
75 img_size: height,
76 });
77 }
78 if expansion == 0 {
79 return Err(VisionError::Internal("expansion must be >= 1".to_string()));
80 }
81 Ok(Self {
82 channels,
83 height,
84 width,
85 kernel,
86 expansion,
87 layer_scale_init,
88 })
89 }
90
91 #[must_use]
93 #[inline]
94 pub fn spatial(&self) -> usize {
95 self.height * self.width
96 }
97
98 #[must_use]
100 #[inline]
101 pub fn hidden(&self) -> usize {
102 self.expansion * self.channels
103 }
104
105 #[must_use]
107 #[inline]
108 pub fn pad(&self) -> usize {
109 (self.kernel - 1) / 2
110 }
111}
112
113pub struct ConvNextBlock {
119 dw_kernel: Vec<f32>,
121 dw_bias: Vec<f32>,
123 ln_gamma: Vec<f32>,
125 ln_beta: Vec<f32>,
127 pw1_weight: Vec<f32>,
129 pw1_bias: Vec<f32>,
131 pw2_weight: Vec<f32>,
133 pw2_bias: Vec<f32>,
135 layer_scale: Vec<f32>,
137 cfg: ConvNextConfig,
139}
140
141impl ConvNextBlock {
142 pub fn new(cfg: ConvNextConfig, rng: &mut LcgRng) -> VisionResult<Self> {
152 let cfg = ConvNextConfig::new(
153 cfg.channels,
154 cfg.height,
155 cfg.width,
156 cfg.kernel,
157 cfg.expansion,
158 cfg.layer_scale_init,
159 )?;
160 let c = cfg.channels;
161 let hidden = cfg.hidden();
162
163 let fill_scaled = |rng: &mut LcgRng, n: usize, sc: f32| -> Vec<f32> {
164 let mut v = vec![0.0f32; n];
165 rng.fill_normal(&mut v);
166 for x in &mut v {
167 *x *= sc;
168 }
169 v
170 };
171
172 let dw_fan_in = cfg.kernel * cfg.kernel;
174 let dw_scale = (2.0 / dw_fan_in as f32).sqrt();
175 let dw_kernel = fill_scaled(rng, c * dw_fan_in, dw_scale);
176 let dw_bias = vec![0.0f32; c];
177
178 let ln_gamma = vec![1.0f32; c];
179 let ln_beta = vec![0.0f32; c];
180
181 let pw1_scale = (2.0 / c as f32).sqrt();
183 let pw1_weight = fill_scaled(rng, hidden * c, pw1_scale);
184 let pw1_bias = vec![0.0f32; hidden];
185
186 let pw2_scale = (2.0 / hidden as f32).sqrt();
188 let pw2_weight = fill_scaled(rng, c * hidden, pw2_scale);
189 let pw2_bias = vec![0.0f32; c];
190
191 let layer_scale = vec![cfg.layer_scale_init; c];
192
193 Ok(Self {
194 dw_kernel,
195 dw_bias,
196 ln_gamma,
197 ln_beta,
198 pw1_weight,
199 pw1_bias,
200 pw2_weight,
201 pw2_bias,
202 layer_scale,
203 cfg,
204 })
205 }
206
207 #[must_use]
209 #[inline]
210 pub fn config(&self) -> &ConvNextConfig {
211 &self.cfg
212 }
213
214 #[inline]
217 pub fn dw_kernel_mut(&mut self) -> &mut [f32] {
218 &mut self.dw_kernel
219 }
220
221 #[inline]
223 pub fn dw_bias_mut(&mut self) -> &mut [f32] {
224 &mut self.dw_bias
225 }
226
227 fn check_input_len(&self, x: &[f32]) -> VisionResult<()> {
229 let expected = self.cfg.channels * self.cfg.spatial();
230 if x.len() != expected {
231 return Err(VisionError::DimensionMismatch {
232 expected,
233 got: x.len(),
234 });
235 }
236 Ok(())
237 }
238
239 pub fn depthwise_conv(&self, x: &[f32]) -> VisionResult<Vec<f32>> {
248 self.check_input_len(x)?;
249 let c = self.cfg.channels;
250 let h = self.cfg.height;
251 let w = self.cfg.width;
252 let k = self.cfg.kernel;
253 let pad = self.cfg.pad();
254 let hw = h * w;
255 let k2 = k * k;
256
257 let mut out = vec![0.0f32; c * hw];
258 for ch in 0..c {
259 let in_base = ch * hw;
260 let ker_base = ch * k2;
261 let bias = self.dw_bias[ch];
262 for oh in 0..h {
263 for ow in 0..w {
264 let mut acc = bias;
265 for ki in 0..k {
266 let ih = oh as isize + ki as isize - pad as isize;
268 if ih < 0 || ih >= h as isize {
269 continue;
270 }
271 let ih = ih as usize;
272 for kj in 0..k {
273 let iw = ow as isize + kj as isize - pad as isize;
274 if iw < 0 || iw >= w as isize {
275 continue;
276 }
277 let iw = iw as usize;
278 acc +=
279 self.dw_kernel[ker_base + ki * k + kj] * x[in_base + ih * w + iw];
280 }
281 }
282 out[in_base + oh * w + ow] = acc;
283 }
284 }
285 }
286 Ok(out)
287 }
288
289 pub fn channel_layernorm(&self, x: &[f32]) -> VisionResult<Vec<f32>> {
298 self.check_input_len(x)?;
299 let c = self.cfg.channels;
300 let hw = self.cfg.spatial();
301
302 let mut out = vec![0.0f32; c * hw];
303 for p in 0..hw {
304 let mut mean = 0.0f32;
306 for ch in 0..c {
307 mean += x[ch * hw + p];
308 }
309 mean /= c as f32;
310 let mut var = 0.0f32;
312 for ch in 0..c {
313 let d = x[ch * hw + p] - mean;
314 var += d * d;
315 }
316 var /= c as f32;
317 let inv_std = 1.0 / (var + LN_EPS).sqrt();
318 for ch in 0..c {
319 let norm = (x[ch * hw + p] - mean) * inv_std;
320 out[ch * hw + p] = norm * self.ln_gamma[ch] + self.ln_beta[ch];
321 }
322 }
323 Ok(out)
324 }
325
326 fn pointwise(
330 &self,
331 x: &[f32],
332 weight: &[f32],
333 bias: &[f32],
334 in_c: usize,
335 out_c: usize,
336 ) -> Vec<f32> {
337 let hw = self.cfg.spatial();
338 let mut out = vec![0.0f32; out_c * hw];
339 for p in 0..hw {
340 for oc in 0..out_c {
341 let wrow = &weight[oc * in_c..(oc + 1) * in_c];
342 let mut acc = bias[oc];
343 for ic in 0..in_c {
344 acc += wrow[ic] * x[ic * hw + p];
345 }
346 out[oc * hw + p] = acc;
347 }
348 }
349 out
350 }
351
352 pub fn forward(&self, x: &[f32]) -> VisionResult<Vec<f32>> {
361 self.check_input_len(x)?;
362 let c = self.cfg.channels;
363 let hidden = self.cfg.hidden();
364 let hw = self.cfg.spatial();
365
366 let y = self.depthwise_conv(x)?;
368 let y = self.channel_layernorm(&y)?;
369
370 let y = self.pointwise(&y, &self.pw1_weight, &self.pw1_bias, c, hidden);
372 let y: Vec<f32> = y.into_iter().map(gelu).collect();
373 let mut y = self.pointwise(&y, &self.pw2_weight, &self.pw2_bias, hidden, c);
374
375 for ch in 0..c {
377 let gamma = self.layer_scale[ch];
378 for p in 0..hw {
379 y[ch * hw + p] *= gamma;
380 }
381 }
382
383 let out: Vec<f32> = x.iter().zip(y.iter()).map(|(a, b)| a + b).collect();
385 if out.iter().any(|v| !v.is_finite()) {
386 return Err(VisionError::NonFinite("convnext block output"));
387 }
388 Ok(out)
389 }
390
391 #[must_use]
405 pub fn n_params(&self) -> usize {
406 let c = self.cfg.channels;
407 let hidden = self.cfg.hidden();
408 let k2 = self.cfg.kernel * self.cfg.kernel;
409 c * k2 + c + c + c + hidden * c + hidden + c * hidden + c + c }
419}
420
421#[inline]
426fn gelu(v: f32) -> f32 {
427 const SQRT_2_OVER_PI: f32 = 0.797_884_6;
428 const COEFF: f32 = 0.044_715;
429 let inner = SQRT_2_OVER_PI * (v + COEFF * v * v * v);
430 0.5 * v * (1.0 + inner.tanh())
431}
432
433#[cfg(test)]
436mod tests {
437 use super::*;
438
439 fn make_cfg() -> ConvNextConfig {
440 ConvNextConfig::new(8, 6, 6, 3, 4, 1e-6).expect("valid config")
442 }
443
444 fn random_input(cfg: &ConvNextConfig, seed: u64) -> Vec<f32> {
445 let mut rng = LcgRng::new(seed);
446 let mut x = vec![0.0f32; cfg.channels * cfg.spatial()];
447 rng.fill_normal(&mut x);
448 x
449 }
450
451 #[test]
452 fn config_derived_quantities() {
453 let cfg = make_cfg();
454 assert_eq!(cfg.spatial(), 36);
455 assert_eq!(cfg.hidden(), 32);
456 assert_eq!(cfg.pad(), 1); }
458
459 #[test]
460 fn depthwise_conv_output_length() {
461 let cfg = make_cfg();
462 let mut rng = LcgRng::new(1);
463 let block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
464 let x = random_input(&cfg, 2);
465 let y = block.depthwise_conv(&x).expect("dw");
466 assert_eq!(y.len(), cfg.channels * cfg.spatial());
467 }
468
469 #[test]
470 fn depthwise_identity_kernel_is_input() {
471 let cfg = make_cfg();
473 let mut rng = LcgRng::new(3);
474 let mut block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
475 let k = cfg.kernel;
476 let k2 = k * k;
477 let center = (k / 2) * k + (k / 2);
478 {
479 let ker = block.dw_kernel_mut();
480 for v in ker.iter_mut() {
481 *v = 0.0;
482 }
483 for ch in 0..cfg.channels {
484 ker[ch * k2 + center] = 1.0;
485 }
486 }
487 for v in block.dw_bias_mut().iter_mut() {
488 *v = 0.0;
489 }
490 let x = random_input(&cfg, 4);
491 let y = block.depthwise_conv(&x).expect("dw");
492 for (a, b) in y.iter().zip(x.iter()) {
493 assert!((a - b).abs() < 1e-5, "identity kernel mismatch: {a} vs {b}");
494 }
495 }
496
497 #[test]
498 fn channel_layernorm_zero_mean_unit_var() {
499 let cfg = make_cfg();
501 let mut rng = LcgRng::new(5);
502 let block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
503 let x = random_input(&cfg, 6);
504 let y = block.channel_layernorm(&x).expect("ln");
505 let c = cfg.channels;
506 let hw = cfg.spatial();
507 for &p in &[0usize, 7, hw - 1] {
508 let mut mean = 0.0f32;
509 for ch in 0..c {
510 mean += y[ch * hw + p];
511 }
512 mean /= c as f32;
513 let mut var = 0.0f32;
514 for ch in 0..c {
515 let d = y[ch * hw + p] - mean;
516 var += d * d;
517 }
518 var /= c as f32;
519 assert!(mean.abs() < 1e-4, "pixel {p} mean not ~0: {mean}");
520 assert!((var - 1.0).abs() < 1e-2, "pixel {p} var not ~1: {var}");
521 }
522 }
523
524 #[test]
525 fn forward_output_length() {
526 let cfg = make_cfg();
527 let mut rng = LcgRng::new(7);
528 let block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
529 let x = random_input(&cfg, 8);
530 let out = block.forward(&x).expect("forward");
531 assert_eq!(out.len(), cfg.channels * cfg.spatial());
532 }
533
534 #[test]
535 fn forward_finite() {
536 let cfg = make_cfg();
537 let mut rng = LcgRng::new(9);
538 let block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
539 let x = random_input(&cfg, 10);
540 let out = block.forward(&x).expect("forward");
541 assert!(out.iter().all(|v| v.is_finite()), "non-finite output");
542 }
543
544 #[test]
545 fn layer_scale_zero_makes_identity() {
546 let cfg = ConvNextConfig::new(8, 6, 6, 3, 4, 0.0).expect("cfg");
548 let mut rng = LcgRng::new(11);
549 let block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
550 let x = random_input(&cfg, 12);
551 let out = block.forward(&x).expect("forward");
552 for (a, b) in out.iter().zip(x.iter()) {
553 assert_eq!(a, b, "zero layer scale must be exact identity");
554 }
555 }
556
557 #[test]
558 fn n_params_formula_matches() {
559 let cfg = make_cfg();
560 let mut rng = LcgRng::new(13);
561 let block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
562 let c = cfg.channels;
563 let hidden = cfg.hidden();
564 let k2 = cfg.kernel * cfg.kernel;
565 let expected = c * k2 + c + c + c + hidden * c + hidden + c * hidden + c + c;
566 assert_eq!(block.n_params(), expected);
567 }
568
569 #[test]
570 fn kernel_one_works() {
571 let cfg = ConvNextConfig::new(4, 5, 5, 1, 2, 1e-6).expect("cfg");
574 assert_eq!(cfg.pad(), 0);
575 let mut rng = LcgRng::new(14);
576 let block = ConvNextBlock::new(cfg.clone(), &mut rng).expect("block");
577 let x = random_input(&cfg, 15);
578 let y = block.depthwise_conv(&x).expect("dw");
579 assert_eq!(y.len(), cfg.channels * cfg.spatial());
580 let out = block.forward(&x).expect("forward");
581 assert_eq!(out.len(), cfg.channels * cfg.spatial());
582 }
583
584 #[test]
585 fn expansion_grows_param_count() {
586 let mut rng = LcgRng::new(16);
587 let cfg2 = ConvNextConfig::new(8, 4, 4, 3, 2, 1e-6).expect("cfg");
588 let cfg4 = ConvNextConfig::new(8, 4, 4, 3, 4, 1e-6).expect("cfg");
589 let b2 = ConvNextBlock::new(cfg2, &mut rng).expect("block");
590 let b4 = ConvNextBlock::new(cfg4, &mut rng).expect("block");
591 assert!(
592 b4.n_params() > b2.n_params(),
593 "more expansion must mean more params"
594 );
595 }
596
597 #[test]
598 fn gelu_zero_is_zero() {
599 assert!(gelu(0.0).abs() < 1e-6);
600 }
601
602 #[test]
603 fn gelu_large_positive_approx_identity() {
604 let v = 10.0f32;
605 assert!((gelu(v) - v).abs() < 1e-3, "GELU({v}) = {}", gelu(v));
606 }
607
608 #[test]
609 fn gelu_large_negative_approx_zero() {
610 let v = -10.0f32;
611 assert!(gelu(v).abs() < 1e-3, "GELU({v}) = {}", gelu(v));
612 }
613
614 #[test]
615 fn err_channels_zero() {
616 let r = ConvNextConfig::new(0, 6, 6, 3, 4, 1e-6);
617 assert!(matches!(r, Err(VisionError::InvalidImageSize { .. })));
618 }
619
620 #[test]
621 fn err_kernel_even() {
622 let r = ConvNextConfig::new(8, 6, 6, 4, 4, 1e-6);
623 assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
624 }
625
626 #[test]
627 fn err_kernel_zero() {
628 let r = ConvNextConfig::new(8, 6, 6, 0, 4, 1e-6);
629 assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
630 }
631
632 #[test]
633 fn err_expansion_zero() {
634 let r = ConvNextConfig::new(8, 6, 6, 3, 0, 1e-6);
635 assert!(matches!(r, Err(VisionError::Internal(_))));
636 }
637
638 #[test]
639 fn err_height_zero() {
640 let r = ConvNextConfig::new(8, 0, 6, 3, 4, 1e-6);
641 assert!(matches!(r, Err(VisionError::InvalidImageSize { .. })));
642 }
643
644 #[test]
645 fn err_forward_wrong_length() {
646 let cfg = make_cfg();
647 let mut rng = LcgRng::new(17);
648 let block = ConvNextBlock::new(cfg, &mut rng).expect("block");
649 let x = vec![0.0f32; 5]; let r = block.forward(&x);
651 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
652 }
653
654 #[test]
655 fn err_depthwise_wrong_length() {
656 let cfg = make_cfg();
657 let mut rng = LcgRng::new(18);
658 let block = ConvNextBlock::new(cfg, &mut rng).expect("block");
659 let x = vec![0.0f32; 3]; let r = block.depthwise_conv(&x);
661 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
662 }
663
664 #[test]
665 fn deterministic_given_seed() {
666 let cfg = make_cfg();
667 let mut rng_a = LcgRng::new(77);
668 let mut rng_b = LcgRng::new(77);
669 let block_a = ConvNextBlock::new(cfg.clone(), &mut rng_a).expect("block");
670 let block_b = ConvNextBlock::new(cfg.clone(), &mut rng_b).expect("block");
671 let x = random_input(&cfg, 78);
672 let out_a = block_a.forward(&x).expect("forward");
673 let out_b = block_b.forward(&x).expect("forward");
674 assert_eq!(out_a, out_b, "same seed must give identical output");
675 }
676}