use crate::error::{DatasetsError, Result};
use scirs2_core::ndarray::{Array1, Array2};
use crate::generators::low_rank::{make_low_rank as low_rank_impl, LowRankConfig};
use crate::generators::sparse_classification::{
make_sparse_classification as sparse_class_impl, SparseClassConfig,
};
pub fn make_low_rank(
n_samples: usize,
n_features: usize,
rank: usize,
noise: f64,
seed: u64,
) -> Result<(Array2<f64>, Array2<f64>)> {
if n_samples == 0 {
return Err(DatasetsError::InvalidFormat("n_samples must be > 0".into()));
}
if n_features == 0 {
return Err(DatasetsError::InvalidFormat(
"n_features must be > 0".into(),
));
}
let effective_rank = rank.max(1).min(n_samples.min(n_features));
let config = LowRankConfig {
n_rows: n_samples,
n_cols: n_features,
rank: effective_rank,
noise_std: noise,
observation_fraction: 0.5,
seed,
};
let ds = low_rank_impl(&config);
let n_rows = ds.matrix.len();
let n_cols = if n_rows > 0 { ds.matrix[0].len() } else { 0 };
let total = n_rows * n_cols;
let mut flat_full = Vec::with_capacity(total);
for row in &ds.matrix {
flat_full.extend_from_slice(row);
}
let mut flat_obs = Vec::with_capacity(total);
for (i, row) in ds.matrix.iter().enumerate() {
for (j, &val) in row.iter().enumerate() {
if ds.observed_mask[i][j] {
flat_obs.push(val);
} else {
flat_obs.push(f64::NAN);
}
}
}
let x_full = Array2::from_shape_vec((n_rows, n_cols), flat_full)
.map_err(|e| DatasetsError::InvalidFormat(format!("Array2 shape error: {e}")))?;
let x_obs = Array2::from_shape_vec((n_rows, n_cols), flat_obs)
.map_err(|e| DatasetsError::InvalidFormat(format!("Array2 shape error: {e}")))?;
Ok((x_full, x_obs))
}
pub fn make_sparse_classification(
n_samples: usize,
n_features: usize,
n_informative: usize,
_density: f64,
n_classes: usize,
seed: u64,
) -> Result<(Array2<f64>, Array1<usize>)> {
if n_samples == 0 {
return Err(DatasetsError::InvalidFormat("n_samples must be > 0".into()));
}
if n_features == 0 {
return Err(DatasetsError::InvalidFormat(
"n_features must be > 0".into(),
));
}
if n_classes == 0 {
return Err(DatasetsError::InvalidFormat("n_classes must be > 0".into()));
}
let config = SparseClassConfig {
n_samples,
n_features,
n_informative: n_informative.min(n_features),
n_classes,
class_sep: 1.0,
seed,
};
let ds = sparse_class_impl(&config);
let n_rows = ds.x.len();
let n_cols = if n_rows > 0 { ds.x[0].len() } else { 0 };
let mut flat = Vec::with_capacity(n_rows * n_cols);
for row in &ds.x {
flat.extend_from_slice(row);
}
let x = Array2::from_shape_vec((n_rows, n_cols), flat)
.map_err(|e| DatasetsError::InvalidFormat(format!("Array2 shape error: {e}")))?;
let y = Array1::from_vec(ds.y);
Ok((x, y))
}
pub fn make_multilabel_classification_nd(
n_samples: usize,
n_features: usize,
n_classes: usize,
n_labels: usize,
seed: u64,
) -> Result<(Array2<f64>, Array2<u8>)> {
use crate::generators::classification::{make_multilabel_classification, MultilabelConfig};
if n_samples == 0 {
return Err(DatasetsError::InvalidFormat("n_samples must be > 0".into()));
}
if n_features == 0 {
return Err(DatasetsError::InvalidFormat(
"n_features must be > 0".into(),
));
}
if n_classes == 0 {
return Err(DatasetsError::InvalidFormat("n_classes must be > 0".into()));
}
let effective_labels = n_labels.max(1).min(n_classes);
let config = MultilabelConfig {
n_samples,
n_features,
n_classes,
n_labels: effective_labels,
allow_unlabeled: false,
random_state: Some(seed),
};
let ds = make_multilabel_classification(config)?;
let nrows = ds.target.nrows();
let ncols = ds.target.ncols();
let mut flat_y = Vec::with_capacity(nrows * ncols);
for i in 0..nrows {
for j in 0..ncols {
flat_y.push(if ds.target[[i, j]] > 0.5 { 1u8 } else { 0u8 });
}
}
let x_flat: Vec<f64> = ds.data.iter().copied().collect();
let x = Array2::from_shape_vec((n_samples, n_features), x_flat)
.map_err(|e| DatasetsError::InvalidFormat(format!("Array2 shape error: {e}")))?;
let y = Array2::from_shape_vec((nrows, ncols), flat_y)
.map_err(|e| DatasetsError::InvalidFormat(format!("Array2 shape error: {e}")))?;
Ok((x, y))
}
pub fn make_heterogeneous_nd(
n_samples: usize,
n_numeric: usize,
n_categorical: usize,
n_categories: usize,
seed: u64,
) -> Result<(Array2<f64>, Array1<usize>)> {
use crate::generators::heterogeneous::{
make_heterogeneous, FeatureType, HeteroConfig, HeteroFeatureValue,
};
if n_samples == 0 {
return Err(DatasetsError::InvalidFormat("n_samples must be > 0".into()));
}
if n_numeric + n_categorical == 0 {
return Err(DatasetsError::InvalidFormat(
"at least one feature column required".into(),
));
}
let n_cats = n_categories.max(2);
let mut feature_types = Vec::new();
for _ in 0..n_numeric {
feature_types.push(FeatureType::Continuous(0.0, 1.0));
}
for _ in 0..n_categorical {
feature_types.push(FeatureType::Categorical(n_cats));
}
let config = HeteroConfig {
n_samples,
feature_types,
n_features: n_numeric + n_categorical,
n_classes: 2,
seed,
};
let ds = make_heterogeneous(&config);
let n_features = n_numeric + n_categorical;
let mut flat = Vec::with_capacity(n_samples * n_features);
for row in &ds.features {
for val in row {
let fval = match val {
HeteroFeatureValue::Float(v) => *v,
HeteroFeatureValue::Int(k) => *k as f64,
HeteroFeatureValue::Bool(b) if *b => 1.0,
HeteroFeatureValue::Bool(_) => 0.0,
#[allow(unreachable_patterns)]
_ => 0.0,
};
flat.push(fval);
}
}
let x = Array2::from_shape_vec((n_samples, n_features), flat)
.map_err(|e| DatasetsError::InvalidFormat(format!("Array2 shape error: {e}")))?;
let y = Array1::from_vec(ds.labels);
Ok((x, y))
}
pub fn make_concept_drift_nd(
n_samples: usize,
n_features: usize,
drift_points: Vec<usize>,
seed: u64,
) -> Result<(Array2<f64>, Array1<usize>, Vec<usize>)> {
if n_samples == 0 {
return Err(DatasetsError::InvalidFormat("n_samples must be > 0".into()));
}
if n_features == 0 {
return Err(DatasetsError::InvalidFormat(
"n_features must be > 0".into(),
));
}
let valid_points: Vec<usize> = drift_points
.iter()
.filter(|&&p| p > 0 && p < n_samples)
.copied()
.collect();
let mut state = seed.wrapping_add(1);
let mut next_u64 = || -> u64 {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
state
};
let mut next_f64 = || -> f64 { (next_u64() >> 11) as f64 / (1u64 << 53) as f64 };
let mut next_normal = || -> f64 {
let u1 = next_f64().max(1e-10);
let u2 = next_f64();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
};
let mut flat_x = Vec::with_capacity(n_samples * n_features);
let mut y_vec = Vec::with_capacity(n_samples);
for t in 0..n_samples {
let segment = valid_points.iter().filter(|&&p| t >= p).count();
let swapped = segment % 2 == 1;
let class: usize = t % 2;
y_vec.push(class);
let mean = if swapped {
if class == 0 {
1.0
} else {
0.0
}
} else {
if class == 0 {
0.0
} else {
1.0
}
};
for _ in 0..n_features {
flat_x.push(mean + next_normal());
}
}
let x = Array2::from_shape_vec((n_samples, n_features), flat_x)
.map_err(|e| DatasetsError::InvalidFormat(format!("Array2 shape error: {e}")))?;
let y = Array1::from_vec(y_vec);
Ok((x, y, valid_points))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_make_low_rank_shape() {
let (x_full, x_obs) =
make_low_rank(80, 60, 5, 0.1, 42).expect("make_low_rank should succeed");
assert_eq!(x_full.nrows(), 80, "X_full rows");
assert_eq!(x_full.ncols(), 60, "X_full cols");
assert_eq!(x_obs.nrows(), 80, "X_obs rows");
assert_eq!(x_obs.ncols(), 60, "X_obs cols");
}
#[test]
fn test_make_low_rank_rank_property() {
let (x_full, _) = make_low_rank(50, 50, 3, 0.0, 7).expect("make_low_rank should succeed");
let n = (50 * 50) as f64;
let mean: f64 = x_full.iter().sum::<f64>() / n;
let var: f64 = x_full.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / n;
assert!(
var > 1e-6,
"X_full should have non-trivial variance, got {var}"
);
let (xf, xo) = make_low_rank(50, 50, 3, 0.0, 7).expect("ok");
let _ = xf; let nan_count = xo.iter().filter(|v| v.is_nan()).count();
let total = 50 * 50;
let nan_fraction = nan_count as f64 / total as f64;
assert!(
(0.3..=0.7).contains(&nan_fraction),
"Expected ~50% NaN in X_obs, got {nan_fraction:.2}"
);
}
#[test]
fn test_make_sparse_classification_sparsity() {
let (x, y) = make_sparse_classification(200, 1000, 10, 0.01, 2, 42)
.expect("make_sparse_classification should succeed");
assert_eq!(x.nrows(), 200);
assert_eq!(x.ncols(), 1000);
assert_eq!(y.len(), 200);
let total = x.len() as f64;
let nonzero = x.iter().filter(|&&v| v != 0.0).count() as f64;
let density = nonzero / total;
assert!(
density < 0.05,
"Expected sparse features (density < 0.05), got {density:.4}"
);
}
#[test]
fn test_make_multilabel_avg_labels() {
let n_samples = 200;
let n_labels = 3;
let (x, y) = make_multilabel_classification_nd(n_samples, 10, 6, n_labels, 42)
.expect("multilabel should succeed");
assert_eq!(x.nrows(), n_samples);
assert_eq!(y.nrows(), n_samples);
assert_eq!(y.ncols(), 6);
let total_active: usize = y.iter().map(|&b| b as usize).sum();
let avg = total_active as f64 / n_samples as f64;
assert!(
avg >= n_labels as f64 * 0.5 && avg <= n_labels as f64 * 1.5,
"Expected avg labels ≈ {n_labels}, got {avg:.2}"
);
}
#[test]
fn test_make_heterogeneous_categorical_range() {
let n_categories = 5usize;
let (x, y) = make_heterogeneous_nd(100, 3, 4, n_categories, 42)
.expect("heterogeneous should succeed");
assert_eq!(x.nrows(), 100);
assert_eq!(x.ncols(), 7); assert_eq!(y.len(), 100);
for i in 0..100 {
for j in 3..7 {
let v = x[[i, j]];
assert!(
v >= 0.0 && v < n_categories as f64,
"Categorical feature {j} out of range: {v}"
);
assert_eq!(
v.fract(),
0.0,
"Categorical feature {j} should be integer, got {v}"
);
}
}
}
#[test]
fn test_make_concept_drift_distributions() {
let n_samples = 1000;
let n_features = 4;
let drift_at = vec![500usize];
let (x, _y, actual) = make_concept_drift_nd(n_samples, n_features, drift_at.clone(), 42)
.expect("concept_drift should succeed");
assert_eq!(x.nrows(), n_samples);
assert_eq!(x.ncols(), n_features);
assert_eq!(actual, drift_at, "Drift points should be preserved");
let mut before_sum = 0.0f64;
let mut before_count = 0usize;
let mut after_sum = 0.0f64;
let mut after_count = 0usize;
for t in 0..n_samples {
if t % 2 == 0 {
let v = x[[t, 0]];
if t < 500 {
before_sum += v;
before_count += 1;
} else {
after_sum += v;
after_count += 1;
}
}
}
let before_mean = if before_count > 0 {
before_sum / before_count as f64
} else {
0.0
};
let after_mean = if after_count > 0 {
after_sum / after_count as f64
} else {
0.0
};
assert!(
before_mean.abs() < 0.5,
"Before-drift class-0 mean should be ≈ 0, got {before_mean:.3}"
);
assert!(
(after_mean - 1.0).abs() < 0.5,
"After-drift class-0 mean should be ≈ 1, got {after_mean:.3}"
);
}
#[test]
fn test_data_shard_coverage() {
use crate::sharding::{shard_by_index, ShardingConfig};
let n_shards = 5;
let n_samples = 97; let config = ShardingConfig {
n_shards,
shuffle: false,
seed: 0,
..Default::default()
};
let _ = config; let shards = shard_by_index(n_samples, n_shards, false, 0);
assert_eq!(shards.len(), n_shards);
let mut seen = vec![false; n_samples];
for shard in &shards {
for &idx in &shard.indices {
assert!(!seen[idx], "index {idx} seen in multiple shards");
seen[idx] = true;
}
}
assert!(seen.iter().all(|&v| v), "Not all samples covered");
let total: usize = shards.iter().map(|s| s.indices.len()).sum();
assert_eq!(total, n_samples, "Total samples mismatch");
}
#[test]
fn test_data_shard_shuffled_consistency() {
use crate::sharding::shard_by_index;
let s1 = shard_by_index(100, 4, true, 999);
let s2 = shard_by_index(100, 4, true, 999);
for (a, b) in s1.iter().zip(s2.iter()) {
assert_eq!(a.indices, b.indices, "Same seed must give same permutation");
}
let s3 = shard_by_index(100, 4, true, 12345);
let differs = s1
.iter()
.zip(s3.iter())
.any(|(a, b)| a.indices != b.indices);
assert!(
differs,
"Different seeds should give different shard indices"
);
}
}