1use crate::error::{PhopError, Result};
4use crate::rng::SplitMix64;
5use scirs2_core::ndarray::{Array1, Array2};
6
7#[derive(Debug, Clone)]
13pub struct Standardizer {
14 pub feat_mean: Vec<f64>,
16 pub feat_std: Vec<f64>,
18 pub y_mean: f64,
20 pub y_std: f64,
22}
23
24impl Standardizer {
25 #[must_use]
27 pub fn inverse_target(&self, y_std_space: f64) -> f64 {
28 y_std_space * self.y_std + self.y_mean
29 }
30
31 #[must_use]
33 pub fn inverse_targets(&self, ys: &Array1<f64>) -> Array1<f64> {
34 ys.mapv(|v| self.inverse_target(v))
35 }
36}
37
38fn mean_std(values: impl Iterator<Item = f64>, n: usize) -> (f64, f64) {
41 if n == 0 {
42 return (0.0, 1.0);
43 }
44 let vals: Vec<f64> = values.collect();
45 let mean = vals.iter().sum::<f64>() / n as f64;
46 let var = vals.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n as f64;
47 (mean, var.sqrt().max(1e-12))
48}
49
50#[derive(Debug, Clone)]
52pub struct DataSet {
53 pub x: Array2<f64>,
55 pub y: Array1<f64>,
57 pub feature_names: Vec<String>,
59 pub target_name: String,
61}
62
63impl DataSet {
64 pub fn from_arrays(x: Array2<f64>, y: Array1<f64>) -> Result<Self> {
69 if x.nrows() != y.len() {
70 return Err(PhopError::ShapeMismatch(format!(
71 "x has {} rows but y has {} entries",
72 x.nrows(),
73 y.len()
74 )));
75 }
76 let n_vars = x.ncols();
77 let feature_names = (0..n_vars).map(|i| format!("x{i}")).collect();
78 Ok(Self {
79 x,
80 y,
81 feature_names,
82 target_name: "y".to_string(),
83 })
84 }
85
86 #[must_use]
88 pub fn n_vars(&self) -> usize {
89 self.x.ncols()
90 }
91
92 #[must_use]
94 pub fn len(&self) -> usize {
95 self.y.len()
96 }
97
98 #[must_use]
100 pub fn is_empty(&self) -> bool {
101 self.y.is_empty()
102 }
103
104 #[must_use]
110 pub fn standardized(&self) -> (DataSet, Standardizer) {
111 let n = self.len();
112 let nv = self.n_vars();
113 let mut feat_mean = vec![0.0; nv];
114 let mut feat_std = vec![1.0; nv];
115 for j in 0..nv {
116 let (m, s) = mean_std(self.x.column(j).iter().copied(), n);
117 feat_mean[j] = m;
118 feat_std[j] = s;
119 }
120 let (y_mean, y_std) = mean_std(self.y.iter().copied(), n);
121
122 let mut xz = self.x.clone();
123 for j in 0..nv {
124 let (m, s) = (feat_mean[j], feat_std[j]);
125 xz.column_mut(j).mapv_inplace(|v| (v - m) / s);
126 }
127 let yz = self.y.mapv(|v| (v - y_mean) / y_std);
128
129 let std = Standardizer {
130 feat_mean,
131 feat_std,
132 y_mean,
133 y_std,
134 };
135 let ds = DataSet {
136 x: xz,
137 y: yz,
138 feature_names: self.feature_names.clone(),
139 target_name: self.target_name.clone(),
140 };
141 (ds, std)
142 }
143
144 pub fn select(&self, rows: &[usize]) -> Result<DataSet> {
149 let nv = self.n_vars();
150 let mut x_flat = Vec::with_capacity(rows.len() * nv);
151 let mut y_vec = Vec::with_capacity(rows.len());
152 for &r in rows {
153 if r >= self.len() {
154 return Err(PhopError::ShapeMismatch(format!(
155 "row index {r} out of range (len {})",
156 self.len()
157 )));
158 }
159 for j in 0..nv {
160 x_flat.push(self.x[[r, j]]);
161 }
162 y_vec.push(self.y[r]);
163 }
164 let x = Array2::from_shape_vec((rows.len(), nv), x_flat)
165 .map_err(|e| PhopError::ShapeMismatch(e.to_string()))?;
166 Ok(DataSet {
167 x,
168 y: Array1::from(y_vec),
169 feature_names: self.feature_names.clone(),
170 target_name: self.target_name.clone(),
171 })
172 }
173
174 #[must_use]
180 pub fn minibatches(&self, size: usize, seed: u64) -> Vec<DataSet> {
181 let n = self.len();
182 if n == 0 {
183 return Vec::new();
184 }
185 let mut idx: Vec<usize> = (0..n).collect();
186 SplitMix64::new(seed).shuffle(&mut idx);
187 let chunk = if size == 0 { n } else { size.min(n) };
188 idx.chunks(chunk)
189 .map(|rows| self.select(rows).expect("indices in range"))
190 .collect()
191 }
192
193 pub fn from_csv<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
201 Self::from_csv_with_target(path, None)
202 }
203
204 pub fn from_csv_with_target<P: AsRef<std::path::Path>>(
213 path: P,
214 target: Option<usize>,
215 ) -> Result<Self> {
216 let (headers, rows) = parse_csv(path)?;
217 let n_cols = headers.len();
218 let target_col = target.unwrap_or(n_cols - 1);
219 if target_col >= n_cols {
220 return Err(PhopError::ShapeMismatch(format!(
221 "target column {target_col} out of range (CSV has {n_cols} columns)"
222 )));
223 }
224 let features: Vec<usize> = (0..n_cols).filter(|&j| j != target_col).collect();
225 Self::assemble(&headers, &rows, &features, target_col)
226 }
227
228 pub fn from_csv_columns<P: AsRef<std::path::Path>>(
235 path: P,
236 features: &[usize],
237 target: usize,
238 ) -> Result<Self> {
239 let (headers, rows) = parse_csv(path)?;
240 let n_cols = headers.len();
241 if target >= n_cols || features.iter().any(|&j| j >= n_cols) {
242 return Err(PhopError::ShapeMismatch(format!(
243 "column index out of range (CSV has {n_cols} columns)"
244 )));
245 }
246 if features.is_empty() {
247 return Err(PhopError::ShapeMismatch(
248 "at least one feature column is required".to_string(),
249 ));
250 }
251 if features.contains(&target) {
252 return Err(PhopError::ShapeMismatch(
253 "target column cannot also be a feature".to_string(),
254 ));
255 }
256 Self::assemble(&headers, &rows, features, target)
257 }
258
259 fn assemble(
261 headers: &[String],
262 rows: &[Vec<f64>],
263 features: &[usize],
264 target: usize,
265 ) -> Result<Self> {
266 let n_rows = rows.len();
267 let n_vars = features.len();
268 let mut x_flat = Vec::with_capacity(n_rows * n_vars);
269 let mut y_vec = Vec::with_capacity(n_rows);
270 for row in rows {
271 for &j in features {
272 x_flat.push(row[j]);
273 }
274 y_vec.push(row[target]);
275 }
276 let x = Array2::from_shape_vec((n_rows, n_vars), x_flat)
277 .map_err(|e| PhopError::ShapeMismatch(e.to_string()))?;
278 let feature_names = features.iter().map(|&j| headers[j].clone()).collect();
279 Ok(Self {
280 x,
281 y: Array1::from(y_vec),
282 feature_names,
283 target_name: headers[target].clone(),
284 })
285 }
286}
287
288fn parse_csv<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<String>, Vec<Vec<f64>>)> {
290 let mut rdr = csv::ReaderBuilder::new()
291 .has_headers(true)
292 .from_path(path)?;
293 let headers: Vec<String> = rdr.headers()?.iter().map(str::to_string).collect();
294 let n_cols = headers.len();
295 if n_cols < 2 {
296 return Err(PhopError::ShapeMismatch(
297 "CSV must have at least two columns (>=1 feature + target)".to_string(),
298 ));
299 }
300 let mut rows: Vec<Vec<f64>> = Vec::new();
301 for rec in rdr.records() {
302 let rec = rec?;
303 if rec.len() != n_cols {
304 return Err(PhopError::ShapeMismatch(format!(
305 "row {} has {} fields, expected {n_cols}",
306 rows.len(),
307 rec.len()
308 )));
309 }
310 let mut row = Vec::with_capacity(n_cols);
311 for field in rec.iter() {
312 row.push(
313 field
314 .trim()
315 .parse::<f64>()
316 .map_err(|_| PhopError::Parse(format!("cannot parse '{field}' as f64")))?,
317 );
318 }
319 rows.push(row);
320 }
321 Ok((headers, rows))
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn from_arrays_checks_shape() {
330 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
331 let y = Array1::from(vec![1.0, 2.0, 3.0]);
332 let ds = DataSet::from_arrays(x, y).unwrap();
333 assert_eq!(ds.n_vars(), 2);
334 assert_eq!(ds.len(), 3);
335 assert!(!ds.is_empty());
336 }
337
338 #[test]
339 fn from_arrays_rejects_mismatch() {
340 let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
341 let y = Array1::from(vec![1.0]);
342 assert!(DataSet::from_arrays(x, y).is_err());
343 }
344
345 #[test]
346 fn standardize_centers_and_scales() {
347 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
348 let y = Array1::from(vec![10.0, 20.0, 30.0, 40.0]);
349 let ds = DataSet::from_arrays(x, y).unwrap();
350 let (z, std) = ds.standardized();
351
352 let col_mean = z.x.column(0).iter().sum::<f64>() / 4.0;
354 assert!(col_mean.abs() < 1e-9, "mean = {col_mean}");
355 for (orig, zz) in ds.y.iter().zip(z.y.iter()) {
357 assert!((std.inverse_target(*zz) - orig).abs() < 1e-9);
358 }
359 }
360
361 #[test]
362 fn minibatches_partition_all_rows() {
363 let x = Array2::from_shape_vec((10, 1), (0..10).map(f64::from).collect()).unwrap();
364 let y = Array1::from((0..10).map(f64::from).collect::<Vec<_>>());
365 let ds = DataSet::from_arrays(x, y).unwrap();
366
367 let batches = ds.minibatches(3, 123);
368 assert_eq!(batches.len(), 4); let total: usize = batches.iter().map(DataSet::len).sum();
370 assert_eq!(total, 10);
371
372 let again = ds.minibatches(3, 123);
374 for (a, b) in batches.iter().zip(&again) {
375 assert_eq!(a.y, b.y);
376 }
377 }
378
379 #[test]
380 fn from_csv_with_target_selects_column() {
381 use std::io::Write;
382 let dir = std::env::temp_dir();
383 let path = dir.join("phop_test_target.csv");
384 let mut f = std::fs::File::create(&path).unwrap();
385 writeln!(f, "y,a,b").unwrap();
387 writeln!(f, "10,1,2").unwrap();
388 writeln!(f, "20,3,4").unwrap();
389 drop(f);
390
391 let ds = DataSet::from_csv_with_target(&path, Some(0)).unwrap();
392 assert_eq!(ds.target_name, "y");
393 assert_eq!(ds.feature_names, vec!["a".to_string(), "b".to_string()]);
394 assert_eq!(ds.n_vars(), 2);
395 assert!((ds.y[0] - 10.0).abs() < 1e-12);
396 assert!((ds.x[[1, 0]] - 3.0).abs() < 1e-12);
397
398 let ds2 = DataSet::from_csv_columns(&path, &[2], 0).unwrap();
400 assert_eq!(ds2.n_vars(), 1);
401 assert_eq!(ds2.feature_names, vec!["b".to_string()]);
402 assert!((ds2.x[[0, 0]] - 2.0).abs() < 1e-12);
403 assert!(DataSet::from_csv_columns(&path, &[0], 0).is_err());
405 std::fs::remove_file(&path).ok();
406 }
407}