causal_hub/datasets/table/categorical/
weighted.rs

1use ndarray::prelude::*;
2
3use crate::{
4    datasets::{CatSample, CatTable, Dataset},
5    models::Labelled,
6    types::{Labels, States},
7};
8
9/// A type alias for a categorical weighted sample.
10pub type CatWtdSample = (CatSample, f64);
11
12/// A multivariate categorical weighted dataset.
13#[derive(Clone, Debug)]
14pub struct CatWtdTable {
15    dataset: CatTable,
16    weights: Array1<f64>,
17}
18
19impl Labelled for CatWtdTable {
20    #[inline]
21    fn labels(&self) -> &Labels {
22        self.dataset.labels()
23    }
24}
25
26impl CatWtdTable {
27    /// Creates a new categorical weighted dataset.
28    ///
29    /// # Arguments
30    ///
31    /// * `dataset` - The categorical dataset.
32    /// * `weights` - The weights of the samples.
33    ///
34    /// # Panics
35    ///
36    /// * Panics if the number of weights is not equal to the number of samples.
37    /// * Panics if any weight is not in the range [0, 1].
38    ///
39    /// # Returns
40    ///
41    /// A new categorical weighted dataset instance.
42    ///
43    pub fn new(dataset: CatTable, weights: Array1<f64>) -> Self {
44        assert_eq!(
45            dataset.values().nrows(),
46            weights.len(),
47            "The number of weights must be equal to the number of samples."
48        );
49        assert!(
50            weights.iter().all(|&w| (0.0..=1.0).contains(&w)),
51            "All weights must be in the range [0, 1]."
52        );
53
54        Self { dataset, weights }
55    }
56
57    /// Returns the states of the variables in the categorical distribution.
58    ///
59    /// # Returns
60    ///
61    /// A reference to the vector of states.
62    ///
63    #[inline]
64    pub const fn states(&self) -> &States {
65        self.dataset.states()
66    }
67
68    /// Returns the shape of the set of states in the categorical distribution.
69    ///
70    /// # Returns
71    ///
72    /// A reference to the array of shape.
73    ///
74    #[inline]
75    pub const fn shape(&self) -> &Array1<usize> {
76        self.dataset.shape()
77    }
78
79    /// Returns the weights of the samples in the categorical distribution.
80    ///
81    /// # Returns
82    ///
83    /// A reference to the array of weights.
84    ///
85    #[inline]
86    pub const fn weights(&self) -> &Array1<f64> {
87        &self.weights
88    }
89}
90
91impl Dataset for CatWtdTable {
92    type Values = CatTable;
93
94    #[inline]
95    fn values(&self) -> &Self::Values {
96        &self.dataset
97    }
98
99    #[inline]
100    fn sample_size(&self) -> f64 {
101        self.weights.sum()
102    }
103}