use crate::primitives::{Matrix, Vector};
struct SplitMix64 {
state: u64,
}
impl SplitMix64 {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
fn next_f32(&mut self) -> f32 {
(self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
}
fn next_gaussian(&mut self) -> f32 {
let u1 = self.next_f32().max(1e-9);
let u2 = self.next_f32();
(-2.0 * u1.ln()).sqrt() * (core::f32::consts::TAU * u2).cos()
}
}
#[must_use]
pub fn make_blobs(
n_samples: usize,
centers: &[Vec<f32>],
cluster_std: f32,
seed: u64,
) -> (Matrix<f32>, Vec<usize>) {
assert!(!centers.is_empty(), "make_blobs: needs >= 1 center");
let n_features = centers[0].len();
assert!(
centers.iter().all(|c| c.len() == n_features),
"make_blobs: all centers must have the same length"
);
let n_centers = centers.len();
let mut rng = SplitMix64::new(seed);
let mut data = Vec::with_capacity(n_samples * n_features);
let mut labels = Vec::with_capacity(n_samples);
for i in 0..n_samples {
let c = i % n_centers;
for f in 0..n_features {
data.push(centers[c][f] + cluster_std * rng.next_gaussian());
}
labels.push(c);
}
let x = Matrix::from_vec(n_samples, n_features, data).expect("make_blobs: valid dims");
(x, labels)
}
#[must_use]
pub fn make_regression(
n_samples: usize,
n_features: usize,
noise: f32,
seed: u64,
) -> (Matrix<f32>, Vector<f32>) {
let mut rng = SplitMix64::new(seed);
let weights: Vec<f32> = (0..n_features).map(|_| rng.next_gaussian()).collect();
let mut data = Vec::with_capacity(n_samples * n_features);
let mut targets = Vec::with_capacity(n_samples);
for _ in 0..n_samples {
let mut y = 0.0f32;
for &w in &weights {
let x = rng.next_gaussian();
data.push(x);
y += w * x;
}
y += noise * rng.next_gaussian();
targets.push(y);
}
let x = Matrix::from_vec(n_samples, n_features, data).expect("make_regression: valid dims");
(x, Vector::from_vec(targets))
}
#[must_use]
pub fn make_classification(
n_samples: usize,
n_features: usize,
n_informative: usize,
n_classes: usize,
seed: u64,
) -> (Matrix<f32>, Vec<usize>) {
assert!(
n_informative <= n_features,
"make_classification: n_informative > n_features"
);
assert!(n_classes > 0, "make_classification: n_classes must be > 0");
let mut rng = SplitMix64::new(seed);
let class_sep = 2.0f32;
let centers: Vec<Vec<f32>> = (0..n_classes)
.map(|_| {
(0..n_informative)
.map(|_| class_sep * rng.next_gaussian())
.collect()
})
.collect();
let mut data = Vec::with_capacity(n_samples * n_features);
let mut labels = Vec::with_capacity(n_samples);
for i in 0..n_samples {
let c = i % n_classes;
for f in 0..n_informative {
data.push(centers[c][f] + rng.next_gaussian());
}
for _ in n_informative..n_features {
data.push(rng.next_gaussian()); }
labels.push(c);
}
let x = Matrix::from_vec(n_samples, n_features, data).expect("make_classification: valid dims");
(x, labels)
}
const IRIS_CSV: &str = include_str!("iris.csv");
#[must_use]
pub fn load_iris() -> (Matrix<f32>, Vec<usize>) {
parse_csv_dataset(IRIS_CSV, 4)
}
fn parse_csv_dataset(csv: &str, n_features: usize) -> (Matrix<f32>, Vec<usize>) {
let mut data = Vec::new();
let mut labels = Vec::new();
let mut n_samples = 0usize;
for line in csv.lines().filter(|l| !l.trim().is_empty()) {
let mut fields = line.split(',');
for _ in 0..n_features {
let v: f32 = fields
.next()
.expect("dataset row: missing feature field")
.trim()
.parse()
.expect("dataset row: feature not f32");
data.push(v);
}
let label: usize = fields
.next()
.expect("dataset row: missing label field")
.trim()
.parse()
.expect("dataset row: label not usize");
labels.push(label);
n_samples += 1;
}
let x = Matrix::from_vec(n_samples, n_features, data).expect("dataset: valid dims");
(x, labels)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn load_iris_is_canonical() {
let (x, y) = load_iris();
assert_eq!(x.n_rows(), 150);
assert_eq!(x.n_cols(), 4);
assert_eq!(y.len(), 150);
for c in 0..3 {
assert_eq!(
y.iter().filter(|&&v| v == c).count(),
50,
"Iris class {c} must have 50 samples"
);
}
assert!(y.iter().all(|&c| c < 3), "Iris labels in {{0,1,2}}");
assert!((x.get(0, 0) - 5.1).abs() < 1e-4);
assert!((x.get(0, 3) - 0.2).abs() < 1e-4);
assert_eq!(y[0], 0);
}
#[test]
fn make_blobs_is_deterministic() {
let centers = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
let (x1, y1) = make_blobs(20, ¢ers, 0.5, 42);
let (x2, y2) = make_blobs(20, ¢ers, 0.5, 42);
assert_eq!(
x1.as_slice(),
x2.as_slice(),
"same seed must give identical X"
);
assert_eq!(y1, y2, "same seed must give identical labels");
let (x3, _) = make_blobs(20, ¢ers, 0.5, 43);
assert_ne!(x1.as_slice(), x3.as_slice(), "different seed must differ");
}
#[test]
fn make_blobs_shapes_and_labels() {
let centers = vec![vec![0.0, 0.0, 0.0], vec![5.0, 5.0, 5.0]];
let (x, y) = make_blobs(10, ¢ers, 0.3, 7);
assert_eq!(x.n_rows(), 10);
assert_eq!(x.n_cols(), 3);
assert_eq!(y.len(), 10);
assert!(y.iter().all(|&c| c < 2));
assert_eq!(y.iter().filter(|&&c| c == 0).count(), 5);
}
#[test]
fn make_blobs_clusters_are_separable() {
let centers = vec![vec![0.0, 0.0], vec![20.0, 20.0]];
let (x, y) = make_blobs(40, ¢ers, 0.5, 99);
for i in 0..x.n_rows() {
let p0 = (x.get(i, 0).powi(2) + x.get(i, 1).powi(2)).sqrt();
let p1 = ((x.get(i, 0) - 20.0).powi(2) + (x.get(i, 1) - 20.0).powi(2)).sqrt();
let nearest = usize::from(p1 < p0);
assert_eq!(nearest, y[i], "sample {i} must be nearest its own center");
}
}
#[test]
fn make_regression_shapes_and_signal() {
let (x, y) = make_regression(100, 4, 0.1, 5);
assert_eq!(x.n_rows(), 100);
assert_eq!(x.n_cols(), 4);
assert_eq!(y.len(), 100);
let (x2, y2) = make_regression(100, 4, 0.1, 5);
assert_eq!(x.as_slice(), x2.as_slice());
assert_eq!(y.as_slice(), y2.as_slice());
let mean = y.as_slice().iter().sum::<f32>() / y.len() as f32;
let var = y.as_slice().iter().map(|v| (v - mean).powi(2)).sum::<f32>() / y.len() as f32;
assert!(
var > 0.1,
"regression target must carry signal, var = {var}"
);
}
#[test]
fn make_classification_shape_balance_and_learnable() {
let (x, y) = make_classification(120, 10, 4, 3, 42);
assert_eq!(x.n_rows(), 120);
assert_eq!(x.n_cols(), 10);
assert_eq!(y.len(), 120);
assert!(y.iter().all(|&c| c < 3));
for c in 0..3 {
assert_eq!(
y.iter().filter(|&&v| v == c).count(),
40,
"class {c} balance"
);
}
let (x2, _) = make_classification(120, 10, 4, 3, 42);
assert_eq!(x.as_slice(), x2.as_slice());
let n_inf = 4;
let mut means = vec![[0.0f32; 4]; 3];
let mut counts = [0usize; 3];
for i in 0..x.n_rows() {
for f in 0..n_inf {
means[y[i]][f] += x.get(i, f);
}
counts[y[i]] += 1;
}
for c in 0..3 {
for f in 0..n_inf {
means[c][f] /= counts[c] as f32;
}
}
let mut correct = 0;
for i in 0..x.n_rows() {
let nearest = (0..3)
.min_by(|&a, &b| {
let da: f32 = (0..n_inf)
.map(|f| (x.get(i, f) - means[a][f]).powi(2))
.sum();
let db: f32 = (0..n_inf)
.map(|f| (x.get(i, f) - means[b][f]).powi(2))
.sum();
da.partial_cmp(&db).unwrap()
})
.unwrap();
if nearest == y[i] {
correct += 1;
}
}
let acc = correct as f32 / x.n_rows() as f32;
assert!(
acc > 0.85,
"make_classification not learnable: nearest-centroid acc {acc} <= 0.85"
);
}
}