1use ndarray::Array2;
4
5use crate::{InvalidInput, Scalar};
6
7pub fn circulant<A>(a: &[A]) -> Array2<A>
19where
20 A: Copy,
21{
22 let mut x = Array2::<A>::uninit((a.len(), a.len()));
23 unsafe {
24 for (i, a_elem) in a.iter().enumerate() {
25 for j in 0..a.len() {
26 *(x[[(i + j) % a.len(), j]].as_mut_ptr()) = *a_elem;
27 }
28 }
29 x.assume_init()
30 }
31}
32
33pub fn companion<A>(a: &[A]) -> Result<Array2<A>, InvalidInput>
53where
54 A: Scalar,
55{
56 if a.len() < 2 {
57 return Err(InvalidInput::Shape(format!(
58 "input polynomial has {} coefficient; expected at least two",
59 a.len()
60 )));
61 }
62 if a[0] == A::zero() {
63 return Err(InvalidInput::Value(format!(
64 "invalid first coefficient {}, expected a non-zero value",
65 a[0]
66 )));
67 }
68
69 let mut matrix = Array2::<A>::zeros((a.len() - 1, a.len() - 1));
70 for (mv, av) in matrix.row_mut(0).into_iter().zip(a.iter().skip(1)) {
71 *mv = -*av / a[0];
72 }
73 for i in 1..a.len() - 1 {
74 matrix[[i, i - 1]] = A::one();
75 }
76 Ok(matrix)
77}
78
79#[cfg(test)]
80mod tests {
81 use ndarray::aview2;
82 use num_complex::{Complex32, Complex64};
83
84 #[test]
85 fn circulant() {
86 let a = Vec::<f32>::new();
87 let c = super::circulant(&a);
88 assert_eq!(c.shape(), [0, 0]);
89
90 let a = vec![Complex32::new(1., 2.)];
91 let c = super::circulant(&a);
92 assert_eq!(c, aview2(&[[Complex32::new(1., 2.)]]));
93
94 let a = vec![1., 2., 4.];
95 let c = super::circulant(&a);
96 assert_eq!(c, aview2(&[[1., 4., 2.], [2., 1., 4.], [4., 2., 1.]]));
97 }
98
99 #[test]
100 fn companion() {
101 let a = Vec::<f32>::new();
102 assert!(super::companion(&a).is_err());
103
104 let a = vec![Complex64::new(1., 2.)];
105 assert!(super::companion(&a).is_err());
106
107 let a = vec![
108 Complex32::new(0., 0.),
109 Complex32::new(1., 1.),
110 Complex32::new(2., 2.),
111 ];
112 assert!(super::companion(&a).is_err());
113
114 let a = vec![Complex32::new(1., 2.), Complex32::new(3., 4.)];
115 let c = super::companion(&a).expect("valid input");
116 assert_eq!(c, aview2(&[[Complex32::new(-2.2, 0.4)]]));
117
118 let a = vec![2., -4., 8., -10.];
119 let c = super::companion(&a).expect("valid input");
120 assert_eq!(
121 c,
122 aview2(&[[2.0, -4.0, 5.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
123 );
124 }
125}