1use crate::iterators::{sum_for_ops_cols, MatrixOp};
2use crate::utils::{get_bit, set_bit};
3use crate::{iter, iter_mut};
4use num_traits::{One, Zero};
5use std::iter::Sum;
6use std::ops::{AddAssign, Mul};
7
8#[cfg(feature = "parallel")]
9use rayon::prelude::*;
10
11pub fn full_to_sub(n: usize, mat_indices: &[usize], full_index: usize) -> usize {
13 let nindices = mat_indices.len();
14 mat_indices
15 .iter()
16 .enumerate()
17 .fold(0, |acc, (j, indx)| -> usize {
18 let bit = get_bit(full_index, n - 1 - *indx);
19 set_bit(acc, nindices - 1 - j, bit)
20 })
21}
22
23pub fn sub_to_full(n: usize, mat_indices: &[usize], sub_index: usize, base: usize) -> usize {
25 let nindices = mat_indices.len();
26 mat_indices.iter().enumerate().fold(base, |acc, (j, indx)| {
27 let bit = get_bit(sub_index, nindices - 1 - j);
28 set_bit(acc, n - 1 - *indx, bit)
29 })
30}
31
32pub fn get_index<P>(op: &MatrixOp<P>, i: usize) -> usize {
34 match &op {
35 MatrixOp::Matrix(indices, _) => indices[i],
36 MatrixOp::SparseMatrix(indices, _) => indices[i],
37 MatrixOp::Swap(a, b) => {
38 if i < a.len() {
39 a[i]
40 } else {
41 b[i - a.len()]
42 }
43 }
44 MatrixOp::Control(cs, os, _) => {
45 if i < cs.len() {
46 cs[i]
47 } else {
48 os[i - cs.len()]
49 }
50 }
51 }
52}
53
54pub fn apply_op_row<P>(
56 n: usize,
57 op: &MatrixOp<P>,
58 input: &[P],
59 outputrow: usize,
60 input_offset: usize,
61 output_offset: usize,
62) -> P
63where
64 P: Clone + One + Zero + Sum + Mul<Output = P> + Send + Sync,
65{
66 let mat_indices: Vec<usize> = (0..op.num_indices()).map(|i| get_index(op, i)).collect();
67 apply_op_row_indices(
68 n,
69 op,
70 input,
71 outputrow,
72 input_offset,
73 output_offset,
74 &mat_indices,
75 )
76}
77
78pub fn apply_op_row_indices<P>(
80 n: usize,
81 op: &MatrixOp<P>,
82 input: &[P],
83 outputrow: usize,
84 input_offset: usize,
85 output_offset: usize,
86 mat_indices: &[usize],
87) -> P
88where
89 P: Clone + One + Zero + Sum + Mul<Output = P> + Send + Sync,
90{
91 let row = output_offset + (outputrow);
92 let matrow = full_to_sub(n, mat_indices, row);
93 let f = |(i, val): (usize, P)| -> P {
96 let colbits = sub_to_full(n, mat_indices, i, row);
97 if colbits < input_offset {
98 P::zero()
99 } else {
100 let vecrow = colbits - input_offset;
101 if vecrow >= input.len() {
102 P::zero()
103 } else {
104 val * input[vecrow].clone()
105 }
106 }
107 };
108
109 op.sum_for_op_cols(mat_indices.len(), matrow, f)
111}
112
113pub fn apply_op<P>(
116 n: usize,
117 op: &MatrixOp<P>,
118 input: &[P],
119 output: &mut [P],
120 input_offset: usize,
121 output_offset: usize,
122) where
123 P: AddAssign + Clone + One + Zero + Sum + Mul<Output = P> + Send + Sync,
124{
125 let mat_indices: Vec<usize> = (0..op.num_indices()).map(|i| get_index(op, i)).collect();
126 let row_fn = |(outputrow, outputloc): (usize, &mut P)| {
127 *outputloc += apply_op_row_indices(
128 n,
129 op,
130 input,
131 outputrow,
132 input_offset,
133 output_offset,
134 &mat_indices,
135 );
136 };
137
138 iter_mut!(output).enumerate().for_each(row_fn);
140}
141
142pub fn apply_op_overwrite<P>(
145 n: usize,
146 op: &MatrixOp<P>,
147 input: &[P],
148 output: &mut [P],
149 input_offset: usize,
150 output_offset: usize,
151) where
152 P: AddAssign + Clone + One + Zero + Sum + Mul<Output = P> + Send + Sync,
153{
154 let mat_indices: Vec<usize> = (0..op.num_indices()).map(|i| get_index(op, i)).collect();
155 let row_fn = |(outputrow, outputloc): (usize, &mut P)| {
156 *outputloc = apply_op_row_indices(
157 n,
158 op,
159 input,
160 outputrow,
161 input_offset,
162 output_offset,
163 &mat_indices,
164 );
165 };
166
167 iter_mut!(output).enumerate().for_each(row_fn);
169}
170
171pub fn apply_ops<P>(
176 n: usize,
177 ops: &[MatrixOp<P>],
178 input: &[P],
179 output: &mut [P],
180 input_offset: usize,
181 output_offset: usize,
182) where
183 P: AddAssign + Clone + One + Zero + Sum + Mul<Output = P> + Send + Sync,
184{
185 match ops {
186 [op] => apply_op(n, op, input, output, input_offset, output_offset),
187 [] => {
188 let lower = input_offset.max(output_offset);
189 let upper = (input_offset + input.len()).min(output_offset + output.len());
190 let input_lower = lower - input_offset;
191 let input_upper = upper - input_offset;
192 let output_lower = lower - output_offset;
193 let output_upper = upper - output_offset;
194
195 let input_iter = iter!(input[input_lower..input_upper]);
196 let output_iter = iter_mut!(output[output_lower..output_upper]);
197 input_iter
198 .zip(output_iter)
199 .for_each(|(input, out)| *out = input.clone());
200 }
201 _ => {
202 let mat_indices: Vec<usize> = ops
203 .iter()
204 .flat_map(|op| -> Vec<usize> {
205 (0..op.num_indices()).map(|i| get_index(op, i)).collect()
206 })
207 .collect();
208
209 let row_fn = |(outputrow, outputloc): (usize, &mut P)| {
210 let row = output_offset + (outputrow);
211 let matrow = full_to_sub(n, &mat_indices, row);
212 let f = |(i, val): (usize, P)| -> P {
215 let colbits = sub_to_full(n, &mat_indices, i, row);
216 if colbits < input_offset {
217 P::zero()
218 } else {
219 let vecrow = colbits - input_offset;
220 if vecrow >= input.len() {
221 P::zero()
222 } else {
223 val * input[vecrow].clone()
224 }
225 }
226 };
227
228 *outputloc += sum_for_ops_cols(matrow, ops, f);
230 };
231
232 iter_mut!(output).enumerate().for_each(row_fn);
234 }
235 }
236}
237
238#[cfg(test)]
239mod test {
240 use super::*;
241 use ndarray::{Array2, ShapeError};
242 use std::ops::{Add, Div, Sub};
243
244 fn make_op_matrix<P>(n: usize, op: &MatrixOp<P>) -> Vec<Vec<P>>
247 where
248 P: AddAssign + Clone + One + Zero + Sum + Mul<Output = P> + Send + Sync,
249 {
250 let zeros: Vec<P> = (0..1 << n).map(|_| P::zero()).collect();
251 (0..1 << n)
252 .map(|i| {
253 let mut input = zeros.clone();
254 let mut output = zeros.clone();
255 input[i] = P::one();
256 apply_op(n, op, &input, &mut output, 0, 0);
257 output
258 })
259 .collect()
260 }
261
262 fn make_op_flat_matrix<P>(n: usize, op: &MatrixOp<P>) -> Result<Array2<P>, ShapeError>
263 where
264 P: AddAssign + Clone + One + Zero + Sum + Mul<Output = P> + Send + Sync,
265 {
266 let v = make_op_matrix(n, op)
267 .into_iter()
268 .flat_map(|v| v.into_iter())
269 .collect::<Vec<_>>();
270 let arr = Array2::from_shape_vec((1 << n, 1 << n), v)?;
271 Ok(arr.reversed_axes())
272 }
273
274 fn ndarray_kron_helper<P>(before: usize, mut mat: Array2<P>, after: usize) -> Array2<P>
275 where
276 P: Copy + One + Zero + Add<Output = P> + Sub<Output = P> + Div<Output = P> + 'static,
277 {
278 let eye = Array2::eye(2);
279 for _ in 0..before {
280 mat = ndarray::linalg::kron(&eye, &mat);
281 }
282 for _ in 0..after {
283 mat = ndarray::linalg::kron(&mat, &eye);
284 }
285 mat
286 }
287
288 #[test]
289 fn test_ident() -> Result<(), String> {
290 let n = 3;
291 let data = [1, 0, 0, 1];
292 let op = MatrixOp::new_matrix([0], data);
293 let arr = Array2::from_shape_vec((2, 2), data.into()).map_err(|e| format!("{:?}", e))?;
294 let mat = make_op_flat_matrix(n, &op).map_err(|e| format!("{:?}", e))?;
295 let comp_mat = ndarray_kron_helper(0, arr, n - 1);
296
297 debug_assert_eq!(mat, comp_mat);
298 Ok(())
299 }
300
301 #[test]
302 fn test_flip() -> Result<(), String> {
303 let n = 3;
304 let data = [0, 1, 1, 0];
305 let op = MatrixOp::new_matrix([0], data);
306 let arr = Array2::from_shape_vec((2, 2), data.into()).map_err(|e| format!("{:?}", e))?;
307 let mat = make_op_flat_matrix(n, &op).map_err(|e| format!("{:?}", e))?;
308 let comp_mat = ndarray_kron_helper(0, arr, n - 1);
309
310 debug_assert_eq!(mat, comp_mat);
311 Ok(())
312 }
313
314 #[test]
315 fn test_flip_mid() -> Result<(), String> {
316 let n = 3;
317 let data = [0, 1, 1, 0];
318 let op = MatrixOp::new_matrix([1], data);
319 let arr = Array2::from_shape_vec((2, 2), data.into()).map_err(|e| format!("{:?}", e))?;
320 let mat = make_op_flat_matrix(n, &op).map_err(|e| format!("{:?}", e))?;
321 let comp_mat = ndarray_kron_helper(1, arr, n - 2);
322
323 debug_assert_eq!(mat, comp_mat);
324 Ok(())
325 }
326
327 #[test]
328 fn test_flip_end() -> Result<(), String> {
329 let n = 3;
330 let data = [0, 1, 1, 0];
331 let op = MatrixOp::new_matrix([2], data);
332 let arr = Array2::from_shape_vec((2, 2), data.into()).map_err(|e| format!("{:?}", e))?;
333 let mat = make_op_flat_matrix(n, &op).map_err(|e| format!("{:?}", e))?;
334 let comp_mat = ndarray_kron_helper(2, arr, n - 3);
335
336 debug_assert_eq!(mat, comp_mat);
337 Ok(())
338 }
339
340 #[test]
341 fn test_flip_mid_twobody() -> Result<(), String> {
342 let n = 4;
343 let data = [1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1];
345 let op = MatrixOp::new_matrix([1, 2], data);
346 let arr = Array2::from_shape_vec((4, 4), data.into()).map_err(|e| format!("{:?}", e))?;
347 let mat = make_op_flat_matrix(n, &op).map_err(|e| format!("{:?}", e))?;
348 let comp_mat = ndarray_kron_helper(1, arr, 1);
349
350 debug_assert_eq!(mat, comp_mat);
351 Ok(())
352 }
353
354 #[test]
355 fn test_counting() -> Result<(), String> {
356 let n = 3;
357 let data = [1, 2, 3, 4];
358 let op = MatrixOp::new_matrix([0], data);
359 let arr = Array2::from_shape_vec((2, 2), data.into()).map_err(|e| format!("{:?}", e))?;
360 let mat = make_op_flat_matrix(n, &op).map_err(|e| format!("{:?}", e))?;
361 let comp_mat = ndarray_kron_helper(0, arr, n - 1);
362
363 debug_assert_eq!(mat, comp_mat);
364 Ok(())
365 }
366
367 #[test]
368 fn test_counting_order() -> Result<(), String> {
369 let n = 2;
370 let data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
371 let op = MatrixOp::new_matrix([0, 1], data);
372 let comp_mat = Array2::from_shape_vec((1 << n, 1 << n), data.into())
373 .map_err(|e| format!("{:?}", e))?;
374 let mat = make_op_flat_matrix(n, &op).map_err(|e| format!("{:?}", e))?;
375
376 debug_assert_eq!(mat, comp_mat);
377 Ok(())
378 }
379
380 #[test]
381 fn test_counting_order_flipped() -> Result<(), String> {
382 let n = 2;
383 let data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
384 let op = MatrixOp::new_matrix([1, 0], data);
385 let comp_mat = Array2::from_shape_vec((1 << n, 1 << n), data.into())
386 .map_err(|e| format!("{:?}", e))?;
387 let mat = make_op_flat_matrix(n, &op).map_err(|e| format!("{:?}", e))?;
388
389 debug_assert_ne!(mat, comp_mat);
390 Ok(())
391 }
392}