causal_hub/datasets/table/gaussian/
dataset.rs

1use csv::ReaderBuilder;
2use ndarray::prelude::*;
3
4use crate::{datasets::Dataset, io::CsvIO, models::Labelled, types::Labels};
5
6/// A type alias for a gaussian variable.
7pub type GaussType = f64;
8/// A type alias for a gaussian sample.
9pub type GaussSample = Array1<GaussType>;
10
11/// A struct representing a gaussian dataset.
12#[derive(Clone, Debug)]
13pub struct GaussTable {
14    labels: Labels,
15    values: Array2<GaussType>,
16}
17
18impl Labelled for GaussTable {
19    #[inline]
20    fn labels(&self) -> &Labels {
21        &self.labels
22    }
23}
24
25impl GaussTable {
26    /// Creates a new gaussian dataset.
27    ///
28    /// # Arguments
29    ///
30    /// * `labels` - The labels of the variables.
31    /// * `values` - The values of the variables.
32    ///
33    /// # Panics
34    ///
35    /// * Panics if the number of columns in `values` does not match the number of `labels`.
36    ///
37    /// # Results
38    ///
39    /// A new gaussian dataset instance.
40    ///
41    pub fn new(mut labels: Labels, mut values: Array2<GaussType>) -> Self {
42        // Assert that the number of labels matches the number of columns in values.
43        assert_eq!(
44            labels.len(),
45            values.ncols(),
46            "Number of labels must match number of columns in values."
47        );
48
49        // Sort labels and values accordingly.
50        if !labels.is_sorted() {
51            // Allocate indices to sort labels.
52            let mut indices: Vec<usize> = (0..labels.len()).collect();
53            // Sort the indices by labels.
54            indices.sort_by_key(|&i| &labels[i]);
55            // Sort the labels.
56            labels.sort();
57            // Allocate new values.
58            let mut new_values = values.clone();
59            // Sort the new values according to the sorted indices.
60            indices.into_iter().enumerate().for_each(|(i, j)| {
61                new_values.column_mut(i).assign(&values.column(j));
62            });
63            // Update values.
64            values = new_values;
65        }
66        // Assert values are finite.
67        assert!(
68            values.iter().all(|&x| x.is_finite()),
69            "Values must have finite values."
70        );
71
72        Self { labels, values }
73    }
74}
75
76impl Dataset for GaussTable {
77    type Values = Array2<GaussType>;
78
79    #[inline]
80    fn values(&self) -> &Self::Values {
81        &self.values
82    }
83
84    #[inline]
85    fn sample_size(&self) -> f64 {
86        self.values.nrows() as f64
87    }
88}
89
90impl CsvIO for GaussTable {
91    fn from_csv(csv: &str) -> Self {
92        // Create a CSV reader from the string.
93        let mut reader = ReaderBuilder::new()
94            .has_headers(true)
95            .from_reader(csv.as_bytes());
96
97        // Assert that the reader has headers.
98        assert!(reader.has_headers(), "Reader must have headers.");
99
100        // Read the headers.
101        let labels: Labels = reader
102            .headers()
103            .expect("Failed to read the headers.")
104            .into_iter()
105            .map(|x| x.to_owned())
106            .collect();
107
108        // Read the records.
109        let values: Array1<_> = reader
110            .into_records()
111            .enumerate()
112            .flat_map(|(i, row)| {
113                // Get the record row.
114                let row = row.unwrap_or_else(|_| panic!("Malformed record on line {}.", i + 1));
115                // Get the record values and convert to indices.
116                let row: Vec<_> = row
117                    .into_iter()
118                    .enumerate()
119                    .map(|(i, x)| {
120                        // Assert no missing values.
121                        assert!(!x.is_empty(), "Missing value on line {}.", i + 1);
122                        // Cast the value.
123                        x.parse::<GaussType>().unwrap()
124                    })
125                    .collect();
126                // Collect the values.
127                row
128            })
129            .collect();
130
131        // Get the number of rows and columns.
132        let ncols = labels.len();
133        let nrows = values.len() / ncols;
134        // Reshape the values to the correct shape.
135        let values = values
136            .into_shape_with_order((nrows, ncols))
137            .expect("Failed to rearrange values to the correct shape.");
138
139        // Construct the dataset.
140        Self::new(labels, values)
141    }
142
143    fn to_csv(&self) -> String {
144        todo!() // FIXME:
145    }
146}