#[derive(Clone, Copy, Debug)]
pub struct Conv5DShapeParams {
pub kernel: (usize, usize, usize),
pub pad: (usize, usize, usize),
pub stride: (usize, usize, usize),
}
pub fn conv5d_output_shape(
d: usize,
h: usize,
w: usize,
params: Conv5DShapeParams,
) -> Option<(usize, usize, usize)> {
let (kd, kh, kw) = params.kernel;
let args = native_neural_network::conv5d::Conv5dOutputShapeArgs {
d,
h,
w,
kd,
kh,
kw,
pad: params.pad,
stride: params.stride,
};
native_neural_network::conv5d::conv5d_output_shape(args)
}
pub fn conv5d_forward(
input: &mut crate::std::tensor_std::TensorStd,
kernel: &mut crate::std::tensor_std::TensorStd,
bias: Option<&[f32]>,
output: &mut crate::std::tensor_std::TensorStd,
pad: (usize, usize, usize),
stride: (usize, usize, usize),
) -> Result<(), &'static str> {
if input.shape.len() != 5 || kernel.shape.len() != 5 || output.shape.len() != 5 {
return Err("tensor rank must be 5");
}
let mut in_shape = [0usize; 5];
for (i, &v) in input.shape.iter().enumerate().take(5) {
in_shape[i] = v;
}
let mut k_shape = [0usize; 5];
for (i, &v) in kernel.shape.iter().enumerate().take(5) {
k_shape[i] = v;
}
let mut out_shape = [0usize; 5];
for (i, &v) in output.shape.iter().enumerate().take(5) {
out_shape[i] = v;
}
let in_view = native_neural_network::tensor::TensorView {
data: input.as_mut_slice(),
shape: in_shape,
};
let k_view = native_neural_network::tensor::TensorView {
data: kernel.as_mut_slice(),
shape: k_shape,
};
let mut out_view = native_neural_network::tensor::TensorView {
data: output.as_mut_slice(),
shape: out_shape,
};
native_neural_network::conv5d::conv5d_forward(
&in_view,
&k_view,
bias,
&mut out_view,
pad,
stride,
);
Ok(())
}
#[derive(Debug)]
pub struct Conv5DBackwardParams<'a> {
pub dbias: Option<&'a mut [f32]>,
pub pad: (usize, usize, usize),
pub stride: (usize, usize, usize),
pub scratch_bytes: &'a mut [u8],
}
pub fn conv5d_backward(
input: &mut crate::std::tensor_std::TensorStd,
kernel: &mut crate::std::tensor_std::TensorStd,
doutput: &crate::std::tensor_std::TensorStd,
dinput: &mut crate::std::tensor_std::TensorStd,
dkernel: &mut crate::std::tensor_std::TensorStd,
params: Conv5DBackwardParams,
) -> Result<(), &'static str> {
if input.shape.len() != 5
|| kernel.shape.len() != 5
|| doutput.shape.len() != 5
|| dinput.shape.len() != 5
|| dkernel.shape.len() != 5
{
return Err("tensor rank must be 5");
}
let mut in_shape = [0usize; 5];
for (i, &v) in input.shape.iter().enumerate().take(5) {
in_shape[i] = v;
}
let mut k_shape = [0usize; 5];
for (i, &v) in kernel.shape.iter().enumerate().take(5) {
k_shape[i] = v;
}
let mut dout_shape = [0usize; 5];
for (i, &v) in doutput.shape.iter().enumerate().take(5) {
dout_shape[i] = v;
}
let mut din_shape = [0usize; 5];
for (i, &v) in dinput.shape.iter().enumerate().take(5) {
din_shape[i] = v;
}
let mut dk_shape = [0usize; 5];
for (i, &v) in dkernel.shape.iter().enumerate().take(5) {
dk_shape[i] = v;
}
let in_view = native_neural_network::tensor::TensorView {
data: input.as_mut_slice(),
shape: in_shape,
};
let k_view = native_neural_network::tensor::TensorView {
data: kernel.as_mut_slice(),
shape: k_shape,
};
let mut dout_vec = doutput.as_slice().to_vec();
let dout_view = native_neural_network::tensor::TensorView {
data: &mut dout_vec[..],
shape: dout_shape,
};
let mut din_view = native_neural_network::tensor::TensorView {
data: dinput.as_mut_slice(),
shape: din_shape,
};
let mut dk_view = native_neural_network::tensor::TensorView {
data: dkernel.as_mut_slice(),
shape: dk_shape,
};
let mut scratch = native_neural_network::scratch::Scratch::new(params.scratch_bytes);
let args = native_neural_network::conv5d::Conv5dBackwardArgs {
input: &in_view,
kernel: &k_view,
doutput: &dout_view,
dinput: &mut din_view,
dkernel: &mut dk_view,
dbias: params.dbias,
pad: params.pad,
stride: params.stride,
scratch: &mut scratch,
};
native_neural_network::conv5d::conv5d_backward(args);
Ok(())
}