causal_hub/datasets/table/categorical/
dataset.rs1use std::fmt::Display;
2
3use csv::ReaderBuilder;
4use itertools::Itertools;
5use log::debug;
6use ndarray::prelude::*;
7
8use crate::{
9 datasets::Dataset,
10 io::CsvIO,
11 models::Labelled,
12 types::{Labels, Set, States},
13};
14
15pub type CatType = u8;
17pub type CatSample = Array1<CatType>;
19
20#[derive(Clone, Debug)]
22pub struct CatTable {
23 labels: Labels,
24 states: States,
25 shape: Array1<usize>,
26 values: Array2<CatType>,
27}
28
29impl Labelled for CatTable {
30 #[inline]
31 fn labels(&self) -> &Labels {
32 &self.labels
33 }
34}
35
36impl CatTable {
37 pub fn new(mut states: States, mut values: Array2<CatType>) -> Self {
61 let mut labels: Labels = states.keys().cloned().collect();
63 let mut shape = Array::from_iter(states.values().map(Set::len));
65
66 debug!(
68 "Creating a new categorical dataset with {} variables and {} samples.",
69 states.len(),
70 values.nrows()
71 );
72
73 states.iter().for_each(|(label, state)| {
75 assert!(
76 state.len() < CatType::MAX as usize,
77 "Variable '{label}' should have less than 256 states: \n\
78 \t expected: |states| < 256 , \n\
79 \t found: |states| == {} .",
80 state.len()
81 );
82 });
83 assert_eq!(
85 states.len(),
86 values.ncols(),
87 "Number of variables must be equal to the number of columns: \n\
88 \t expected: |states| == |values.columns()| , \n\
89 \t found: |states| == {} and |values.columns()| == {} .",
90 states.len(),
91 values.ncols()
92 );
93 values
95 .fold_axis(Axis(0), 0, |&a, &b| if a > b { a } else { b })
96 .into_iter()
97 .enumerate()
98 .for_each(|(i, x)| {
99 assert!(
100 x < states[i].len() as CatType,
101 "Values of variable '{label}' must be less than the number of states: \n\
102 \t expected: values[.., '{label}'] < |states['{label}']| , \n\
103 \t found: values[.., '{label}'] == {x} and |states['{label}']| == {} .",
104 states[i].len(),
105 label = labels[i],
106 );
107 });
108
109 if !states.keys().is_sorted() || !states.values().all(|x| x.iter().is_sorted()) {
111 let mut new_states = states.clone();
113 new_states.sort_keys();
115 new_states.values_mut().for_each(Set::sort);
116 let mut new_values = values.clone();
118 new_states
120 .iter()
121 .enumerate()
122 .for_each(|(i, (new_label, new_states))| {
123 let (j, _, states_j) = states
125 .get_full(new_label)
126 .expect("Failed to get full old states.");
127 new_values
129 .column_mut(i)
130 .iter_mut()
131 .zip(values.column(j))
132 .for_each(|(new_val, old_val)| {
133 let old_val = &states_j[*old_val as usize];
135 *new_val = new_states
137 .get_index_of(old_val)
138 .expect("Failed to get new state index.")
139 as CatType;
140 });
141 });
142 values = new_values;
144 states = new_states;
146 labels = states.keys().cloned().collect();
148 shape = Array::from_iter(states.values().map(Set::len));
150 }
151
152 debug_assert_eq!(
154 labels.iter().unique().count(),
155 labels.len(),
156 "Labels must be unique."
157 );
158 debug_assert!(labels.iter().is_sorted(), "Labels must be sorted.");
160 debug_assert_eq!(
162 states.keys().unique().count(),
163 states.len(),
164 "States keys must be unique."
165 );
166 debug_assert!(states.keys().is_sorted(), "States keys must be sorted.");
168 debug_assert_eq!(
170 states
171 .values()
172 .map(|x| x.iter().unique().count())
173 .sum::<usize>(),
174 states.values().map(Set::len).sum::<usize>(),
175 "States values must be unique."
176 );
177 debug_assert!(
179 states.values().all(|x| x.iter().is_sorted()),
180 "States values must be sorted."
181 );
182 debug_assert!(
184 labels.iter().eq(states.keys()),
185 "Labels and states keys must be the same."
186 );
187 debug_assert!(
189 shape
190 .iter()
191 .zip(states.values())
192 .all(|(&a, b)| a == b.len()),
193 "Shape must match the number of states values."
194 );
195
196 Self {
197 labels,
198 states,
199 shape,
200 values,
201 }
202 }
203
204 #[inline]
211 pub const fn states(&self) -> &States {
212 &self.states
213 }
214
215 #[inline]
222 pub const fn shape(&self) -> &Array1<usize> {
223 &self.shape
224 }
225}
226
227impl Display for CatTable {
228 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229 let n = self
231 .labels()
232 .iter()
233 .chain(self.states().values().flatten())
234 .map(|x| x.len())
235 .max()
236 .unwrap_or(0);
237
238 let hline = std::iter::repeat_n("-", (n + 3) * self.labels().len() + 1).join("");
240 writeln!(f, "{hline}")?;
241 let header = self.labels().iter().map(|x| format!("{x:n$}")).join(" | ");
243 writeln!(f, "| {header} |")?;
244 let separator = (0..self.labels().len()).map(|_| "-".repeat(n)).join(" | ");
246 writeln!(f, "| {separator} |")?;
247 for row in self.values.rows() {
249 let row = row
251 .iter()
252 .enumerate()
253 .map(|(i, &x)| &self.states()[i][x as usize])
254 .map(|x| format!("{x:n$}"))
255 .join(" | ");
256 writeln!(f, "| {row} |")?;
257 }
258 writeln!(f, "{hline}")
260 }
261}
262
263impl Dataset for CatTable {
264 type Values = Array2<CatType>;
265
266 #[inline]
267 fn values(&self) -> &Self::Values {
268 &self.values
269 }
270
271 #[inline]
272 fn sample_size(&self) -> f64 {
273 self.values.nrows() as f64
274 }
275}
276
277impl CsvIO for CatTable {
278 fn from_csv(csv: &str) -> Self {
279 let mut reader = ReaderBuilder::new()
281 .has_headers(true)
282 .from_reader(csv.as_bytes());
283
284 assert!(reader.has_headers(), "Reader must have headers.");
286
287 let labels: Labels = reader
289 .headers()
290 .expect("Failed to read the headers.")
291 .into_iter()
292 .map(|x| x.to_owned())
293 .collect();
294
295 let mut states: States = labels
297 .iter()
298 .map(|x| (x.clone(), Default::default()))
299 .collect();
300
301 let values: Array1<_> = reader
303 .into_records()
304 .enumerate()
305 .flat_map(|(i, row)| {
306 let row = row.unwrap_or_else(|_| panic!("Malformed record on line {}.", i + 1));
308 let row: Vec<_> = row
310 .into_iter()
311 .enumerate()
312 .map(|(i, x)| {
313 assert!(!x.is_empty(), "Missing value on line {}.", i + 1);
315 let (x, _) = states[i].insert_full(x.to_owned());
317 x as CatType
319 })
320 .collect();
321 row
323 })
324 .collect();
325
326 let ncols = labels.len();
328 let nrows = values.len() / ncols;
329 let values = values
331 .into_shape_with_order((nrows, ncols))
332 .expect("Failed to rearrange values to the correct shape.");
333
334 Self::new(states, values)
336 }
337
338 fn to_csv(&self) -> String {
339 todo!() }
341}