#[derive(Clone, Copy, Debug)]
pub struct Conv3DShapeParams {
pub kernel: (usize, usize, usize),
pub pad: (usize, usize, usize),
pub stride: (usize, usize, usize),
}
pub fn conv3d_output_shape(
d: usize,
h: usize,
w: usize,
params: Conv3DShapeParams,
) -> Option<(usize, usize, usize)> {
let (kd, kh, kw) = params.kernel;
let args = native_neural_network::conv3d::Conv3dOutputShapeArgs {
d,
h,
w,
kd,
kh,
kw,
pad: params.pad,
stride: params.stride,
};
native_neural_network::conv3d::conv3d_output_shape(args)
}
pub fn conv3d_forward(
input: &mut crate::std::tensor_std::TensorStd,
kernel: &mut crate::std::tensor_std::TensorStd,
bias: Option<&[f32]>,
output: &mut crate::std::tensor_std::TensorStd,
pad: (usize, usize, usize),
stride: (usize, usize, usize),
) -> Result<(), &'static str> {
if input.shape.len() != 5 || kernel.shape.len() != 5 || output.shape.len() != 5 {
return Err("tensor rank must be 5");
}
let mut in_shape = [0usize; 5];
for (i, &v) in input.shape.iter().enumerate().take(5) {
in_shape[i] = v;
}
let mut k_shape = [0usize; 5];
for (i, &v) in kernel.shape.iter().enumerate().take(5) {
k_shape[i] = v;
}
let mut out_shape = [0usize; 5];
for (i, &v) in output.shape.iter().enumerate().take(5) {
out_shape[i] = v;
}
let in_view = native_neural_network::tensor::TensorView {
data: input.as_mut_slice(),
shape: in_shape,
};
let k_view = native_neural_network::tensor::TensorView {
data: kernel.as_mut_slice(),
shape: k_shape,
};
let mut out_view = native_neural_network::tensor::TensorView {
data: output.as_mut_slice(),
shape: out_shape,
};
native_neural_network::conv3d::conv3d_forward(
&in_view,
&k_view,
bias,
&mut out_view,
pad,
stride,
);
Ok(())
}
pub fn conv3d_is_compatible(
input: &crate::std::tensor_std::TensorStd,
kernel: &crate::std::tensor_std::TensorStd,
output: &crate::std::tensor_std::TensorStd,
) -> bool {
if input.shape.len() != 5 || kernel.shape.len() != 5 || output.shape.len() != 5 {
return false;
}
if input.shape[1] != kernel.shape[1] {
return false;
}
if output.shape[0] != input.shape[0] {
return false;
}
if output.shape[1] != kernel.shape[0] {
return false;
}
let prod =
|s: &[usize]| -> usize { s.iter().copied().fold(1usize, |a, b| a.saturating_mul(b)) };
if prod(&input.shape) != input.data.len() {
return false;
}
if prod(&kernel.shape) != kernel.data.len() {
return false;
}
if prod(&output.shape) != output.data.len() {
return false;
}
true
}
pub fn conv3d_layout_compatible(
input: &crate::std::tensor_std::TensorStd,
kernel: &crate::std::tensor_std::TensorStd,
output: &crate::std::tensor_std::TensorStd,
) -> bool {
conv3d_is_compatible(input, kernel, output)
}