1use bit_vec::BitVec;
2use nalgebra::DVector;
3use rayon::prelude::*;
4use serde_derive::{Deserialize, Serialize};
5use std::{ops::Index, sync::Arc};
6
7use crate::utils::Mask;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct MaskedSample {
12 pub(crate) data: DVector<f64>,
13 pub(crate) mask: Mask,
14}
15
16impl MaskedSample {
17 pub fn mask_non_finite(data: DVector<f64>) -> MaskedSample {
20 let mask = data.iter().copied().map(f64::is_finite).collect::<BitVec>();
21 MaskedSample::new(data, Mask(mask))
22 }
23
24 pub fn new(data: DVector<f64>, mask: Mask) -> MaskedSample {
27 MaskedSample { data, mask }
28 }
29
30 pub fn unmasked(data: DVector<f64>) -> MaskedSample {
32 MaskedSample {
33 mask: Mask::unmasked(data.len()),
34 data,
35 }
36 }
37
38 pub fn data_vector(&self) -> DVector<f64> {
40 DVector::from(self.data.clone())
41 }
42
43 pub fn is_empty(&self) -> bool {
45 !self.mask.0.any()
46 }
47
48 pub fn mask(&self) -> &Mask {
51 &self.mask
52 }
53
54 pub fn is_set(&self, idx: usize) -> bool {
60 self.mask.is_set(idx)
61 }
62
63 pub fn masked_vector(&self) -> DVector<f64> {
65 self.data
66 .iter()
67 .copied()
68 .zip(&self.mask.0)
69 .map(|(value, selected)| if selected { value } else { f64::NAN })
70 .collect::<Vec<_>>()
71 .into()
72 }
73}
74
75impl Index<usize> for MaskedSample {
76 type Output = f64;
77 fn index(&self, index: usize) -> &Self::Output {
78 if self.is_set(index) {
79 &self.data[index]
80 } else {
81 panic!("Index out of bounds: index {index} is masked in sample")
82 }
83 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct Dataset {
94 pub data: Arc<Vec<MaskedSample>>,
96 pub weights: Vec<f64>,
100}
101
102impl From<Vec<MaskedSample>> for Dataset {
103 fn from(value: Vec<MaskedSample>) -> Self {
104 Dataset {
105 weights: vec![1.0; value.len()],
106 data: Arc::new(value),
107 }
108 }
109}
110
111impl FromIterator<MaskedSample> for Dataset {
112 fn from_iter<T>(iter: T) -> Self
113 where
114 T: IntoIterator<Item = MaskedSample>,
115 {
116 let data: Vec<_> = iter.into_iter().collect();
117 Self::new(data)
118 }
119}
120
121impl FromIterator<(MaskedSample, f64)> for Dataset {
122 fn from_iter<T>(iter: T) -> Self
123 where
124 T: IntoIterator<Item = (MaskedSample, f64)>,
125 {
126 let (data, weights): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
127 Self::new_with_weights(data, weights)
128 }
129}
130
131impl FromParallelIterator<MaskedSample> for Dataset {
132 fn from_par_iter<T>(iter: T) -> Self
133 where
134 T: IntoParallelIterator<Item = MaskedSample>,
135 {
136 let data: Vec<_> = iter.into_par_iter().collect();
137 Self::new(data)
138 }
139}
140
141impl FromParallelIterator<(MaskedSample, f64)> for Dataset {
142 fn from_par_iter<T>(iter: T) -> Self
143 where
144 T: IntoParallelIterator<Item = (MaskedSample, f64)>,
145 {
146 let (data, weights): (Vec<_>, Vec<_>) = iter.into_par_iter().unzip();
147 Self::new_with_weights(data, weights)
148 }
149}
150
151impl Dataset {
152 pub fn new(data: Vec<MaskedSample>) -> Dataset {
154 Dataset {
155 weights: vec![1.0; data.len()],
156 data: Arc::new(data),
157 }
158 }
159
160 pub fn new_with_weights(data: Vec<MaskedSample>, weights: Vec<f64>) -> Dataset {
162 assert_eq!(data.len(), weights.len());
163 Dataset {
164 data: Arc::new(data),
165 weights,
166 }
167 }
168
169 pub fn with_weights(&self, weights: Vec<f64>) -> Dataset {
172 Dataset {
173 data: self.data.clone(),
174 weights,
175 }
176 }
177
178 pub fn len(&self) -> usize {
180 self.data.len()
181 }
182
183 pub fn is_empty(&self) -> bool {
185 self.data.is_empty()
186 }
187
188 pub fn output_size(&self) -> Option<usize> {
190 self.data.first().map(|sample| sample.mask().0.len())
191 }
192
193 pub fn empty_dimensions(&self) -> Vec<usize> {
195 let Some(n_dimensions) = self.data.first().map(|sample| sample.mask().0.len()) else {
196 return vec![]
197 };
198 let new_mask = || BitVec::from_elem(n_dimensions, false);
199 let poormans_or = |mut this: BitVec, other: &BitVec| {
200 for (position, is_selected) in other.iter().enumerate() {
201 if is_selected {
202 this.set(position, true);
203 }
204 }
205 this
206 };
207
208 let is_not_empty_dimension = self
209 .data
210 .par_iter()
211 .fold(&new_mask, |buffer, sample| {
212 poormans_or(buffer, &sample.mask().0)
213 })
214 .reduce(&new_mask, |this, other| poormans_or(this, &other));
215
216 is_not_empty_dimension
217 .into_iter()
218 .enumerate()
219 .filter(|(_, is_not_empty)| !is_not_empty)
220 .map(|(dimension, _)| dimension)
221 .collect()
222 }
223}