qudit_expr/expressions/
brasys.rs

1use std::ops::{Deref, DerefMut};
2
3use crate::{
4    GenerationShape, TensorExpression,
5    expressions::JittableExpression,
6    index::{IndexDirection, TensorIndex},
7};
8
9use super::NamedExpression;
10use qudit_core::QuditSystem;
11use qudit_core::Radices;
12
13#[derive(PartialEq, Eq, Debug, Clone)]
14pub struct BraSystemExpression {
15    inner: NamedExpression,
16    radices: Radices,
17    num_states: usize,
18}
19
20impl BraSystemExpression {
21    pub fn new<T: AsRef<str>>(input: T) -> Self {
22        TensorExpression::new(input).try_into().unwrap()
23    }
24
25    pub fn num_qudits(&self) -> usize {
26        self.radices.num_qudits()
27    }
28}
29
30impl JittableExpression for BraSystemExpression {
31    fn generation_shape(&self) -> GenerationShape {
32        GenerationShape::Tensor3D(self.num_states, 1, self.radices.dimension())
33    }
34}
35
36impl AsRef<NamedExpression> for BraSystemExpression {
37    fn as_ref(&self) -> &NamedExpression {
38        &self.inner
39    }
40}
41
42impl From<BraSystemExpression> for NamedExpression {
43    fn from(value: BraSystemExpression) -> Self {
44        value.inner
45    }
46}
47
48impl Deref for BraSystemExpression {
49    type Target = NamedExpression;
50
51    fn deref(&self) -> &Self::Target {
52        &self.inner
53    }
54}
55
56impl DerefMut for BraSystemExpression {
57    fn deref_mut(&mut self) -> &mut Self::Target {
58        &mut self.inner
59    }
60}
61
62// TODO: replace individual From<X> for TensorExpression impls with blanket one
63// pub trait HasIndices {
64//     fn indices(&self) -> &[TensorIndex];
65// }
66
67// impl<T: HasIndices + Into<NamedExpression>> From<T> for TensorExpression {
68//     fn from(value: T) -> Self {
69//         let indices = value.indices().iter().cloned().collect();
70//         let inner = value.into();
71//         TensorExpression::from_raw(indices, inner)
72//     }
73// }
74
75impl From<BraSystemExpression> for TensorExpression {
76    fn from(value: BraSystemExpression) -> Self {
77        let BraSystemExpression {
78            inner,
79            radices,
80            num_states,
81        } = value;
82        // TODO: add a proper implementation of into_iter for QuditRadices
83        let indices = [num_states]
84            .into_iter()
85            .map(|r| (IndexDirection::Batch, r))
86            .chain(
87                radices
88                    .iter()
89                    .map(|r| (IndexDirection::Input, usize::from(*r))),
90            )
91            .enumerate()
92            .map(|(i, (d, r))| TensorIndex::new(d, i, r))
93            .collect();
94        TensorExpression::from_raw(indices, inner)
95    }
96}
97
98impl TryFrom<TensorExpression> for BraSystemExpression {
99    // TODO: Come up with proper error handling
100    type Error = String;
101
102    fn try_from(value: TensorExpression) -> Result<Self, Self::Error> {
103        let mut num_states = None;
104        let mut radices = vec![];
105        for idx in value.indices() {
106            match idx.direction() {
107                IndexDirection::Batch => match num_states {
108                    Some(n) => num_states = Some(n * idx.index_size()),
109                    None => num_states = Some(idx.index_size()),
110                },
111                IndexDirection::Input => {
112                    radices.push(idx.index_size());
113                }
114                _ => {
115                    if idx.index_size() > 1 {
116                        return Err(String::from(
117                            "Cannot convert a tensor with non-input or batch indices to a bra system.",
118                        ));
119                    }
120                }
121            }
122        }
123
124        Ok(BraSystemExpression {
125            inner: value.into(),
126            radices: radices.into(),
127            num_states: num_states.unwrap_or(1),
128        })
129    }
130}
131
132#[cfg(feature = "python")]
133mod python {
134    use super::*;
135    use crate::python::PyExpressionRegistrar;
136    use pyo3::prelude::*;
137    use qudit_core::Radix;
138
139    #[pyclass]
140    #[pyo3(name = "BraSystemExpression")]
141    pub struct PyBraSystemExpression {
142        expr: BraSystemExpression,
143    }
144
145    #[pymethods]
146    impl PyBraSystemExpression {
147        #[new]
148        fn new(expr: String) -> Self {
149            Self {
150                expr: BraSystemExpression::new(expr),
151            }
152        }
153
154        fn num_params(&self) -> usize {
155            self.expr.num_params()
156        }
157
158        fn name(&self) -> String {
159            self.expr.name().to_string()
160        }
161
162        fn radices(&self) -> Vec<Radix> {
163            self.expr.radices.to_vec()
164        }
165
166        fn num_qudits(&self) -> usize {
167            self.expr.num_qudits()
168        }
169
170        fn num_states(&self) -> usize {
171            self.expr.num_states
172        }
173
174        fn dimension(&self) -> usize {
175            self.expr.radices.dimension()
176        }
177
178        fn __repr__(&self) -> String {
179            format!(
180                "BraSystemExpression(name='{}', radices={:?}, num_states={}, params={})",
181                self.expr.name(),
182                self.expr.radices.to_vec(),
183                self.expr.num_states,
184                self.expr.num_params()
185            )
186        }
187    }
188
189    impl From<BraSystemExpression> for PyBraSystemExpression {
190        fn from(value: BraSystemExpression) -> Self {
191            PyBraSystemExpression { expr: value }
192        }
193    }
194
195    impl From<PyBraSystemExpression> for BraSystemExpression {
196        fn from(value: PyBraSystemExpression) -> Self {
197            value.expr
198        }
199    }
200
201    impl<'py> IntoPyObject<'py> for BraSystemExpression {
202        type Target = <PyBraSystemExpression as IntoPyObject<'py>>::Target;
203        type Output = Bound<'py, Self::Target>;
204        type Error = PyErr;
205
206        fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
207            let py_expr = PyBraSystemExpression::from(self);
208            Bound::new(py, py_expr)
209        }
210    }
211
212    impl<'a, 'py> FromPyObject<'a, 'py> for BraSystemExpression {
213        type Error = PyErr;
214
215        fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
216            let py_expr: PyRef<PyBraSystemExpression> = ob.extract()?;
217            Ok(py_expr.expr.clone())
218        }
219    }
220
221    /// Registers the BraSystemExpression class with the Python module.
222    fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
223        parent_module.add_class::<PyBraSystemExpression>()?;
224        Ok(())
225    }
226    inventory::submit!(PyExpressionRegistrar { func: register });
227}