qudit_expr/expressions/
isometry.rs1use std::ops::{Deref, DerefMut};
2
3use crate::{
4 GenerationShape, TensorExpression,
5 expressions::JittableExpression,
6 index::{IndexDirection, TensorIndex},
7};
8
9use super::NamedExpression;
10use qudit_core::QuditSystem;
11use qudit_core::Radices;
12
13#[derive(PartialEq, Eq, Debug, Clone)]
14pub struct IsometryExpression {
15 inner: NamedExpression,
16 input_radices: Radices,
17 output_radices: Radices,
18}
19
20impl JittableExpression for IsometryExpression {
21 fn generation_shape(&self) -> GenerationShape {
22 GenerationShape::Matrix(
23 self.output_radices.dimension(),
24 self.input_radices.dimension(),
25 )
26 }
27}
28
29impl AsRef<NamedExpression> for IsometryExpression {
30 fn as_ref(&self) -> &NamedExpression {
31 &self.inner
32 }
33}
34
35impl From<IsometryExpression> for NamedExpression {
36 fn from(value: IsometryExpression) -> Self {
37 value.inner
38 }
39}
40
41impl Deref for IsometryExpression {
42 type Target = NamedExpression;
43
44 fn deref(&self) -> &Self::Target {
45 &self.inner
46 }
47}
48
49impl DerefMut for IsometryExpression {
50 fn deref_mut(&mut self) -> &mut Self::Target {
51 &mut self.inner
52 }
53}
54
55impl From<IsometryExpression> for TensorExpression {
56 fn from(value: IsometryExpression) -> Self {
57 let IsometryExpression {
58 inner,
59 input_radices,
60 output_radices,
61 } = value;
62 let indices = output_radices
64 .iter()
65 .map(|r| (IndexDirection::Output, usize::from(*r)))
66 .chain(
67 input_radices
68 .iter()
69 .map(|r| (IndexDirection::Input, usize::from(*r))),
70 )
71 .enumerate()
72 .map(|(i, (d, r))| TensorIndex::new(d, i, r))
73 .collect();
74 TensorExpression::from_raw(indices, inner)
75 }
76}
77
78impl TryFrom<TensorExpression> for IsometryExpression {
79 type Error = String;
81
82 fn try_from(value: TensorExpression) -> Result<Self, Self::Error> {
83 let mut input_radices = vec![];
84 let mut output_radices = vec![];
85 for idx in value.indices() {
86 match idx.direction() {
87 IndexDirection::Input => {
88 input_radices.push(idx.index_size());
89 }
90 IndexDirection::Output => {
91 output_radices.push(idx.index_size());
92 }
93 _ => {
94 return Err(String::from(
95 "Cannot convert a tensor with non-input, non-output indices to an isometry.",
96 ));
97 }
98 }
99 }
100
101 Ok(IsometryExpression {
102 inner: value.into(),
103 input_radices: input_radices.into(),
104 output_radices: output_radices.into(),
105 })
106 }
107}