causal_hub/datasets/table/categorical/
dataset.rs

1use std::fmt::Display;
2
3use csv::ReaderBuilder;
4use itertools::Itertools;
5use log::debug;
6use ndarray::prelude::*;
7
8use crate::{
9    datasets::Dataset,
10    io::CsvIO,
11    models::Labelled,
12    types::{Labels, Set, States},
13};
14
15/// A type alias for a categorical variable.
16pub type CatType = u8;
17/// A type alias for a categorical sample.
18pub type CatSample = Array1<CatType>;
19
20/// A struct representing a categorical dataset.
21#[derive(Clone, Debug)]
22pub struct CatTable {
23    labels: Labels,
24    states: States,
25    shape: Array1<usize>,
26    values: Array2<CatType>,
27}
28
29impl Labelled for CatTable {
30    #[inline]
31    fn labels(&self) -> &Labels {
32        &self.labels
33    }
34}
35
36impl CatTable {
37    /// Creates a new categorical dataset.
38    ///
39    /// # Arguments
40    ///
41    /// * `states` - The variables states.
42    /// * `values` - The values of the variables.
43    ///
44    /// # Notes
45    ///
46    /// * Labels and states will be sorted in alphabetical order.
47    ///
48    /// # Panics
49    ///
50    /// * If the variable labels are not unique.
51    /// * If the variable states are not unique.
52    /// * If the number of variable states is higher than `CatType::MAX`.
53    /// * If the number of variables is different from the number of values columns.
54    /// * If the variables values are not smaller than the number of states.
55    ///
56    /// # Returns
57    ///
58    /// A new categorical dataset instance.
59    ///
60    pub fn new(mut states: States, mut values: Array2<CatType>) -> Self {
61        // Get the labels of the variables.
62        let mut labels: Labels = states.keys().cloned().collect();
63        // Get the shape of the states.
64        let mut shape = Array::from_iter(states.values().map(Set::len));
65
66        // Log the creation of the categorical dataset.
67        debug!(
68            "Creating a new categorical dataset with {} variables and {} samples.",
69            states.len(),
70            values.nrows()
71        );
72
73        // Check if the number of states is less than `CatType::MAX`.
74        states.iter().for_each(|(label, state)| {
75            assert!(
76                state.len() < CatType::MAX as usize,
77                "Variable '{label}' should have less than 256 states: \n\
78                \t expected:    |states| <  256 , \n\
79                \t found:       |states| == {} .",
80                state.len()
81            );
82        });
83        // Check if the number of variables is equal to the number of columns.
84        assert_eq!(
85            states.len(),
86            values.ncols(),
87            "Number of variables must be equal to the number of columns: \n\
88            \t expected:    |states| == |values.columns()| , \n\
89            \t found:       |states| == {} and |values.columns()| == {} .",
90            states.len(),
91            values.ncols()
92        );
93        // Check if the maximum value of the values is less than the number of states.
94        values
95            .fold_axis(Axis(0), 0, |&a, &b| if a > b { a } else { b })
96            .into_iter()
97            .enumerate()
98            .for_each(|(i, x)| {
99                assert!(
100                    x < states[i].len() as CatType,
101                    "Values of variable '{label}' must be less than the number of states: \n\
102                    \t expected: values[.., '{label}'] < |states['{label}']| , \n\
103                    \t found:    values[.., '{label}'] == {x} and |states['{label}']| == {} .",
104                    states[i].len(),
105                    label = labels[i],
106                );
107            });
108
109        // Check if the values are already sorted.
110        if !states.keys().is_sorted() || !states.values().all(|x| x.iter().is_sorted()) {
111            // Clone the states.
112            let mut new_states = states.clone();
113            // Sort the states.
114            new_states.sort_keys();
115            new_states.values_mut().for_each(Set::sort);
116            // Clone the values.
117            let mut new_values = values.clone();
118            // Update the values according to the sorted states.
119            new_states
120                .iter()
121                .enumerate()
122                .for_each(|(i, (new_label, new_states))| {
123                    // Get the index of the new label in the old states.
124                    let (j, _, states_j) = states
125                        .get_full(new_label)
126                        .expect("Failed to get full old states.");
127                    // Update the values.
128                    new_values
129                        .column_mut(i)
130                        .iter_mut()
131                        .zip(values.column(j))
132                        .for_each(|(new_val, old_val)| {
133                            // Get the old state label.
134                            let old_val = &states_j[*old_val as usize];
135                            // Get the new state index.
136                            *new_val = new_states
137                                .get_index_of(old_val)
138                                .expect("Failed to get new state index.")
139                                as CatType;
140                        });
141                });
142            // Update the values.
143            values = new_values;
144            // Update the states.
145            states = new_states;
146            // Update the labels.
147            labels = states.keys().cloned().collect();
148            // Update the shape.
149            shape = Array::from_iter(states.values().map(Set::len));
150        }
151
152        // Debug assert labels are unique.
153        debug_assert_eq!(
154            labels.iter().unique().count(),
155            labels.len(),
156            "Labels must be unique."
157        );
158        // Debug assert labels are sorted.
159        debug_assert!(labels.iter().is_sorted(), "Labels must be sorted.");
160        // Debug assert states keys are unique.
161        debug_assert_eq!(
162            states.keys().unique().count(),
163            states.len(),
164            "States keys must be unique."
165        );
166        // Debug assert states keys are sorted.
167        debug_assert!(states.keys().is_sorted(), "States keys must be sorted.");
168        // Debug assert states values are unique.
169        debug_assert_eq!(
170            states
171                .values()
172                .map(|x| x.iter().unique().count())
173                .sum::<usize>(),
174            states.values().map(Set::len).sum::<usize>(),
175            "States values must be unique."
176        );
177        // Debug assert states values are sorted.
178        debug_assert!(
179            states.values().all(|x| x.iter().is_sorted()),
180            "States values must be sorted."
181        );
182        // Debug assert labels and states keys are the same.
183        debug_assert!(
184            labels.iter().eq(states.keys()),
185            "Labels and states keys must be the same."
186        );
187        // Debug assert shape must match the number of states.
188        debug_assert!(
189            shape
190                .iter()
191                .zip(states.values())
192                .all(|(&a, b)| a == b.len()),
193            "Shape must match the number of states values."
194        );
195
196        Self {
197            labels,
198            states,
199            shape,
200            values,
201        }
202    }
203
204    /// Returns the states of the variables in the categorical distribution.
205    ///
206    /// # Returns
207    ///
208    /// A reference to the vector of states.
209    ///
210    #[inline]
211    pub const fn states(&self) -> &States {
212        &self.states
213    }
214
215    /// Returns the shape of the set of states in the categorical distribution.
216    ///
217    /// # Returns
218    ///
219    /// A reference to the array of shape.
220    ///
221    #[inline]
222    pub const fn shape(&self) -> &Array1<usize> {
223        &self.shape
224    }
225}
226
227impl Display for CatTable {
228    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229        // Get the maximum length of the labels and states.
230        let n = self
231            .labels()
232            .iter()
233            .chain(self.states().values().flatten())
234            .map(|x| x.len())
235            .max()
236            .unwrap_or(0);
237
238        // Write the top line.
239        let hline = std::iter::repeat_n("-", (n + 3) * self.labels().len() + 1).join("");
240        writeln!(f, "{hline}")?;
241        // Write the header.
242        let header = self.labels().iter().map(|x| format!("{x:n$}")).join(" | ");
243        writeln!(f, "| {header} |")?;
244        // Write the separator.
245        let separator = (0..self.labels().len()).map(|_| "-".repeat(n)).join(" | ");
246        writeln!(f, "| {separator} |")?;
247        // Write the values.
248        for row in self.values.rows() {
249            // Get the state corresponding to the value.
250            let row = row
251                .iter()
252                .enumerate()
253                .map(|(i, &x)| &self.states()[i][x as usize])
254                .map(|x| format!("{x:n$}"))
255                .join(" | ");
256            writeln!(f, "| {row} |")?;
257        }
258        // Write the bottom line.
259        writeln!(f, "{hline}")
260    }
261}
262
263impl Dataset for CatTable {
264    type Values = Array2<CatType>;
265
266    #[inline]
267    fn values(&self) -> &Self::Values {
268        &self.values
269    }
270
271    #[inline]
272    fn sample_size(&self) -> f64 {
273        self.values.nrows() as f64
274    }
275}
276
277impl CsvIO for CatTable {
278    fn from_csv(csv: &str) -> Self {
279        // Create a CSV reader from the string.
280        let mut reader = ReaderBuilder::new()
281            .has_headers(true)
282            .from_reader(csv.as_bytes());
283
284        // Assert that the reader has headers.
285        assert!(reader.has_headers(), "Reader must have headers.");
286
287        // Read the headers.
288        let labels: Labels = reader
289            .headers()
290            .expect("Failed to read the headers.")
291            .into_iter()
292            .map(|x| x.to_owned())
293            .collect();
294
295        // Get the states of the variables.
296        let mut states: States = labels
297            .iter()
298            .map(|x| (x.clone(), Default::default()))
299            .collect();
300
301        // Read the records.
302        let values: Array1<_> = reader
303            .into_records()
304            .enumerate()
305            .flat_map(|(i, row)| {
306                // Get the record row.
307                let row = row.unwrap_or_else(|_| panic!("Malformed record on line {}.", i + 1));
308                // Get the record values and convert to indices.
309                let row: Vec<_> = row
310                    .into_iter()
311                    .enumerate()
312                    .map(|(i, x)| {
313                        // Assert no missing values.
314                        assert!(!x.is_empty(), "Missing value on line {}.", i + 1);
315                        // Insert the value into the states, if not present.
316                        let (x, _) = states[i].insert_full(x.to_owned());
317                        // Cast the value.
318                        x as CatType
319                    })
320                    .collect();
321                // Collect the values.
322                row
323            })
324            .collect();
325
326        // Get the number of rows and columns.
327        let ncols = labels.len();
328        let nrows = values.len() / ncols;
329        // Reshape the values to the correct shape.
330        let values = values
331            .into_shape_with_order((nrows, ncols))
332            .expect("Failed to rearrange values to the correct shape.");
333
334        // Construct the dataset.
335        Self::new(states, values)
336    }
337
338    fn to_csv(&self) -> String {
339        todo!() // FIXME:
340    }
341}