causal_hub/utils/multi_index.rs
1use ndarray::prelude::*;
2use serde::{Deserialize, Serialize};
3
4/// A structure to compute the ravel index of a multi-dimensional array.
5#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
6pub struct MI {
7 shape: Array1<usize>,
8 strides: Array1<usize>,
9}
10
11impl MI {
12 /// Construct a new `MI` from the shape of each dimension.
13 ///
14 /// # Arguments
15 ///
16 /// * `shape` - An iterator over the shape of each dimension.
17 ///
18 /// # Returns
19 ///
20 /// A new `MI` instance.
21 ///
22 pub fn new<I>(shape: I) -> Self
23 where
24 I: IntoIterator<Item = usize>,
25 {
26 // Collect the multi index.
27 let shape: Array1<_> = shape.into_iter().collect();
28 // Allocate the strides of the parameters.
29 let mut strides = Array1::from_elem(shape.len(), 1);
30 // Compute cumulative product in reverse order (row-major strides).
31 for i in (0..shape.len().saturating_sub(1)).rev() {
32 strides[i] = strides[i + 1] * shape[i + 1];
33 }
34
35 Self { shape, strides }
36 }
37
38 /// Return the number of dimensions.
39 ///
40 /// # Returns
41 ///
42 /// The number of dimensions.
43 ///
44 #[inline]
45 pub const fn shape(&self) -> &Array1<usize> {
46 &self.shape
47 }
48
49 /// Compute the ravel index from a multi-dimensional index.
50 ///
51 /// # Arguments
52 ///
53 /// * `multi_index` - An iterator over the multi-dimensional index.
54 ///
55 /// # Returns
56 ///
57 /// The ravelled index.
58 ///
59 pub fn ravel<I>(&self, multi_index: I) -> usize
60 where
61 I: IntoIterator<Item = usize>,
62 {
63 self.strides
64 .iter()
65 .zip(multi_index)
66 .map(|(i, j)| i * j)
67 .sum()
68 }
69
70 /// Compute the multi-dimensional index from a ravelled index.
71 ///
72 /// # Arguments
73 ///
74 /// * `index` - The ravelled index.
75 ///
76 /// # Returns
77 ///
78 /// A vector containing the multi-dimensional index.
79 ///
80 pub fn unravel(&self, index: usize) -> Vec<usize> {
81 let mut multi_index = Vec::with_capacity(self.shape.len());
82 let mut remaining_index = index;
83
84 for &stride in &self.strides {
85 let value = remaining_index / stride;
86 multi_index.push(value);
87 remaining_index -= value * stride;
88 }
89
90 multi_index
91 }
92}