qudit_expr/expressions/
kraus.rs

1use std::ops::{Deref, DerefMut};
2
3use crate::{
4    GenerationShape, TensorExpression,
5    expressions::JittableExpression,
6    index::{IndexDirection, TensorIndex},
7};
8
9use super::NamedExpression;
10use qudit_core::QuditSystem;
11use qudit_core::Radices;
12
13#[derive(PartialEq, Eq, Debug, Clone)]
14pub struct KrausOperatorsExpression {
15    inner: NamedExpression,
16    input_radices: Radices,
17    output_radices: Radices,
18    num_operators: usize,
19}
20
21impl KrausOperatorsExpression {
22    pub fn new<T: AsRef<str>>(input: T) -> Self {
23        TensorExpression::new(input).try_into().unwrap()
24    }
25
26    pub fn num_qudits(&self) -> usize {
27        if self.input_radices == self.output_radices {
28            self.input_radices.num_qudits()
29        } else {
30            panic!("Input and output number of qudits are different for kraus operator.")
31        }
32    }
33}
34
35impl JittableExpression for KrausOperatorsExpression {
36    fn generation_shape(&self) -> GenerationShape {
37        GenerationShape::Tensor3D(
38            self.num_operators,
39            self.output_radices.dimension(),
40            self.input_radices.dimension(),
41        )
42    }
43}
44
45impl AsRef<NamedExpression> for KrausOperatorsExpression {
46    fn as_ref(&self) -> &NamedExpression {
47        &self.inner
48    }
49}
50
51impl From<KrausOperatorsExpression> for NamedExpression {
52    fn from(value: KrausOperatorsExpression) -> Self {
53        value.inner
54    }
55}
56
57impl Deref for KrausOperatorsExpression {
58    type Target = NamedExpression;
59
60    fn deref(&self) -> &Self::Target {
61        &self.inner
62    }
63}
64
65impl DerefMut for KrausOperatorsExpression {
66    fn deref_mut(&mut self) -> &mut Self::Target {
67        &mut self.inner
68    }
69}
70
71impl From<KrausOperatorsExpression> for TensorExpression {
72    fn from(value: KrausOperatorsExpression) -> Self {
73        let KrausOperatorsExpression {
74            inner,
75            input_radices,
76            output_radices,
77            num_operators,
78        } = value;
79        // TODO: add a proper implementation of into_iter for QuditRadices
80        let indices = [num_operators]
81            .into_iter()
82            .map(|r| (IndexDirection::Batch, r))
83            .chain(
84                output_radices
85                    .iter()
86                    .map(|r| (IndexDirection::Output, usize::from(*r))),
87            )
88            .chain(
89                input_radices
90                    .iter()
91                    .map(|r| (IndexDirection::Input, usize::from(*r))),
92            )
93            .enumerate()
94            .map(|(i, (d, r))| TensorIndex::new(d, i, r))
95            .collect();
96        TensorExpression::from_raw(indices, inner)
97    }
98}
99
100impl TryFrom<TensorExpression> for KrausOperatorsExpression {
101    // TODO: Come up with proper error handling
102    type Error = String;
103
104    fn try_from(value: TensorExpression) -> Result<Self, Self::Error> {
105        let mut num_operators = None;
106        let mut input_radices = vec![];
107        let mut output_radices = vec![];
108        for idx in value.indices() {
109            match idx.direction() {
110                IndexDirection::Batch => match num_operators {
111                    Some(n) => num_operators = Some(n * idx.index_size()),
112                    None => num_operators = Some(idx.index_size()),
113                },
114                IndexDirection::Input => {
115                    input_radices.push(idx.index_size());
116                }
117                IndexDirection::Output => {
118                    output_radices.push(idx.index_size());
119                }
120                _ => unreachable!(),
121            }
122        }
123
124        Ok(KrausOperatorsExpression {
125            inner: value.into(),
126            input_radices: input_radices.into(),
127            output_radices: output_radices.into(),
128            num_operators: num_operators.unwrap_or(1),
129        })
130    }
131}
132
133#[cfg(feature = "python")]
134mod python {
135    use super::*;
136    use crate::python::PyExpressionRegistrar;
137    use pyo3::prelude::*;
138    use qudit_core::Radix;
139
140    #[pyclass]
141    #[pyo3(name = "KrausOperatorsExpression")]
142    pub struct PyKrausOperatorsExpression {
143        expr: KrausOperatorsExpression,
144    }
145
146    #[pymethods]
147    impl PyKrausOperatorsExpression {
148        #[new]
149        fn new(expr: String) -> Self {
150            Self {
151                expr: KrausOperatorsExpression::new(expr),
152            }
153        }
154
155        fn num_params(&self) -> usize {
156            self.expr.num_params()
157        }
158
159        fn name(&self) -> String {
160            self.expr.name().to_string()
161        }
162
163        fn radices(&self) -> Vec<Radix> {
164            self.expr.input_radices.to_vec()
165        }
166
167        fn num_qudits(&self) -> usize {
168            self.expr.num_qudits()
169        }
170
171        fn num_operators(&self) -> usize {
172            self.expr.num_operators
173        }
174
175        fn dimension(&self) -> usize {
176            self.expr.input_radices.dimension()
177        }
178
179        fn __repr__(&self) -> String {
180            format!(
181                "KrausOperatorsExpression(name='{}', radices={:?}, num_operators={}, params={})",
182                self.expr.name(),
183                self.expr.input_radices.to_vec(),
184                self.expr.num_operators,
185                self.expr.num_params()
186            )
187        }
188    }
189
190    impl From<KrausOperatorsExpression> for PyKrausOperatorsExpression {
191        fn from(value: KrausOperatorsExpression) -> Self {
192            PyKrausOperatorsExpression { expr: value }
193        }
194    }
195
196    impl From<PyKrausOperatorsExpression> for KrausOperatorsExpression {
197        fn from(value: PyKrausOperatorsExpression) -> Self {
198            value.expr
199        }
200    }
201
202    impl<'py> IntoPyObject<'py> for KrausOperatorsExpression {
203        type Target = <PyKrausOperatorsExpression as IntoPyObject<'py>>::Target;
204        type Output = Bound<'py, Self::Target>;
205        type Error = PyErr;
206
207        fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
208            let py_expr = PyKrausOperatorsExpression::from(self);
209            Bound::new(py, py_expr)
210        }
211    }
212
213    impl<'a, 'py> FromPyObject<'a, 'py> for KrausOperatorsExpression {
214        type Error = PyErr;
215
216        fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
217            let py_expr: PyRef<PyKrausOperatorsExpression> = ob.extract()?;
218            Ok(py_expr.expr.clone())
219        }
220    }
221
222    /// Registers the KrausOperatorsExpression class with the Python module.
223    fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
224        parent_module.add_class::<PyKrausOperatorsExpression>()?;
225        Ok(())
226    }
227    inventory::submit!(PyExpressionRegistrar { func: register });
228}