#[inline]
fn sample_padded(
input: &[f32],
c: usize,
ih: usize,
iw: usize,
in_h: usize,
in_w: usize,
ph: usize,
pw: usize,
) -> f32 {
let in_bounds = ih >= ph && ih < in_h + ph && iw >= pw && iw < in_w + pw;
if in_bounds {
input[c * in_h * in_w + (ih - ph) * in_w + (iw - pw)]
} else {
0.0
}
}
#[inline]
fn fill_col_row(
col: &mut [f32],
input: &[f32],
row: usize,
col_w: usize,
c: usize,
y: usize,
x: usize,
out_h: usize,
out_w: usize,
sh: usize,
sw: usize,
in_h: usize,
in_w: usize,
ph: usize,
pw: usize,
) {
for oh in 0..out_h {
for ow in 0..out_w {
let ih = oh * sh + y;
let iw = ow * sw + x;
col[row * col_w + oh * out_w + ow] =
sample_padded(input, c, ih, iw, in_h, in_w, ph, pw);
}
}
}
#[must_use]
pub(crate) fn im2col_2d(
input: &[f32],
in_c: usize,
in_h: usize,
in_w: usize,
kh: usize,
kw: usize,
sh: usize,
sw: usize,
ph: usize,
pw: usize,
) -> (Vec<f32>, usize, usize) {
let out_h = (in_h + 2 * ph - kh) / sh + 1;
let out_w = (in_w + 2 * pw - kw) / sw + 1;
let col_h = in_c * kh * kw;
let col_w = out_h * out_w;
let mut col = vec![0.0f32; col_h * col_w];
for c in 0..in_c {
for y in 0..kh {
for x in 0..kw {
let row = c * kh * kw + y * kw + x;
fill_col_row(
&mut col, input, row, col_w, c, y, x, out_h, out_w, sh, sw, in_h, in_w, ph, pw,
);
}
}
}
(col, col_h, col_w)
}
#[must_use]
pub(crate) fn im2col_1d(
input: &[f32],
in_c: usize,
in_l: usize,
k: usize,
s: usize,
p: usize,
) -> (Vec<f32>, usize, usize) {
let out_l = (in_l + 2 * p - k) / s + 1;
let col_h = in_c * k;
let col_w = out_l;
let mut col = vec![0.0f32; col_h * col_w];
for c in 0..in_c {
for ki in 0..k {
let row = c * k + ki;
for ol in 0..out_l {
let il = ol * s + ki;
let val = if il < p || il >= in_l + p {
0.0
} else {
input[c * in_l + (il - p)]
};
col[row * col_w + ol] = val;
}
}
}
(col, col_h, col_w)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_im2col_2d_no_padding() {
let input = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let (col, col_h, col_w) = im2col_2d(&input, 1, 3, 3, 2, 2, 1, 1, 0, 0);
assert_eq!(col_h, 1 * 2 * 2); assert_eq!(col_w, 2 * 2);
assert_eq!(
col,
vec![1.0, 2.0, 4.0, 5.0, 2.0, 3.0, 5.0, 6.0, 4.0, 5.0, 7.0, 8.0, 5.0, 6.0, 8.0, 9.0,]
);
}
#[test]
fn test_im2col_2d_with_padding() {
let input = [1.0, 2.0, 3.0, 4.0];
let (_col, col_h, col_w) = im2col_2d(&input, 1, 2, 2, 3, 3, 1, 1, 1, 1);
assert_eq!(col_h, 9); assert_eq!(col_w, 4); }
#[test]
fn test_im2col_2d_stride_2() {
let input: Vec<f32> = (1..=16).map(|x| x as f32).collect();
let (col, col_h, col_w) = im2col_2d(&input, 1, 4, 4, 2, 2, 2, 2, 0, 0);
assert_eq!(col_h, 4); assert_eq!(col_w, 4);
assert_eq!(&col[0..4], &[1.0, 3.0, 9.0, 11.0]);
}
#[test]
fn test_im2col_2d_multi_channel() {
let input = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
let (col, col_h, col_w) = im2col_2d(&input, 2, 2, 2, 2, 2, 1, 1, 0, 0);
assert_eq!(col_h, 8); assert_eq!(col_w, 1);
assert_eq!(col, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
}
#[test]
fn test_im2col_1d_no_padding() {
let input = [1.0, 2.0, 3.0, 4.0, 5.0];
let (col, col_h, col_w) = im2col_1d(&input, 1, 5, 3, 1, 0);
assert_eq!(col_h, 3); assert_eq!(col_w, 3);
assert_eq!(col, vec![1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0,]);
}
#[test]
fn test_im2col_1d_with_padding() {
let input = [1.0, 2.0, 3.0];
let (col, col_h, col_w) = im2col_1d(&input, 1, 3, 3, 1, 1);
assert_eq!(col_h, 3);
assert_eq!(col_w, 3);
assert_eq!(col, vec![0.0, 1.0, 2.0, 1.0, 2.0, 3.0, 2.0, 3.0, 0.0,]);
}
#[test]
fn test_im2col_1d_stride_2() {
let input = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let (col, col_h, col_w) = im2col_1d(&input, 1, 6, 3, 2, 0);
assert_eq!(col_h, 3);
assert_eq!(col_w, 2);
assert_eq!(col, vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]);
}
#[test]
fn test_im2col_1d_multi_channel() {
let input = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
let (col, col_h, col_w) = im2col_1d(&input, 2, 3, 2, 1, 0);
assert_eq!(col_h, 4); assert_eq!(col_w, 2);
assert_eq!(col, vec![1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 5.0, 6.0]);
}
#[test]
fn test_im2col_2d_1x1_kernel() {
let input = [1.0, 2.0, 3.0, 4.0]; let (col, col_h, col_w) = im2col_2d(&input, 1, 2, 2, 1, 1, 1, 1, 0, 0);
assert_eq!(col_h, 1); assert_eq!(col_w, 4); assert_eq!(col, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_im2col_1d_1x_kernel() {
let input = [1.0, 2.0, 3.0];
let (col, col_h, col_w) = im2col_1d(&input, 1, 3, 1, 1, 0);
assert_eq!(col_h, 1);
assert_eq!(col_w, 3);
assert_eq!(col, vec![1.0, 2.0, 3.0]);
}
}