Skip to main content

aeon_tk/kernel/
convolution.rs

1use crate::kernel::{Border, Derivative, Kernel, SecondDerivative, Value};
2
3// ************************************
4// Convolution ************************
5// ************************************
6
7/// A N-dimensional tensor product of several seperable kernels.
8pub trait Convolution<const N: usize> {
9    fn border_width(&self, axis: usize) -> usize;
10    fn interior(&self, axis: usize) -> &[f64];
11    fn free(&self, border: Border, axis: usize) -> &[f64];
12    fn scale(&self, spacing: [f64; N]) -> f64;
13}
14
15macro_rules! impl_convolution_for_tuples {
16    ($($N:literal => $($T:ident $i:tt)*)*) => {
17        $(
18            impl<$($T: Kernel,)*> Convolution<$N> for ($($T,)*) {
19                fn border_width(&self, axis: usize) -> usize {
20                    match axis {
21                        $($i => self.$i.border_width(),)*
22                        _ => panic!("Invalid Axis")
23                    }
24                }
25
26                fn interior(&self, axis: usize) -> &[f64] {
27                    match axis {
28                        $($i => self.$i.interior(),)*
29                        _ => panic!("Invalid Axis")
30                    }
31                }
32
33                fn free(&self, border: Border, axis: usize) -> &[f64] {
34                    match axis {
35                        $($i => self.$i.free(border),)*
36                        _ => panic!("Invalid Axis")
37                    }
38                }
39
40                fn scale(&self, spacing: [f64; $N]) -> f64 {
41                    let mut result = 1.0;
42
43                    $(
44                        result *= self.$i.scale(spacing[$i]);
45                    )*
46
47                    result
48                }
49            }
50        )*
51
52    };
53}
54
55impl_convolution_for_tuples! {
56   1 => K0 0
57   2 => K0 0 K1 1
58   3 => K0 0 K1 1 K2 2
59   4 => K0 0 K1 1 K2 2 K3 3
60}
61
62/// Computes the gradient along the given axis.
63#[derive(Clone)]
64pub struct Gradient<const ORDER: usize>(pub usize);
65
66impl<const N: usize, const ORDER: usize> Convolution<N> for Gradient<ORDER> {
67    fn border_width(&self, axis: usize) -> usize {
68        if axis == self.0 {
69            Derivative::<ORDER>.border_width()
70        } else {
71            Value.border_width()
72        }
73    }
74
75    fn interior(&self, axis: usize) -> &[f64] {
76        if axis == self.0 {
77            Derivative::<ORDER>.interior()
78        } else {
79            Value.interior()
80        }
81    }
82
83    fn free(&self, border: Border, axis: usize) -> &[f64] {
84        if axis == self.0 {
85            Derivative::<ORDER>.free(border)
86        } else {
87            Value.free(border)
88        }
89    }
90
91    fn scale(&self, spacing: [f64; N]) -> f64 {
92        1.0 / spacing[self.0]
93    }
94}
95
96/// Computes the mixed derivative of the given axes.
97#[derive(Debug, Clone, Copy)]
98pub struct Hessian<const ORDER: usize>(pub usize, pub usize);
99
100impl<const ORDER: usize> Hessian<ORDER> {
101    /// Constructs a convolution which computes the given entry of the hessian matrix.
102    pub const fn new(i: usize, j: usize) -> Self {
103        Self(i, j)
104    }
105}
106
107impl<const ORDER: usize> Hessian<ORDER> {
108    fn is_second(&self, axis: usize) -> bool {
109        self.0 == self.1 && axis == self.0
110    }
111
112    fn is_first(&self, axis: usize) -> bool {
113        self.0 == axis || self.1 == axis
114    }
115}
116
117impl<const N: usize, const ORDER: usize> Convolution<N> for Hessian<ORDER> {
118    fn border_width(&self, axis: usize) -> usize {
119        if self.is_second(axis) {
120            SecondDerivative::<ORDER>.border_width()
121        } else if self.is_first(axis) {
122            Derivative::<ORDER>.border_width()
123        } else {
124            Value.border_width()
125        }
126    }
127
128    fn interior(&self, axis: usize) -> &[f64] {
129        if self.is_second(axis) {
130            SecondDerivative::<ORDER>.interior()
131        } else if self.is_first(axis) {
132            Derivative::<ORDER>.interior()
133        } else {
134            Value.interior()
135        }
136    }
137
138    fn free(&self, border: Border, axis: usize) -> &[f64] {
139        if self.is_second(axis) {
140            SecondDerivative::<ORDER>.free(border)
141        } else if self.is_first(axis) {
142            Derivative::<ORDER>.free(border)
143        } else {
144            Value.free(border)
145        }
146    }
147    fn scale(&self, spacing: [f64; N]) -> f64 {
148        1.0 / (spacing[self.0] * spacing[self.1])
149    }
150}