linfa_linalg/
reflection.rs

1use ndarray::{ArrayBase, Data, DataMut, Ix1, Ix2, NdFloat};
2
3/// Reflection with respect to a plane
4pub struct Reflection<A, D: Data<Elem = A>> {
5    axis: ArrayBase<D, Ix1>,
6    bias: A,
7}
8
9impl<A, D: Data<Elem = A>> Reflection<A, D> {
10    /// Create a new reflection with respect to the plane orthogonal to the given axis and bias
11    ///
12    /// `axis` must be a unit vector
13    /// `bias` is the position of the plane on the axis from the origin
14    pub fn new(axis: ArrayBase<D, Ix1>, bias: A) -> Self {
15        Self { axis, bias }
16    }
17
18    pub fn axis(&self) -> &ArrayBase<D, Ix1> {
19        &self.axis
20    }
21}
22
23// XXX Can use matrix multiplication algorithm instead of iterative algorithm for both reflections
24impl<A: NdFloat, D: Data<Elem = A>> Reflection<A, D> {
25    /// Apply reflection to the columns of `rhs`
26    pub fn reflect_cols<M: DataMut<Elem = A>>(&self, rhs: &mut ArrayBase<M, Ix2>) {
27        for i in 0..rhs.ncols() {
28            let m_two = A::from(-2.0f64).unwrap();
29            let factor = (self.axis.dot(&rhs.column(i)) - self.bias) * m_two;
30            rhs.column_mut(i).scaled_add(factor, &self.axis);
31        }
32    }
33
34    /// Apply reflection to the rows of `lhs`
35    pub fn reflect_rows<M: DataMut<Elem = A>>(&self, lhs: &mut ArrayBase<M, Ix2>) {
36        self.reflect_cols(&mut lhs.view_mut().reversed_axes());
37    }
38}
39
40#[cfg(test)]
41mod tests {
42    use approx::assert_abs_diff_eq;
43    use ndarray::array;
44
45    use super::*;
46
47    #[test]
48    fn reflect_plane_col() {
49        let y_axis = array![0., 1., 0.];
50        let refl = Reflection::new(y_axis.view(), 0.0);
51
52        let mut v = array![[1., 2., 3.], [3., 4., 5.]].reversed_axes();
53        refl.reflect_cols(&mut v);
54        assert_abs_diff_eq!(v, array![[1., -2., 3.], [3., -4., 5.]].reversed_axes());
55        refl.reflect_cols(&mut v);
56        assert_abs_diff_eq!(v, array![[1., 2., 3.], [3., 4., 5.]].reversed_axes());
57
58        let refl = Reflection::new(y_axis.view(), 3.0);
59        let mut v = array![[1., 2., 3.], [3., 4., 5.]].reversed_axes();
60        refl.reflect_cols(&mut v);
61        assert_abs_diff_eq!(v, array![[1., 4., 3.], [3., 2., 5.]].reversed_axes());
62    }
63
64    #[test]
65    fn reflect_plane_row() {
66        let y_axis = array![0., 1., 0.];
67        let refl = Reflection::new(y_axis.view(), 0.0);
68
69        let mut v = array![[1., 2., 3.], [3., 4., 5.]];
70        refl.reflect_rows(&mut v);
71        assert_abs_diff_eq!(v, array![[1., -2., 3.], [3., -4., 5.]]);
72        refl.reflect_rows(&mut v);
73        assert_abs_diff_eq!(v, array![[1., 2., 3.], [3., 4., 5.]]);
74
75        let refl = Reflection::new(y_axis.view(), 3.0);
76        let mut v = array![[1., 2., 3.], [3., 4., 5.]];
77        refl.reflect_rows(&mut v);
78        assert_abs_diff_eq!(v, array![[1., 4., 3.], [3., 2., 5.]]);
79    }
80}