use rten_gemm::{ColOffsets, Im2Col, RowOffsets};
use rten_tensor::NdTensorView;
use rten_tensor::prelude::*;
use crate::ops::Padding;
use crate::ops::pooling::{RoundMode, calc_output_size_and_padding};
pub fn build_im2col<T>(
image: NdTensorView<T, 3>,
kernel: [usize; 2],
padding: [usize; 4],
strides: [usize; 2],
dilations: [usize; 2],
col_count_step: usize,
row_count_step: usize,
) -> Im2Col<T> {
assert!(image.len() > 0);
let [chans, h, w] = image.shape();
let [k_h, k_w] = kernel;
let [stride_h, stride_w] = strides;
let [dilation_y, dilation_x] = dilations;
let [pad_top, pad_left, _pad_bottom, _pad_right] = padding;
let (y_patches, x_patches, _) = calc_output_size_and_padding(
(h, w),
(k_h, k_w),
(stride_h, stride_w),
Padding::Fixed(padding.into()),
Some((dilation_y, dilation_x)),
RoundMode::default(),
)
.expect("invalid im2col params");
let [im_stride_c, im_stride_h, im_stride_w]: [i32; 3] =
image.strides().map(|s| s.try_into().unwrap());
let n_rows = chans * k_h * k_w;
let n_rows_padded = n_rows.next_multiple_of(row_count_step);
let mut row_chan_offsets = Vec::<i32>::with_capacity(n_rows_padded);
let mut row_y_offsets = Vec::<i32>::with_capacity(n_rows_padded);
let mut row_x_offsets = Vec::<i32>::with_capacity(n_rows_padded);
for chan in 0..chans {
row_chan_offsets.extend(std::iter::repeat(chan as i32 * im_stride_c).take(k_h * k_w));
for k_y in 0..k_h {
row_y_offsets
.extend(std::iter::repeat(im_stride_h * k_y as i32 * dilation_y as i32).take(k_w));
row_x_offsets.extend(
(0..k_w as i32)
.map(|k_x| im_stride_w * k_x * dilation_x as i32)
.take(k_w),
);
}
}
let max_y_offset: i32 = ((image.size(1) - 1) * image.stride(1))
.try_into()
.expect("invalid im2col params");
let max_x_offset: i32 = ((image.size(2) - 1) * image.stride(2))
.try_into()
.expect("invalid im2col params");
for _ in n_rows..n_rows_padded {
row_chan_offsets.push(0);
row_x_offsets.push(max_x_offset + 1);
row_y_offsets.push(max_y_offset + 1);
}
let n_cols = x_patches * y_patches;
let n_cols_padded = n_cols.next_multiple_of(col_count_step);
let mut col_y_offsets = Vec::with_capacity(n_cols_padded);
let mut col_x_offsets = Vec::with_capacity(n_cols_padded);
for patch_y in 0..y_patches {
let img_y = (patch_y as i32 * stride_h as i32) - pad_top as i32;
col_y_offsets.extend(std::iter::repeat(img_y * im_stride_h).take(x_patches));
col_x_offsets.extend((0..x_patches).map(|patch_x| {
let img_x = (patch_x as i32 * stride_w as i32) - pad_left as i32;
img_x * im_stride_w
}));
}
for col in n_cols..n_cols_padded {
let patch_y = col as i32 / x_patches as i32;
let patch_x = col as i32 % x_patches as i32;
let img_x = (patch_x * stride_w as i32) - pad_left as i32;
let img_y = (patch_y * stride_h as i32) - pad_top as i32;
col_y_offsets.push(img_y * im_stride_h);
col_x_offsets.push(img_x * im_stride_w);
}
Im2Col {
image,
n_rows,
n_cols,
row_offsets: RowOffsets {
chan: row_chan_offsets,
y: row_y_offsets,
x: row_x_offsets,
},
col_offsets: ColOffsets {
y: col_y_offsets,
x: col_x_offsets,
},
max_y_offset,
max_x_offset,
}
}