lair/
matrix.rs

1//! Matrix functions and special matrices.
2
3use ndarray::Array2;
4
5use crate::{InvalidInput, Scalar};
6
7/// Constructs a circulant matrix.
8///
9/// # Examples
10///
11/// ```
12/// use lair::matrix::circulant;
13///
14/// let a = vec![1., 2., 3.];
15/// let c = circulant(&a);
16/// assert_eq!(c, ndarray::aview2(&[[1., 3., 2.], [2., 1., 3.], [3., 2., 1.]]));
17/// ```
18pub fn circulant<A>(a: &[A]) -> Array2<A>
19where
20    A: Copy,
21{
22    let mut x = Array2::<A>::uninit((a.len(), a.len()));
23    unsafe {
24        for (i, a_elem) in a.iter().enumerate() {
25            for j in 0..a.len() {
26                *(x[[(i + j) % a.len(), j]].as_mut_ptr()) = *a_elem;
27            }
28        }
29        x.assume_init()
30    }
31}
32
33/// Constructs a companion matrix.
34///
35/// # Errors
36///
37/// * [`InvalidInput::Shape`] if `a` contains less than two coefficients.
38/// * [`InvalidInput::Value`] if `a[0]` is zero.
39///
40/// [`InvalidInput::Shape`]: ../enum.InvalidInput.html#variant.Shape
41/// [`InvalidInput::Value`]: ../enum.InvalidInput.html#variant.Value
42///
43/// # Examples
44///
45/// ```
46/// use lair::matrix::companion;
47///
48/// let a = vec![1., -10., 31., -30.];
49/// let c = companion(&a).expect("valid input");
50/// assert_eq!(c, ndarray::aview2(&[[10., -31., 30.], [1., 0., 0.], [0., 1., 0.]]));
51/// ```
52pub fn companion<A>(a: &[A]) -> Result<Array2<A>, InvalidInput>
53where
54    A: Scalar,
55{
56    if a.len() < 2 {
57        return Err(InvalidInput::Shape(format!(
58            "input polynomial has {} coefficient; expected at least two",
59            a.len()
60        )));
61    }
62    if a[0] == A::zero() {
63        return Err(InvalidInput::Value(format!(
64            "invalid first coefficient {}, expected a non-zero value",
65            a[0]
66        )));
67    }
68
69    let mut matrix = Array2::<A>::zeros((a.len() - 1, a.len() - 1));
70    for (mv, av) in matrix.row_mut(0).into_iter().zip(a.iter().skip(1)) {
71        *mv = -*av / a[0];
72    }
73    for i in 1..a.len() - 1 {
74        matrix[[i, i - 1]] = A::one();
75    }
76    Ok(matrix)
77}
78
79#[cfg(test)]
80mod tests {
81    use ndarray::aview2;
82    use num_complex::{Complex32, Complex64};
83
84    #[test]
85    fn circulant() {
86        let a = Vec::<f32>::new();
87        let c = super::circulant(&a);
88        assert_eq!(c.shape(), [0, 0]);
89
90        let a = vec![Complex32::new(1., 2.)];
91        let c = super::circulant(&a);
92        assert_eq!(c, aview2(&[[Complex32::new(1., 2.)]]));
93
94        let a = vec![1., 2., 4.];
95        let c = super::circulant(&a);
96        assert_eq!(c, aview2(&[[1., 4., 2.], [2., 1., 4.], [4., 2., 1.]]));
97    }
98
99    #[test]
100    fn companion() {
101        let a = Vec::<f32>::new();
102        assert!(super::companion(&a).is_err());
103
104        let a = vec![Complex64::new(1., 2.)];
105        assert!(super::companion(&a).is_err());
106
107        let a = vec![
108            Complex32::new(0., 0.),
109            Complex32::new(1., 1.),
110            Complex32::new(2., 2.),
111        ];
112        assert!(super::companion(&a).is_err());
113
114        let a = vec![Complex32::new(1., 2.), Complex32::new(3., 4.)];
115        let c = super::companion(&a).expect("valid input");
116        assert_eq!(c, aview2(&[[Complex32::new(-2.2, 0.4)]]));
117
118        let a = vec![2., -4., 8., -10.];
119        let c = super::companion(&a).expect("valid input");
120        assert_eq!(
121            c,
122            aview2(&[[2.0, -4.0, 5.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
123        );
124    }
125}