1use crate::error::{VisionError, VisionResult};
22use crate::handle::LcgRng;
23
24pub type VisionRng = LcgRng;
26
27#[inline]
31fn relu6(x: f32) -> f32 {
32 x.clamp(0.0, 6.0)
33}
34
35#[inline]
37fn sigmoid(x: f32) -> f32 {
38 if x >= 0.0 {
39 1.0 / (1.0 + (-x).exp())
40 } else {
41 let e = x.exp();
42 e / (1.0 + e)
43 }
44}
45
46#[derive(Debug, Clone)]
50pub struct MbConvConfig {
51 pub in_channels: usize,
53 pub out_channels: usize,
55 pub expand_ratio: usize,
57 pub stride: usize,
59 pub kernel_size: usize,
61 pub se_ratio: f32,
63 pub h: usize,
65 pub w: usize,
67}
68
69impl MbConvConfig {
70 #[must_use]
72 pub fn expanded_channels(&self) -> usize {
73 self.in_channels * self.expand_ratio
74 }
75
76 #[must_use]
78 pub fn se_channels(&self) -> usize {
79 let se = (self.in_channels as f32 * self.se_ratio).round() as usize;
80 se.max(1)
81 }
82
83 #[must_use]
85 pub fn has_skip(&self) -> bool {
86 self.stride == 1 && self.in_channels == self.out_channels
87 }
88}
89
90pub struct MbConvBlock {
101 expand_w: Vec<f32>,
102 expand_b: Vec<f32>,
103 dw_w: Vec<f32>,
104 dw_b: Vec<f32>,
105 se_fc1_w: Vec<f32>,
106 se_fc1_b: Vec<f32>,
107 se_fc2_w: Vec<f32>,
108 se_fc2_b: Vec<f32>,
109 proj_w: Vec<f32>,
110 proj_b: Vec<f32>,
111 config: MbConvConfig,
112 has_skip: bool,
113}
114
115impl MbConvBlock {
116 pub fn new(config: MbConvConfig, rng: &mut VisionRng) -> VisionResult<Self> {
124 if config.in_channels == 0 || config.out_channels == 0 {
125 return Err(VisionError::InvalidImageSize {
126 height: config.h,
127 width: config.w,
128 channels: config.in_channels,
129 });
130 }
131 if config.expand_ratio == 0 {
132 return Err(VisionError::InvalidEmbedDim(0));
133 }
134 if config.se_ratio <= 0.0 {
135 return Err(VisionError::NonPositiveTemperature(config.se_ratio));
136 }
137
138 let in_ch = config.in_channels;
139 let exp_ch = config.expanded_channels();
140 let out_ch = config.out_channels;
141 let k = config.kernel_size;
142 let se_ch = config.se_channels();
143 let has_skip = config.has_skip();
144
145 let xavier = |fan_in: usize, fan_out: usize, rng: &mut VisionRng| -> Vec<f32> {
146 let limit = (6.0_f32 / (fan_in + fan_out) as f32).sqrt();
147 let n = fan_out * fan_in;
148 (0..n)
149 .map(|_| (rng.next_f32() * 2.0 - 1.0) * limit)
150 .collect()
151 };
152
153 let expand_w = xavier(in_ch, exp_ch, rng);
155 let expand_b = vec![0.0_f32; exp_ch];
156
157 let dw_w = xavier(k * k, exp_ch, rng);
159 let dw_b = vec![0.0_f32; exp_ch];
160
161 let se_fc1_w = xavier(exp_ch, se_ch, rng);
163 let se_fc1_b = vec![0.0_f32; se_ch];
164
165 let se_fc2_w = xavier(se_ch, exp_ch, rng);
167 let se_fc2_b = vec![0.0_f32; exp_ch];
168
169 let proj_w = xavier(exp_ch, out_ch, rng);
171 let proj_b = vec![0.0_f32; out_ch];
172
173 Ok(Self {
174 expand_w,
175 expand_b,
176 dw_w,
177 dw_b,
178 se_fc1_w,
179 se_fc1_b,
180 se_fc2_w,
181 se_fc2_b,
182 proj_w,
183 proj_b,
184 config,
185 has_skip,
186 })
187 }
188
189 #[must_use]
191 pub fn has_skip(&self) -> bool {
192 self.has_skip
193 }
194
195 pub fn forward(&self, x: &[f32], batch_size: usize) -> VisionResult<Vec<f32>> {
218 let in_ch = self.config.in_channels;
219 let exp_ch = self.config.expanded_channels();
220 let out_ch = self.config.out_channels;
221 let se_ch = self.config.se_channels();
222 let k = self.config.kernel_size;
223
224 if x.len() != batch_size * in_ch {
225 return Err(VisionError::DimensionMismatch {
226 expected: batch_size * in_ch,
227 got: x.len(),
228 });
229 }
230
231 let mut out = vec![0.0_f32; batch_size * out_ch];
232
233 for b in 0..batch_size {
234 let x_row = &x[b * in_ch..(b + 1) * in_ch];
235
236 let h_exp: Vec<f32> = (0..exp_ch)
238 .map(|i| {
239 let acc = self.expand_b[i]
240 + x_row
241 .iter()
242 .enumerate()
243 .map(|(j, &xj)| self.expand_w[i * in_ch + j] * xj)
244 .sum::<f32>();
245 relu6(acc)
246 })
247 .collect();
248
249 let h_dw: Vec<f32> = (0..exp_ch)
252 .map(|c| {
253 let w_slice = &self.dw_w[c * k * k..(c + 1) * k * k];
254 let w_mean: f32 = w_slice.iter().sum::<f32>() / (k * k) as f32;
255 relu6(h_exp[c] * w_mean + self.dw_b[c])
256 })
257 .collect();
258
259 let pooled = &h_dw;
262
263 let se_h1: Vec<f32> = (0..se_ch)
265 .map(|i| {
266 let acc = self.se_fc1_b[i]
267 + pooled
268 .iter()
269 .enumerate()
270 .map(|(j, &pj)| self.se_fc1_w[i * exp_ch + j] * pj)
271 .sum::<f32>();
272 acc.max(0.0)
273 })
274 .collect();
275
276 let se_gate: Vec<f32> = (0..exp_ch)
278 .map(|i| {
279 let acc = self.se_fc2_b[i]
280 + se_h1
281 .iter()
282 .enumerate()
283 .map(|(j, &sj)| self.se_fc2_w[i * se_ch + j] * sj)
284 .sum::<f32>();
285 sigmoid(acc)
286 })
287 .collect();
288
289 let h_se: Vec<f32> = h_dw
291 .iter()
292 .zip(se_gate.iter())
293 .map(|(&hd, &sg)| hd * sg)
294 .collect();
295
296 let mut y: Vec<f32> = (0..out_ch)
298 .map(|i| {
299 self.proj_b[i]
300 + h_se
301 .iter()
302 .enumerate()
303 .map(|(j, &hj)| self.proj_w[i * exp_ch + j] * hj)
304 .sum::<f32>()
305 })
307 .collect();
308
309 if self.has_skip {
311 for (yi, &xi) in y.iter_mut().zip(x_row.iter()) {
312 *yi += xi;
313 }
314 }
315
316 out[b * out_ch..(b + 1) * out_ch].copy_from_slice(&y);
317 }
318
319 Ok(out)
320 }
321}
322
323#[cfg(test)]
326mod tests {
327 use super::*;
328 use crate::handle::LcgRng;
329
330 fn rng() -> LcgRng {
331 LcgRng::new(42)
332 }
333
334 fn default_config() -> MbConvConfig {
335 MbConvConfig {
336 in_channels: 16,
337 out_channels: 16,
338 expand_ratio: 6,
339 stride: 1,
340 kernel_size: 3,
341 se_ratio: 0.25,
342 h: 8,
343 w: 8,
344 }
345 }
346
347 fn make_input(batch: usize, channels: usize, seed: u64) -> Vec<f32> {
348 let mut r = LcgRng::new(seed);
349 (0..batch * channels).map(|_| r.next_f32()).collect()
350 }
351
352 #[test]
354 fn output_shape() {
355 let cfg = default_config();
356 let mut r = rng();
357 let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
358 let x = make_input(4, 16, 1);
359 let out = block.forward(&x, 4).expect("forward should succeed");
360 assert_eq!(out.len(), 4 * 16);
361 }
362
363 #[test]
365 fn output_finite() {
366 let cfg = default_config();
367 let mut r = rng();
368 let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
369 let x = make_input(2, 16, 2);
370 let out = block.forward(&x, 2).expect("forward should succeed");
371 for (i, &v) in out.iter().enumerate() {
372 assert!(v.is_finite(), "out[{i}] = {v}");
373 }
374 }
375
376 #[test]
378 fn expand_ratio_1_works() {
379 let cfg = MbConvConfig {
380 in_channels: 8,
381 out_channels: 8,
382 expand_ratio: 1,
383 stride: 1,
384 kernel_size: 3,
385 se_ratio: 0.25,
386 h: 4,
387 w: 4,
388 };
389 let mut r = rng();
390 let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
391 let x = make_input(3, 8, 3);
392 let out = block.forward(&x, 3).expect("forward should succeed");
393 assert_eq!(out.len(), 3 * 8);
394 }
395
396 #[test]
398 fn has_skip_correct_same_channels() {
399 let cfg = default_config(); let mut r = rng();
401 let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
402 assert!(block.has_skip());
403 }
404
405 #[test]
407 fn no_skip_different_channels() {
408 let cfg = MbConvConfig {
409 in_channels: 8,
410 out_channels: 16,
411 expand_ratio: 6,
412 stride: 1,
413 kernel_size: 3,
414 se_ratio: 0.25,
415 h: 4,
416 w: 4,
417 };
418 let mut r = rng();
419 let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
420 assert!(!block.has_skip());
421 }
422
423 #[test]
425 fn relu6_clamps_at_6() {
426 assert!((relu6(10.0) - 6.0).abs() < 1e-7);
427 assert!((relu6(-1.0) - 0.0).abs() < 1e-7);
428 assert!((relu6(3.0) - 3.0).abs() < 1e-7);
429 }
430
431 #[test]
433 fn batch_size_varies() {
434 let cfg = default_config();
435 for &bs in &[1_usize, 2, 8] {
436 let mut r = LcgRng::new(bs as u64);
437 let block = MbConvBlock::new(cfg.clone(), &mut r).expect("value should be present");
438 let x = make_input(bs, 16, bs as u64);
439 let out = block.forward(&x, bs).expect("forward should succeed");
440 assert_eq!(out.len(), bs * 16);
441 }
442 }
443
444 #[test]
446 fn stride_2_config_accepted() {
447 let cfg = MbConvConfig {
448 in_channels: 8,
449 out_channels: 16,
450 expand_ratio: 6,
451 stride: 2,
452 kernel_size: 5,
453 se_ratio: 0.25,
454 h: 8,
455 w: 8,
456 };
457 let mut r = rng();
458 let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
459 assert!(!block.has_skip()); let x = make_input(2, 8, 9);
461 let out = block.forward(&x, 2).expect("forward should succeed");
462 assert_eq!(out.len(), 2 * 16);
463 }
464
465 #[test]
467 fn expand_ratio_0_error() {
468 let cfg = MbConvConfig {
469 in_channels: 8,
470 out_channels: 8,
471 expand_ratio: 0,
472 stride: 1,
473 kernel_size: 3,
474 se_ratio: 0.25,
475 h: 4,
476 w: 4,
477 };
478 let mut r = rng();
479 let result = MbConvBlock::new(cfg, &mut r);
480 assert!(result.is_err());
481 }
482
483 #[test]
485 fn se_ratio_zero_error() {
486 let cfg = MbConvConfig {
487 in_channels: 8,
488 out_channels: 8,
489 expand_ratio: 6,
490 stride: 1,
491 kernel_size: 3,
492 se_ratio: 0.0,
493 h: 4,
494 w: 4,
495 };
496 let mut r = rng();
497 let result = MbConvBlock::new(cfg, &mut r);
498 assert!(result.is_err());
499 }
500
501 #[test]
503 fn se_ratio_affects_se_channels() {
504 let cfg1 = MbConvConfig {
505 se_ratio: 0.25,
506 ..default_config()
507 };
508 let cfg2 = MbConvConfig {
509 se_ratio: 0.5,
510 ..default_config()
511 };
512 assert!(cfg1.se_channels() < cfg2.se_channels());
513 }
514
515 #[test]
517 fn dimension_mismatch_error() {
518 let cfg = default_config();
519 let mut r = rng();
520 let block = MbConvBlock::new(cfg, &mut r).expect("new should succeed");
521 let wrong_x = vec![0.0_f32; 2 * 8]; let result = block.forward(&wrong_x, 2);
523 assert!(result.is_err());
524 }
525}