#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
pub mod boston;
pub mod breast_cancer;
pub mod diabetes;
pub mod digits;
pub mod generator;
pub mod iris;
#[cfg(not(target_arch = "wasm32"))]
use crate::numbers::{basenum::Number, realnum::RealNumber};
#[cfg(not(target_arch = "wasm32"))]
use std::fs::File;
use std::io;
#[cfg(not(target_arch = "wasm32"))]
use std::io::prelude::*;
#[derive(Debug)]
pub struct Dataset<X, Y> {
pub data: Vec<X>,
pub target: Vec<Y>,
pub num_samples: usize,
pub num_features: usize,
pub feature_names: Vec<String>,
pub target_names: Vec<String>,
pub description: String,
}
impl<X, Y> Dataset<X, Y> {
pub fn as_matrix(&self) -> Vec<Vec<&X>> {
let mut result: Vec<Vec<&X>> = Vec::with_capacity(self.num_samples);
for r in 0..self.num_samples {
let mut row = Vec::with_capacity(self.num_features);
for c in 0..self.num_features {
row.push(&self.data[r * self.num_features + c]);
}
result.push(row);
}
result
}
}
#[cfg(not(target_arch = "wasm32"))]
#[allow(dead_code)]
pub(crate) fn serialize_data<X: Number + RealNumber, Y: RealNumber>(
dataset: &Dataset<X, Y>,
filename: &str,
) -> Result<(), io::Error> {
match File::create(filename) {
Ok(mut file) => {
file.write_all(&(dataset.num_features as u64).to_le_bytes())?;
file.write_all(&(dataset.num_samples as u64).to_le_bytes())?;
let x: Vec<u8> = dataset
.data
.iter()
.copied()
.flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec())
.collect();
file.write_all(&x)?;
let y: Vec<u8> = dataset
.target
.iter()
.copied()
.flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec())
.collect();
file.write_all(&y)?;
}
Err(why) => panic!("couldn't create {filename}: {why}"),
}
Ok(())
}
pub(crate) fn deserialize_data(
bytes: &[u8],
) -> Result<(Vec<f32>, Vec<f32>, usize, usize), io::Error> {
const FIELD_SIZE: usize = std::mem::size_of::<u64>(); const HEADER_LEN: usize = 2 * FIELD_SIZE;
if bytes.len() < HEADER_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"deserialize_data: buffer too small for header (need {HEADER_LEN} bytes, got {})",
bytes.len()
),
));
}
let (num_samples, num_features) = {
let mut buf8 = [0u8; FIELD_SIZE];
buf8.copy_from_slice(&bytes[0..FIELD_SIZE]);
let num_features = u64::from_le_bytes(buf8) as usize;
buf8.copy_from_slice(&bytes[FIELD_SIZE..HEADER_LEN]);
let num_samples = u64::from_le_bytes(buf8) as usize;
(num_samples, num_features)
};
let num_x_values = num_samples.checked_mul(num_features).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"deserialize_data: num_samples * num_features overflows usize",
)
})?;
let x_bytes = num_x_values.checked_mul(4).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"deserialize_data: x byte range overflows usize",
)
})?;
let y_bytes = num_samples.checked_mul(4).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"deserialize_data: y byte range overflows usize",
)
})?;
let expected_len = HEADER_LEN
.checked_add(x_bytes)
.and_then(|n| n.checked_add(y_bytes))
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"deserialize_data: total expected length overflows usize",
)
})?;
if bytes.len() < expected_len {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"deserialize_data: buffer too short (expected {expected_len} bytes, got {})",
bytes.len()
),
));
}
let mut x = Vec::with_capacity(num_x_values);
let mut y = Vec::with_capacity(num_samples);
let mut buf4 = [0u8; 4];
let mut c = HEADER_LEN;
for _ in 0..num_x_values {
buf4.copy_from_slice(&bytes[c..(c + 4)]);
let v = f32::from_bits(u32::from_le_bytes(buf4));
if !v.is_finite() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"deserialize_data: non-finite value in feature data (bits: {:#010x})",
u32::from_le_bytes(buf4)
),
));
}
x.push(v);
c += 4;
}
for _ in 0..num_samples {
buf4.copy_from_slice(&bytes[c..(c + 4)]);
let v = f32::from_bits(u32::from_le_bytes(buf4));
if !v.is_finite() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"deserialize_data: non-finite value in target data (bits: {:#010x})",
u32::from_le_bytes(buf4)
),
));
}
y.push(v);
c += 4;
}
Ok((x, y, num_samples, num_features))
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn as_matrix() {
let dataset = Dataset {
data: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
target: vec![1, 2, 3],
num_samples: 2,
num_features: 5,
feature_names: vec![],
target_names: vec![],
description: "".to_string(),
};
let m = dataset.as_matrix();
assert_eq!(m.len(), 2);
assert_eq!(m[0].len(), 5);
assert_eq!(*m[1][3], 9);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn deserialize_data_too_short() {
let result = deserialize_data(&[0u8; 4]);
assert!(result.is_err());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn deserialize_data_truncated_body() {
let mut buf = vec![0u8; 16];
buf[0..8].copy_from_slice(&1u64.to_le_bytes()); buf[8..16].copy_from_slice(&1u64.to_le_bytes()); let result = deserialize_data(&buf);
assert!(result.is_err());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn deserialize_data_nan_rejected() {
let nan_bits: u32 = f32::NAN.to_bits();
let mut buf = vec![0u8; 16 + 4 + 4];
buf[0..8].copy_from_slice(&1u64.to_le_bytes()); buf[8..16].copy_from_slice(&1u64.to_le_bytes()); buf[16..20].copy_from_slice(&nan_bits.to_le_bytes()); buf[20..24].copy_from_slice(&1.0f32.to_le_bytes()); let result = deserialize_data(&buf);
assert!(result.is_err());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn deserialize_data_roundtrip_1x1() {
let x_val = 3.14f32;
let y_val = 1.0f32;
let mut buf = vec![0u8; 16 + 4 + 4];
buf[0..8].copy_from_slice(&1u64.to_le_bytes()); buf[8..16].copy_from_slice(&1u64.to_le_bytes()); buf[16..20].copy_from_slice(&x_val.to_bits().to_le_bytes());
buf[20..24].copy_from_slice(&y_val.to_bits().to_le_bytes());
let (x, y, ns, nf) = deserialize_data(&buf).expect("roundtrip must succeed");
assert_eq!(ns, 1);
assert_eq!(nf, 1);
assert_eq!(x.len(), 1);
assert_eq!(y.len(), 1);
assert!((x[0] - x_val).abs() < 1e-6);
assert!((y[0] - y_val).abs() < 1e-6);
}
}