qudit_expr/expressions/
ket.rs

1use std::ops::{Deref, DerefMut};
2
3use crate::{
4    GenerationShape, TensorExpression,
5    expressions::JittableExpression,
6    index::{IndexDirection, TensorIndex},
7};
8
9use super::ComplexExpression;
10use super::NamedExpression;
11use qudit_core::QuditSystem;
12use qudit_core::Radices;
13
14#[derive(PartialEq, Eq, Debug, Clone)]
15pub struct KetExpression {
16    inner: NamedExpression,
17    radices: Radices,
18}
19
20impl KetExpression {
21    pub fn new<T: AsRef<str>>(input: T) -> Self {
22        TensorExpression::new(input).try_into().unwrap()
23    }
24
25    pub fn zero<R: Into<Radices>>(radices: R) -> Self {
26        let name = "zero";
27        let radices = radices.into();
28        let mut body = vec![ComplexExpression::zero(); radices.dimension()];
29        body[0] = ComplexExpression::one();
30        let variables = vec![];
31        let inner = NamedExpression::new(name, variables, body);
32        KetExpression { inner, radices }
33    }
34}
35
36impl JittableExpression for KetExpression {
37    fn generation_shape(&self) -> GenerationShape {
38        GenerationShape::Matrix(self.radices.dimension(), 1)
39    }
40}
41
42impl AsRef<NamedExpression> for KetExpression {
43    fn as_ref(&self) -> &NamedExpression {
44        &self.inner
45    }
46}
47
48impl From<KetExpression> for NamedExpression {
49    fn from(value: KetExpression) -> Self {
50        value.inner
51    }
52}
53
54impl Deref for KetExpression {
55    type Target = NamedExpression;
56
57    fn deref(&self) -> &Self::Target {
58        &self.inner
59    }
60}
61
62impl DerefMut for KetExpression {
63    fn deref_mut(&mut self) -> &mut Self::Target {
64        &mut self.inner
65    }
66}
67
68impl QuditSystem for KetExpression {
69    fn radices(&self) -> Radices {
70        self.radices.clone()
71    }
72
73    fn num_qudits(&self) -> usize {
74        self.radices().num_qudits()
75    }
76}
77
78impl From<KetExpression> for TensorExpression {
79    fn from(value: KetExpression) -> Self {
80        let KetExpression { inner, radices } = value;
81        // TODO: add a proper implementation of into_iter for QuditRadices
82        let indices = radices
83            .iter()
84            .enumerate()
85            .map(|(i, r)| TensorIndex::new(IndexDirection::Output, i, usize::from(*r)))
86            .collect();
87        TensorExpression::from_raw(indices, inner)
88    }
89}
90
91impl TryFrom<TensorExpression> for KetExpression {
92    // TODO: Come up with proper error handling
93    type Error = String;
94
95    fn try_from(value: TensorExpression) -> Result<Self, Self::Error> {
96        if value
97            .indices()
98            .iter()
99            .any(|idx| idx.direction() != IndexDirection::Output && idx.index_size() != 1)
100        {
101            return Err(String::from(
102                "Cannot convert a tensor with non-output indices to a ket.",
103            ));
104        }
105        let radices = Radices::from_iter(
106            value
107                .indices()
108                .iter()
109                .filter(|idx| idx.index_size() > 1)
110                .map(|idx| idx.index_size()),
111        );
112        Ok(KetExpression {
113            inner: value.into(),
114            radices,
115        })
116    }
117}
118
119#[cfg(feature = "python")]
120mod python {
121    use super::*;
122    use crate::python::PyExpressionRegistrar;
123    use pyo3::prelude::*;
124    use qudit_core::Radix;
125
126    #[pyclass]
127    #[pyo3(name = "KetExpression")]
128    pub struct PyKetExpression {
129        expr: KetExpression,
130    }
131
132    #[pymethods]
133    impl PyKetExpression {
134        #[new]
135        fn new(expr: String) -> Self {
136            Self {
137                expr: KetExpression::new(expr),
138            }
139        }
140
141        fn num_params(&self) -> usize {
142            self.expr.num_params()
143        }
144
145        fn name(&self) -> String {
146            self.expr.name().to_string()
147        }
148
149        fn radices(&self) -> Vec<Radix> {
150            self.expr.radices().to_vec()
151        }
152
153        fn dimension(&self) -> usize {
154            self.expr.dimension()
155        }
156
157        fn __repr__(&self) -> String {
158            format!(
159                "KetExpression(name='{}', radices={:?}, params={})",
160                self.expr.name(),
161                self.expr.radices().to_vec(),
162                self.expr.num_params()
163            )
164        }
165    }
166
167    impl From<KetExpression> for PyKetExpression {
168        fn from(value: KetExpression) -> Self {
169            PyKetExpression { expr: value }
170        }
171    }
172
173    impl From<PyKetExpression> for KetExpression {
174        fn from(value: PyKetExpression) -> Self {
175            value.expr
176        }
177    }
178
179    impl<'py> IntoPyObject<'py> for KetExpression {
180        type Target = <PyKetExpression as IntoPyObject<'py>>::Target;
181        type Output = Bound<'py, Self::Target>;
182        type Error = PyErr;
183
184        fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
185            let py_expr = PyKetExpression::from(self);
186            Bound::new(py, py_expr)
187        }
188    }
189
190    impl<'a, 'py> FromPyObject<'a, 'py> for KetExpression {
191        type Error = PyErr;
192
193        fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
194            let py_expr: PyRef<PyKetExpression> = ob.extract()?;
195            Ok(py_expr.expr.clone())
196        }
197    }
198
199    /// Registers the KetExpression class with the Python module.
200    fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
201        parent_module.add_class::<PyKetExpression>()?;
202        Ok(())
203    }
204    inventory::submit!(PyExpressionRegistrar { func: register });
205}