causal_hub/datasets/table/gaussian/
weighted.rs1use ndarray::prelude::*;
2
3use crate::{
4 datasets::{Dataset, GaussSample, GaussTable},
5 models::Labelled,
6 types::Labels,
7};
8
9pub type GaussWtdSample = (GaussSample, f64);
11
12#[derive(Clone, Debug)]
14pub struct GaussWtdTable {
15 dataset: GaussTable,
16 weights: Array1<f64>,
17}
18
19impl Labelled for GaussWtdTable {
20 #[inline]
21 fn labels(&self) -> &Labels {
22 self.dataset.labels()
23 }
24}
25
26impl GaussWtdTable {
27 pub fn new(dataset: GaussTable, weights: Array1<f64>) -> Self {
44 assert_eq!(
45 dataset.values().nrows(),
46 weights.len(),
47 "The number of weights must be equal to the number of samples."
48 );
49 assert!(
50 weights.iter().all(|&w| (0.0..=1.0).contains(&w)),
51 "All weights must be in the range [0, 1]."
52 );
53
54 Self { dataset, weights }
55 }
56
57 #[inline]
64 pub const fn weights(&self) -> &Array1<f64> {
65 &self.weights
66 }
67}
68
69impl Dataset for GaussWtdTable {
70 type Values = GaussTable;
71
72 #[inline]
73 fn values(&self) -> &Self::Values {
74 &self.dataset
75 }
76
77 #[inline]
78 fn sample_size(&self) -> f64 {
79 self.weights.sum()
80 }
81}