use crate::Dataset;
use scirs2_core::RngExt;
use tenflowers_core::{Result, Tensor, TensorError};
pub struct BenchmarkDatasets;
impl BenchmarkDatasets {
pub fn synthetic_mnist(num_samples: usize, seed: Option<u64>) -> Result<MnistDataset<f32>> {
use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
let mut rng = if let Some(seed) = seed {
StdRng::seed_from_u64(seed)
} else {
StdRng::seed_from_u64(12345)
};
let image_size = 28 * 28;
let num_classes = 10;
let mut image_data = Vec::with_capacity(num_samples * image_size);
let mut labels = Vec::with_capacity(num_samples);
for _ in 0..num_samples {
let class = rng.random_range(0..num_classes);
labels.push(class as f32);
for i in 0..image_size {
let row = i / 28;
let col = i % 28;
let pixel_value = match class {
0 => create_circle_pattern(row, col, 28),
1 => create_vertical_line_pattern(row, col, 28),
2 => create_horizontal_line_pattern(row, col, 28),
3 => create_diagonal_pattern(row, col, 28),
4 => create_cross_pattern(row, col, 28),
5 => create_square_pattern(row, col, 28),
_ => rng.random::<f32>() * 0.3, };
let noise = rng.random::<f32>() * 0.1;
let final_value = (pixel_value + noise).clamp(0.0, 1.0);
image_data.push(final_value);
}
}
let images = Tensor::from_vec(image_data, &[num_samples, image_size])?;
let labels_tensor = Tensor::from_vec(labels, &[num_samples])?;
Ok(MnistDataset {
images,
labels: labels_tensor,
num_samples,
})
}
pub fn synthetic_cifar10(num_samples: usize, seed: Option<u64>) -> Result<CifarDataset<f32>> {
use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
let mut rng = if let Some(seed) = seed {
StdRng::seed_from_u64(seed)
} else {
StdRng::seed_from_u64(12345)
};
let image_size = 32 * 32 * 3; let num_classes = 10;
let mut image_data = Vec::with_capacity(num_samples * image_size);
let mut labels = Vec::with_capacity(num_samples);
for _ in 0..num_samples {
let class = rng.random_range(0..num_classes);
labels.push(class as f32);
for i in 0..(32 * 32) {
let row = i / 32;
let col = i % 32;
let (r, g, b) = match class {
0 => (
rng.random::<f32>() * 0.5 + 0.5,
rng.random::<f32>() * 0.3,
rng.random::<f32>() * 0.3,
), 1 => (
rng.random::<f32>() * 0.3,
rng.random::<f32>() * 0.5 + 0.5,
rng.random::<f32>() * 0.3,
), 2 => (
rng.random::<f32>() * 0.3,
rng.random::<f32>() * 0.3,
rng.random::<f32>() * 0.5 + 0.5,
), 3 => create_rgb_gradient(row, col, 32), 4 => create_rgb_checkerboard(row, col), _ => (
rng.random::<f32>(),
rng.random::<f32>(),
rng.random::<f32>(),
), };
image_data.push(r);
image_data.push(g);
image_data.push(b);
}
}
let images = Tensor::from_vec(image_data, &[num_samples, 3, 32, 32])?;
let labels_tensor = Tensor::from_vec(labels, &[num_samples])?;
Ok(CifarDataset {
images,
labels: labels_tensor,
num_samples,
})
}
pub fn synthetic_iris(num_samples: usize, seed: Option<u64>) -> Result<IrisDataset<f32>> {
use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
let mut rng = if let Some(seed) = seed {
StdRng::seed_from_u64(seed)
} else {
StdRng::seed_from_u64(12345)
};
let num_features = 4; let num_classes = 3;
let mut features = Vec::with_capacity(num_samples * num_features);
let mut labels = Vec::with_capacity(num_samples);
for _ in 0..num_samples {
let class = rng.random_range(0..num_classes);
labels.push(class as f32);
let (sepal_length, sepal_width, petal_length, petal_width) = match class {
0 => {
(
rng.random::<f32>() * 1.0 + 4.5, rng.random::<f32>() * 1.0 + 3.0, rng.random::<f32>() * 0.5 + 1.0, rng.random::<f32>() * 0.3 + 0.1, )
}
1 => {
(
rng.random::<f32>() * 1.0 + 5.5, rng.random::<f32>() * 0.8 + 2.2, rng.random::<f32>() * 1.0 + 3.5, rng.random::<f32>() * 0.5 + 1.0, )
}
2 => {
(
rng.random::<f32>() * 1.0 + 6.0, rng.random::<f32>() * 0.8 + 2.5, rng.random::<f32>() * 1.5 + 4.5, rng.random::<f32>() * 0.8 + 1.5, )
}
_ => unreachable!(),
};
features.push(sepal_length);
features.push(sepal_width);
features.push(petal_length);
features.push(petal_width);
}
let features_tensor = Tensor::from_vec(features, &[num_samples, num_features])?;
let labels_tensor = Tensor::from_vec(labels, &[num_samples])?;
Ok(IrisDataset {
features: features_tensor,
labels: labels_tensor,
num_samples,
})
}
}
fn create_circle_pattern(row: usize, col: usize, size: usize) -> f32 {
let center = size as f32 / 2.0;
let distance = ((row as f32 - center).powi(2) + (col as f32 - center).powi(2)).sqrt();
let radius = size as f32 / 3.0;
if distance < radius && distance > radius - 3.0 {
0.8
} else {
0.1
}
}
fn create_vertical_line_pattern(_row: usize, col: usize, size: usize) -> f32 {
let center_col = size / 2;
if col.abs_diff(center_col) < 2 {
0.8
} else {
0.1
}
}
fn create_horizontal_line_pattern(row: usize, _col: usize, size: usize) -> f32 {
let center_row = size / 2;
if row.abs_diff(center_row) < 2 {
0.8
} else {
0.1
}
}
fn create_diagonal_pattern(row: usize, col: usize, _size: usize) -> f32 {
if row.abs_diff(col) < 2 {
0.8
} else {
0.1
}
}
fn create_cross_pattern(row: usize, col: usize, size: usize) -> f32 {
let center_row = size / 2;
let center_col = size / 2;
if row.abs_diff(center_row) < 2 || col.abs_diff(center_col) < 2 {
0.8
} else {
0.1
}
}
fn create_square_pattern(row: usize, col: usize, size: usize) -> f32 {
let margin = size / 4;
if row >= margin && row < size - margin && col >= margin && col < size - margin {
if row < margin + 2
|| row >= size - margin - 2
|| col < margin + 2
|| col >= size - margin - 2
{
0.8
} else {
0.1
}
} else {
0.1
}
}
fn create_rgb_gradient(row: usize, col: usize, size: usize) -> (f32, f32, f32) {
let r = row as f32 / size as f32;
let g = col as f32 / size as f32;
let b = 1.0 - r;
(r, g, b)
}
fn create_rgb_checkerboard(row: usize, col: usize) -> (f32, f32, f32) {
if (row / 4 + col / 4) % 2 == 0 {
(0.8, 0.8, 0.8)
} else {
(0.2, 0.2, 0.2)
}
}
#[derive(Debug, Clone)]
pub struct MnistDataset<T> {
images: Tensor<T>,
labels: Tensor<T>,
num_samples: usize,
}
impl<T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static> Dataset<T>
for MnistDataset<T>
{
fn len(&self) -> usize {
self.num_samples
}
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
if index >= self.num_samples {
return Err(TensorError::invalid_argument(format!(
"Index {} out of bounds for dataset of length {}",
index, self.num_samples
)));
}
use tenflowers_core::ops::slice;
let image_size = 28 * 28;
let image_ranges = vec![index..index + 1, 0..image_size];
#[allow(clippy::single_range_in_vec_init)]
let label_ranges = vec![index..index + 1];
let image_slice = slice(&self.images, &image_ranges)?;
let label_slice = slice(&self.labels, &label_ranges)?;
let image = tenflowers_core::ops::reshape(&image_slice, &[image_size])?;
let label = tenflowers_core::ops::reshape(&label_slice, &[])?;
Ok((image, label))
}
}
#[derive(Debug, Clone)]
pub struct CifarDataset<T> {
images: Tensor<T>,
labels: Tensor<T>,
num_samples: usize,
}
impl<T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static> Dataset<T>
for CifarDataset<T>
{
fn len(&self) -> usize {
self.num_samples
}
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
if index >= self.num_samples {
return Err(TensorError::invalid_argument(format!(
"Index {} out of bounds for dataset of length {}",
index, self.num_samples
)));
}
use tenflowers_core::ops::slice;
let image_ranges = vec![index..index + 1, 0..3, 0..32, 0..32];
#[allow(clippy::single_range_in_vec_init)]
let label_ranges = vec![index..index + 1];
let image_slice = slice(&self.images, &image_ranges)?;
let label_slice = slice(&self.labels, &label_ranges)?;
let image = tenflowers_core::ops::reshape(&image_slice, &[3, 32, 32])?;
let label = tenflowers_core::ops::reshape(&label_slice, &[])?;
Ok((image, label))
}
}
#[derive(Debug, Clone)]
pub struct IrisDataset<T> {
features: Tensor<T>,
labels: Tensor<T>,
num_samples: usize,
}
impl<T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static> Dataset<T>
for IrisDataset<T>
{
fn len(&self) -> usize {
self.num_samples
}
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
if index >= self.num_samples {
return Err(TensorError::invalid_argument(format!(
"Index {} out of bounds for dataset of length {}",
index, self.num_samples
)));
}
use tenflowers_core::ops::slice;
let features_ranges = vec![index..index + 1, 0..4];
#[allow(clippy::single_range_in_vec_init)]
let label_ranges = vec![index..index + 1];
let features_slice = slice(&self.features, &features_ranges)?;
let label_slice = slice(&self.labels, &label_ranges)?;
let features = tenflowers_core::ops::reshape(&features_slice, &[4])?;
let label = tenflowers_core::ops::reshape(&label_slice, &[])?;
Ok((features, label))
}
}
impl<T> MnistDataset<T> {
pub fn images(&self) -> &Tensor<T> {
&self.images
}
pub fn labels(&self) -> &Tensor<T> {
&self.labels
}
pub fn info(&self) -> DatasetInfo {
DatasetInfo {
name: "Synthetic MNIST".to_string(),
num_samples: self.num_samples,
num_classes: 10,
image_shape: vec![28, 28],
num_channels: 1,
}
}
}
impl<T> CifarDataset<T> {
pub fn images(&self) -> &Tensor<T> {
&self.images
}
pub fn labels(&self) -> &Tensor<T> {
&self.labels
}
pub fn info(&self) -> DatasetInfo {
DatasetInfo {
name: "Synthetic CIFAR-10".to_string(),
num_samples: self.num_samples,
num_classes: 10,
image_shape: vec![32, 32],
num_channels: 3,
}
}
}
impl<T> IrisDataset<T> {
pub fn features(&self) -> &Tensor<T> {
&self.features
}
pub fn labels(&self) -> &Tensor<T> {
&self.labels
}
pub fn info(&self) -> DatasetInfo {
DatasetInfo {
name: "Synthetic Iris".to_string(),
num_samples: self.num_samples,
num_classes: 3,
image_shape: vec![4], num_channels: 1,
}
}
pub fn feature_names(&self) -> Vec<String> {
vec![
"sepal_length".to_string(),
"sepal_width".to_string(),
"petal_length".to_string(),
"petal_width".to_string(),
]
}
pub fn class_names(&self) -> Vec<String> {
vec![
"setosa".to_string(),
"versicolor".to_string(),
"virginica".to_string(),
]
}
}
#[derive(Debug, Clone)]
pub struct DatasetInfo {
pub name: String,
pub num_samples: usize,
pub num_classes: usize,
pub image_shape: Vec<usize>,
pub num_channels: usize,
}
impl DatasetInfo {
pub fn display(&self) -> String {
let mut output = String::new();
output.push_str(&format!("Dataset: {}\n", self.name));
output.push_str(&format!("Samples: {}\n", self.num_samples));
output.push_str(&format!("Classes: {}\n", self.num_classes));
output.push_str(&format!("Shape: {:?}\n", self.image_shape));
output.push_str(&format!("Channels: {}\n", self.num_channels));
output
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_synthetic_mnist() {
let dataset = BenchmarkDatasets::synthetic_mnist(100, Some(42))
.expect("test: operation should succeed");
assert_eq!(dataset.len(), 100);
let (image, label) = dataset.get(0).expect("index should be in bounds");
assert_eq!(image.shape().dims(), &[784]); assert_eq!(label.shape().dims(), &[] as &[usize]);
let info = dataset.info();
assert_eq!(info.name, "Synthetic MNIST");
assert_eq!(info.num_classes, 10);
}
#[test]
fn test_synthetic_cifar10() {
let dataset = BenchmarkDatasets::synthetic_cifar10(50, Some(42))
.expect("test: operation should succeed");
assert_eq!(dataset.len(), 50);
let (image, label) = dataset.get(0).expect("index should be in bounds");
assert_eq!(image.shape().dims(), &[3, 32, 32]); assert_eq!(label.shape().dims(), &[] as &[usize]);
let info = dataset.info();
assert_eq!(info.name, "Synthetic CIFAR-10");
assert_eq!(info.num_classes, 10);
assert_eq!(info.num_channels, 3);
}
#[test]
fn test_synthetic_iris() {
let dataset = BenchmarkDatasets::synthetic_iris(150, Some(42))
.expect("test: operation should succeed");
assert_eq!(dataset.len(), 150);
let (features, label) = dataset.get(0).expect("index should be in bounds");
assert_eq!(features.shape().dims(), &[4]); assert_eq!(label.shape().dims(), &[] as &[usize]);
let info = dataset.info();
assert_eq!(info.name, "Synthetic Iris");
assert_eq!(info.num_classes, 3);
let feature_names = dataset.feature_names();
assert_eq!(feature_names.len(), 4);
assert!(feature_names.contains(&"sepal_length".to_string()));
let class_names = dataset.class_names();
assert_eq!(class_names.len(), 3);
assert!(class_names.contains(&"setosa".to_string()));
}
#[test]
fn test_dataset_reproducibility() {
let dataset1 = BenchmarkDatasets::synthetic_mnist(10, Some(123))
.expect("test: operation should succeed");
let dataset2 = BenchmarkDatasets::synthetic_mnist(10, Some(123))
.expect("test: operation should succeed");
let (img1, label1) = dataset1.get(0).expect("index should be in bounds");
let (img2, label2) = dataset2.get(0).expect("index should be in bounds");
assert_eq!(img1.shape(), img2.shape());
assert_eq!(label1.shape(), label2.shape());
if let (Some(data1), Some(data2)) = (img1.as_slice(), img2.as_slice()) {
assert_eq!(data1[0], data2[0]); }
}
#[test]
fn test_pattern_generation() {
let circle = create_circle_pattern(14, 14, 28);
assert!((0.0..=1.0).contains(&circle));
let line = create_vertical_line_pattern(10, 14, 28);
assert!((0.0..=1.0).contains(&line));
let (r, g, b) = create_rgb_gradient(16, 16, 32);
assert!((0.0..=1.0).contains(&r));
assert!((0.0..=1.0).contains(&g));
assert!((0.0..=1.0).contains(&b));
}
}