causal_hub/datasets/table/gaussian/
dataset.rs1use csv::ReaderBuilder;
2use ndarray::prelude::*;
3
4use crate::{datasets::Dataset, io::CsvIO, models::Labelled, types::Labels};
5
6pub type GaussType = f64;
8pub type GaussSample = Array1<GaussType>;
10
11#[derive(Clone, Debug)]
13pub struct GaussTable {
14 labels: Labels,
15 values: Array2<GaussType>,
16}
17
18impl Labelled for GaussTable {
19 #[inline]
20 fn labels(&self) -> &Labels {
21 &self.labels
22 }
23}
24
25impl GaussTable {
26 pub fn new(mut labels: Labels, mut values: Array2<GaussType>) -> Self {
42 assert_eq!(
44 labels.len(),
45 values.ncols(),
46 "Number of labels must match number of columns in values."
47 );
48
49 if !labels.is_sorted() {
51 let mut indices: Vec<usize> = (0..labels.len()).collect();
53 indices.sort_by_key(|&i| &labels[i]);
55 labels.sort();
57 let mut new_values = values.clone();
59 indices.into_iter().enumerate().for_each(|(i, j)| {
61 new_values.column_mut(i).assign(&values.column(j));
62 });
63 values = new_values;
65 }
66 assert!(
68 values.iter().all(|&x| x.is_finite()),
69 "Values must have finite values."
70 );
71
72 Self { labels, values }
73 }
74}
75
76impl Dataset for GaussTable {
77 type Values = Array2<GaussType>;
78
79 #[inline]
80 fn values(&self) -> &Self::Values {
81 &self.values
82 }
83
84 #[inline]
85 fn sample_size(&self) -> f64 {
86 self.values.nrows() as f64
87 }
88}
89
90impl CsvIO for GaussTable {
91 fn from_csv(csv: &str) -> Self {
92 let mut reader = ReaderBuilder::new()
94 .has_headers(true)
95 .from_reader(csv.as_bytes());
96
97 assert!(reader.has_headers(), "Reader must have headers.");
99
100 let labels: Labels = reader
102 .headers()
103 .expect("Failed to read the headers.")
104 .into_iter()
105 .map(|x| x.to_owned())
106 .collect();
107
108 let values: Array1<_> = reader
110 .into_records()
111 .enumerate()
112 .flat_map(|(i, row)| {
113 let row = row.unwrap_or_else(|_| panic!("Malformed record on line {}.", i + 1));
115 let row: Vec<_> = row
117 .into_iter()
118 .enumerate()
119 .map(|(i, x)| {
120 assert!(!x.is_empty(), "Missing value on line {}.", i + 1);
122 x.parse::<GaussType>().unwrap()
124 })
125 .collect();
126 row
128 })
129 .collect();
130
131 let ncols = labels.len();
133 let nrows = values.len() / ncols;
134 let values = values
136 .into_shape_with_order((nrows, ncols))
137 .expect("Failed to rearrange values to the correct shape.");
138
139 Self::new(labels, values)
141 }
142
143 fn to_csv(&self) -> String {
144 todo!() }
146}