use crate::DType;
use numr::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ConvMode {
#[default]
Full,
Same,
Valid,
}
impl ConvMode {
pub fn output_len(&self, signal_len: usize, kernel_len: usize) -> usize {
match self {
ConvMode::Full => signal_len + kernel_len - 1,
ConvMode::Same => signal_len.max(kernel_len),
ConvMode::Valid => {
let min_len = signal_len.min(kernel_len);
let max_len = signal_len.max(kernel_len);
if min_len == 0 {
0
} else {
max_len - min_len + 1
}
}
}
}
pub fn slice_start(&self, signal_len: usize, kernel_len: usize) -> usize {
match self {
ConvMode::Full => 0,
ConvMode::Same => {
let full_len = signal_len + kernel_len - 1;
let out_len = signal_len.max(kernel_len);
(full_len - out_len) / 2
}
ConvMode::Valid => kernel_len - 1,
}
}
pub fn output_shape_2d(
&self,
signal_shape: (usize, usize),
kernel_shape: (usize, usize),
) -> (usize, usize) {
(
self.output_len(signal_shape.0, kernel_shape.0),
self.output_len(signal_shape.1, kernel_shape.1),
)
}
}
pub trait ConvolutionAlgorithms<R: Runtime<DType = DType>> {
fn convolve(&self, signal: &Tensor<R>, kernel: &Tensor<R>, mode: ConvMode)
-> Result<Tensor<R>>;
fn convolve2d(
&self,
signal: &Tensor<R>,
kernel: &Tensor<R>,
mode: ConvMode,
) -> Result<Tensor<R>>;
fn correlate(
&self,
signal: &Tensor<R>,
kernel: &Tensor<R>,
mode: ConvMode,
) -> Result<Tensor<R>>;
fn correlate2d(
&self,
signal: &Tensor<R>,
kernel: &Tensor<R>,
mode: ConvMode,
) -> Result<Tensor<R>>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_conv_mode_output_len() {
assert_eq!(ConvMode::Full.output_len(10, 3), 12);
assert_eq!(ConvMode::Same.output_len(10, 3), 10);
assert_eq!(ConvMode::Valid.output_len(10, 3), 8);
assert_eq!(ConvMode::Same.output_len(3, 10), 10);
}
#[test]
fn test_conv_mode_slice_start() {
assert_eq!(ConvMode::Full.slice_start(10, 3), 0);
assert_eq!(ConvMode::Same.slice_start(10, 3), 1);
assert_eq!(ConvMode::Valid.slice_start(10, 3), 2);
}
}