qudit_expr/expressions/
ket.rs1use std::ops::{Deref, DerefMut};
2
3use crate::{
4 GenerationShape, TensorExpression,
5 expressions::JittableExpression,
6 index::{IndexDirection, TensorIndex},
7};
8
9use super::ComplexExpression;
10use super::NamedExpression;
11use qudit_core::QuditSystem;
12use qudit_core::Radices;
13
14#[derive(PartialEq, Eq, Debug, Clone)]
15pub struct KetExpression {
16 inner: NamedExpression,
17 radices: Radices,
18}
19
20impl KetExpression {
21 pub fn new<T: AsRef<str>>(input: T) -> Self {
22 TensorExpression::new(input).try_into().unwrap()
23 }
24
25 pub fn zero<R: Into<Radices>>(radices: R) -> Self {
26 let name = "zero";
27 let radices = radices.into();
28 let mut body = vec![ComplexExpression::zero(); radices.dimension()];
29 body[0] = ComplexExpression::one();
30 let variables = vec![];
31 let inner = NamedExpression::new(name, variables, body);
32 KetExpression { inner, radices }
33 }
34}
35
36impl JittableExpression for KetExpression {
37 fn generation_shape(&self) -> GenerationShape {
38 GenerationShape::Matrix(self.radices.dimension(), 1)
39 }
40}
41
42impl AsRef<NamedExpression> for KetExpression {
43 fn as_ref(&self) -> &NamedExpression {
44 &self.inner
45 }
46}
47
48impl From<KetExpression> for NamedExpression {
49 fn from(value: KetExpression) -> Self {
50 value.inner
51 }
52}
53
54impl Deref for KetExpression {
55 type Target = NamedExpression;
56
57 fn deref(&self) -> &Self::Target {
58 &self.inner
59 }
60}
61
62impl DerefMut for KetExpression {
63 fn deref_mut(&mut self) -> &mut Self::Target {
64 &mut self.inner
65 }
66}
67
68impl QuditSystem for KetExpression {
69 fn radices(&self) -> Radices {
70 self.radices.clone()
71 }
72
73 fn num_qudits(&self) -> usize {
74 self.radices().num_qudits()
75 }
76}
77
78impl From<KetExpression> for TensorExpression {
79 fn from(value: KetExpression) -> Self {
80 let KetExpression { inner, radices } = value;
81 let indices = radices
83 .iter()
84 .enumerate()
85 .map(|(i, r)| TensorIndex::new(IndexDirection::Output, i, usize::from(*r)))
86 .collect();
87 TensorExpression::from_raw(indices, inner)
88 }
89}
90
91impl TryFrom<TensorExpression> for KetExpression {
92 type Error = String;
94
95 fn try_from(value: TensorExpression) -> Result<Self, Self::Error> {
96 if value
97 .indices()
98 .iter()
99 .any(|idx| idx.direction() != IndexDirection::Output && idx.index_size() != 1)
100 {
101 return Err(String::from(
102 "Cannot convert a tensor with non-output indices to a ket.",
103 ));
104 }
105 let radices = Radices::from_iter(
106 value
107 .indices()
108 .iter()
109 .filter(|idx| idx.index_size() > 1)
110 .map(|idx| idx.index_size()),
111 );
112 Ok(KetExpression {
113 inner: value.into(),
114 radices,
115 })
116 }
117}
118
119#[cfg(feature = "python")]
120mod python {
121 use super::*;
122 use crate::python::PyExpressionRegistrar;
123 use pyo3::prelude::*;
124 use qudit_core::Radix;
125
126 #[pyclass]
127 #[pyo3(name = "KetExpression")]
128 pub struct PyKetExpression {
129 expr: KetExpression,
130 }
131
132 #[pymethods]
133 impl PyKetExpression {
134 #[new]
135 fn new(expr: String) -> Self {
136 Self {
137 expr: KetExpression::new(expr),
138 }
139 }
140
141 fn num_params(&self) -> usize {
142 self.expr.num_params()
143 }
144
145 fn name(&self) -> String {
146 self.expr.name().to_string()
147 }
148
149 fn radices(&self) -> Vec<Radix> {
150 self.expr.radices().to_vec()
151 }
152
153 fn dimension(&self) -> usize {
154 self.expr.dimension()
155 }
156
157 fn __repr__(&self) -> String {
158 format!(
159 "KetExpression(name='{}', radices={:?}, params={})",
160 self.expr.name(),
161 self.expr.radices().to_vec(),
162 self.expr.num_params()
163 )
164 }
165 }
166
167 impl From<KetExpression> for PyKetExpression {
168 fn from(value: KetExpression) -> Self {
169 PyKetExpression { expr: value }
170 }
171 }
172
173 impl From<PyKetExpression> for KetExpression {
174 fn from(value: PyKetExpression) -> Self {
175 value.expr
176 }
177 }
178
179 impl<'py> IntoPyObject<'py> for KetExpression {
180 type Target = <PyKetExpression as IntoPyObject<'py>>::Target;
181 type Output = Bound<'py, Self::Target>;
182 type Error = PyErr;
183
184 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
185 let py_expr = PyKetExpression::from(self);
186 Bound::new(py, py_expr)
187 }
188 }
189
190 impl<'a, 'py> FromPyObject<'a, 'py> for KetExpression {
191 type Error = PyErr;
192
193 fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
194 let py_expr: PyRef<PyKetExpression> = ob.extract()?;
195 Ok(py_expr.expr.clone())
196 }
197 }
198
199 fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
201 parent_module.add_class::<PyKetExpression>()?;
202 Ok(())
203 }
204 inventory::submit!(PyExpressionRegistrar { func: register });
205}