qudit_expr/
shape.rs

1use crate::index::{IndexDirection, TensorIndex};
2use std::ops::Add;
3
4/// Represents the shape of a tensor as it will be generated.
5///
6/// While tensors can conceptually have rank larger than four, even infinite,
7/// tensors in the OpenQudit Expression library are generated into a buffer
8/// indexed by 0, 1, 2, 3, or 4 physical dimensions.
9#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
10pub enum GenerationShape {
11    /// A 0-dimensional tensor (a single value).
12    Scalar,
13
14    /// A 1-dimensional tensor with `nelems` elements.
15    Vector(usize),
16
17    /// A 2-dimensional tensor (matrix) with `nrows` rows and `ncols` columns.
18    Matrix(usize, usize),
19
20    /// A 3-dimensional tensor with `nmats` matrices, each of `nrows` rows and `ncols` columns.
21    Tensor3D(usize, usize, usize),
22
23    /// A 4-dimensional tensor usually for derivatives (ntens, nmats, nrows, ncols)
24    Tensor4D(usize, usize, usize, usize),
25}
26
27impl std::fmt::Debug for GenerationShape {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            GenerationShape::Scalar => write!(f, "Scalar"),
31            GenerationShape::Vector(nelems) => write!(f, "Vector({})", nelems),
32            GenerationShape::Matrix(nrows, ncols) => write!(f, "Matrix({}, {})", nrows, ncols),
33            GenerationShape::Tensor3D(nmats, nrows, ncols) => {
34                write!(f, "Tensor3D({}, {}, {})", nmats, nrows, ncols)
35            }
36            GenerationShape::Tensor4D(ntens, nmats, nrows, ncols) => {
37                write!(f, "Tensor4D({}, {}, {}, {})", ntens, nmats, nrows, ncols)
38            }
39        }
40    }
41}
42
43impl GenerationShape {
44    /// Calculates the total number of elements in a tensor with this shape.
45    ///
46    /// # Returns
47    /// The total number of elements as `usize`.
48    pub fn num_elements(&self) -> usize {
49        match self {
50            GenerationShape::Scalar => 1,
51            GenerationShape::Vector(nelems) => *nelems,
52            GenerationShape::Matrix(nrows, ncols) => nrows * ncols,
53            GenerationShape::Tensor3D(nmats, nrows, ncols) => nmats * nrows * ncols,
54            GenerationShape::Tensor4D(ntens, nmats, nrows, ncols) => ntens * nmats * nrows * ncols,
55        }
56    }
57
58    /// Determines the shape of the derivative of a tensor with respect to `num_params` parameters.
59    ///
60    /// This method effectively prepends `num_params` to the current tensor's dimensions.
61    /// For example, the derivative of a `Scalar` with respect to `num_params` becomes a `Vector(num_params)`.
62    /// The derivative of a `Matrix(R, C)` becomes a `Tensor3D(num_params, R, C)`.
63    ///
64    /// # Arguments
65    /// * `num_params` - The number of parameters in the gradient.
66    ///
67    /// # Returns
68    /// A new `GenerationShape` representing the shape of the gradient.
69    ///
70    /// # See Also
71    /// - `[hessian_shape]` For the shape of a hessian tensor.
72    pub fn gradient_shape(&self, num_params: usize) -> Self {
73        match self {
74            GenerationShape::Scalar => GenerationShape::Vector(num_params),
75            GenerationShape::Vector(nelems) => GenerationShape::Matrix(num_params, *nelems),
76            GenerationShape::Matrix(nrows, ncols) => {
77                GenerationShape::Tensor3D(num_params, *nrows, *ncols)
78            }
79            GenerationShape::Tensor3D(nmats, nrows, ncols) => {
80                GenerationShape::Tensor4D(num_params, *nmats, *nrows, *ncols)
81            }
82            GenerationShape::Tensor4D(_, _, _, _) => {
83                panic!("Unable to find shape for gradient of 4D tensor.")
84            }
85        }
86    }
87
88    /// Determine the hessian shape of a tensor with this shape that has `num_params` parameters.
89    ///
90    /// # Arguments
91    /// * `num_params` - The number of parameters in the hessian.
92    ///
93    /// # Returns
94    /// A new `GenerationShape` representing the shape of the hessian.
95    ///
96    /// # See Also
97    /// - `[gradient_shape]` For the shape of a gradient tensor.
98    pub fn hessian_shape(&self, num_params: usize) -> Self {
99        let sym_sq_size = num_params * (num_params - 1) / 2; // TODO: Is it +1 or -1
100        match self {
101            GenerationShape::Scalar => GenerationShape::Vector(sym_sq_size),
102            GenerationShape::Vector(nelems) => GenerationShape::Matrix(sym_sq_size, *nelems),
103            GenerationShape::Matrix(nrows, ncols) => {
104                GenerationShape::Tensor3D(sym_sq_size, *nrows, *ncols)
105            }
106            GenerationShape::Tensor3D(nmats, nrows, ncols) => {
107                GenerationShape::Tensor4D(sym_sq_size, *nmats, *nrows, *ncols)
108            }
109            GenerationShape::Tensor4D(_, _, _, _) => {
110                panic!("Unable to find shape for Hessian of 4D tensor.")
111            }
112        }
113    }
114
115    /// Converts the tensor shape object to a vector of integers.
116    ///
117    /// # Returns
118    ///
119    /// A `Vec<usize>` containing the dimensions of the shape.
120    ///
121    /// # Examples
122    /// ```
123    /// use qudit_expr::GenerationShape;
124    ///
125    /// let scalar_shape = GenerationShape::Scalar;
126    /// assert_eq!(scalar_shape.to_vec(), Vec::<usize>::new());
127    ///
128    /// let vector_shape = GenerationShape::Vector(5);
129    /// assert_eq!(vector_shape.to_vec(), vec![5]);
130    ///
131    /// let matrix_shape = GenerationShape::Matrix(2, 3);
132    /// assert_eq!(matrix_shape.to_vec(), vec![2, 3]);
133    /// ```
134    pub fn to_vec(&self) -> Vec<usize> {
135        match self {
136            GenerationShape::Scalar => vec![],
137            GenerationShape::Vector(nelems) => vec![*nelems],
138            GenerationShape::Matrix(nrows, ncols) => vec![*nrows, *ncols],
139            GenerationShape::Tensor3D(nmats, nrows, ncols) => vec![*nmats, *nrows, *ncols],
140            GenerationShape::Tensor4D(ntens, nmats, nrows, ncols) => {
141                vec![*ntens, *nmats, *nrows, *ncols]
142            }
143        }
144    }
145
146    /// Checks if the current `GenerationShape` is strictly a scalar variant.
147    pub fn is_scalar(&self) -> bool {
148        matches!(self, GenerationShape::Scalar)
149    }
150
151    /// Checks if the current `GenerationShape` is strictly a vector variant.
152    pub fn is_vector(&self) -> bool {
153        matches!(self, GenerationShape::Vector(_))
154    }
155
156    /// Checks if the current `GenerationShape` is strictly a matrix variant.
157    pub fn is_matrix(&self) -> bool {
158        matches!(self, GenerationShape::Matrix(_, _))
159    }
160
161    /// Checks if the current `GenerationShape` is strictly a tensor3D variant.
162    pub fn is_tensor3d(&self) -> bool {
163        matches!(self, GenerationShape::Tensor3D(_, _, _))
164    }
165
166    /// Checks if the current `GenerationShape` is strictly a tensor4D variant.
167    pub fn is_tensor4d(&self) -> bool {
168        matches!(self, GenerationShape::Tensor4D(_, _, _, _))
169    }
170
171    /// Check if there is only one element.
172    ///
173    /// # Returns
174    /// `true` if the shape can be treated as a scalar, `false` otherwise.
175    ///
176    /// # Examples
177    /// ```
178    /// use qudit_expr::GenerationShape;
179    ///
180    /// let test_scalar = GenerationShape::Scalar;
181    /// let test_vector = GenerationShape::Vector(1);
182    /// let test_matrix = GenerationShape::Matrix(1, 1);
183    /// let test_tensor3d = GenerationShape::Tensor3D(1, 1, 1);
184    ///
185    /// let test_vector_2 = GenerationShape::Vector(9);
186    /// let test_matrix_2 = GenerationShape::Matrix(9, 9);
187    /// let test_tensor3d_2 = GenerationShape::Tensor3D(1, 9, 9);
188    ///
189    /// assert!(test_scalar.is_0d());
190    /// assert!(test_vector.is_0d());
191    /// assert!(test_matrix.is_0d());
192    /// assert!(test_tensor3d.is_0d());
193    ///
194    /// assert_eq!(test_vector_2.is_0d(), false);
195    /// assert_eq!(test_matrix_2.is_0d(), false);
196    /// assert_eq!(test_tensor3d_2.is_0d(), false);
197    /// ```
198    pub fn is_0d(&self) -> bool {
199        self.num_elements() == 1
200    }
201
202    /// Check if the shape can be conceptually treated as a 1d tensor.
203    ///
204    /// # Returns
205    /// `true` if the shape has exactly one dimension with 1 or more elements.
206    ///
207    /// # Examples
208    /// ```
209    /// use qudit_expr::GenerationShape;
210    ///
211    /// let test_vector = GenerationShape::Vector(9);
212    /// let test_matrix = GenerationShape::Matrix(1, 9);
213    /// let test_tensor3d = GenerationShape::Tensor3D(1, 1, 9);
214    ///
215    /// let test_scalar = GenerationShape::Scalar;
216    /// let test_matrix_2 = GenerationShape::Matrix(9, 9);
217    /// let test_tensor3d_2 = GenerationShape::Tensor3D(1, 9, 9);
218    ///
219    /// assert!(test_vector.is_1d());
220    /// assert!(test_matrix.is_1d());
221    /// assert!(test_tensor3d.is_1d());
222    ///
223    /// assert_eq!(test_scalar.is_1d(), false);
224    /// assert_eq!(test_matrix_2.is_1d(), false);
225    /// assert_eq!(test_tensor3d_2.is_1d(), false);
226    /// ```
227    pub fn is_1d(&self) -> bool {
228        match self {
229            GenerationShape::Scalar => false,
230            GenerationShape::Vector(_) => true,
231            GenerationShape::Matrix(nrows, ncols) => *nrows == 1 || *ncols == 1,
232            GenerationShape::Tensor3D(nmats, nrows, ncols) => {
233                let non_one_count = [*nmats, *nrows, *ncols].iter().filter(|&&d| d > 1).count();
234                non_one_count == 1
235            }
236            GenerationShape::Tensor4D(ntens, nmats, nrows, ncols) => {
237                let non_one_count = [*ntens, *nmats, *nrows, *ncols]
238                    .iter()
239                    .filter(|&&d| d > 1)
240                    .count();
241                non_one_count == 1
242            }
243        }
244    }
245
246    /// Checks if the current `GenerationShape` can be conceptually treated as a 2-dimensional matrix.
247    /// This is true for `GenerationShape` variants with a dimensionality of at least 2, with
248    /// any additional dimensions having size 1.
249    ///
250    /// # Returns
251    /// `true` if the shape can be treated as a matrix, `false` otherwise.
252    ///
253    /// # Examples
254    /// ```
255    /// use qudit_expr::GenerationShape;
256    ///
257    /// let test_scalar = GenerationShape::Scalar;
258    /// let test_vector = GenerationShape::Vector(1);
259    /// let test_tensor3d_2 = GenerationShape::Tensor3D(9, 9, 9);
260    /// let test_tensor_nd_2 = GenerationShape::Tensor4D(1, 9, 9, 9);
261    ///
262    /// let test_matrix = GenerationShape::Matrix(9, 9);
263    /// let test_tensor3d = GenerationShape::Tensor3D(1, 9, 9);
264    /// let test_tensor_nd = GenerationShape::Tensor4D(1, 1, 9, 9);
265    ///
266    /// assert_eq!(test_scalar.is_2d(), false);
267    /// assert_eq!(test_vector.is_2d(), false);
268    /// assert_eq!(test_tensor3d_2.is_2d(), false);
269    /// assert_eq!(test_tensor_nd_2.is_2d(), false);
270    ///
271    /// assert_eq!(test_matrix.is_2d(), true);
272    /// assert_eq!(test_tensor3d.is_2d(), true);
273    /// assert_eq!(test_tensor_nd.is_2d(), true);
274    /// ```
275    pub fn is_2d(&self) -> bool {
276        match self {
277            GenerationShape::Scalar => false,
278            GenerationShape::Vector(_) => false,
279            GenerationShape::Matrix(_, _) => true,
280            // A Tensor3D can be seen as a matrix if it's essentially a stack of column vectors,
281            // or perhaps if it represents a single matrix (nmats=1).
282            // The current implementation checks if ncols is 1, implying it's a stack of column vectors.
283            GenerationShape::Tensor3D(nmats, _, _) => *nmats == 1,
284            GenerationShape::Tensor4D(ntens, nmats, _, _) => *ntens == 1 && *nmats == 1,
285        }
286    }
287
288    /// Checks if the current `GenerationShape` can be conceptually treated as a 3D tensor.
289    /// This is true for `GenerationShape` variants with a dimensionality of at least 3, with
290    /// any additional dimensions having size 1.
291    ///
292    /// # Returns
293    /// `true` if the shape can be treated as a 3D tensor, `false` otherwise.
294    ///
295    /// # Examples
296    /// ```
297    /// use qudit_expr::GenerationShape;
298    ///
299    /// let test_scalar = GenerationShape::Scalar;
300    /// let test_vector = GenerationShape::Vector(1);
301    /// let test_matrix = GenerationShape::Matrix(1, 1);
302    /// let test_tensor_nd_2 = GenerationShape::Tensor4D(9, 1, 9, 9);
303    ///
304    /// let test_tensor3d = GenerationShape::Tensor3D(9, 9, 9);
305    /// let test_tensor_nd = GenerationShape::Tensor4D(1, 9, 9, 9);
306    ///
307    /// assert_eq!(test_scalar.is_3d(), false);
308    /// assert_eq!(test_vector.is_3d(), false);
309    /// assert_eq!(test_matrix.is_3d(), false);
310    /// assert_eq!(test_tensor_nd_2.is_3d(), false);
311    ///
312    /// assert_eq!(test_tensor3d.is_3d(), true);
313    /// assert_eq!(test_tensor_nd.is_3d(), true);
314    /// ```
315    pub fn is_3d(&self) -> bool {
316        match self {
317            GenerationShape::Scalar => false,
318            GenerationShape::Vector(_) => false,
319            GenerationShape::Matrix(_, _) => false,
320            GenerationShape::Tensor3D(_, _, _) => true,
321            GenerationShape::Tensor4D(ntens, _, _, _) => *ntens == 1,
322        }
323    }
324
325    /// Checks if the current `GenerationShape` can be conceptually treated as a 4D tensor.
326    /// This is true for `GenerationShape` variants with a dimensionality of at least 4, with
327    /// any additional dimensions having size 1.
328    ///
329    /// # Returns
330    /// `true` if the shape can be treated as a 4D tensor, `false` otherwise.
331    ///
332    /// # Examples
333    /// ```
334    /// use qudit_expr::GenerationShape;
335    ///
336    /// let test_scalar = GenerationShape::Scalar;
337    /// let test_vector = GenerationShape::Vector(1);
338    /// let test_matrix = GenerationShape::Matrix(1, 1);
339    /// let test_tensor3d = GenerationShape::Tensor3D(1, 1, 1);
340    ///
341    /// let test_tensor4d = GenerationShape::Tensor4D(9, 9, 9, 9);
342    /// let test_tensor_nd = GenerationShape::Tensor4D(1, 9, 9, 9);
343    ///
344    /// assert_eq!(test_scalar.is_4d(), false);
345    /// assert_eq!(test_vector.is_4d(), false);
346    /// assert_eq!(test_matrix.is_4d(), false);
347    /// assert_eq!(test_tensor3d.is_4d(), false);
348    ///
349    /// assert_eq!(test_tensor4d.is_4d(), true);
350    /// assert_eq!(test_tensor_nd.is_4d(), true);
351    /// ```
352    pub fn is_4d(&self) -> bool {
353        match self {
354            GenerationShape::Scalar => false,
355            GenerationShape::Vector(_) => false,
356            GenerationShape::Matrix(_, _) => false,
357            GenerationShape::Tensor3D(_, _, _) => false,
358            GenerationShape::Tensor4D(_, _, _, _) => true,
359        }
360    }
361
362    /// Returns the number of columns for the current shape.
363    ///
364    /// # Returns
365    /// The number of columns.
366    ///
367    /// # Examples
368    /// ```
369    /// use qudit_expr::GenerationShape;
370    /// let matrix_shape = GenerationShape::Matrix(2, 3);
371    /// assert_eq!(matrix_shape.ncols(), 3);
372    /// ```
373    pub fn ncols(&self) -> usize {
374        match self {
375            GenerationShape::Scalar => 1,
376            GenerationShape::Vector(ncols) => *ncols,
377            GenerationShape::Matrix(_, ncols) => *ncols,
378            GenerationShape::Tensor3D(_, _, ncols) => *ncols,
379            GenerationShape::Tensor4D(_, _, _, ncols) => *ncols,
380        }
381    }
382
383    /// Returns the number of rows for the current shape.
384    ///
385    /// # Returns
386    /// The number of rows.
387    ///
388    /// # Examples
389    /// ```
390    /// use qudit_expr::GenerationShape;
391    /// let matrix_shape = GenerationShape::Matrix(2, 3);
392    /// assert_eq!(matrix_shape.nrows(), 2);
393    /// ```
394    pub fn nrows(&self) -> usize {
395        match self {
396            GenerationShape::Scalar => 1,
397            GenerationShape::Vector(_) => 1,
398            GenerationShape::Matrix(nrows, _) => *nrows,
399            GenerationShape::Tensor3D(_, nrows, _) => *nrows,
400            GenerationShape::Tensor4D(_, _, nrows, _) => *nrows,
401        }
402    }
403
404    /// Returns the number of matrices for the current shape.
405    ///
406    /// # Returns
407    /// The number of matrices.
408    ///
409    /// # Examples
410    /// ```
411    /// use qudit_expr::GenerationShape;
412    /// let tensor3d_shape = GenerationShape::Tensor3D(5, 2, 3);
413    /// assert_eq!(tensor3d_shape.nmats(), 5);
414    /// ```
415    pub fn nmats(&self) -> usize {
416        match self {
417            GenerationShape::Scalar => 1,
418            GenerationShape::Vector(_) => 1,
419            GenerationShape::Matrix(_, _) => 1,
420            GenerationShape::Tensor3D(nmats, _, _) => *nmats,
421            GenerationShape::Tensor4D(_, nmats, _, _) => *nmats,
422        }
423    }
424
425    /// Returns the number of tensors (in the first dimension) for the current shape.
426    ///
427    /// # Returns
428    /// The number of tensors.
429    ///
430    /// # Examples
431    /// ```
432    /// use qudit_expr::GenerationShape;
433    /// let tensor4d_shape = GenerationShape::Tensor4D(7, 5, 2, 3);
434    /// assert_eq!(tensor4d_shape.ntens(), 7);
435    /// ```
436    pub fn ntens(&self) -> usize {
437        match self {
438            GenerationShape::Scalar => 1,
439            GenerationShape::Vector(_) => 1,
440            GenerationShape::Matrix(_, _) => 1,
441            GenerationShape::Tensor3D(_, _, _) => 1,
442            GenerationShape::Tensor4D(ntens, _, _, _) => *ntens,
443        }
444    }
445
446    pub fn calculate_directions(&self, index_sizes: &[usize]) -> Vec<IndexDirection> {
447        match self {
448            GenerationShape::Scalar => vec![],
449            GenerationShape::Vector(_) => vec![IndexDirection::Input; index_sizes.len()],
450            GenerationShape::Matrix(nrows, _) => {
451                let mut index_size_acm = 1usize;
452                let mut index_iter = 0;
453                let mut index_directions = vec![];
454                while index_iter < index_sizes.len() && index_size_acm < *nrows {
455                    index_size_acm *= index_sizes[index_iter];
456                    index_directions.push(IndexDirection::Output);
457                    index_iter += 1;
458                }
459                while index_iter < index_sizes.len() {
460                    index_directions.push(IndexDirection::Input);
461                    index_iter += 1;
462                }
463                index_directions
464            }
465            GenerationShape::Tensor3D(nmats, nrows, _) => {
466                let mut index_size_acm = 1usize;
467                let mut index_iter = 0;
468                let mut index_directions = vec![];
469                while index_iter < index_sizes.len() && index_size_acm < *nmats {
470                    index_size_acm *= index_sizes[index_iter];
471                    index_directions.push(IndexDirection::Batch);
472                    index_iter += 1;
473                }
474                index_size_acm = 1usize; // Reset to calculate for nrows
475                while index_iter < index_sizes.len() && index_size_acm < *nrows {
476                    index_size_acm *= index_sizes[index_iter];
477                    index_directions.push(IndexDirection::Output);
478                    index_iter += 1;
479                }
480                while index_iter < index_sizes.len() {
481                    index_directions.push(IndexDirection::Input);
482                    index_iter += 1;
483                }
484                index_directions
485            }
486            GenerationShape::Tensor4D(_, _, _, _) => {
487                todo!()
488            }
489        }
490    }
491}
492
493// impl From<Vec<TensorIndex>> for GenerationShape {
494//     fn from(indices: Vec<TensorIndex>) -> Self {
495//         GenerationShape::from(indices.as_slice())
496//     }
497// }
498
499impl<I: AsRef<[TensorIndex]>> From<I> for GenerationShape {
500    fn from(indices: I) -> Self {
501        let indices = indices.as_ref();
502        let mut dimensions = [1, 1, 1, 1];
503        for index in indices.iter() {
504            match index.direction() {
505                IndexDirection::Derivative => dimensions[0] *= index.index_size(),
506                IndexDirection::Batch => dimensions[1] *= index.index_size(),
507                IndexDirection::Output => dimensions[2] *= index.index_size(),
508                IndexDirection::Input => dimensions[3] *= index.index_size(),
509            }
510        }
511
512        match dimensions {
513            [1, 1, 1, 1] => GenerationShape::Scalar,
514            [1, 1, 1, nelems] => GenerationShape::Vector(nelems),
515            [1, 1, nrows, ncols] => GenerationShape::Matrix(nrows, ncols),
516            [1, nmats, nrows, ncols] => GenerationShape::Tensor3D(nmats, nrows, ncols),
517            [ntens, nmats, nrows, ncols] => GenerationShape::Tensor4D(ntens, nmats, nrows, ncols),
518        }
519    }
520}
521
522impl Add for GenerationShape {
523    type Output = Self;
524
525    fn add(self, other: Self) -> Self::Output {
526        match (self, other) {
527            // TODO: re-evaluate this...
528            (GenerationShape::Scalar, other_shape) => other_shape,
529            (other_shape, GenerationShape::Scalar) => other_shape,
530            (GenerationShape::Vector(s_nelems), GenerationShape::Vector(o_nelems)) => {
531                GenerationShape::Vector(s_nelems + o_nelems)
532            }
533            (
534                GenerationShape::Matrix(s_nrows, s_ncols),
535                GenerationShape::Matrix(o_nrows, o_ncols),
536            ) => GenerationShape::Matrix(s_nrows + o_nrows, s_ncols + o_ncols),
537            (
538                GenerationShape::Tensor3D(s_nmats, s_nrows, s_ncols),
539                GenerationShape::Tensor3D(o_nmats, o_nrows, o_ncols),
540            ) => GenerationShape::Tensor3D(s_nmats + o_nmats, s_nrows + o_nrows, s_ncols + o_ncols),
541            (
542                GenerationShape::Tensor4D(s_ntens, s_nmats, s_nrows, s_ncols),
543                GenerationShape::Tensor4D(o_ntens, o_nmats, o_nrows, o_ncols),
544            ) => GenerationShape::Tensor4D(
545                s_ntens + o_ntens,
546                s_nmats + o_nmats,
547                s_nrows + o_nrows,
548                s_ncols + o_ncols,
549            ),
550            _ => {
551                panic!("Cannot add tensors of different fundamental shapes or incompatible ranks.")
552            }
553        }
554    }
555}