1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
use crate::iterators::{
    ControlledOpIterator, MatrixOpIterator, SparseMatrixOpIterator, SwapOpIterator,
};
use num_traits::{One, Zero};
use std::fmt;
use std::iter::Sum;
use std::ops::Mul;

/// Ops which can be applied to quantum states.
#[derive(Clone)]
pub enum MatrixOp<P> {
    /// Indices, Matrix data
    Matrix(Vec<usize>, Vec<P>),
    /// Indices, per row [(col, value)]
    SparseMatrix(Vec<usize>, Vec<Vec<(usize, P)>>),
    /// A indices, B indices
    Swap(Vec<usize>, Vec<usize>),
    /// Control indices, Op indices, Op
    Control(Vec<usize>, Vec<usize>, Box<MatrixOp<P>>),
}

impl<P> MatrixOp<P> {
    pub fn new_matrix<Indx, Dat>(indices: Indx, data: Dat) -> Self
    where
        Indx: Into<Vec<usize>>,
        Dat: Into<Vec<P>>,
    {
        Self::Matrix(indices.into(), data.into())
    }
}

impl<P> MatrixOp<P>
where
    P: Clone + Zero + One + Mul,
{
    /// The function `f` maps a column to a complex value (given the `row`) for the op matrix.
    /// Sums for all nonzero entries for a given `op`
    pub fn sum_for_op_cols<T, F>(&self, nindices: usize, row: usize, f: F) -> T
    where
        T: Sum,
        F: Fn((usize, P)) -> T,
    {
        match &self {
            MatrixOp::Matrix(_, data) => MatrixOpIterator::new(row, nindices, data).map(f).sum(),
            MatrixOp::SparseMatrix(_, data) => SparseMatrixOpIterator::new(row, data.as_slice())
                .map(f)
                .sum(),
            MatrixOp::Swap(_, _) => SwapOpIterator::new(row, nindices).map(f).sum(),
            MatrixOp::Control(c_indices, o_indices, op) => {
                let n_control_indices = c_indices.len();
                let n_op_indices = o_indices.len();
                op.sum_for_control_iterator(row, n_control_indices, n_op_indices, f)
            }
        }
    }

    fn sum_for_control_iterator<T, F>(
        &self,
        row: usize,
        n_control_indices: usize,
        n_op_indices: usize,
        f: F,
    ) -> T
    where
        T: Sum,
        F: Fn((usize, P)) -> T,
    {
        match &self {
            MatrixOp::Matrix(_, data) => {
                let iter_builder = |row: usize| MatrixOpIterator::new(row, n_op_indices, data);
                ControlledOpIterator::new(row, n_control_indices, n_op_indices, iter_builder)
                    .map(f)
                    .sum()
            }
            MatrixOp::SparseMatrix(_, data) => {
                let iter_builder = |row: usize| SparseMatrixOpIterator::new(row, data);
                ControlledOpIterator::new(row, n_control_indices, n_op_indices, iter_builder)
                    .map(f)
                    .sum()
            }
            MatrixOp::Swap(_, _) => {
                let iter_builder = |row: usize| SwapOpIterator::new(row, n_op_indices);
                ControlledOpIterator::new(row, n_control_indices, n_op_indices, iter_builder)
                    .map(f)
                    .sum()
            }
            // Control ops are automatically collapsed if made with helper, but implement this anyway
            // just to account for the possibility.
            MatrixOp::Control(c_indices, o_indices, op) => {
                let n_control_indices = n_control_indices + c_indices.len();
                let n_op_indices = o_indices.len();
                op.sum_for_control_iterator(row, n_control_indices, n_op_indices, f)
            }
        }
    }
}

impl<P> fmt::Debug for MatrixOp<P> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        let (name, indices) = match self {
            MatrixOp::Matrix(indices, _) => ("Matrix".to_string(), indices.clone()),
            MatrixOp::SparseMatrix(indices, _) => ("SparseMatrix".to_string(), indices.clone()),
            MatrixOp::Swap(a_indices, b_indices) => {
                let indices: Vec<_> = a_indices
                    .iter()
                    .cloned()
                    .chain(b_indices.iter().cloned())
                    .collect();
                ("Swap".to_string(), indices)
            }
            MatrixOp::Control(indices, _, op) => {
                let name = format!("C({:?})", *op);
                (name, indices.clone())
            }
        };
        let int_strings = indices
            .iter()
            .map(|x| x.clone().to_string())
            .collect::<Vec<String>>();

        write!(f, "{}[{}]", name, int_strings.join(", "))
    }
}

/// Get the number of indices represented by `op`
pub fn num_indices<P>(op: &MatrixOp<P>) -> usize {
    match &op {
        MatrixOp::Matrix(indices, _) => indices.len(),
        MatrixOp::SparseMatrix(indices, _) => indices.len(),
        MatrixOp::Swap(a, b) => a.len() + b.len(),
        MatrixOp::Control(cs, os, _) => cs.len() + os.len(),
    }
}