causal_hub/datasets/table/categorical/weighted.rs
1use ndarray::prelude::*;
2
3use crate::{
4 datasets::{CatSample, CatTable, Dataset},
5 models::Labelled,
6 types::{Labels, States},
7};
8
9/// A type alias for a categorical weighted sample.
10pub type CatWtdSample = (CatSample, f64);
11
12/// A multivariate categorical weighted dataset.
13#[derive(Clone, Debug)]
14pub struct CatWtdTable {
15 dataset: CatTable,
16 weights: Array1<f64>,
17}
18
19impl Labelled for CatWtdTable {
20 #[inline]
21 fn labels(&self) -> &Labels {
22 self.dataset.labels()
23 }
24}
25
26impl CatWtdTable {
27 /// Creates a new categorical weighted dataset.
28 ///
29 /// # Arguments
30 ///
31 /// * `dataset` - The categorical dataset.
32 /// * `weights` - The weights of the samples.
33 ///
34 /// # Panics
35 ///
36 /// * Panics if the number of weights is not equal to the number of samples.
37 /// * Panics if any weight is not in the range [0, 1].
38 ///
39 /// # Returns
40 ///
41 /// A new categorical weighted dataset instance.
42 ///
43 pub fn new(dataset: CatTable, weights: Array1<f64>) -> Self {
44 assert_eq!(
45 dataset.values().nrows(),
46 weights.len(),
47 "The number of weights must be equal to the number of samples."
48 );
49 assert!(
50 weights.iter().all(|&w| (0.0..=1.0).contains(&w)),
51 "All weights must be in the range [0, 1]."
52 );
53
54 Self { dataset, weights }
55 }
56
57 /// Returns the states of the variables in the categorical distribution.
58 ///
59 /// # Returns
60 ///
61 /// A reference to the vector of states.
62 ///
63 #[inline]
64 pub const fn states(&self) -> &States {
65 self.dataset.states()
66 }
67
68 /// Returns the shape of the set of states in the categorical distribution.
69 ///
70 /// # Returns
71 ///
72 /// A reference to the array of shape.
73 ///
74 #[inline]
75 pub const fn shape(&self) -> &Array1<usize> {
76 self.dataset.shape()
77 }
78
79 /// Returns the weights of the samples in the categorical distribution.
80 ///
81 /// # Returns
82 ///
83 /// A reference to the array of weights.
84 ///
85 #[inline]
86 pub const fn weights(&self) -> &Array1<f64> {
87 &self.weights
88 }
89}
90
91impl Dataset for CatWtdTable {
92 type Values = CatTable;
93
94 #[inline]
95 fn values(&self) -> &Self::Values {
96 &self.dataset
97 }
98
99 #[inline]
100 fn sample_size(&self) -> f64 {
101 self.weights.sum()
102 }
103}