causal_hub/datasets/trajectory/categorical/
dataset.rs1use itertools::Itertools;
2use ndarray::prelude::*;
3use rayon::prelude::*;
4
5use crate::{
6 datasets::{CatTable, CatType, Dataset},
7 models::Labelled,
8 types::{Labels, States},
9};
10
11#[derive(Clone, Debug)]
13pub struct CatTrj {
14 events: CatTable,
15 times: Array1<f64>,
16}
17
18impl CatTrj {
19 pub fn new(states: States, mut events: Array2<CatType>, mut times: Array1<f64>) -> Self {
32 assert_eq!(
34 events.nrows(),
35 times.len(),
36 "Trajectory events and times must have the same length."
37 );
38 times.iter().for_each(|&t| {
40 assert!(
41 t.is_finite() && t >= 0.,
42 "Trajectory times must be finite and positive: \n\
43 \t expected: time >= 0 , \n\
44 \t found: time == {t} ."
45 );
46 });
47
48 let mut sorted_idx: Vec<_> = (0..events.nrows()).collect();
50 sorted_idx.sort_by(|&a, &b| {
51 times[a]
52 .partial_cmp(×[b])
53 .unwrap_or_else(|| unreachable!())
55 });
56
57 if !sorted_idx.iter().is_sorted() {
59 let mut new_times = times.clone();
61 new_times
62 .iter_mut()
63 .enumerate()
64 .for_each(|(i, new_time)| *new_time = times[sorted_idx[i]]);
65 times = new_times;
67
68 let mut new_events = events.clone();
70 new_events
72 .rows_mut()
73 .into_iter()
74 .enumerate()
75 .for_each(|(i, mut new_events_row)| {
76 new_events_row.assign(&events.row(sorted_idx[i]));
77 });
78 events = new_events;
80 }
81
82 {
84 let count = times.iter().dedup().count();
86 let length = times.len();
88 assert_eq!(
90 count, length,
91 "Trajectory times must be unique: \n\
92 \t expected: {count} deduplicated time-points, \n\
93 \t found: {length} non-deduplicated time-points, \n\
94 \t for: {times}."
95 );
96 }
97
98 events
100 .rows()
101 .into_iter()
102 .zip(×)
103 .tuple_windows()
104 .for_each(|((e_i, t_i), (e_j, t_j))| {
105 let count = e_i.iter().zip(e_j).filter(|(a, b)| a != b).count();
107 assert!(
109 count <= 1,
110 "Trajectory events must contain at max one change per transition: \n\
111 \t expected: count <= 1 state change, \n\
112 \t found: count == {count} state changes, \n\
113 \t for: {e_i} event with time {t_i}, \n\
114 \t and: {e_j} event with time {t_j}."
115 );
116 });
117
118 let events = CatTable::new(states, events);
120
121 Self { events, times }
123 }
124
125 #[inline]
132 pub const fn states(&self) -> &States {
133 self.events.states()
134 }
135
136 #[inline]
143 pub const fn shape(&self) -> &Array1<usize> {
144 self.events.shape()
145 }
146
147 #[inline]
154 pub const fn times(&self) -> &Array1<f64> {
155 &self.times
156 }
157}
158
159impl Labelled for CatTrj {
160 #[inline]
161 fn labels(&self) -> &Labels {
162 self.events.labels()
163 }
164}
165
166impl Dataset for CatTrj {
167 type Values = Array2<CatType>;
168
169 #[inline]
170 fn values(&self) -> &Self::Values {
171 self.events.values()
172 }
173
174 #[inline]
175 fn sample_size(&self) -> f64 {
176 self.events.values().nrows() as f64
177 }
178}
179
180#[derive(Clone, Debug)]
182pub struct CatTrjs {
183 labels: Labels,
184 states: States,
185 shape: Array1<usize>,
186 values: Vec<CatTrj>,
187}
188
189impl CatTrjs {
190 pub fn new<I>(values: I) -> Self
209 where
210 I: IntoIterator<Item = CatTrj>,
211 {
212 let values: Vec<_> = values.into_iter().collect();
214
215 assert!(
217 values
218 .windows(2)
219 .all(|trjs| trjs[0].labels().eq(trjs[1].labels())),
220 "All trajectories must have the same labels."
221 );
222 assert!(
224 values
225 .windows(2)
226 .all(|trjs| trjs[0].states().eq(trjs[1].states())),
227 "All trajectories must have the same states."
228 );
229 assert!(
231 values
232 .windows(2)
233 .all(|trjs| trjs[0].shape().eq(trjs[1].shape())),
234 "All trajectories must have the same shape."
235 );
236
237 let (labels, states, shape) = match values.first() {
239 None => (Labels::default(), States::default(), Array1::default((0,))),
240 Some(x) => (x.labels().clone(), x.states().clone(), x.shape().clone()),
241 };
242
243 Self {
244 labels,
245 states,
246 shape,
247 values,
248 }
249 }
250
251 #[inline]
258 pub fn states(&self) -> &States {
259 &self.states
260 }
261
262 #[inline]
269 pub fn shape(&self) -> &Array1<usize> {
270 &self.shape
271 }
272}
273
274impl FromIterator<CatTrj> for CatTrjs {
275 #[inline]
276 fn from_iter<I: IntoIterator<Item = CatTrj>>(iter: I) -> Self {
277 Self::new(iter)
278 }
279}
280
281impl FromParallelIterator<CatTrj> for CatTrjs {
282 #[inline]
283 fn from_par_iter<I: IntoParallelIterator<Item = CatTrj>>(iter: I) -> Self {
284 Self::new(iter.into_par_iter().collect::<Vec<_>>())
285 }
286}
287
288impl<'a> IntoIterator for &'a CatTrjs {
289 type IntoIter = std::slice::Iter<'a, CatTrj>;
290 type Item = &'a CatTrj;
291
292 #[inline]
293 fn into_iter(self) -> Self::IntoIter {
294 self.values.iter()
295 }
296}
297
298impl<'a> IntoParallelRefIterator<'a> for CatTrjs {
299 type Item = &'a CatTrj;
300 type Iter = rayon::slice::Iter<'a, CatTrj>;
301
302 #[inline]
303 fn par_iter(&'a self) -> Self::Iter {
304 self.values.par_iter()
305 }
306}
307
308impl Labelled for CatTrjs {
309 #[inline]
310 fn labels(&self) -> &Labels {
311 &self.labels
312 }
313}
314
315impl Dataset for CatTrjs {
316 type Values = Vec<CatTrj>;
317
318 #[inline]
319 fn values(&self) -> &Self::Values {
320 &self.values
321 }
322
323 #[inline]
324 fn sample_size(&self) -> f64 {
325 self.values.iter().map(Dataset::sample_size).sum()
326 }
327}