pub fn conv3d_output_dim(input: usize, kernel: usize, pad: usize, stride: usize) -> Option<usize> {
if stride == 0 || kernel == 0 {
return None;
}
let padded = input.checked_add(pad.checked_mul(2)?)?;
if padded < kernel {
return Some(0);
}
Some((padded - kernel) / stride + 1)
}
pub struct Conv3dOutputShapeArgs {
pub d: usize,
pub h: usize,
pub w: usize,
pub kd: usize,
pub kh: usize,
pub kw: usize,
pub pad: (usize, usize, usize),
pub stride: (usize, usize, usize),
}
pub fn conv3d_output_shape(args: Conv3dOutputShapeArgs) -> Option<(usize, usize, usize)> {
Some((
conv3d_output_dim(args.d, args.kd, args.pad.0, args.stride.0)?,
conv3d_output_dim(args.h, args.kh, args.pad.1, args.stride.1)?,
conv3d_output_dim(args.w, args.kw, args.pad.2, args.stride.2)?,
))
}