qudit_tensor/network/
tensor.rs

1use qudit_core::ParamInfo;
2use qudit_expr::ExpressionId;
3use qudit_expr::GenerationShape;
4
5use qudit_expr::index::IndexDirection;
6use qudit_expr::index::TensorIndex;
7
8#[derive(Debug, Clone)]
9pub struct QuditTensor {
10    pub indices: Vec<TensorIndex>,
11    pub expression: ExpressionId,
12    pub param_info: ParamInfo,
13}
14
15impl QuditTensor {
16    /// Construct a new tensor object from a tensor expression and param indices.
17    pub fn new(indices: Vec<TensorIndex>, expression: ExpressionId, param_info: ParamInfo) -> Self {
18        QuditTensor {
19            indices,
20            expression,
21            param_info,
22        }
23    }
24
25    pub fn num_indices(&self) -> usize {
26        self.indices.len()
27    }
28
29    pub fn shape(&self) -> GenerationShape {
30        (&self.indices).into()
31    }
32
33    /// Returns a vector of index IDs for all batch legs of the tensor.
34    pub fn batch_indices(&self) -> Vec<usize> {
35        self.indices
36            .iter()
37            .filter_map(|index| match index.direction() {
38                IndexDirection::Batch => Some(index.index_id()),
39                _ => None,
40            })
41            .collect()
42    }
43
44    /// Returns a vector of index IDs for all output legs of the tensor.
45    pub fn output_indices(&self) -> Vec<usize> {
46        self.indices
47            .iter()
48            .filter_map(|index| match index.direction() {
49                IndexDirection::Output => Some(index.index_id()),
50                _ => None,
51            })
52            .collect()
53    }
54
55    /// Returns a vector of index IDs for all input legs of the tensor.
56    pub fn input_indices(&self) -> Vec<usize> {
57        self.indices
58            .iter()
59            .filter_map(|index| match index.direction() {
60                IndexDirection::Input => Some(index.index_id()),
61                _ => None,
62            })
63            .collect()
64    }
65
66    /// Returns a vector of sizes for all batch legs of the tensor.
67    pub fn batch_sizes(&self) -> Vec<usize> {
68        self.indices
69            .iter()
70            .filter_map(|index| match index.direction() {
71                IndexDirection::Batch => Some(index.index_size()),
72                _ => None,
73            })
74            .collect()
75    }
76
77    /// Returns a vector of sizes for all output legs of the tensor.
78    pub fn output_sizes(&self) -> Vec<usize> {
79        self.indices
80            .iter()
81            .filter_map(|index| match index.direction() {
82                IndexDirection::Output => Some(index.index_size()),
83                _ => None,
84            })
85            .collect()
86    }
87
88    /// Returns a vector of sizes for all input legs of the tensor.
89    pub fn input_sizes(&self) -> Vec<usize> {
90        self.indices
91            .iter()
92            .filter_map(|index| match index.direction() {
93                IndexDirection::Input => Some(index.index_size()),
94                _ => None,
95            })
96            .collect()
97    }
98
99    /// Returns the rank of the tensor.
100    pub fn rank(&self) -> usize {
101        self.indices.len()
102    }
103}