ndarray_conv/lib.rs
1//! `ndarray-conv` provides N-dimensional convolution operations for `ndarray` arrays.
2//!
3//! This crate extends the `ndarray` library with both standard and
4//! FFT-accelerated convolution methods.
5//!
6//! # Getting Started
7//!
8//! To start performing convolutions, you'll interact with the following:
9//!
10//! 1. **Input Arrays:** Use `ndarray`'s [`Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)
11//! or [`ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html)
12//! as your input data and convolution kernel.
13//! 2. **Convolution Methods:** Call `array.conv(...)` or `array.conv_fft(...)`.
14//! These methods are added to `ArrayBase` types via the traits
15//! [`ConvExt::conv`] and [`ConvFFTExt::conv_fft`].
16//! 3. **Convolution Mode:** [`ConvMode`] specifies the size of the output.
17//! 4. **Padding Mode:** [`PaddingMode`] specifies how to handle array boundaries.
18//!
19//! # Basic Example:
20//!
21//! Here's a simple example of how to perform a 2D convolution using `ndarray-conv`:
22//!
23//! ```rust
24//! use ndarray::prelude::*;
25//! use ndarray_conv::{ConvExt, ConvFFTExt, ConvMode, PaddingMode};
26//!
27//! // Input data
28//! let input = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
29//!
30//! // Convolution kernel
31//! let kernel = array![[1, 1], [1, 1]];
32//!
33//! // Perform standard convolution with "same" output size and zero padding
34//! let output = input.conv(
35//! &kernel,
36//! ConvMode::Same,
37//! PaddingMode::Zeros,
38//! ).unwrap();
39//!
40//! println!("Standard Convolution Output:\n{:?}", output);
41//!
42//! // Perform FFT-accelerated convolution with "same" output size and zero padding
43//! let output_fft = input.map(|&x| x as f32).conv_fft(
44//! &kernel.map(|&x| x as f32),
45//! ConvMode::Same,
46//! PaddingMode::Zeros,
47//! ).unwrap();
48//!
49//! println!("FFT Convolution Output:\n{:?}", output_fft);
50//! ```
51//!
52//! # Choosing a convolution method
53//!
54//! * Use [`ConvExt::conv`] for standard convolution
55//! * Use [`ConvFFTExt::conv_fft`] for FFT accelerated convolution.
56//! FFT accelerated convolution is generally faster for larger kernels, but
57//! standard convolution may be faster for smaller kernels.
58//!
59//! # Key Structs, Enums and Traits
60//!
61//! * [`ConvMode`]: Specifies how to determine the size of the convolution output (e.g., `Full`, `Same`, `Valid`).
62//! * [`PaddingMode`]: Specifies how to handle array boundaries (e.g., `Zeros`, `Reflect`, `Replicate`). You can also use `PaddingMode::Custom` or `PaddingMode::Explicit` to combine different [`BorderType`] strategies for each dimension or for each side of each dimension.
63//! * [`BorderType`]: Used with [`PaddingMode`] for `Custom` and `Explicit`, specifies the padding strategy (e.g., `Zeros`, `Reflect`, `Replicate`, `Circular`).
64//! * [`ConvExt`]: The trait that adds the `conv` method, extending `ndarray` arrays with standard convolution functionality.
65//! * [`ConvFFTExt`]: The trait that adds the `conv_fft` method, extending `ndarray` arrays with FFT-accelerated convolution functionality.
66
67mod conv;
68mod conv_fft;
69mod dilation;
70mod padding;
71
72pub(crate) use padding::ExplicitPadding;
73
74pub use conv::ConvExt;
75pub use conv_fft::{
76 get_processor as get_fft_processor, ConvFFTExt, GetProcessor, Processor as FftProcessor,
77};
78pub use dilation::{ReverseKernel, WithDilation};
79
80/// Specifies the convolution mode, which determines the output size.
81#[derive(Debug, Clone, Copy)]
82pub enum ConvMode<const N: usize> {
83 /// The output has the largest size, including all positions where
84 /// the kernel and input overlap at least partially.
85 Full,
86 /// The output has the same size as the input.
87 Same,
88 /// The output has the smallest size, including only positions
89 /// where the kernel and input fully overlap.
90 Valid,
91 /// Specifies custom padding and strides.
92 Custom {
93 /// The padding to use for each dimension.
94 padding: [usize; N],
95 /// The strides to use for each dimension.
96 strides: [usize; N],
97 },
98 /// Specifies explicit padding and strides.
99 Explicit {
100 /// The padding to use for each side of each dimension.
101 padding: [[usize; 2]; N],
102 /// The strides to use for each dimension.
103 strides: [usize; N],
104 },
105}
106/// Specifies the padding mode, which determines how to handle borders.
107///
108/// The padding mode can be either a single `BorderType` applied on all sides
109/// or a custom tuple of two `BorderTypes` for each dimension or a `BorderType`
110/// for each side of each dimension.
111#[derive(Debug, Clone, Copy)]
112pub enum PaddingMode<const N: usize, T: num::traits::NumAssign + Copy> {
113 /// Pads with zeros.
114 Zeros,
115 /// Pads with a constant value.
116 Const(T),
117 /// Reflects the input at the borders.
118 Reflect,
119 /// Replicates the edge values.
120 Replicate,
121 /// Treats the input as a circular buffer.
122 Circular,
123 /// Specifies a different `BorderType` for each dimension.
124 Custom([BorderType<T>; N]),
125 /// Specifies a different `BorderType` for each side of each dimension.
126 Explicit([[BorderType<T>; 2]; N]),
127}
128
129/// Used with [`PaddingMode`]. Specifies the padding mode for a single dimension
130/// or a single side of a dimension.
131#[derive(Debug, Clone, Copy)]
132pub enum BorderType<T: num::traits::NumAssign + Copy> {
133 /// Pads with zeros.
134 Zeros,
135 /// Pads with a constant value.
136 Const(T),
137 /// Reflects the input at the borders.
138 Reflect,
139 /// Replicates the edge values.
140 Replicate,
141 /// Treats the input as a circular buffer.
142 Circular,
143}
144
145use thiserror::Error;
146
147/// Error type for convolution operations.
148#[derive(Error, Debug)]
149pub enum Error<const N: usize> {
150 /// Indicates that the input data array has a dimension with zero size.
151 #[error("Data shape shouldn't have ZERO. {0:?}")]
152 DataShape(ndarray::Dim<[ndarray::Ix; N]>),
153 /// Indicates that the kernel array has a dimension with zero size.
154 #[error("Kernel shape shouldn't have ZERO. {0:?}")]
155 KernelShape(ndarray::Dim<[ndarray::Ix; N]>),
156 /// Indicates that the shape of the kernel with dilation is not compatible with the chosen `ConvMode`.
157 #[error("ConvMode {0:?} does not match KernelWithDilation Size {1:?}")]
158 MismatchShape(ConvMode<N>, [ndarray::Ix; N]),
159}