qudit_expr/expressions/
ketsys.rs1use std::ops::{Deref, DerefMut};
2
3use crate::{
4 GenerationShape, TensorExpression,
5 expressions::JittableExpression,
6 index::{IndexDirection, TensorIndex},
7};
8
9use super::NamedExpression;
10use qudit_core::QuditSystem;
11use qudit_core::Radices;
12
13#[derive(PartialEq, Eq, Debug, Clone)]
14pub struct KetSystemExpression {
15 inner: NamedExpression,
16 radices: Radices,
17 num_states: usize,
18}
19
20impl JittableExpression for KetSystemExpression {
21 fn generation_shape(&self) -> GenerationShape {
22 GenerationShape::Tensor3D(self.num_states, self.radices.dimension(), 1)
23 }
24}
25
26impl AsRef<NamedExpression> for KetSystemExpression {
27 fn as_ref(&self) -> &NamedExpression {
28 &self.inner
29 }
30}
31
32impl From<KetSystemExpression> for NamedExpression {
33 fn from(value: KetSystemExpression) -> Self {
34 value.inner
35 }
36}
37
38impl Deref for KetSystemExpression {
39 type Target = NamedExpression;
40
41 fn deref(&self) -> &Self::Target {
42 &self.inner
43 }
44}
45
46impl DerefMut for KetSystemExpression {
47 fn deref_mut(&mut self) -> &mut Self::Target {
48 &mut self.inner
49 }
50}
51
52impl From<KetSystemExpression> for TensorExpression {
53 fn from(value: KetSystemExpression) -> Self {
54 let KetSystemExpression {
55 inner,
56 radices,
57 num_states,
58 } = value;
59 let indices = [num_states]
61 .into_iter()
62 .map(|r| (IndexDirection::Batch, r))
63 .chain(
64 radices
65 .iter()
66 .map(|r| (IndexDirection::Output, usize::from(*r))),
67 )
68 .enumerate()
69 .map(|(i, (d, r))| TensorIndex::new(d, i, r))
70 .collect();
71 TensorExpression::from_raw(indices, inner)
72 }
73}
74
75impl TryFrom<TensorExpression> for KetSystemExpression {
76 type Error = String;
78
79 fn try_from(value: TensorExpression) -> Result<Self, Self::Error> {
80 let mut num_states = None;
81 let mut radices = vec![];
82 for idx in value.indices() {
83 match idx.direction() {
84 IndexDirection::Batch => match num_states {
85 Some(n) => num_states = Some(n * idx.index_size()),
86 None => num_states = Some(idx.index_size()),
87 },
88 IndexDirection::Output => {
89 radices.push(idx.index_size());
90 }
91 _ => {
92 return Err(String::from(
93 "Cannot convert a tensor with non-output or batch indices to a ket system.",
94 ));
95 }
96 }
97 }
98
99 Ok(KetSystemExpression {
100 inner: value.into(),
101 radices: radices.into(),
102 num_states: num_states.unwrap_or(1),
103 })
104 }
105}