causal_hub/datasets/trajectory/categorical/
dataset.rs

1use itertools::Itertools;
2use ndarray::prelude::*;
3use rayon::prelude::*;
4
5use crate::{
6    datasets::{CatTable, CatType, Dataset},
7    models::Labelled,
8    types::{Labels, States},
9};
10
11/// A multivariate trajectory.
12#[derive(Clone, Debug)]
13pub struct CatTrj {
14    events: CatTable,
15    times: Array1<f64>,
16}
17
18impl CatTrj {
19    /// Constructs a new trajectory instance.
20    ///
21    /// # Arguments
22    ///
23    /// * `states` - An iterator of tuples containing the state labels and their corresponding values.
24    /// * `events` - A 2D array of events.
25    /// * `times` - A 1D array of times.
26    ///
27    /// # Returns
28    ///
29    /// A new instance of `CatTrj`.
30    ///
31    pub fn new(states: States, mut events: Array2<CatType>, mut times: Array1<f64>) -> Self {
32        // Assert the number of rows in values and times are equal.
33        assert_eq!(
34            events.nrows(),
35            times.len(),
36            "Trajectory events and times must have the same length."
37        );
38        // Assert times must be positive and finite.
39        times.iter().for_each(|&t| {
40            assert!(
41                t.is_finite() && t >= 0.,
42                "Trajectory times must be finite and positive: \n\
43                \t expected: time >= 0 , \n\
44                \t found:    time == {t} ."
45            );
46        });
47
48        // Sort values by times.
49        let mut sorted_idx: Vec<_> = (0..events.nrows()).collect();
50        sorted_idx.sort_by(|&a, &b| {
51            times[a]
52                .partial_cmp(&times[b])
53                // Due to previous assertions, this should never fail.
54                .unwrap_or_else(|| unreachable!())
55        });
56
57        // Check if the times are already sorted.
58        if !sorted_idx.iter().is_sorted() {
59            // Sort times.
60            let mut new_times = times.clone();
61            new_times
62                .iter_mut()
63                .enumerate()
64                .for_each(|(i, new_time)| *new_time = times[sorted_idx[i]]);
65            // Update the times with the sorted values.
66            times = new_times;
67
68            // Sort events by time.
69            let mut new_events = events.clone();
70            // Sort the events by the sorted indices.
71            new_events
72                .rows_mut()
73                .into_iter()
74                .enumerate()
75                .for_each(|(i, mut new_events_row)| {
76                    new_events_row.assign(&events.row(sorted_idx[i]));
77                });
78            // Update the events with the sorted values.
79            events = new_events;
80        }
81
82        // Assert no duplicate times.
83        {
84            // Count the number of unique times.
85            let count = times.iter().dedup().count();
86            // Get the length of the times array.
87            let length = times.len();
88            // Assert the number of unique times is equal to the length of the times array.
89            assert_eq!(
90                count, length,
91                "Trajectory times must be unique: \n\
92                \t expected: {count} deduplicated time-points, \n\
93                \t found:    {length} non-deduplicated time-points, \n\
94                \t for:      {times}."
95            );
96        }
97
98        // Assert at max one state change per transition.
99        events
100            .rows()
101            .into_iter()
102            .zip(&times)
103            .tuple_windows()
104            .for_each(|((e_i, t_i), (e_j, t_j))| {
105                // Count the number of state changes.
106                let count = e_i.iter().zip(e_j).filter(|(a, b)| a != b).count();
107                // Assert there is one and only one state change.
108                assert!(
109                    count <= 1,
110                    "Trajectory events must contain at max one change per transition: \n\
111                    \t expected: count <= 1 state change, \n\
112                    \t found:    count == {count} state changes, \n\
113                    \t for:      {e_i} event with time {t_i}, \n\
114                    \t and:      {e_j} event with time {t_j}."
115                );
116            });
117
118        // Create a new categorical dataset instance.
119        let events = CatTable::new(states, events);
120
121        // Return a new trajectory instance.
122        Self { events, times }
123    }
124
125    /// Returns the states of the trajectory.
126    ///
127    /// # Returns
128    ///
129    /// A reference to the states of the trajectory.
130    ///
131    #[inline]
132    pub const fn states(&self) -> &States {
133        self.events.states()
134    }
135
136    /// Returns the shape of the trajectory.
137    ///
138    /// # Returns
139    ///
140    /// A reference to the shape of the trajectory.
141    ///
142    #[inline]
143    pub const fn shape(&self) -> &Array1<usize> {
144        self.events.shape()
145    }
146
147    /// Returns the times of the trajectory.
148    ///
149    /// # Returns
150    ///
151    /// A reference to the times of the trajectory.
152    ///
153    #[inline]
154    pub const fn times(&self) -> &Array1<f64> {
155        &self.times
156    }
157}
158
159impl Labelled for CatTrj {
160    #[inline]
161    fn labels(&self) -> &Labels {
162        self.events.labels()
163    }
164}
165
166impl Dataset for CatTrj {
167    type Values = Array2<CatType>;
168
169    #[inline]
170    fn values(&self) -> &Self::Values {
171        self.events.values()
172    }
173
174    #[inline]
175    fn sample_size(&self) -> f64 {
176        self.events.values().nrows() as f64
177    }
178}
179
180/// A collection of multivariate trajectories.
181#[derive(Clone, Debug)]
182pub struct CatTrjs {
183    labels: Labels,
184    states: States,
185    shape: Array1<usize>,
186    values: Vec<CatTrj>,
187}
188
189impl CatTrjs {
190    /// Constructs a new collection of trajectories.
191    ///
192    /// # Arguments
193    ///
194    /// * `trajectories` - An iterator of `CategoricalTrajectory` instances.
195    ///
196    /// # Panics
197    ///
198    /// Panics if:
199    ///
200    /// * The trajectories have different labels.
201    /// * The trajectories have different states.
202    /// * The trajectories have different shape.
203    ///
204    /// # Returns
205    ///
206    /// A new instance of `CategoricalTrajectories`.
207    ///
208    pub fn new<I>(values: I) -> Self
209    where
210        I: IntoIterator<Item = CatTrj>,
211    {
212        // Collect the trajectories into a vector.
213        let values: Vec<_> = values.into_iter().collect();
214
215        // Assert every trajectory has the same labels.
216        assert!(
217            values
218                .windows(2)
219                .all(|trjs| trjs[0].labels().eq(trjs[1].labels())),
220            "All trajectories must have the same labels."
221        );
222        // Assert every trajectory has the same states.
223        assert!(
224            values
225                .windows(2)
226                .all(|trjs| trjs[0].states().eq(trjs[1].states())),
227            "All trajectories must have the same states."
228        );
229        // Assert every trajectory has the same shape.
230        assert!(
231            values
232                .windows(2)
233                .all(|trjs| trjs[0].shape().eq(trjs[1].shape())),
234            "All trajectories must have the same shape."
235        );
236
237        // Get the labels, states and shape from the first trajectory.
238        let (labels, states, shape) = match values.first() {
239            None => (Labels::default(), States::default(), Array1::default((0,))),
240            Some(x) => (x.labels().clone(), x.states().clone(), x.shape().clone()),
241        };
242
243        Self {
244            labels,
245            states,
246            shape,
247            values,
248        }
249    }
250
251    /// Returns the states of the trajectories.
252    ///
253    /// # Returns
254    ///
255    /// A reference to the states of the trajectories.
256    ///
257    #[inline]
258    pub fn states(&self) -> &States {
259        &self.states
260    }
261
262    /// Returns the shape of the trajectories.
263    ///
264    /// # Returns
265    ///
266    /// A reference to the shape of the trajectories.
267    ///
268    #[inline]
269    pub fn shape(&self) -> &Array1<usize> {
270        &self.shape
271    }
272}
273
274impl FromIterator<CatTrj> for CatTrjs {
275    #[inline]
276    fn from_iter<I: IntoIterator<Item = CatTrj>>(iter: I) -> Self {
277        Self::new(iter)
278    }
279}
280
281impl FromParallelIterator<CatTrj> for CatTrjs {
282    #[inline]
283    fn from_par_iter<I: IntoParallelIterator<Item = CatTrj>>(iter: I) -> Self {
284        Self::new(iter.into_par_iter().collect::<Vec<_>>())
285    }
286}
287
288impl<'a> IntoIterator for &'a CatTrjs {
289    type IntoIter = std::slice::Iter<'a, CatTrj>;
290    type Item = &'a CatTrj;
291
292    #[inline]
293    fn into_iter(self) -> Self::IntoIter {
294        self.values.iter()
295    }
296}
297
298impl<'a> IntoParallelRefIterator<'a> for CatTrjs {
299    type Item = &'a CatTrj;
300    type Iter = rayon::slice::Iter<'a, CatTrj>;
301
302    #[inline]
303    fn par_iter(&'a self) -> Self::Iter {
304        self.values.par_iter()
305    }
306}
307
308impl Labelled for CatTrjs {
309    #[inline]
310    fn labels(&self) -> &Labels {
311        &self.labels
312    }
313}
314
315impl Dataset for CatTrjs {
316    type Values = Vec<CatTrj>;
317
318    #[inline]
319    fn values(&self) -> &Self::Values {
320        &self.values
321    }
322
323    #[inline]
324    fn sample_size(&self) -> f64 {
325        self.values.iter().map(Dataset::sample_size).sum()
326    }
327}