qudit_expr/expressions/
utrysys.rs1use std::ops::{Deref, DerefMut};
2
3use crate::{
4 GenerationShape, TensorExpression,
5 expressions::JittableExpression,
6 index::{IndexDirection, TensorIndex},
7};
8
9use super::NamedExpression;
10use qudit_core::QuditSystem;
11use qudit_core::Radices;
12
13#[derive(PartialEq, Eq, Debug, Clone)]
14pub struct UnitarySystemExpression {
15 inner: NamedExpression,
16 radices: Radices,
17 num_unitaries: usize,
18}
19
20impl UnitarySystemExpression {
21 pub fn new<T: AsRef<str>>(input: T) -> Self {
22 TensorExpression::new(input).try_into().unwrap()
23 }
24
25 pub fn num_qudits(&self) -> usize {
26 self.radices.num_qudits()
27 }
28}
29
30impl JittableExpression for UnitarySystemExpression {
31 fn generation_shape(&self) -> GenerationShape {
32 GenerationShape::Tensor3D(
33 self.num_unitaries,
34 self.radices.dimension(),
35 self.radices.dimension(),
36 )
37 }
38}
39
40impl AsRef<NamedExpression> for UnitarySystemExpression {
41 fn as_ref(&self) -> &NamedExpression {
42 &self.inner
43 }
44}
45
46impl From<UnitarySystemExpression> for NamedExpression {
47 fn from(value: UnitarySystemExpression) -> Self {
48 value.inner
49 }
50}
51
52impl Deref for UnitarySystemExpression {
53 type Target = NamedExpression;
54
55 fn deref(&self) -> &Self::Target {
56 &self.inner
57 }
58}
59
60impl DerefMut for UnitarySystemExpression {
61 fn deref_mut(&mut self) -> &mut Self::Target {
62 &mut self.inner
63 }
64}
65
66impl From<UnitarySystemExpression> for TensorExpression {
67 fn from(value: UnitarySystemExpression) -> Self {
68 let UnitarySystemExpression {
69 inner,
70 radices,
71 num_unitaries,
72 } = value;
73 let indices = [num_unitaries]
75 .into_iter()
76 .map(|r| (IndexDirection::Batch, r))
77 .chain(
78 radices
79 .iter()
80 .map(|r| (IndexDirection::Output, usize::from(*r))),
81 )
82 .chain(
83 radices
84 .iter()
85 .map(|r| (IndexDirection::Input, usize::from(*r))),
86 )
87 .enumerate()
88 .map(|(i, (d, r))| TensorIndex::new(d, i, r))
89 .collect();
90 TensorExpression::from_raw(indices, inner)
91 }
92}
93
94impl TryFrom<TensorExpression> for UnitarySystemExpression {
95 type Error = String;
97
98 fn try_from(value: TensorExpression) -> Result<Self, Self::Error> {
99 let mut num_unitaries = None;
100 let mut input_radices = vec![];
101 let mut output_radices = vec![];
102 for idx in value.indices() {
103 match idx.direction() {
104 IndexDirection::Batch => match &mut num_unitaries {
105 Some(n) => *n *= idx.index_size(),
106 None => num_unitaries = Some(idx.index_size()),
107 },
108 IndexDirection::Input => {
109 input_radices.push(idx.index_size());
110 }
111 IndexDirection::Output => {
112 output_radices.push(idx.index_size());
113 }
114 _ => unreachable!(),
115 }
116 }
117
118 if input_radices != output_radices {
119 return Err(String::from(
120 "Non-square matrix tensor cannot be converted to a unitary.",
121 ));
122 }
123
124 Ok(UnitarySystemExpression {
125 inner: value.into(),
126 radices: input_radices.into(),
127 num_unitaries: num_unitaries.unwrap_or(1),
128 })
129 }
130}
131
132#[cfg(feature = "python")]
133mod python {
134 use super::*;
135 use crate::python::PyExpressionRegistrar;
136 use pyo3::prelude::*;
137 use qudit_core::Radix;
138
139 #[pyclass]
140 #[pyo3(name = "UnitarySystemExpression")]
141 pub struct PyUnitarySystemExpression {
142 expr: UnitarySystemExpression,
143 }
144
145 #[pymethods]
146 impl PyUnitarySystemExpression {
147 #[new]
148 fn new(expr: String) -> Self {
149 Self {
150 expr: UnitarySystemExpression::new(expr),
151 }
152 }
153
154 fn num_params(&self) -> usize {
155 self.expr.num_params()
156 }
157
158 fn name(&self) -> String {
159 self.expr.name().to_string()
160 }
161
162 fn radices(&self) -> Vec<Radix> {
163 self.expr.radices.to_vec()
164 }
165
166 fn num_qudits(&self) -> usize {
167 self.expr.num_qudits()
168 }
169
170 fn num_unitaries(&self) -> usize {
171 self.expr.num_unitaries
172 }
173
174 fn dimension(&self) -> usize {
175 self.expr.radices.dimension()
176 }
177
178 fn __repr__(&self) -> String {
179 format!(
180 "UnitarySystemExpression(name='{}', radices={:?}, num_unitaries={}, params={})",
181 self.expr.name(),
182 self.expr.radices.to_vec(),
183 self.expr.num_unitaries,
184 self.expr.num_params()
185 )
186 }
187 }
188
189 impl From<UnitarySystemExpression> for PyUnitarySystemExpression {
190 fn from(value: UnitarySystemExpression) -> Self {
191 PyUnitarySystemExpression { expr: value }
192 }
193 }
194
195 impl From<PyUnitarySystemExpression> for UnitarySystemExpression {
196 fn from(value: PyUnitarySystemExpression) -> Self {
197 value.expr
198 }
199 }
200
201 impl<'py> IntoPyObject<'py> for UnitarySystemExpression {
202 type Target = <PyUnitarySystemExpression as IntoPyObject<'py>>::Target;
203 type Output = Bound<'py, Self::Target>;
204 type Error = PyErr;
205
206 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
207 let py_expr = PyUnitarySystemExpression::from(self);
208 Bound::new(py, py_expr)
209 }
210 }
211
212 impl<'a, 'py> FromPyObject<'a, 'py> for UnitarySystemExpression {
213 type Error = PyErr;
214
215 fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
216 let py_expr: PyRef<PyUnitarySystemExpression> = ob.extract()?;
217 Ok(py_expr.expr.clone())
218 }
219 }
220
221 fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
223 parent_module.add_class::<PyUnitarySystemExpression>()?;
224 Ok(())
225 }
226 inventory::submit!(PyExpressionRegistrar { func: register });
227}