qudit_expr/shape.rs
1use crate::index::{IndexDirection, TensorIndex};
2use std::ops::Add;
3
4/// Represents the shape of a tensor as it will be generated.
5///
6/// While tensors can conceptually have rank larger than four, even infinite,
7/// tensors in the OpenQudit Expression library are generated into a buffer
8/// indexed by 0, 1, 2, 3, or 4 physical dimensions.
9#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
10pub enum GenerationShape {
11 /// A 0-dimensional tensor (a single value).
12 Scalar,
13
14 /// A 1-dimensional tensor with `nelems` elements.
15 Vector(usize),
16
17 /// A 2-dimensional tensor (matrix) with `nrows` rows and `ncols` columns.
18 Matrix(usize, usize),
19
20 /// A 3-dimensional tensor with `nmats` matrices, each of `nrows` rows and `ncols` columns.
21 Tensor3D(usize, usize, usize),
22
23 /// A 4-dimensional tensor usually for derivatives (ntens, nmats, nrows, ncols)
24 Tensor4D(usize, usize, usize, usize),
25}
26
27impl std::fmt::Debug for GenerationShape {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 match self {
30 GenerationShape::Scalar => write!(f, "Scalar"),
31 GenerationShape::Vector(nelems) => write!(f, "Vector({})", nelems),
32 GenerationShape::Matrix(nrows, ncols) => write!(f, "Matrix({}, {})", nrows, ncols),
33 GenerationShape::Tensor3D(nmats, nrows, ncols) => {
34 write!(f, "Tensor3D({}, {}, {})", nmats, nrows, ncols)
35 }
36 GenerationShape::Tensor4D(ntens, nmats, nrows, ncols) => {
37 write!(f, "Tensor4D({}, {}, {}, {})", ntens, nmats, nrows, ncols)
38 }
39 }
40 }
41}
42
43impl GenerationShape {
44 /// Calculates the total number of elements in a tensor with this shape.
45 ///
46 /// # Returns
47 /// The total number of elements as `usize`.
48 pub fn num_elements(&self) -> usize {
49 match self {
50 GenerationShape::Scalar => 1,
51 GenerationShape::Vector(nelems) => *nelems,
52 GenerationShape::Matrix(nrows, ncols) => nrows * ncols,
53 GenerationShape::Tensor3D(nmats, nrows, ncols) => nmats * nrows * ncols,
54 GenerationShape::Tensor4D(ntens, nmats, nrows, ncols) => ntens * nmats * nrows * ncols,
55 }
56 }
57
58 /// Determines the shape of the derivative of a tensor with respect to `num_params` parameters.
59 ///
60 /// This method effectively prepends `num_params` to the current tensor's dimensions.
61 /// For example, the derivative of a `Scalar` with respect to `num_params` becomes a `Vector(num_params)`.
62 /// The derivative of a `Matrix(R, C)` becomes a `Tensor3D(num_params, R, C)`.
63 ///
64 /// # Arguments
65 /// * `num_params` - The number of parameters in the gradient.
66 ///
67 /// # Returns
68 /// A new `GenerationShape` representing the shape of the gradient.
69 ///
70 /// # See Also
71 /// - `[hessian_shape]` For the shape of a hessian tensor.
72 pub fn gradient_shape(&self, num_params: usize) -> Self {
73 match self {
74 GenerationShape::Scalar => GenerationShape::Vector(num_params),
75 GenerationShape::Vector(nelems) => GenerationShape::Matrix(num_params, *nelems),
76 GenerationShape::Matrix(nrows, ncols) => {
77 GenerationShape::Tensor3D(num_params, *nrows, *ncols)
78 }
79 GenerationShape::Tensor3D(nmats, nrows, ncols) => {
80 GenerationShape::Tensor4D(num_params, *nmats, *nrows, *ncols)
81 }
82 GenerationShape::Tensor4D(_, _, _, _) => {
83 panic!("Unable to find shape for gradient of 4D tensor.")
84 }
85 }
86 }
87
88 /// Determine the hessian shape of a tensor with this shape that has `num_params` parameters.
89 ///
90 /// # Arguments
91 /// * `num_params` - The number of parameters in the hessian.
92 ///
93 /// # Returns
94 /// A new `GenerationShape` representing the shape of the hessian.
95 ///
96 /// # See Also
97 /// - `[gradient_shape]` For the shape of a gradient tensor.
98 pub fn hessian_shape(&self, num_params: usize) -> Self {
99 let sym_sq_size = num_params * (num_params - 1) / 2; // TODO: Is it +1 or -1
100 match self {
101 GenerationShape::Scalar => GenerationShape::Vector(sym_sq_size),
102 GenerationShape::Vector(nelems) => GenerationShape::Matrix(sym_sq_size, *nelems),
103 GenerationShape::Matrix(nrows, ncols) => {
104 GenerationShape::Tensor3D(sym_sq_size, *nrows, *ncols)
105 }
106 GenerationShape::Tensor3D(nmats, nrows, ncols) => {
107 GenerationShape::Tensor4D(sym_sq_size, *nmats, *nrows, *ncols)
108 }
109 GenerationShape::Tensor4D(_, _, _, _) => {
110 panic!("Unable to find shape for Hessian of 4D tensor.")
111 }
112 }
113 }
114
115 /// Converts the tensor shape object to a vector of integers.
116 ///
117 /// # Returns
118 ///
119 /// A `Vec<usize>` containing the dimensions of the shape.
120 ///
121 /// # Examples
122 /// ```
123 /// use qudit_expr::GenerationShape;
124 ///
125 /// let scalar_shape = GenerationShape::Scalar;
126 /// assert_eq!(scalar_shape.to_vec(), Vec::<usize>::new());
127 ///
128 /// let vector_shape = GenerationShape::Vector(5);
129 /// assert_eq!(vector_shape.to_vec(), vec![5]);
130 ///
131 /// let matrix_shape = GenerationShape::Matrix(2, 3);
132 /// assert_eq!(matrix_shape.to_vec(), vec![2, 3]);
133 /// ```
134 pub fn to_vec(&self) -> Vec<usize> {
135 match self {
136 GenerationShape::Scalar => vec![],
137 GenerationShape::Vector(nelems) => vec![*nelems],
138 GenerationShape::Matrix(nrows, ncols) => vec![*nrows, *ncols],
139 GenerationShape::Tensor3D(nmats, nrows, ncols) => vec![*nmats, *nrows, *ncols],
140 GenerationShape::Tensor4D(ntens, nmats, nrows, ncols) => {
141 vec![*ntens, *nmats, *nrows, *ncols]
142 }
143 }
144 }
145
146 /// Checks if the current `GenerationShape` is strictly a scalar variant.
147 pub fn is_scalar(&self) -> bool {
148 matches!(self, GenerationShape::Scalar)
149 }
150
151 /// Checks if the current `GenerationShape` is strictly a vector variant.
152 pub fn is_vector(&self) -> bool {
153 matches!(self, GenerationShape::Vector(_))
154 }
155
156 /// Checks if the current `GenerationShape` is strictly a matrix variant.
157 pub fn is_matrix(&self) -> bool {
158 matches!(self, GenerationShape::Matrix(_, _))
159 }
160
161 /// Checks if the current `GenerationShape` is strictly a tensor3D variant.
162 pub fn is_tensor3d(&self) -> bool {
163 matches!(self, GenerationShape::Tensor3D(_, _, _))
164 }
165
166 /// Checks if the current `GenerationShape` is strictly a tensor4D variant.
167 pub fn is_tensor4d(&self) -> bool {
168 matches!(self, GenerationShape::Tensor4D(_, _, _, _))
169 }
170
171 /// Check if there is only one element.
172 ///
173 /// # Returns
174 /// `true` if the shape can be treated as a scalar, `false` otherwise.
175 ///
176 /// # Examples
177 /// ```
178 /// use qudit_expr::GenerationShape;
179 ///
180 /// let test_scalar = GenerationShape::Scalar;
181 /// let test_vector = GenerationShape::Vector(1);
182 /// let test_matrix = GenerationShape::Matrix(1, 1);
183 /// let test_tensor3d = GenerationShape::Tensor3D(1, 1, 1);
184 ///
185 /// let test_vector_2 = GenerationShape::Vector(9);
186 /// let test_matrix_2 = GenerationShape::Matrix(9, 9);
187 /// let test_tensor3d_2 = GenerationShape::Tensor3D(1, 9, 9);
188 ///
189 /// assert!(test_scalar.is_0d());
190 /// assert!(test_vector.is_0d());
191 /// assert!(test_matrix.is_0d());
192 /// assert!(test_tensor3d.is_0d());
193 ///
194 /// assert_eq!(test_vector_2.is_0d(), false);
195 /// assert_eq!(test_matrix_2.is_0d(), false);
196 /// assert_eq!(test_tensor3d_2.is_0d(), false);
197 /// ```
198 pub fn is_0d(&self) -> bool {
199 self.num_elements() == 1
200 }
201
202 /// Check if the shape can be conceptually treated as a 1d tensor.
203 ///
204 /// # Returns
205 /// `true` if the shape has exactly one dimension with 1 or more elements.
206 ///
207 /// # Examples
208 /// ```
209 /// use qudit_expr::GenerationShape;
210 ///
211 /// let test_vector = GenerationShape::Vector(9);
212 /// let test_matrix = GenerationShape::Matrix(1, 9);
213 /// let test_tensor3d = GenerationShape::Tensor3D(1, 1, 9);
214 ///
215 /// let test_scalar = GenerationShape::Scalar;
216 /// let test_matrix_2 = GenerationShape::Matrix(9, 9);
217 /// let test_tensor3d_2 = GenerationShape::Tensor3D(1, 9, 9);
218 ///
219 /// assert!(test_vector.is_1d());
220 /// assert!(test_matrix.is_1d());
221 /// assert!(test_tensor3d.is_1d());
222 ///
223 /// assert_eq!(test_scalar.is_1d(), false);
224 /// assert_eq!(test_matrix_2.is_1d(), false);
225 /// assert_eq!(test_tensor3d_2.is_1d(), false);
226 /// ```
227 pub fn is_1d(&self) -> bool {
228 match self {
229 GenerationShape::Scalar => false,
230 GenerationShape::Vector(_) => true,
231 GenerationShape::Matrix(nrows, ncols) => *nrows == 1 || *ncols == 1,
232 GenerationShape::Tensor3D(nmats, nrows, ncols) => {
233 let non_one_count = [*nmats, *nrows, *ncols].iter().filter(|&&d| d > 1).count();
234 non_one_count == 1
235 }
236 GenerationShape::Tensor4D(ntens, nmats, nrows, ncols) => {
237 let non_one_count = [*ntens, *nmats, *nrows, *ncols]
238 .iter()
239 .filter(|&&d| d > 1)
240 .count();
241 non_one_count == 1
242 }
243 }
244 }
245
246 /// Checks if the current `GenerationShape` can be conceptually treated as a 2-dimensional matrix.
247 /// This is true for `GenerationShape` variants with a dimensionality of at least 2, with
248 /// any additional dimensions having size 1.
249 ///
250 /// # Returns
251 /// `true` if the shape can be treated as a matrix, `false` otherwise.
252 ///
253 /// # Examples
254 /// ```
255 /// use qudit_expr::GenerationShape;
256 ///
257 /// let test_scalar = GenerationShape::Scalar;
258 /// let test_vector = GenerationShape::Vector(1);
259 /// let test_tensor3d_2 = GenerationShape::Tensor3D(9, 9, 9);
260 /// let test_tensor_nd_2 = GenerationShape::Tensor4D(1, 9, 9, 9);
261 ///
262 /// let test_matrix = GenerationShape::Matrix(9, 9);
263 /// let test_tensor3d = GenerationShape::Tensor3D(1, 9, 9);
264 /// let test_tensor_nd = GenerationShape::Tensor4D(1, 1, 9, 9);
265 ///
266 /// assert_eq!(test_scalar.is_2d(), false);
267 /// assert_eq!(test_vector.is_2d(), false);
268 /// assert_eq!(test_tensor3d_2.is_2d(), false);
269 /// assert_eq!(test_tensor_nd_2.is_2d(), false);
270 ///
271 /// assert_eq!(test_matrix.is_2d(), true);
272 /// assert_eq!(test_tensor3d.is_2d(), true);
273 /// assert_eq!(test_tensor_nd.is_2d(), true);
274 /// ```
275 pub fn is_2d(&self) -> bool {
276 match self {
277 GenerationShape::Scalar => false,
278 GenerationShape::Vector(_) => false,
279 GenerationShape::Matrix(_, _) => true,
280 // A Tensor3D can be seen as a matrix if it's essentially a stack of column vectors,
281 // or perhaps if it represents a single matrix (nmats=1).
282 // The current implementation checks if ncols is 1, implying it's a stack of column vectors.
283 GenerationShape::Tensor3D(nmats, _, _) => *nmats == 1,
284 GenerationShape::Tensor4D(ntens, nmats, _, _) => *ntens == 1 && *nmats == 1,
285 }
286 }
287
288 /// Checks if the current `GenerationShape` can be conceptually treated as a 3D tensor.
289 /// This is true for `GenerationShape` variants with a dimensionality of at least 3, with
290 /// any additional dimensions having size 1.
291 ///
292 /// # Returns
293 /// `true` if the shape can be treated as a 3D tensor, `false` otherwise.
294 ///
295 /// # Examples
296 /// ```
297 /// use qudit_expr::GenerationShape;
298 ///
299 /// let test_scalar = GenerationShape::Scalar;
300 /// let test_vector = GenerationShape::Vector(1);
301 /// let test_matrix = GenerationShape::Matrix(1, 1);
302 /// let test_tensor_nd_2 = GenerationShape::Tensor4D(9, 1, 9, 9);
303 ///
304 /// let test_tensor3d = GenerationShape::Tensor3D(9, 9, 9);
305 /// let test_tensor_nd = GenerationShape::Tensor4D(1, 9, 9, 9);
306 ///
307 /// assert_eq!(test_scalar.is_3d(), false);
308 /// assert_eq!(test_vector.is_3d(), false);
309 /// assert_eq!(test_matrix.is_3d(), false);
310 /// assert_eq!(test_tensor_nd_2.is_3d(), false);
311 ///
312 /// assert_eq!(test_tensor3d.is_3d(), true);
313 /// assert_eq!(test_tensor_nd.is_3d(), true);
314 /// ```
315 pub fn is_3d(&self) -> bool {
316 match self {
317 GenerationShape::Scalar => false,
318 GenerationShape::Vector(_) => false,
319 GenerationShape::Matrix(_, _) => false,
320 GenerationShape::Tensor3D(_, _, _) => true,
321 GenerationShape::Tensor4D(ntens, _, _, _) => *ntens == 1,
322 }
323 }
324
325 /// Checks if the current `GenerationShape` can be conceptually treated as a 4D tensor.
326 /// This is true for `GenerationShape` variants with a dimensionality of at least 4, with
327 /// any additional dimensions having size 1.
328 ///
329 /// # Returns
330 /// `true` if the shape can be treated as a 4D tensor, `false` otherwise.
331 ///
332 /// # Examples
333 /// ```
334 /// use qudit_expr::GenerationShape;
335 ///
336 /// let test_scalar = GenerationShape::Scalar;
337 /// let test_vector = GenerationShape::Vector(1);
338 /// let test_matrix = GenerationShape::Matrix(1, 1);
339 /// let test_tensor3d = GenerationShape::Tensor3D(1, 1, 1);
340 ///
341 /// let test_tensor4d = GenerationShape::Tensor4D(9, 9, 9, 9);
342 /// let test_tensor_nd = GenerationShape::Tensor4D(1, 9, 9, 9);
343 ///
344 /// assert_eq!(test_scalar.is_4d(), false);
345 /// assert_eq!(test_vector.is_4d(), false);
346 /// assert_eq!(test_matrix.is_4d(), false);
347 /// assert_eq!(test_tensor3d.is_4d(), false);
348 ///
349 /// assert_eq!(test_tensor4d.is_4d(), true);
350 /// assert_eq!(test_tensor_nd.is_4d(), true);
351 /// ```
352 pub fn is_4d(&self) -> bool {
353 match self {
354 GenerationShape::Scalar => false,
355 GenerationShape::Vector(_) => false,
356 GenerationShape::Matrix(_, _) => false,
357 GenerationShape::Tensor3D(_, _, _) => false,
358 GenerationShape::Tensor4D(_, _, _, _) => true,
359 }
360 }
361
362 /// Returns the number of columns for the current shape.
363 ///
364 /// # Returns
365 /// The number of columns.
366 ///
367 /// # Examples
368 /// ```
369 /// use qudit_expr::GenerationShape;
370 /// let matrix_shape = GenerationShape::Matrix(2, 3);
371 /// assert_eq!(matrix_shape.ncols(), 3);
372 /// ```
373 pub fn ncols(&self) -> usize {
374 match self {
375 GenerationShape::Scalar => 1,
376 GenerationShape::Vector(ncols) => *ncols,
377 GenerationShape::Matrix(_, ncols) => *ncols,
378 GenerationShape::Tensor3D(_, _, ncols) => *ncols,
379 GenerationShape::Tensor4D(_, _, _, ncols) => *ncols,
380 }
381 }
382
383 /// Returns the number of rows for the current shape.
384 ///
385 /// # Returns
386 /// The number of rows.
387 ///
388 /// # Examples
389 /// ```
390 /// use qudit_expr::GenerationShape;
391 /// let matrix_shape = GenerationShape::Matrix(2, 3);
392 /// assert_eq!(matrix_shape.nrows(), 2);
393 /// ```
394 pub fn nrows(&self) -> usize {
395 match self {
396 GenerationShape::Scalar => 1,
397 GenerationShape::Vector(_) => 1,
398 GenerationShape::Matrix(nrows, _) => *nrows,
399 GenerationShape::Tensor3D(_, nrows, _) => *nrows,
400 GenerationShape::Tensor4D(_, _, nrows, _) => *nrows,
401 }
402 }
403
404 /// Returns the number of matrices for the current shape.
405 ///
406 /// # Returns
407 /// The number of matrices.
408 ///
409 /// # Examples
410 /// ```
411 /// use qudit_expr::GenerationShape;
412 /// let tensor3d_shape = GenerationShape::Tensor3D(5, 2, 3);
413 /// assert_eq!(tensor3d_shape.nmats(), 5);
414 /// ```
415 pub fn nmats(&self) -> usize {
416 match self {
417 GenerationShape::Scalar => 1,
418 GenerationShape::Vector(_) => 1,
419 GenerationShape::Matrix(_, _) => 1,
420 GenerationShape::Tensor3D(nmats, _, _) => *nmats,
421 GenerationShape::Tensor4D(_, nmats, _, _) => *nmats,
422 }
423 }
424
425 /// Returns the number of tensors (in the first dimension) for the current shape.
426 ///
427 /// # Returns
428 /// The number of tensors.
429 ///
430 /// # Examples
431 /// ```
432 /// use qudit_expr::GenerationShape;
433 /// let tensor4d_shape = GenerationShape::Tensor4D(7, 5, 2, 3);
434 /// assert_eq!(tensor4d_shape.ntens(), 7);
435 /// ```
436 pub fn ntens(&self) -> usize {
437 match self {
438 GenerationShape::Scalar => 1,
439 GenerationShape::Vector(_) => 1,
440 GenerationShape::Matrix(_, _) => 1,
441 GenerationShape::Tensor3D(_, _, _) => 1,
442 GenerationShape::Tensor4D(ntens, _, _, _) => *ntens,
443 }
444 }
445
446 pub fn calculate_directions(&self, index_sizes: &[usize]) -> Vec<IndexDirection> {
447 match self {
448 GenerationShape::Scalar => vec![],
449 GenerationShape::Vector(_) => vec![IndexDirection::Input; index_sizes.len()],
450 GenerationShape::Matrix(nrows, _) => {
451 let mut index_size_acm = 1usize;
452 let mut index_iter = 0;
453 let mut index_directions = vec![];
454 while index_iter < index_sizes.len() && index_size_acm < *nrows {
455 index_size_acm *= index_sizes[index_iter];
456 index_directions.push(IndexDirection::Output);
457 index_iter += 1;
458 }
459 while index_iter < index_sizes.len() {
460 index_directions.push(IndexDirection::Input);
461 index_iter += 1;
462 }
463 index_directions
464 }
465 GenerationShape::Tensor3D(nmats, nrows, _) => {
466 let mut index_size_acm = 1usize;
467 let mut index_iter = 0;
468 let mut index_directions = vec![];
469 while index_iter < index_sizes.len() && index_size_acm < *nmats {
470 index_size_acm *= index_sizes[index_iter];
471 index_directions.push(IndexDirection::Batch);
472 index_iter += 1;
473 }
474 index_size_acm = 1usize; // Reset to calculate for nrows
475 while index_iter < index_sizes.len() && index_size_acm < *nrows {
476 index_size_acm *= index_sizes[index_iter];
477 index_directions.push(IndexDirection::Output);
478 index_iter += 1;
479 }
480 while index_iter < index_sizes.len() {
481 index_directions.push(IndexDirection::Input);
482 index_iter += 1;
483 }
484 index_directions
485 }
486 GenerationShape::Tensor4D(_, _, _, _) => {
487 todo!()
488 }
489 }
490 }
491}
492
493// impl From<Vec<TensorIndex>> for GenerationShape {
494// fn from(indices: Vec<TensorIndex>) -> Self {
495// GenerationShape::from(indices.as_slice())
496// }
497// }
498
499impl<I: AsRef<[TensorIndex]>> From<I> for GenerationShape {
500 fn from(indices: I) -> Self {
501 let indices = indices.as_ref();
502 let mut dimensions = [1, 1, 1, 1];
503 for index in indices.iter() {
504 match index.direction() {
505 IndexDirection::Derivative => dimensions[0] *= index.index_size(),
506 IndexDirection::Batch => dimensions[1] *= index.index_size(),
507 IndexDirection::Output => dimensions[2] *= index.index_size(),
508 IndexDirection::Input => dimensions[3] *= index.index_size(),
509 }
510 }
511
512 match dimensions {
513 [1, 1, 1, 1] => GenerationShape::Scalar,
514 [1, 1, 1, nelems] => GenerationShape::Vector(nelems),
515 [1, 1, nrows, ncols] => GenerationShape::Matrix(nrows, ncols),
516 [1, nmats, nrows, ncols] => GenerationShape::Tensor3D(nmats, nrows, ncols),
517 [ntens, nmats, nrows, ncols] => GenerationShape::Tensor4D(ntens, nmats, nrows, ncols),
518 }
519 }
520}
521
522impl Add for GenerationShape {
523 type Output = Self;
524
525 fn add(self, other: Self) -> Self::Output {
526 match (self, other) {
527 // TODO: re-evaluate this...
528 (GenerationShape::Scalar, other_shape) => other_shape,
529 (other_shape, GenerationShape::Scalar) => other_shape,
530 (GenerationShape::Vector(s_nelems), GenerationShape::Vector(o_nelems)) => {
531 GenerationShape::Vector(s_nelems + o_nelems)
532 }
533 (
534 GenerationShape::Matrix(s_nrows, s_ncols),
535 GenerationShape::Matrix(o_nrows, o_ncols),
536 ) => GenerationShape::Matrix(s_nrows + o_nrows, s_ncols + o_ncols),
537 (
538 GenerationShape::Tensor3D(s_nmats, s_nrows, s_ncols),
539 GenerationShape::Tensor3D(o_nmats, o_nrows, o_ncols),
540 ) => GenerationShape::Tensor3D(s_nmats + o_nmats, s_nrows + o_nrows, s_ncols + o_ncols),
541 (
542 GenerationShape::Tensor4D(s_ntens, s_nmats, s_nrows, s_ncols),
543 GenerationShape::Tensor4D(o_ntens, o_nmats, o_nrows, o_ncols),
544 ) => GenerationShape::Tensor4D(
545 s_ntens + o_ntens,
546 s_nmats + o_nmats,
547 s_nrows + o_nrows,
548 s_ncols + o_ncols,
549 ),
550 _ => {
551 panic!("Cannot add tensors of different fundamental shapes or incompatible ranks.")
552 }
553 }
554 }
555}