causal_hub/datasets/trajectory/categorical/
weighted.rs

1use ndarray::prelude::*;
2use rayon::prelude::*;
3
4use crate::{
5    datasets::{CatTrj, CatType, Dataset},
6    models::Labelled,
7    types::{Labels, States},
8};
9
10/// A multivariate weighted trajectory.
11#[derive(Clone, Debug)]
12pub struct CatWtdTrj {
13    trajectory: CatTrj,
14    weight: f64,
15}
16
17impl From<(CatTrj, f64)> for CatWtdTrj {
18    fn from((trajectory, weight): (CatTrj, f64)) -> Self {
19        Self::new(trajectory, weight)
20    }
21}
22
23impl From<CatWtdTrj> for (CatTrj, f64) {
24    fn from(other: CatWtdTrj) -> Self {
25        (other.trajectory, other.weight)
26    }
27}
28
29impl CatWtdTrj {
30    /// Creates a new categorical weighted trajectory.
31    ///
32    /// # Arguments
33    ///
34    /// * `trajectory` - The trajectory.
35    /// * `weight` - The weight of the trajectory.
36    ///
37    /// # Panics
38    ///
39    /// Panics if the weight is not in the range [0, 1].
40    ///
41    /// # Returns
42    ///
43    /// A new categorical weighted trajectory.
44    ///
45    pub fn new(trajectory: CatTrj, weight: f64) -> Self {
46        // Assert that the weight is in the range [0, 1].
47        assert!(
48            (0.0..=1.0).contains(&weight),
49            "Weight must be in the range [0, 1], but got {weight}."
50        );
51
52        Self { trajectory, weight }
53    }
54
55    /// Returns the trajectory.
56    ///
57    /// # Returns
58    ///
59    /// A reference to the trajectory.
60    ///
61    #[inline]
62    pub const fn trajectory(&self) -> &CatTrj {
63        &self.trajectory
64    }
65
66    /// Returns the weight of the trajectory.
67    ///
68    /// # Returns
69    ///
70    /// The weight of the trajectory.
71    ///
72    #[inline]
73    pub const fn weight(&self) -> f64 {
74        self.weight
75    }
76
77    /// Returns the states of the trajectory.
78    ///
79    /// # Returns
80    ///
81    /// A reference to the states of the trajectory.
82    ///
83    #[inline]
84    pub const fn states(&self) -> &States {
85        self.trajectory.states()
86    }
87
88    /// Returns the shape of the trajectory.
89    ///
90    /// # Returns
91    ///
92    /// A reference to the shape of the trajectory.
93    ///
94    #[inline]
95    pub const fn shape(&self) -> &Array1<usize> {
96        self.trajectory.shape()
97    }
98
99    /// Returns the times of the trajectory.
100    ///
101    /// # Returns
102    ///
103    /// A reference to the times of the trajectory.
104    ///
105    #[inline]
106    pub const fn times(&self) -> &Array1<f64> {
107        self.trajectory.times()
108    }
109}
110
111impl Labelled for CatWtdTrj {
112    #[inline]
113    fn labels(&self) -> &Labels {
114        self.trajectory.labels()
115    }
116}
117
118impl Dataset for CatWtdTrj {
119    type Values = Array2<CatType>;
120
121    #[inline]
122    fn values(&self) -> &Self::Values {
123        self.trajectory.values()
124    }
125
126    #[inline]
127    fn sample_size(&self) -> f64 {
128        self.weight * (self.trajectory.values().nrows() as f64)
129    }
130}
131
132/// A collection of weighted trajectories.
133#[derive(Clone, Debug)]
134pub struct CatWtdTrjs {
135    labels: Labels,
136    states: States,
137    shape: Array1<usize>,
138    values: Vec<CatWtdTrj>,
139}
140
141impl CatWtdTrjs {
142    /// Constructs a new collection of trajectories.
143    ///
144    /// # Arguments
145    ///
146    /// * `trajectories` - An iterator of `CategoricalTrajectory` instances.
147    ///
148    /// # Panics
149    ///
150    /// Panics if:
151    ///
152    /// * The trajectories have different labels.
153    /// * The trajectories have different states.
154    /// * The trajectories have different shape.
155    /// * The trajectories are empty.
156    ///
157    /// # Returns
158    ///
159    /// A new instance of `CategoricalTrajectories`.
160    ///
161    pub fn new<I>(values: I) -> Self
162    where
163        I: IntoIterator<Item = CatWtdTrj>,
164    {
165        // Collect the trajectories into a vector.
166        let values: Vec<_> = values.into_iter().collect();
167
168        // Assert every trajectory has the same labels.
169        assert!(
170            values
171                .windows(2)
172                .all(|trjs| trjs[0].labels().eq(trjs[1].labels())),
173            "All trajectories must have the same labels."
174        );
175        // Assert every trajectory has the same states.
176        assert!(
177            values
178                .windows(2)
179                .all(|trjs| trjs[0].states().eq(trjs[1].states())),
180            "All trajectories must have the same states."
181        );
182        // Assert every trajectory has the same shape.
183        assert!(
184            values
185                .windows(2)
186                .all(|trjs| trjs[0].shape().eq(trjs[1].shape())),
187            "All trajectories must have the same shape."
188        );
189
190        // Get the labels, states and shape from the first trajectory.
191        let trj = values.first().expect("No trajectory in the dataset.");
192        let labels = trj.labels().clone();
193        let states = trj.states().clone();
194        let shape = trj.shape().clone();
195
196        Self {
197            labels,
198            states,
199            shape,
200            values,
201        }
202    }
203
204    /// Returns the states of the trajectories.
205    ///
206    /// # Returns
207    ///
208    /// A reference to the states of the trajectories.
209    ///
210    #[inline]
211    pub fn states(&self) -> &States {
212        &self.states
213    }
214
215    /// Returns the shape of the trajectories.
216    ///
217    /// # Returns
218    ///
219    /// A reference to the shape of the trajectories.
220    ///
221    #[inline]
222    pub fn shape(&self) -> &Array1<usize> {
223        &self.shape
224    }
225}
226
227impl FromIterator<CatWtdTrj> for CatWtdTrjs {
228    #[inline]
229    fn from_iter<I: IntoIterator<Item = CatWtdTrj>>(iter: I) -> Self {
230        Self::new(iter)
231    }
232}
233
234impl FromParallelIterator<CatWtdTrj> for CatWtdTrjs {
235    #[inline]
236    fn from_par_iter<I: IntoParallelIterator<Item = CatWtdTrj>>(iter: I) -> Self {
237        Self::new(iter.into_par_iter().collect::<Vec<_>>())
238    }
239}
240
241impl<'a> IntoIterator for &'a CatWtdTrjs {
242    type IntoIter = std::slice::Iter<'a, CatWtdTrj>;
243    type Item = &'a CatWtdTrj;
244
245    #[inline]
246    fn into_iter(self) -> Self::IntoIter {
247        self.values.iter()
248    }
249}
250
251impl<'a> IntoParallelRefIterator<'a> for CatWtdTrjs {
252    type Item = &'a CatWtdTrj;
253    type Iter = rayon::slice::Iter<'a, CatWtdTrj>;
254
255    #[inline]
256    fn par_iter(&'a self) -> Self::Iter {
257        self.values.par_iter()
258    }
259}
260
261impl Labelled for CatWtdTrjs {
262    #[inline]
263    fn labels(&self) -> &Labels {
264        &self.labels
265    }
266}
267
268impl Dataset for CatWtdTrjs {
269    type Values = Vec<CatWtdTrj>;
270
271    #[inline]
272    fn values(&self) -> &Self::Values {
273        &self.values
274    }
275
276    #[inline]
277    fn sample_size(&self) -> f64 {
278        self.values.iter().map(Dataset::sample_size).sum()
279    }
280}