qudit_expr/expressions/
ketsys.rs

1use std::ops::{Deref, DerefMut};
2
3use crate::{
4    GenerationShape, TensorExpression,
5    expressions::JittableExpression,
6    index::{IndexDirection, TensorIndex},
7};
8
9use super::NamedExpression;
10use qudit_core::QuditSystem;
11use qudit_core::Radices;
12
13#[derive(PartialEq, Eq, Debug, Clone)]
14pub struct KetSystemExpression {
15    inner: NamedExpression,
16    radices: Radices,
17    num_states: usize,
18}
19
20impl JittableExpression for KetSystemExpression {
21    fn generation_shape(&self) -> GenerationShape {
22        GenerationShape::Tensor3D(self.num_states, self.radices.dimension(), 1)
23    }
24}
25
26impl AsRef<NamedExpression> for KetSystemExpression {
27    fn as_ref(&self) -> &NamedExpression {
28        &self.inner
29    }
30}
31
32impl From<KetSystemExpression> for NamedExpression {
33    fn from(value: KetSystemExpression) -> Self {
34        value.inner
35    }
36}
37
38impl Deref for KetSystemExpression {
39    type Target = NamedExpression;
40
41    fn deref(&self) -> &Self::Target {
42        &self.inner
43    }
44}
45
46impl DerefMut for KetSystemExpression {
47    fn deref_mut(&mut self) -> &mut Self::Target {
48        &mut self.inner
49    }
50}
51
52impl From<KetSystemExpression> for TensorExpression {
53    fn from(value: KetSystemExpression) -> Self {
54        let KetSystemExpression {
55            inner,
56            radices,
57            num_states,
58        } = value;
59        // TODO: add a proper implementation of into_iter for QuditRadices
60        let indices = [num_states]
61            .into_iter()
62            .map(|r| (IndexDirection::Batch, r))
63            .chain(
64                radices
65                    .iter()
66                    .map(|r| (IndexDirection::Output, usize::from(*r))),
67            )
68            .enumerate()
69            .map(|(i, (d, r))| TensorIndex::new(d, i, r))
70            .collect();
71        TensorExpression::from_raw(indices, inner)
72    }
73}
74
75impl TryFrom<TensorExpression> for KetSystemExpression {
76    // TODO: Come up with proper error handling
77    type Error = String;
78
79    fn try_from(value: TensorExpression) -> Result<Self, Self::Error> {
80        let mut num_states = None;
81        let mut radices = vec![];
82        for idx in value.indices() {
83            match idx.direction() {
84                IndexDirection::Batch => match num_states {
85                    Some(n) => num_states = Some(n * idx.index_size()),
86                    None => num_states = Some(idx.index_size()),
87                },
88                IndexDirection::Output => {
89                    radices.push(idx.index_size());
90                }
91                _ => {
92                    return Err(String::from(
93                        "Cannot convert a tensor with non-output or batch indices to a ket system.",
94                    ));
95                }
96            }
97        }
98
99        Ok(KetSystemExpression {
100            inner: value.into(),
101            radices: radices.into(),
102            num_states: num_states.unwrap_or(1),
103        })
104    }
105}