use crate::tensor::TensorView;
use crate::scratch::Scratch;
pub struct Conv5dBackwardArgs<'t> {
pub input: &'t TensorView<'t>,
pub kernel: &'t TensorView<'t>,
pub doutput: &'t TensorView<'t>,
pub dinput: &'t mut TensorView<'t>,
pub dkernel: &'t mut TensorView<'t>,
pub dbias: Option<&'t mut [f32]>,
pub pad: (usize, usize, usize),
pub stride: (usize, usize, usize),
pub scratch: &'t mut Scratch<'t>,
}
pub fn conv5d_forward(
input: &TensorView<'_>,
kernel: &TensorView<'_>,
bias: Option<&[f32]>,
output: &mut TensorView<'_>,
pad: (usize, usize, usize),
stride: (usize, usize, usize),
) {
if stride.0 == 0 || stride.1 == 0 || stride.2 == 0 { return; }
if !input.is_valid_layout() || !kernel.is_valid_layout() || !output.is_valid_layout() { return; }
let n = input.shape[0];
let c = input.shape[1];
let d = input.shape[2];
let h = input.shape[3];
let w = input.shape[4];
let cout = kernel.shape[0];
if kernel.shape[1] != c { return; }
if output.shape[0] != n || output.shape[1] != cout { return; }
if let Some(b) = bias {
if b.len() < cout { return; }
}
let pad_d = pad.0 as isize;
let pad_h = pad.1 as isize;
let pad_w = pad.2 as isize;
let d_isize = d as isize;
let h_isize = h as isize;
let w_isize = w as isize;
for n in 0..n {
for co in 0..cout {
for od in 0..output.shape[2] {
for oh in 0..output.shape[3] {
for ow in 0..output.shape[4] {
let mut acc = 0f32;
for ci in 0..c {
for kd in 0..kernel.shape[2] {
for kh in 0..kernel.shape[3] {
for kw in 0..kernel.shape[4] {
let id = (od * stride.0 + kd) as isize - pad_d;
let ih = (oh * stride.1 + kh) as isize - pad_h;
let iw = (ow * stride.2 + kw) as isize - pad_w;
if id >= 0 && ih >= 0 && iw >= 0 && id < d_isize && ih < h_isize && iw < w_isize {
let idu = id as usize;
let ihu = ih as usize;
let iwu = iw as usize;
let i_idx = match input.idx_linear(n, ci, idu, ihu, iwu) {
Some(v) => v,
None => continue,
};
let k_idx = match kernel.idx_linear(co, ci, kd, kh, kw) {
Some(v) => v,
None => continue,
};
acc += input.data[i_idx] * kernel.data[k_idx];
}
}
}
}
}
if let Some(out_i) = output.idx_linear(n, co, od, oh, ow) {
output.data[out_i] = acc + bias.map(|b| b[co]).unwrap_or(0.0);
}
}
}
}
}
}
}
pub fn conv5d_backward<'a>(args: Conv5dBackwardArgs<'a>) {
let Conv5dBackwardArgs { input, kernel, doutput, dinput, dkernel, mut dbias, pad, stride, scratch } = args;
if stride.0 == 0 || stride.1 == 0 || stride.2 == 0 { return; }
if !input.is_valid_layout() || !kernel.is_valid_layout() || !doutput.is_valid_layout() || !dinput.is_valid_layout() || !dkernel.is_valid_layout() {
return;
}
for v in dinput.data.iter_mut() { *v = 0.0; }
for v in dkernel.data.iter_mut() { *v = 0.0; }
if let Some(b) = dbias.as_mut() {
for x in b.iter_mut() { *x = 0.0; }
}
let n = input.shape[0];
let c = input.shape[1];
let d = input.shape[2];
let h = input.shape[3];
let w = input.shape[4];
let cout = kernel.shape[0];
if kernel.shape[1] != c { return; }
if doutput.shape[0] != n || doutput.shape[1] != cout { return; }
if dinput.shape != input.shape { return; }
if dkernel.shape != kernel.shape { return; }
if let Some(b) = dbias.as_ref() {
if b.len() < cout { return; }
}
let pad_d = pad.0 as isize;
let pad_h = pad.1 as isize;
let pad_w = pad.2 as isize;
let d_isize = d as isize;
let h_isize = h as isize;
let w_isize = w as isize;
if let Some(b) = dbias.as_mut() {
for co in 0..cout {
let mut sum = 0f32;
for n in 0..n {
for od in 0..doutput.shape[2] {
for oh in 0..doutput.shape[3] {
for ow in 0..doutput.shape[4] {
let idx = match doutput.idx_linear(n, co, od, oh, ow) {
Some(v) => v,
None => continue,
};
sum += doutput.data[idx];
}
}
}
}
b[co] = sum;
}
}
scratch.base_ptr();
for co in 0..cout {
for ci in 0..c {
for kd in 0..kernel.shape[2] {
for kh in 0..kernel.shape[3] {
for kw in 0..kernel.shape[4] {
let mut acc = 0f32;
for n in 0..n {
for od in 0..doutput.shape[2] {
for oh in 0..doutput.shape[3] {
for ow in 0..doutput.shape[4] {
let id = (od * stride.0 + kd) as isize - pad_d;
let ih = (oh * stride.1 + kh) as isize - pad_h;
let iw = (ow * stride.2 + kw) as isize - pad_w;
if id >= 0 && ih >= 0 && iw >= 0 && id < d_isize && ih < h_isize && iw < w_isize {
let idu = id as usize;
let ihu = ih as usize;
let iwu = iw as usize;
let in_idx = match input.idx_linear(n, ci, idu, ihu, iwu) {
Some(v) => v,
None => continue,
};
let dout_idx = match doutput.idx_linear(n, co, od, oh, ow) {
Some(v) => v,
None => continue,
};
acc += input.data[in_idx] * doutput.data[dout_idx];
}
}
}
}
}
if let Some(k_idx) = dkernel.idx_linear(co, ci, kd, kh, kw) {
dkernel.data[k_idx] = acc;
}
}
}
}
}
}
for n in 0..n {
for ci in 0..c {
for id in 0..d {
for ih in 0..h {
for iw in 0..w {
let mut acc = 0f32;
for co in 0..cout {
for kd in 0..kernel.shape[2] {
for kh in 0..kernel.shape[3] {
for kw in 0..kernel.shape[4] {
let od_num = (id + pad.0) as isize - kd as isize;
let oh_num = (ih + pad.1) as isize - kh as isize;
let ow_num = (iw + pad.2) as isize - kw as isize;
if od_num < 0 || oh_num < 0 || ow_num < 0 { continue; }
let od = od_num as usize;
let oh = oh_num as usize;
let ow = ow_num as usize;
if !od.is_multiple_of(stride.0) || !oh.is_multiple_of(stride.1) || !ow.is_multiple_of(stride.2) { continue; }
let od = od / stride.0;
let oh = oh / stride.1;
let ow = ow / stride.2;
if od < doutput.shape[2] && oh < doutput.shape[3] && ow < doutput.shape[4] {
let dout_idx = match doutput.idx_linear(n, co, od, oh, ow) {
Some(v) => v,
None => continue,
};
let k_idx = match kernel.idx_linear(co, ci, kd, kh, kw) {
Some(v) => v,
None => continue,
};
acc += doutput.data[dout_idx] * kernel.data[k_idx];
}
}
}
}
}
if let Some(i_idx) = dinput.idx_linear(n, ci, id, ih, iw) {
dinput.data[i_idx] = acc;
}
}
}
}
}
}
}