#[derive(Debug, Clone)]
pub struct SparseClassConfig {
pub n_samples: usize,
pub n_features: usize,
pub n_informative: usize,
pub n_classes: usize,
pub class_sep: f64,
pub seed: u64,
}
impl Default for SparseClassConfig {
fn default() -> Self {
Self {
n_samples: 1000,
n_features: 10000,
n_informative: 20,
n_classes: 2,
class_sep: 1.0,
seed: 42,
}
}
}
#[derive(Debug, Clone)]
pub struct SparseClassDataset {
pub x: Vec<Vec<f64>>,
pub y: Vec<usize>,
pub informative_features: Vec<usize>,
pub feature_weights: Vec<f64>,
}
struct Lcg {
state: u64,
}
impl Lcg {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_u64(&mut self) -> u64 {
self.state = self
.state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
self.state
}
fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
fn next_normal(&mut self) -> f64 {
let u1 = self.next_f64().max(1e-10);
let u2 = self.next_f64();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
fn next_usize_below(&mut self, n: usize) -> usize {
(self.next_u64() % n as u64) as usize
}
}
pub fn make_sparse_classification(config: &SparseClassConfig) -> SparseClassDataset {
let mut rng = Lcg::new(config.seed);
let n_inf = config.n_informative.min(config.n_features);
let mut informative_features: Vec<usize> = {
let mut indices: Vec<usize> = (0..config.n_features).collect();
for i in 0..n_inf {
let j = i + rng.next_usize_below(config.n_features - i);
indices.swap(i, j);
}
indices[..n_inf].to_vec()
};
informative_features.sort_unstable();
let centroids: Vec<Vec<f64>> = (0..config.n_classes)
.map(|_| {
(0..n_inf)
.map(|_| rng.next_normal() * config.class_sep)
.collect()
})
.collect();
let mut feature_weights = vec![0.0f64; config.n_features];
for (idx, &fi) in informative_features.iter().enumerate() {
let mean_val: f64 = centroids.iter().map(|c| c[idx]).sum::<f64>() / config.n_classes as f64;
feature_weights[fi] = mean_val;
}
let n_per_class = config.n_samples / config.n_classes;
let mut x: Vec<Vec<f64>> = Vec::with_capacity(config.n_samples);
let mut y: Vec<usize> = Vec::with_capacity(config.n_samples);
for (class_idx, centroid) in centroids.iter().enumerate() {
let count = if class_idx == config.n_classes - 1 {
config.n_samples - n_per_class * (config.n_classes - 1)
} else {
n_per_class
};
for _ in 0..count {
let mut sample = vec![0.0f64; config.n_features];
for (inf_idx, &fi) in informative_features.iter().enumerate() {
sample[fi] = centroid[inf_idx] + rng.next_normal() * 0.5;
}
x.push(sample);
y.push(class_idx);
}
}
let n = x.len();
for i in (1..n).rev() {
let j = rng.next_usize_below(i + 1);
x.swap(i, j);
y.swap(i, j);
}
SparseClassDataset {
x,
y,
informative_features,
feature_weights,
}
}
pub fn sparsity_ratio(x: &[Vec<f64>]) -> f64 {
if x.is_empty() {
return 1.0;
}
let n_cols = x[0].len();
if n_cols == 0 {
return 1.0;
}
let total = (x.len() * n_cols) as f64;
let zeros = x
.iter()
.flat_map(|row| row.iter())
.filter(|&&v| v == 0.0)
.count() as f64;
zeros / total
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparsity_high() {
let config = SparseClassConfig {
n_samples: 200,
n_features: 1000,
n_informative: 10,
n_classes: 2,
class_sep: 1.0,
seed: 42,
};
let ds = make_sparse_classification(&config);
let ratio = sparsity_ratio(&ds.x);
assert!(ratio > 0.98, "Sparsity ratio should be > 0.98, got {ratio}");
}
#[test]
fn test_label_balance() {
let config = SparseClassConfig {
n_samples: 100,
n_features: 500,
n_informative: 5,
n_classes: 2,
class_sep: 1.0,
seed: 7,
};
let ds = make_sparse_classification(&config);
assert_eq!(ds.y.len(), 100);
let class0 = ds.y.iter().filter(|&&c| c == 0).count();
let class1 = ds.y.iter().filter(|&&c| c == 1).count();
assert!((40..=60).contains(&class0), "Class 0 count: {class0}");
assert!((40..=60).contains(&class1), "Class 1 count: {class1}");
}
#[test]
fn test_informative_feature_count() {
let config = SparseClassConfig {
n_samples: 50,
n_features: 200,
n_informative: 8,
n_classes: 3,
class_sep: 1.5,
seed: 99,
};
let ds = make_sparse_classification(&config);
assert_eq!(ds.informative_features.len(), 8);
for &fi in &ds.informative_features {
assert!(fi < 200, "Informative feature index out of range: {fi}");
}
}
#[test]
fn test_non_informative_are_zero() {
let config = SparseClassConfig {
n_samples: 20,
n_features: 100,
n_informative: 5,
n_classes: 2,
class_sep: 1.0,
seed: 13,
};
let ds = make_sparse_classification(&config);
let inf_set: std::collections::HashSet<usize> =
ds.informative_features.iter().copied().collect();
for row in &ds.x {
for (j, &val) in row.iter().enumerate() {
if !inf_set.contains(&j) {
assert_eq!(val, 0.0, "Non-informative feature {j} should be zero");
}
}
}
}
#[test]
fn test_default_config_shape() {
let config = SparseClassConfig {
n_samples: 50,
n_features: 200,
n_informative: 10,
..Default::default()
};
let ds = make_sparse_classification(&config);
assert_eq!(ds.x.len(), 50);
assert_eq!(ds.x[0].len(), 200);
}
}