use alloc::vec::Vec;
pub fn maybe_conv1d_output_size(
input_size: usize,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> Option<usize> {
assert!(input_size > 0);
assert!(kernel_size > 0);
assert!(stride > 0);
assert!(dilation > 0);
let effective_size = input_size + 2 * padding;
let pos = effective_size + stride;
let kernel_width = 1 + dilation * (kernel_size - 1);
if pos < kernel_width {
return None;
}
let x = (pos - kernel_width) / stride;
if x < 1 { None } else { Some(x) }
}
pub fn expect_conv1d_output_size(
input_size: usize,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> usize {
match maybe_conv1d_output_size(input_size, kernel_size, stride, padding, dilation) {
Some(x) => x,
None => panic!(
"No legal output size for conv1d with:\n input_size:{input_size}\n kernel_size:{kernel_size}\n stride:{stride}\n dilation:{dilation}\n padding:{padding}",
),
}
}
pub fn maybe_conv_output_shape_dyn(
input_shape: &[usize],
kernel_shape: &[usize],
stride: &[usize],
padding: &[usize],
dilation: &[usize],
) -> Option<Vec<usize>> {
let rank = input_shape.len();
assert_eq!(kernel_shape.len(), rank);
assert_eq!(stride.len(), rank);
assert_eq!(dilation.len(), rank);
assert_eq!(padding.len(), rank);
let mut output_shape = Vec::with_capacity(rank);
for i in 0..rank {
output_shape.push(maybe_conv1d_output_size(
input_shape[i],
kernel_shape[i],
stride[i],
padding[i],
dilation[i],
)?);
}
Some(output_shape)
}
pub fn expect_conv_output_shape_dyn(
input_shape: &[usize],
kernel_shape: &[usize],
stride: &[usize],
padding: &[usize],
dilation: &[usize],
) -> Vec<usize> {
match maybe_conv_output_shape_dyn(input_shape, kernel_shape, stride, padding, dilation) {
Some(shape) => shape,
None => panic!(
"No legal output size for conv with:\n input_shape:{input_shape:?}\n kernel_shape:{kernel_shape:?}\n stride:{stride:?}\n dilation:{dilation:?}\n padding:{padding:?}",
),
}
}
pub fn maybe_conv_output_shape<const D: usize>(
input_shape: [usize; D],
kernel_shape: [usize; D],
stride: [usize; D],
padding: [usize; D],
dilation: [usize; D],
) -> Option<[usize; D]> {
let mut output_shape = input_shape;
for i in 0..D {
output_shape[i] = maybe_conv1d_output_size(
input_shape[i],
kernel_shape[i],
stride[i],
padding[i],
dilation[i],
)?;
}
Some(output_shape)
}
pub fn expect_conv_output_shape<const D: usize>(
input_shape: [usize; D],
kernel_shape: [usize; D],
stride: [usize; D],
padding: [usize; D],
dilation: [usize; D],
) -> [usize; D] {
match maybe_conv_output_shape(input_shape, kernel_shape, stride, padding, dilation) {
Some(shape) => shape,
None => panic!(
"No legal output size for conv with:\n input_shape:{input_shape:?}\n kernel_shape:{kernel_shape:?}\n stride:{stride:?}\n dilation:{dilation:?}\n padding:{padding:?}",
),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_conv1d_output_shape() {
pub fn conv1d_output_size_reference(
input_shape: usize,
kernel_shape: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> Option<usize> {
let input_shape = input_shape as f64;
let kernel_shape = kernel_shape as f64;
let stride = stride as f64;
let dilation = dilation as f64;
let padding = padding as f64;
let effective_shape = input_shape + 2.0 * padding;
let kernel_width = 1.0 + dilation * (kernel_shape - 1.0);
let x = (((effective_shape - kernel_width) / stride) + 1.0).floor();
if x < 1.0 { None } else { Some(x as usize) }
}
for input_shape in 1..10 {
for stride in 1..3 {
for kernel_shape in 1..4 {
for dilation in 1..2 {
for padding in 0..10 {
assert_eq!(
maybe_conv1d_output_size(
input_shape,
kernel_shape,
stride,
padding,
dilation,
),
conv1d_output_size_reference(
input_shape,
kernel_shape,
stride,
padding,
dilation,
)
)
}
}
}
}
}
}
}