causal_hub/datasets/trajectory/categorical/
weighted.rs1use ndarray::prelude::*;
2use rayon::prelude::*;
3
4use crate::{
5 datasets::{CatTrj, CatType, Dataset},
6 models::Labelled,
7 types::{Labels, States},
8};
9
10#[derive(Clone, Debug)]
12pub struct CatWtdTrj {
13 trajectory: CatTrj,
14 weight: f64,
15}
16
17impl From<(CatTrj, f64)> for CatWtdTrj {
18 fn from((trajectory, weight): (CatTrj, f64)) -> Self {
19 Self::new(trajectory, weight)
20 }
21}
22
23impl From<CatWtdTrj> for (CatTrj, f64) {
24 fn from(other: CatWtdTrj) -> Self {
25 (other.trajectory, other.weight)
26 }
27}
28
29impl CatWtdTrj {
30 pub fn new(trajectory: CatTrj, weight: f64) -> Self {
46 assert!(
48 (0.0..=1.0).contains(&weight),
49 "Weight must be in the range [0, 1], but got {weight}."
50 );
51
52 Self { trajectory, weight }
53 }
54
55 #[inline]
62 pub const fn trajectory(&self) -> &CatTrj {
63 &self.trajectory
64 }
65
66 #[inline]
73 pub const fn weight(&self) -> f64 {
74 self.weight
75 }
76
77 #[inline]
84 pub const fn states(&self) -> &States {
85 self.trajectory.states()
86 }
87
88 #[inline]
95 pub const fn shape(&self) -> &Array1<usize> {
96 self.trajectory.shape()
97 }
98
99 #[inline]
106 pub const fn times(&self) -> &Array1<f64> {
107 self.trajectory.times()
108 }
109}
110
111impl Labelled for CatWtdTrj {
112 #[inline]
113 fn labels(&self) -> &Labels {
114 self.trajectory.labels()
115 }
116}
117
118impl Dataset for CatWtdTrj {
119 type Values = Array2<CatType>;
120
121 #[inline]
122 fn values(&self) -> &Self::Values {
123 self.trajectory.values()
124 }
125
126 #[inline]
127 fn sample_size(&self) -> f64 {
128 self.weight * (self.trajectory.values().nrows() as f64)
129 }
130}
131
132#[derive(Clone, Debug)]
134pub struct CatWtdTrjs {
135 labels: Labels,
136 states: States,
137 shape: Array1<usize>,
138 values: Vec<CatWtdTrj>,
139}
140
141impl CatWtdTrjs {
142 pub fn new<I>(values: I) -> Self
162 where
163 I: IntoIterator<Item = CatWtdTrj>,
164 {
165 let values: Vec<_> = values.into_iter().collect();
167
168 assert!(
170 values
171 .windows(2)
172 .all(|trjs| trjs[0].labels().eq(trjs[1].labels())),
173 "All trajectories must have the same labels."
174 );
175 assert!(
177 values
178 .windows(2)
179 .all(|trjs| trjs[0].states().eq(trjs[1].states())),
180 "All trajectories must have the same states."
181 );
182 assert!(
184 values
185 .windows(2)
186 .all(|trjs| trjs[0].shape().eq(trjs[1].shape())),
187 "All trajectories must have the same shape."
188 );
189
190 let trj = values.first().expect("No trajectory in the dataset.");
192 let labels = trj.labels().clone();
193 let states = trj.states().clone();
194 let shape = trj.shape().clone();
195
196 Self {
197 labels,
198 states,
199 shape,
200 values,
201 }
202 }
203
204 #[inline]
211 pub fn states(&self) -> &States {
212 &self.states
213 }
214
215 #[inline]
222 pub fn shape(&self) -> &Array1<usize> {
223 &self.shape
224 }
225}
226
227impl FromIterator<CatWtdTrj> for CatWtdTrjs {
228 #[inline]
229 fn from_iter<I: IntoIterator<Item = CatWtdTrj>>(iter: I) -> Self {
230 Self::new(iter)
231 }
232}
233
234impl FromParallelIterator<CatWtdTrj> for CatWtdTrjs {
235 #[inline]
236 fn from_par_iter<I: IntoParallelIterator<Item = CatWtdTrj>>(iter: I) -> Self {
237 Self::new(iter.into_par_iter().collect::<Vec<_>>())
238 }
239}
240
241impl<'a> IntoIterator for &'a CatWtdTrjs {
242 type IntoIter = std::slice::Iter<'a, CatWtdTrj>;
243 type Item = &'a CatWtdTrj;
244
245 #[inline]
246 fn into_iter(self) -> Self::IntoIter {
247 self.values.iter()
248 }
249}
250
251impl<'a> IntoParallelRefIterator<'a> for CatWtdTrjs {
252 type Item = &'a CatWtdTrj;
253 type Iter = rayon::slice::Iter<'a, CatWtdTrj>;
254
255 #[inline]
256 fn par_iter(&'a self) -> Self::Iter {
257 self.values.par_iter()
258 }
259}
260
261impl Labelled for CatWtdTrjs {
262 #[inline]
263 fn labels(&self) -> &Labels {
264 &self.labels
265 }
266}
267
268impl Dataset for CatWtdTrjs {
269 type Values = Vec<CatWtdTrj>;
270
271 #[inline]
272 fn values(&self) -> &Self::Values {
273 &self.values
274 }
275
276 #[inline]
277 fn sample_size(&self) -> f64 {
278 self.values.iter().map(Dataset::sample_size).sum()
279 }
280}