use crate::error::{DatasetsError, Result};
use scirs2_core::ndarray::{s, Array1, Array3, Array4, ArrayView3};
use scirs2_core::random::prelude::*;
use scirs2_core::random::rand_distributions::Distribution;
pub const IMAGENET100_N_CLASSES: usize = 100;
#[derive(Debug, Clone)]
pub struct ImageNet100Config {
pub n_samples_per_class: usize,
pub image_size: usize,
pub seed: u64,
}
impl Default for ImageNet100Config {
fn default() -> Self {
Self {
n_samples_per_class: 10,
image_size: 64,
seed: 42,
}
}
}
#[derive(Debug, Clone)]
pub struct ImageNet100Dataset {
images: Array4<f32>,
labels: Array1<u32>,
class_names: Vec<String>,
config: ImageNet100Config,
}
impl ImageNet100Dataset {
pub fn generate(config: ImageNet100Config) -> Result<Self> {
if config.n_samples_per_class == 0 {
return Err(DatasetsError::InvalidFormat(
"ImageNet100Config: n_samples_per_class must be > 0".to_string(),
));
}
if config.image_size == 0 {
return Err(DatasetsError::InvalidFormat(
"ImageNet100Config: image_size must be > 0".to_string(),
));
}
let n_total = config.n_samples_per_class * IMAGENET100_N_CLASSES;
let h = config.image_size;
let w = config.image_size;
let mut rng = StdRng::seed_from_u64(config.seed);
let noise_dist = Normal::new(0.0_f32, 0.1_f32).map_err(|e| {
DatasetsError::ComputationError(format!("Normal dist construction failed: {e}"))
})?;
let class_means: Vec<[f32; 3]> = (0..IMAGENET100_N_CLASSES)
.map(|c| {
let hue = c as f32 / IMAGENET100_N_CLASSES as f32; let r = (hue * 6.0).sin().abs() * 0.8 + 0.1;
let g = ((hue + 0.333) * 6.0).sin().abs() * 0.8 + 0.1;
let b = ((hue + 0.667) * 6.0).sin().abs() * 0.8 + 0.1;
[r, g, b]
})
.collect();
let mut images = Array4::zeros((n_total, 3, h, w));
let mut labels = Array1::zeros(n_total);
let class_names: Vec<String> = (0..IMAGENET100_N_CLASSES)
.map(|c| format!("class_{c:03}"))
.collect();
for (class_id, mean_rgb) in class_means.iter().enumerate() {
for sample_in_class in 0..config.n_samples_per_class {
let sample_idx = class_id * config.n_samples_per_class + sample_in_class;
labels[sample_idx] = class_id as u32;
for (c, &mean) in mean_rgb.iter().enumerate() {
for row in 0..h {
for col in 0..w {
let noise: f32 = noise_dist.sample(&mut rng);
let pixel = (mean + noise).clamp(0.0, 1.0);
images[[sample_idx, c, row, col]] = pixel;
}
}
}
}
}
Ok(Self {
images,
labels,
class_names,
config,
})
}
pub fn images(&self) -> &Array4<f32> {
&self.images
}
pub fn labels(&self) -> &Array1<u32> {
&self.labels
}
pub fn class_names(&self) -> &[String] {
&self.class_names
}
pub fn n_classes(&self) -> usize {
IMAGENET100_N_CLASSES
}
pub fn n_samples(&self) -> usize {
self.config.n_samples_per_class * IMAGENET100_N_CLASSES
}
pub fn get_sample(&self, idx: usize) -> (ArrayView3<'_, f32>, u32) {
let view: ArrayView3<'_, f32> = self.images.slice(s![idx, .., .., ..]);
(view, self.labels[idx])
}
pub fn get_sample_owned(&self, idx: usize) -> Result<(Array3<f32>, u32)> {
if idx >= self.n_samples() {
return Err(DatasetsError::InvalidFormat(format!(
"ImageNet100Dataset: index {idx} out of bounds (n_samples = {})",
self.n_samples()
)));
}
let view = self.images.slice(s![idx, .., .., ..]);
Ok((view.to_owned(), self.labels[idx]))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn small_config() -> ImageNet100Config {
ImageNet100Config {
n_samples_per_class: 2,
image_size: 8,
seed: 42,
}
}
#[test]
fn test_imagenet100_shape() {
let cfg = small_config();
let ds = ImageNet100Dataset::generate(cfg.clone()).expect("generate failed");
let n = cfg.n_samples_per_class * IMAGENET100_N_CLASSES;
assert_eq!(ds.n_samples(), n);
assert_eq!(ds.n_classes(), IMAGENET100_N_CLASSES);
let imgs = ds.images();
assert_eq!(imgs.shape(), &[n, 3, cfg.image_size, cfg.image_size]);
assert_eq!(ds.labels().len(), n);
}
#[test]
fn test_imagenet100_deterministic() {
let cfg = small_config();
let ds1 = ImageNet100Dataset::generate(cfg.clone()).expect("generate failed");
let ds2 = ImageNet100Dataset::generate(cfg).expect("generate failed");
assert_eq!(ds1.images(), ds2.images());
assert_eq!(ds1.labels(), ds2.labels());
}
#[test]
fn test_imagenet100_pixel_range() {
let cfg = small_config();
let ds = ImageNet100Dataset::generate(cfg).expect("generate failed");
let imgs = ds.images();
let imgs_ref = imgs.view();
let slice = imgs_ref.as_slice().expect("contiguous");
for &v in slice {
assert!((0.0..=1.0).contains(&v), "pixel value {v} out of [0,1]");
assert!(!v.is_nan(), "NaN pixel found");
}
}
#[test]
fn test_imagenet100_labels_in_range() {
let cfg = small_config();
let ds = ImageNet100Dataset::generate(cfg).expect("generate failed");
for &label in ds.labels().iter() {
assert!(
(label as usize) < IMAGENET100_N_CLASSES,
"label {label} out of range"
);
}
}
#[test]
fn test_imagenet100_class_names() {
let cfg = small_config();
let ds = ImageNet100Dataset::generate(cfg).expect("generate failed");
let names = ds.class_names();
assert_eq!(names.len(), IMAGENET100_N_CLASSES);
assert_eq!(names[0], "class_000");
assert_eq!(names[99], "class_099");
}
#[test]
fn test_imagenet100_get_sample() {
let cfg = small_config();
let ds = ImageNet100Dataset::generate(cfg.clone()).expect("generate failed");
let (view, label) = ds.get_sample(0);
assert_eq!(view.shape(), &[3, cfg.image_size, cfg.image_size]);
assert_eq!(label, 0u32); }
#[test]
fn test_imagenet100_get_sample_owned() {
let cfg = small_config();
let ds = ImageNet100Dataset::generate(cfg.clone()).expect("generate failed");
let (arr, label) = ds.get_sample_owned(0).expect("get_sample_owned failed");
assert_eq!(arr.shape(), &[3, cfg.image_size, cfg.image_size]);
assert_eq!(label, 0u32);
}
#[test]
fn test_imagenet100_error_zero_samples_per_class() {
let cfg = ImageNet100Config {
n_samples_per_class: 0,
..ImageNet100Config::default()
};
assert!(ImageNet100Dataset::generate(cfg).is_err());
}
#[test]
fn test_imagenet100_get_sample_owned_out_of_bounds() {
let cfg = small_config();
let ds = ImageNet100Dataset::generate(cfg.clone()).expect("generate failed");
assert!(ds.get_sample_owned(ds.n_samples()).is_err());
}
#[test]
fn test_imagenet100_labels_balanced() {
let cfg = ImageNet100Config {
n_samples_per_class: 3,
image_size: 4,
seed: 1,
};
let ds = ImageNet100Dataset::generate(cfg).expect("generate failed");
let mut counts = vec![0u32; IMAGENET100_N_CLASSES];
for &lbl in ds.labels().iter() {
counts[lbl as usize] += 1;
}
for (cls, &cnt) in counts.iter().enumerate() {
assert_eq!(cnt, 3, "class {cls} should have exactly 3 samples");
}
}
}