qudit_expr/expressions/
kraus.rs1use std::ops::{Deref, DerefMut};
2
3use crate::{
4 GenerationShape, TensorExpression,
5 expressions::JittableExpression,
6 index::{IndexDirection, TensorIndex},
7};
8
9use super::NamedExpression;
10use qudit_core::QuditSystem;
11use qudit_core::Radices;
12
13#[derive(PartialEq, Eq, Debug, Clone)]
14pub struct KrausOperatorsExpression {
15 inner: NamedExpression,
16 input_radices: Radices,
17 output_radices: Radices,
18 num_operators: usize,
19}
20
21impl KrausOperatorsExpression {
22 pub fn new<T: AsRef<str>>(input: T) -> Self {
23 TensorExpression::new(input).try_into().unwrap()
24 }
25
26 pub fn num_qudits(&self) -> usize {
27 if self.input_radices == self.output_radices {
28 self.input_radices.num_qudits()
29 } else {
30 panic!("Input and output number of qudits are different for kraus operator.")
31 }
32 }
33}
34
35impl JittableExpression for KrausOperatorsExpression {
36 fn generation_shape(&self) -> GenerationShape {
37 GenerationShape::Tensor3D(
38 self.num_operators,
39 self.output_radices.dimension(),
40 self.input_radices.dimension(),
41 )
42 }
43}
44
45impl AsRef<NamedExpression> for KrausOperatorsExpression {
46 fn as_ref(&self) -> &NamedExpression {
47 &self.inner
48 }
49}
50
51impl From<KrausOperatorsExpression> for NamedExpression {
52 fn from(value: KrausOperatorsExpression) -> Self {
53 value.inner
54 }
55}
56
57impl Deref for KrausOperatorsExpression {
58 type Target = NamedExpression;
59
60 fn deref(&self) -> &Self::Target {
61 &self.inner
62 }
63}
64
65impl DerefMut for KrausOperatorsExpression {
66 fn deref_mut(&mut self) -> &mut Self::Target {
67 &mut self.inner
68 }
69}
70
71impl From<KrausOperatorsExpression> for TensorExpression {
72 fn from(value: KrausOperatorsExpression) -> Self {
73 let KrausOperatorsExpression {
74 inner,
75 input_radices,
76 output_radices,
77 num_operators,
78 } = value;
79 let indices = [num_operators]
81 .into_iter()
82 .map(|r| (IndexDirection::Batch, r))
83 .chain(
84 output_radices
85 .iter()
86 .map(|r| (IndexDirection::Output, usize::from(*r))),
87 )
88 .chain(
89 input_radices
90 .iter()
91 .map(|r| (IndexDirection::Input, usize::from(*r))),
92 )
93 .enumerate()
94 .map(|(i, (d, r))| TensorIndex::new(d, i, r))
95 .collect();
96 TensorExpression::from_raw(indices, inner)
97 }
98}
99
100impl TryFrom<TensorExpression> for KrausOperatorsExpression {
101 type Error = String;
103
104 fn try_from(value: TensorExpression) -> Result<Self, Self::Error> {
105 let mut num_operators = None;
106 let mut input_radices = vec![];
107 let mut output_radices = vec![];
108 for idx in value.indices() {
109 match idx.direction() {
110 IndexDirection::Batch => match num_operators {
111 Some(n) => num_operators = Some(n * idx.index_size()),
112 None => num_operators = Some(idx.index_size()),
113 },
114 IndexDirection::Input => {
115 input_radices.push(idx.index_size());
116 }
117 IndexDirection::Output => {
118 output_radices.push(idx.index_size());
119 }
120 _ => unreachable!(),
121 }
122 }
123
124 Ok(KrausOperatorsExpression {
125 inner: value.into(),
126 input_radices: input_radices.into(),
127 output_radices: output_radices.into(),
128 num_operators: num_operators.unwrap_or(1),
129 })
130 }
131}
132
133#[cfg(feature = "python")]
134mod python {
135 use super::*;
136 use crate::python::PyExpressionRegistrar;
137 use pyo3::prelude::*;
138 use qudit_core::Radix;
139
140 #[pyclass]
141 #[pyo3(name = "KrausOperatorsExpression")]
142 pub struct PyKrausOperatorsExpression {
143 expr: KrausOperatorsExpression,
144 }
145
146 #[pymethods]
147 impl PyKrausOperatorsExpression {
148 #[new]
149 fn new(expr: String) -> Self {
150 Self {
151 expr: KrausOperatorsExpression::new(expr),
152 }
153 }
154
155 fn num_params(&self) -> usize {
156 self.expr.num_params()
157 }
158
159 fn name(&self) -> String {
160 self.expr.name().to_string()
161 }
162
163 fn radices(&self) -> Vec<Radix> {
164 self.expr.input_radices.to_vec()
165 }
166
167 fn num_qudits(&self) -> usize {
168 self.expr.num_qudits()
169 }
170
171 fn num_operators(&self) -> usize {
172 self.expr.num_operators
173 }
174
175 fn dimension(&self) -> usize {
176 self.expr.input_radices.dimension()
177 }
178
179 fn __repr__(&self) -> String {
180 format!(
181 "KrausOperatorsExpression(name='{}', radices={:?}, num_operators={}, params={})",
182 self.expr.name(),
183 self.expr.input_radices.to_vec(),
184 self.expr.num_operators,
185 self.expr.num_params()
186 )
187 }
188 }
189
190 impl From<KrausOperatorsExpression> for PyKrausOperatorsExpression {
191 fn from(value: KrausOperatorsExpression) -> Self {
192 PyKrausOperatorsExpression { expr: value }
193 }
194 }
195
196 impl From<PyKrausOperatorsExpression> for KrausOperatorsExpression {
197 fn from(value: PyKrausOperatorsExpression) -> Self {
198 value.expr
199 }
200 }
201
202 impl<'py> IntoPyObject<'py> for KrausOperatorsExpression {
203 type Target = <PyKrausOperatorsExpression as IntoPyObject<'py>>::Target;
204 type Output = Bound<'py, Self::Target>;
205 type Error = PyErr;
206
207 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
208 let py_expr = PyKrausOperatorsExpression::from(self);
209 Bound::new(py, py_expr)
210 }
211 }
212
213 impl<'a, 'py> FromPyObject<'a, 'py> for KrausOperatorsExpression {
214 type Error = PyErr;
215
216 fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
217 let py_expr: PyRef<PyKrausOperatorsExpression> = ob.extract()?;
218 Ok(py_expr.expr.clone())
219 }
220 }
221
222 fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
224 parent_module.add_class::<PyKrausOperatorsExpression>()?;
225 Ok(())
226 }
227 inventory::submit!(PyExpressionRegistrar { func: register });
228}