qip_iterators/iterators/
qubit_multi_iterator.rs

1use num_traits::One;
2use std::ops::Mul;
3
4/// Iterator which provides the indices of nonzero columns for a given row for a collection of ops.
5#[derive(Debug)]
6pub struct MultiOpIterator<'a, P> {
7    iter_ns: &'a [usize],
8    iter_outputs: &'a [&'a [(usize, P)]],
9    curr_poss: Vec<usize>,
10    overflow: bool,
11}
12
13impl<'a, P> MultiOpIterator<'a, P> {
14    /// Build a new iterator using the number of qubits in each sub iterator, and the outputs of
15    /// said iterators on a given row.
16    pub fn new(
17        iter_ns: &'a [usize],
18        iter_outputs: &'a [&'a [(usize, P)]],
19    ) -> MultiOpIterator<'a, P> {
20        let curr_poss: Vec<usize> = iter_ns.iter().map(|_| 0).collect();
21        MultiOpIterator {
22            iter_ns,
23            iter_outputs,
24            curr_poss,
25            overflow: false,
26        }
27    }
28}
29
30impl<'a, P> Iterator for MultiOpIterator<'a, P>
31where
32    P: One + Clone + Mul<P>,
33{
34    type Item = (usize, P);
35
36    fn next(&mut self) -> Option<Self::Item> {
37        if self.overflow {
38            self.overflow = false;
39            None
40        } else {
41            let init = (0usize, P::one());
42            let ret_val = self
43                .curr_poss
44                .iter()
45                .cloned()
46                .zip(self.iter_ns.iter().cloned())
47                .zip(self.iter_outputs.iter())
48                .fold(init, |(acc_col, acc_val), ((cur_pos, n_pos), outs)| {
49                    let (col, val) = outs[cur_pos].clone();
50                    let acc_col = (acc_col << n_pos) | col;
51                    (acc_col, acc_val * val)
52                });
53
54            // Iterate through the current positions and increment when needed.
55            let mut broke_early = false;
56            let pos_iter = self
57                .curr_poss
58                .iter_mut()
59                .rev()
60                .zip(self.iter_outputs.iter().rev());
61
62            for (cur_pos, iter_n) in pos_iter {
63                *cur_pos += 1;
64                if *cur_pos == iter_n.len() {
65                    *cur_pos = 0;
66                } else {
67                    broke_early = true;
68                    break;
69                }
70            }
71
72            // If all poss overflowed, then next output should be None.
73            if !broke_early {
74                self.overflow = true;
75            }
76            Some(ret_val)
77        }
78    }
79}
80
81#[cfg(test)]
82mod multi_iter_tests {
83    use super::*;
84    use num_complex::Complex;
85
86    #[test]
87    fn test_trivial() {
88        let one = Complex::one();
89        let entry1 = [(1, one)];
90        let entry2 = [(0, one)];
91        let r_entry: [&[(usize, Complex<f64>)]; 2] = [&entry1, &entry2];
92        let ns = [1, 1];
93        let it = MultiOpIterator::new(&ns, &r_entry);
94        let v: Vec<_> = it.collect();
95
96        assert_eq!(v, vec![(2, Complex { re: 1.0, im: 0.0 })]);
97    }
98
99    #[test]
100    fn test_nontrivial() {
101        let one = Complex::one();
102        let entry1 = [(0, one), (1, one)];
103        let entry2 = [(0, one)];
104        let r_entry: [&[(usize, Complex<f64>)]; 2] = [&entry1, &entry2];
105        let ns = [1, 1];
106        let it = MultiOpIterator::new(&ns, &r_entry);
107        let v: Vec<_> = it.collect();
108
109        assert_eq!(v, vec![(0, Complex::one()), (2, Complex::one())]);
110    }
111
112    #[test]
113    fn test_nontrivial_other() {
114        let one = Complex::one();
115        let entry1 = [(0, one)];
116        let entry2 = [(0, one), (1, one)];
117        let r_entry: [&[(usize, Complex<f64>)]; 2] = [&entry1, &entry2];
118        let ns = [1, 1];
119        let it = MultiOpIterator::new(&ns, &r_entry);
120        let v: Vec<_> = it.collect();
121
122        assert_eq!(v, vec![(0, Complex::one()), (1, Complex::one())]);
123    }
124
125    #[test]
126    fn test_mat_iterator() {
127        let n = 1usize;
128        let one = Complex::one();
129        let mat: Vec<Vec<f64>> = (0..1 << n)
130            .map(|i| -> Vec<f64> {
131                let entry = [(1 - i, one)];
132                let r_entry: [&[(usize, Complex<f64>)]; 1] = [&entry];
133                let ns = [n];
134                let it = MultiOpIterator::new(&ns, &r_entry);
135                let v: Vec<f64> = (0..1 << n).map(|_| 0.0).collect();
136                it.fold(v, |mut v, (indx, _)| {
137                    v[indx] = 1.0;
138                    v
139                })
140            })
141            .collect();
142
143        let expected = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
144
145        assert_eq!(mat, expected);
146    }
147
148    #[test]
149    fn test_double_mat_identity() {
150        let n = 2usize;
151        let one = Complex::one();
152        let mat: Vec<Vec<f64>> = (0..1 << n)
153            .map(|i| -> Vec<f64> {
154                let entry1 = [((i & 2) >> 1, one)];
155                let entry2 = [(i & 1, one)];
156                let r_entry: [&[(usize, Complex<f64>)]; 2] = [&entry1, &entry2];
157                let ns = [1, 1];
158                let it = MultiOpIterator::new(&ns, &r_entry);
159                let v: Vec<f64> = (0..1 << n).map(|_| 0.0).collect();
160                it.fold(v, |mut v, (indx, _)| {
161                    v[indx] = 1.0;
162                    v
163                })
164            })
165            .collect();
166
167        let expected = vec![
168            vec![1.0, 0.0, 0.0, 0.0],
169            vec![0.0, 1.0, 0.0, 0.0],
170            vec![0.0, 0.0, 1.0, 0.0],
171            vec![0.0, 0.0, 0.0, 1.0],
172        ];
173
174        assert_eq!(mat, expected);
175    }
176
177    #[test]
178    fn test_double_mat_swap() {
179        let n = 2usize;
180        let one = Complex::one();
181        let mat: Vec<Vec<f64>> = (0..1 << n)
182            .map(|i| -> Vec<f64> {
183                let entry1 = [((!i & 2) >> 1, one)];
184                let entry2 = [(!i & 1, one)];
185                let r_entry: [&[(usize, Complex<f64>)]; 2] = [&entry1, &entry2];
186                let ns = [1, 1];
187                let it = MultiOpIterator::new(&ns, &r_entry);
188                let v: Vec<f64> = (0..1 << n).map(|_| 0.0).collect();
189                it.fold(v, |mut v, (indx, _)| {
190                    v[indx] = 1.0;
191                    v
192                })
193            })
194            .collect();
195
196        let expected = vec![
197            vec![0.0, 0.0, 0.0, 1.0],
198            vec![0.0, 0.0, 1.0, 0.0],
199            vec![0.0, 1.0, 0.0, 0.0],
200            vec![1.0, 0.0, 0.0, 0.0],
201        ];
202
203        assert_eq!(mat, expected);
204    }
205}