causal_hub/utils/
multi_index.rs

1use ndarray::prelude::*;
2use serde::{Deserialize, Serialize};
3
4/// A structure to compute the ravel index of a multi-dimensional array.
5#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
6pub struct MI {
7    shape: Array1<usize>,
8    strides: Array1<usize>,
9}
10
11impl MI {
12    /// Construct a new `MI` from the shape of each dimension.
13    ///
14    /// # Arguments
15    ///
16    /// * `shape` - An iterator over the shape of each dimension.
17    ///
18    /// # Returns
19    ///
20    /// A new `MI` instance.
21    ///
22    pub fn new<I>(shape: I) -> Self
23    where
24        I: IntoIterator<Item = usize>,
25    {
26        // Collect the multi index.
27        let shape: Array1<_> = shape.into_iter().collect();
28        // Allocate the strides of the parameters.
29        let mut strides = Array1::from_elem(shape.len(), 1);
30        // Compute cumulative product in reverse order (row-major strides).
31        for i in (0..shape.len().saturating_sub(1)).rev() {
32            strides[i] = strides[i + 1] * shape[i + 1];
33        }
34
35        Self { shape, strides }
36    }
37
38    /// Return the number of dimensions.
39    ///
40    /// # Returns
41    ///
42    /// The number of dimensions.
43    ///
44    #[inline]
45    pub const fn shape(&self) -> &Array1<usize> {
46        &self.shape
47    }
48
49    /// Compute the ravel index from a multi-dimensional index.
50    ///
51    /// # Arguments
52    ///
53    /// * `multi_index` - An iterator over the multi-dimensional index.
54    ///
55    /// # Returns
56    ///
57    /// The ravelled index.
58    ///
59    pub fn ravel<I>(&self, multi_index: I) -> usize
60    where
61        I: IntoIterator<Item = usize>,
62    {
63        self.strides
64            .iter()
65            .zip(multi_index)
66            .map(|(i, j)| i * j)
67            .sum()
68    }
69
70    /// Compute the multi-dimensional index from a ravelled index.
71    ///
72    /// # Arguments
73    ///
74    /// * `index` - The ravelled index.
75    ///
76    /// # Returns
77    ///
78    /// A vector containing the multi-dimensional index.
79    ///
80    pub fn unravel(&self, index: usize) -> Vec<usize> {
81        let mut multi_index = Vec::with_capacity(self.shape.len());
82        let mut remaining_index = index;
83
84        for &stride in &self.strides {
85            let value = remaining_index / stride;
86            multi_index.push(value);
87            remaining_index -= value * stride;
88        }
89
90        multi_index
91    }
92}