qudit_expr/expressions/
isometry.rs

1use std::ops::{Deref, DerefMut};
2
3use crate::{
4    GenerationShape, TensorExpression,
5    expressions::JittableExpression,
6    index::{IndexDirection, TensorIndex},
7};
8
9use super::NamedExpression;
10use qudit_core::QuditSystem;
11use qudit_core::Radices;
12
13#[derive(PartialEq, Eq, Debug, Clone)]
14pub struct IsometryExpression {
15    inner: NamedExpression,
16    input_radices: Radices,
17    output_radices: Radices,
18}
19
20impl JittableExpression for IsometryExpression {
21    fn generation_shape(&self) -> GenerationShape {
22        GenerationShape::Matrix(
23            self.output_radices.dimension(),
24            self.input_radices.dimension(),
25        )
26    }
27}
28
29impl AsRef<NamedExpression> for IsometryExpression {
30    fn as_ref(&self) -> &NamedExpression {
31        &self.inner
32    }
33}
34
35impl From<IsometryExpression> for NamedExpression {
36    fn from(value: IsometryExpression) -> Self {
37        value.inner
38    }
39}
40
41impl Deref for IsometryExpression {
42    type Target = NamedExpression;
43
44    fn deref(&self) -> &Self::Target {
45        &self.inner
46    }
47}
48
49impl DerefMut for IsometryExpression {
50    fn deref_mut(&mut self) -> &mut Self::Target {
51        &mut self.inner
52    }
53}
54
55impl From<IsometryExpression> for TensorExpression {
56    fn from(value: IsometryExpression) -> Self {
57        let IsometryExpression {
58            inner,
59            input_radices,
60            output_radices,
61        } = value;
62        // TODO: add a proper implementation of into_iter for QuditRadices
63        let indices = output_radices
64            .iter()
65            .map(|r| (IndexDirection::Output, usize::from(*r)))
66            .chain(
67                input_radices
68                    .iter()
69                    .map(|r| (IndexDirection::Input, usize::from(*r))),
70            )
71            .enumerate()
72            .map(|(i, (d, r))| TensorIndex::new(d, i, r))
73            .collect();
74        TensorExpression::from_raw(indices, inner)
75    }
76}
77
78impl TryFrom<TensorExpression> for IsometryExpression {
79    // TODO: Come up with proper error handling
80    type Error = String;
81
82    fn try_from(value: TensorExpression) -> Result<Self, Self::Error> {
83        let mut input_radices = vec![];
84        let mut output_radices = vec![];
85        for idx in value.indices() {
86            match idx.direction() {
87                IndexDirection::Input => {
88                    input_radices.push(idx.index_size());
89                }
90                IndexDirection::Output => {
91                    output_radices.push(idx.index_size());
92                }
93                _ => {
94                    return Err(String::from(
95                        "Cannot convert a tensor with non-input, non-output indices to an isometry.",
96                    ));
97                }
98            }
99        }
100
101        Ok(IsometryExpression {
102            inner: value.into(),
103            input_radices: input_radices.into(),
104            output_radices: output_radices.into(),
105        })
106    }
107}