qudit_expr/expressions/
utrysys.rs

1use std::ops::{Deref, DerefMut};
2
3use crate::{
4    GenerationShape, TensorExpression,
5    expressions::JittableExpression,
6    index::{IndexDirection, TensorIndex},
7};
8
9use super::NamedExpression;
10use qudit_core::QuditSystem;
11use qudit_core::Radices;
12
13#[derive(PartialEq, Eq, Debug, Clone)]
14pub struct UnitarySystemExpression {
15    inner: NamedExpression,
16    radices: Radices,
17    num_unitaries: usize,
18}
19
20impl UnitarySystemExpression {
21    pub fn new<T: AsRef<str>>(input: T) -> Self {
22        TensorExpression::new(input).try_into().unwrap()
23    }
24
25    pub fn num_qudits(&self) -> usize {
26        self.radices.num_qudits()
27    }
28}
29
30impl JittableExpression for UnitarySystemExpression {
31    fn generation_shape(&self) -> GenerationShape {
32        GenerationShape::Tensor3D(
33            self.num_unitaries,
34            self.radices.dimension(),
35            self.radices.dimension(),
36        )
37    }
38}
39
40impl AsRef<NamedExpression> for UnitarySystemExpression {
41    fn as_ref(&self) -> &NamedExpression {
42        &self.inner
43    }
44}
45
46impl From<UnitarySystemExpression> for NamedExpression {
47    fn from(value: UnitarySystemExpression) -> Self {
48        value.inner
49    }
50}
51
52impl Deref for UnitarySystemExpression {
53    type Target = NamedExpression;
54
55    fn deref(&self) -> &Self::Target {
56        &self.inner
57    }
58}
59
60impl DerefMut for UnitarySystemExpression {
61    fn deref_mut(&mut self) -> &mut Self::Target {
62        &mut self.inner
63    }
64}
65
66impl From<UnitarySystemExpression> for TensorExpression {
67    fn from(value: UnitarySystemExpression) -> Self {
68        let UnitarySystemExpression {
69            inner,
70            radices,
71            num_unitaries,
72        } = value;
73        // TODO: add a proper implementation of into_iter for QuditRadices
74        let indices = [num_unitaries]
75            .into_iter()
76            .map(|r| (IndexDirection::Batch, r))
77            .chain(
78                radices
79                    .iter()
80                    .map(|r| (IndexDirection::Output, usize::from(*r))),
81            )
82            .chain(
83                radices
84                    .iter()
85                    .map(|r| (IndexDirection::Input, usize::from(*r))),
86            )
87            .enumerate()
88            .map(|(i, (d, r))| TensorIndex::new(d, i, r))
89            .collect();
90        TensorExpression::from_raw(indices, inner)
91    }
92}
93
94impl TryFrom<TensorExpression> for UnitarySystemExpression {
95    // TODO: Come up with proper error handling
96    type Error = String;
97
98    fn try_from(value: TensorExpression) -> Result<Self, Self::Error> {
99        let mut num_unitaries = None;
100        let mut input_radices = vec![];
101        let mut output_radices = vec![];
102        for idx in value.indices() {
103            match idx.direction() {
104                IndexDirection::Batch => match &mut num_unitaries {
105                    Some(n) => *n *= idx.index_size(),
106                    None => num_unitaries = Some(idx.index_size()),
107                },
108                IndexDirection::Input => {
109                    input_radices.push(idx.index_size());
110                }
111                IndexDirection::Output => {
112                    output_radices.push(idx.index_size());
113                }
114                _ => unreachable!(),
115            }
116        }
117
118        if input_radices != output_radices {
119            return Err(String::from(
120                "Non-square matrix tensor cannot be converted to a unitary.",
121            ));
122        }
123
124        Ok(UnitarySystemExpression {
125            inner: value.into(),
126            radices: input_radices.into(),
127            num_unitaries: num_unitaries.unwrap_or(1),
128        })
129    }
130}
131
132#[cfg(feature = "python")]
133mod python {
134    use super::*;
135    use crate::python::PyExpressionRegistrar;
136    use pyo3::prelude::*;
137    use qudit_core::Radix;
138
139    #[pyclass]
140    #[pyo3(name = "UnitarySystemExpression")]
141    pub struct PyUnitarySystemExpression {
142        expr: UnitarySystemExpression,
143    }
144
145    #[pymethods]
146    impl PyUnitarySystemExpression {
147        #[new]
148        fn new(expr: String) -> Self {
149            Self {
150                expr: UnitarySystemExpression::new(expr),
151            }
152        }
153
154        fn num_params(&self) -> usize {
155            self.expr.num_params()
156        }
157
158        fn name(&self) -> String {
159            self.expr.name().to_string()
160        }
161
162        fn radices(&self) -> Vec<Radix> {
163            self.expr.radices.to_vec()
164        }
165
166        fn num_qudits(&self) -> usize {
167            self.expr.num_qudits()
168        }
169
170        fn num_unitaries(&self) -> usize {
171            self.expr.num_unitaries
172        }
173
174        fn dimension(&self) -> usize {
175            self.expr.radices.dimension()
176        }
177
178        fn __repr__(&self) -> String {
179            format!(
180                "UnitarySystemExpression(name='{}', radices={:?}, num_unitaries={}, params={})",
181                self.expr.name(),
182                self.expr.radices.to_vec(),
183                self.expr.num_unitaries,
184                self.expr.num_params()
185            )
186        }
187    }
188
189    impl From<UnitarySystemExpression> for PyUnitarySystemExpression {
190        fn from(value: UnitarySystemExpression) -> Self {
191            PyUnitarySystemExpression { expr: value }
192        }
193    }
194
195    impl From<PyUnitarySystemExpression> for UnitarySystemExpression {
196        fn from(value: PyUnitarySystemExpression) -> Self {
197            value.expr
198        }
199    }
200
201    impl<'py> IntoPyObject<'py> for UnitarySystemExpression {
202        type Target = <PyUnitarySystemExpression as IntoPyObject<'py>>::Target;
203        type Output = Bound<'py, Self::Target>;
204        type Error = PyErr;
205
206        fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
207            let py_expr = PyUnitarySystemExpression::from(self);
208            Bound::new(py, py_expr)
209        }
210    }
211
212    impl<'a, 'py> FromPyObject<'a, 'py> for UnitarySystemExpression {
213        type Error = PyErr;
214
215        fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
216            let py_expr: PyRef<PyUnitarySystemExpression> = ob.extract()?;
217            Ok(py_expr.expr.clone())
218        }
219    }
220
221    /// Registers the UnitarySystemExpression class with the Python module.
222    fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
223        parent_module.add_class::<PyUnitarySystemExpression>()?;
224        Ok(())
225    }
226    inventory::submit!(PyExpressionRegistrar { func: register });
227}