1use num_traits::One;
2use std::ops::Mul;
3
4#[derive(Debug)]
6pub struct MultiOpIterator<'a, P> {
7 iter_ns: &'a [usize],
8 iter_outputs: &'a [&'a [(usize, P)]],
9 curr_poss: Vec<usize>,
10 overflow: bool,
11}
12
13impl<'a, P> MultiOpIterator<'a, P> {
14 pub fn new(
17 iter_ns: &'a [usize],
18 iter_outputs: &'a [&'a [(usize, P)]],
19 ) -> MultiOpIterator<'a, P> {
20 let curr_poss: Vec<usize> = iter_ns.iter().map(|_| 0).collect();
21 MultiOpIterator {
22 iter_ns,
23 iter_outputs,
24 curr_poss,
25 overflow: false,
26 }
27 }
28}
29
30impl<'a, P> Iterator for MultiOpIterator<'a, P>
31where
32 P: One + Clone + Mul<P>,
33{
34 type Item = (usize, P);
35
36 fn next(&mut self) -> Option<Self::Item> {
37 if self.overflow {
38 self.overflow = false;
39 None
40 } else {
41 let init = (0usize, P::one());
42 let ret_val = self
43 .curr_poss
44 .iter()
45 .cloned()
46 .zip(self.iter_ns.iter().cloned())
47 .zip(self.iter_outputs.iter())
48 .fold(init, |(acc_col, acc_val), ((cur_pos, n_pos), outs)| {
49 let (col, val) = outs[cur_pos].clone();
50 let acc_col = (acc_col << n_pos) | col;
51 (acc_col, acc_val * val)
52 });
53
54 let mut broke_early = false;
56 let pos_iter = self
57 .curr_poss
58 .iter_mut()
59 .rev()
60 .zip(self.iter_outputs.iter().rev());
61
62 for (cur_pos, iter_n) in pos_iter {
63 *cur_pos += 1;
64 if *cur_pos == iter_n.len() {
65 *cur_pos = 0;
66 } else {
67 broke_early = true;
68 break;
69 }
70 }
71
72 if !broke_early {
74 self.overflow = true;
75 }
76 Some(ret_val)
77 }
78 }
79}
80
81#[cfg(test)]
82mod multi_iter_tests {
83 use super::*;
84 use num_complex::Complex;
85
86 #[test]
87 fn test_trivial() {
88 let one = Complex::one();
89 let entry1 = [(1, one)];
90 let entry2 = [(0, one)];
91 let r_entry: [&[(usize, Complex<f64>)]; 2] = [&entry1, &entry2];
92 let ns = [1, 1];
93 let it = MultiOpIterator::new(&ns, &r_entry);
94 let v: Vec<_> = it.collect();
95
96 assert_eq!(v, vec![(2, Complex { re: 1.0, im: 0.0 })]);
97 }
98
99 #[test]
100 fn test_nontrivial() {
101 let one = Complex::one();
102 let entry1 = [(0, one), (1, one)];
103 let entry2 = [(0, one)];
104 let r_entry: [&[(usize, Complex<f64>)]; 2] = [&entry1, &entry2];
105 let ns = [1, 1];
106 let it = MultiOpIterator::new(&ns, &r_entry);
107 let v: Vec<_> = it.collect();
108
109 assert_eq!(v, vec![(0, Complex::one()), (2, Complex::one())]);
110 }
111
112 #[test]
113 fn test_nontrivial_other() {
114 let one = Complex::one();
115 let entry1 = [(0, one)];
116 let entry2 = [(0, one), (1, one)];
117 let r_entry: [&[(usize, Complex<f64>)]; 2] = [&entry1, &entry2];
118 let ns = [1, 1];
119 let it = MultiOpIterator::new(&ns, &r_entry);
120 let v: Vec<_> = it.collect();
121
122 assert_eq!(v, vec![(0, Complex::one()), (1, Complex::one())]);
123 }
124
125 #[test]
126 fn test_mat_iterator() {
127 let n = 1usize;
128 let one = Complex::one();
129 let mat: Vec<Vec<f64>> = (0..1 << n)
130 .map(|i| -> Vec<f64> {
131 let entry = [(1 - i, one)];
132 let r_entry: [&[(usize, Complex<f64>)]; 1] = [&entry];
133 let ns = [n];
134 let it = MultiOpIterator::new(&ns, &r_entry);
135 let v: Vec<f64> = (0..1 << n).map(|_| 0.0).collect();
136 it.fold(v, |mut v, (indx, _)| {
137 v[indx] = 1.0;
138 v
139 })
140 })
141 .collect();
142
143 let expected = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
144
145 assert_eq!(mat, expected);
146 }
147
148 #[test]
149 fn test_double_mat_identity() {
150 let n = 2usize;
151 let one = Complex::one();
152 let mat: Vec<Vec<f64>> = (0..1 << n)
153 .map(|i| -> Vec<f64> {
154 let entry1 = [((i & 2) >> 1, one)];
155 let entry2 = [(i & 1, one)];
156 let r_entry: [&[(usize, Complex<f64>)]; 2] = [&entry1, &entry2];
157 let ns = [1, 1];
158 let it = MultiOpIterator::new(&ns, &r_entry);
159 let v: Vec<f64> = (0..1 << n).map(|_| 0.0).collect();
160 it.fold(v, |mut v, (indx, _)| {
161 v[indx] = 1.0;
162 v
163 })
164 })
165 .collect();
166
167 let expected = vec![
168 vec![1.0, 0.0, 0.0, 0.0],
169 vec![0.0, 1.0, 0.0, 0.0],
170 vec![0.0, 0.0, 1.0, 0.0],
171 vec![0.0, 0.0, 0.0, 1.0],
172 ];
173
174 assert_eq!(mat, expected);
175 }
176
177 #[test]
178 fn test_double_mat_swap() {
179 let n = 2usize;
180 let one = Complex::one();
181 let mat: Vec<Vec<f64>> = (0..1 << n)
182 .map(|i| -> Vec<f64> {
183 let entry1 = [((!i & 2) >> 1, one)];
184 let entry2 = [(!i & 1, one)];
185 let r_entry: [&[(usize, Complex<f64>)]; 2] = [&entry1, &entry2];
186 let ns = [1, 1];
187 let it = MultiOpIterator::new(&ns, &r_entry);
188 let v: Vec<f64> = (0..1 << n).map(|_| 0.0).collect();
189 it.fold(v, |mut v, (indx, _)| {
190 v[indx] = 1.0;
191 v
192 })
193 })
194 .collect();
195
196 let expected = vec![
197 vec![0.0, 0.0, 0.0, 1.0],
198 vec![0.0, 0.0, 1.0, 0.0],
199 vec![0.0, 1.0, 0.0, 0.0],
200 vec![1.0, 0.0, 0.0, 0.0],
201 ];
202
203 assert_eq!(mat, expected);
204 }
205}