qudit_tensor/network/
tensor.rs1use qudit_core::ParamInfo;
2use qudit_expr::ExpressionId;
3use qudit_expr::GenerationShape;
4
5use qudit_expr::index::IndexDirection;
6use qudit_expr::index::TensorIndex;
7
8#[derive(Debug, Clone)]
9pub struct QuditTensor {
10 pub indices: Vec<TensorIndex>,
11 pub expression: ExpressionId,
12 pub param_info: ParamInfo,
13}
14
15impl QuditTensor {
16 pub fn new(indices: Vec<TensorIndex>, expression: ExpressionId, param_info: ParamInfo) -> Self {
18 QuditTensor {
19 indices,
20 expression,
21 param_info,
22 }
23 }
24
25 pub fn num_indices(&self) -> usize {
26 self.indices.len()
27 }
28
29 pub fn shape(&self) -> GenerationShape {
30 (&self.indices).into()
31 }
32
33 pub fn batch_indices(&self) -> Vec<usize> {
35 self.indices
36 .iter()
37 .filter_map(|index| match index.direction() {
38 IndexDirection::Batch => Some(index.index_id()),
39 _ => None,
40 })
41 .collect()
42 }
43
44 pub fn output_indices(&self) -> Vec<usize> {
46 self.indices
47 .iter()
48 .filter_map(|index| match index.direction() {
49 IndexDirection::Output => Some(index.index_id()),
50 _ => None,
51 })
52 .collect()
53 }
54
55 pub fn input_indices(&self) -> Vec<usize> {
57 self.indices
58 .iter()
59 .filter_map(|index| match index.direction() {
60 IndexDirection::Input => Some(index.index_id()),
61 _ => None,
62 })
63 .collect()
64 }
65
66 pub fn batch_sizes(&self) -> Vec<usize> {
68 self.indices
69 .iter()
70 .filter_map(|index| match index.direction() {
71 IndexDirection::Batch => Some(index.index_size()),
72 _ => None,
73 })
74 .collect()
75 }
76
77 pub fn output_sizes(&self) -> Vec<usize> {
79 self.indices
80 .iter()
81 .filter_map(|index| match index.direction() {
82 IndexDirection::Output => Some(index.index_size()),
83 _ => None,
84 })
85 .collect()
86 }
87
88 pub fn input_sizes(&self) -> Vec<usize> {
90 self.indices
91 .iter()
92 .filter_map(|index| match index.direction() {
93 IndexDirection::Input => Some(index.index_size()),
94 _ => None,
95 })
96 .collect()
97 }
98
99 pub fn rank(&self) -> usize {
101 self.indices.len()
102 }
103}