use bit_vec::BitVec;
use nalgebra::DVector;
use rayon::prelude::*;
use serde_derive::{Deserialize, Serialize};
use std::{ops::Index, sync::Arc};
use crate::utils::Mask;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MaskedSample {
pub(crate) data: DVector<f64>,
pub(crate) mask: Mask,
}
impl MaskedSample {
pub fn mask_non_finite(data: DVector<f64>) -> MaskedSample {
let mask = data.iter().copied().map(f64::is_finite).collect::<BitVec>();
MaskedSample::new(data, Mask(mask))
}
pub fn new(data: DVector<f64>, mask: Mask) -> MaskedSample {
MaskedSample { data, mask }
}
pub fn unmasked(data: DVector<f64>) -> MaskedSample {
MaskedSample {
mask: Mask::unmasked(data.len()),
data,
}
}
pub fn data_vector(&self) -> DVector<f64> {
DVector::from(self.data.clone())
}
pub fn is_empty(&self) -> bool {
!self.mask.0.any()
}
pub fn mask(&self) -> &Mask {
&self.mask
}
pub fn is_set(&self, idx: usize) -> bool {
self.mask.is_set(idx)
}
pub fn masked_vector(&self) -> DVector<f64> {
self.data
.iter()
.copied()
.zip(&self.mask.0)
.map(|(value, selected)| if selected { value } else { f64::NAN })
.collect::<Vec<_>>()
.into()
}
}
impl Index<usize> for MaskedSample {
type Output = f64;
fn index(&self, index: usize) -> &Self::Output {
if self.is_set(index) {
&self.data[index]
} else {
panic!("Index out of bounds: index {index} is masked in sample")
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Dataset {
pub data: Arc<Vec<MaskedSample>>,
pub weights: Vec<f64>,
}
impl From<Vec<MaskedSample>> for Dataset {
fn from(value: Vec<MaskedSample>) -> Self {
Dataset {
weights: vec![1.0; value.len()],
data: Arc::new(value),
}
}
}
impl FromIterator<MaskedSample> for Dataset {
fn from_iter<T>(iter: T) -> Self
where
T: IntoIterator<Item = MaskedSample>,
{
let data: Vec<_> = iter.into_iter().collect();
Self::new(data)
}
}
impl FromIterator<(MaskedSample, f64)> for Dataset {
fn from_iter<T>(iter: T) -> Self
where
T: IntoIterator<Item = (MaskedSample, f64)>,
{
let (data, weights): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
Self::new_with_weights(data, weights)
}
}
impl FromParallelIterator<MaskedSample> for Dataset {
fn from_par_iter<T>(iter: T) -> Self
where
T: IntoParallelIterator<Item = MaskedSample>,
{
let data: Vec<_> = iter.into_par_iter().collect();
Self::new(data)
}
}
impl FromParallelIterator<(MaskedSample, f64)> for Dataset {
fn from_par_iter<T>(iter: T) -> Self
where
T: IntoParallelIterator<Item = (MaskedSample, f64)>,
{
let (data, weights): (Vec<_>, Vec<_>) = iter.into_par_iter().unzip();
Self::new_with_weights(data, weights)
}
}
impl Dataset {
pub fn new(data: Vec<MaskedSample>) -> Dataset {
Dataset {
weights: vec![1.0; data.len()],
data: Arc::new(data),
}
}
pub fn new_with_weights(data: Vec<MaskedSample>, weights: Vec<f64>) -> Dataset {
assert_eq!(data.len(), weights.len());
Dataset {
data: Arc::new(data),
weights,
}
}
pub fn with_weights(&self, weights: Vec<f64>) -> Dataset {
Dataset {
data: self.data.clone(),
weights,
}
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn output_size(&self) -> Option<usize> {
self.data.first().map(|sample| sample.mask().0.len())
}
pub fn empty_dimensions(&self) -> Vec<usize> {
let Some(n_dimensions) = self.data.first().map(|sample| sample.mask().0.len()) else {
return vec![]
};
let new_mask = || BitVec::from_elem(n_dimensions, false);
let poormans_or = |mut this: BitVec, other: &BitVec| {
for (position, is_selected) in other.iter().enumerate() {
if is_selected {
this.set(position, true);
}
}
this
};
let is_not_empty_dimension = self
.data
.par_iter()
.fold(&new_mask, |buffer, sample| {
poormans_or(buffer, &sample.mask().0)
})
.reduce(&new_mask, |this, other| poormans_or(this, &other));
is_not_empty_dimension
.into_iter()
.enumerate()
.filter(|(_, is_not_empty)| !is_not_empty)
.map(|(dimension, _)| dimension)
.collect()
}
}