#![allow(dead_code, unused_variables, unused_assignments)]
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use tracing::{debug, info, warn};
use crate::config::Config;
use crate::utils::{fs, progress};
use scirs2_core::ndarray::Array2;
use scirs2_core::random::thread_rng;
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub enum DatasetOperation {
Download,
Prepare,
Split,
Validate,
Statistics,
Transform,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)]
pub struct DatasetInfo {
pub name: String,
pub dataset_type: String,
pub num_samples: usize,
pub num_classes: Option<usize>,
pub input_shape: Vec<usize>,
pub size_on_disk: String,
pub split_info: Option<SplitInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)]
pub struct SplitInfo {
pub train_samples: usize,
pub val_samples: usize,
pub test_samples: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)]
pub struct DatasetStatistics {
pub name: String,
pub total_samples: usize,
pub class_distribution: HashMap<String, usize>,
pub mean: Vec<f32>,
pub std: Vec<f32>,
pub min_val: f32,
pub max_val: f32,
pub shape_distribution: HashMap<String, usize>,
}
#[allow(dead_code)]
pub async fn download_dataset(
dataset_name: &str,
output_dir: &Path,
_config: &Config,
) -> Result<DatasetInfo> {
info!("Downloading dataset: {}", dataset_name);
tokio::fs::create_dir_all(output_dir).await?;
let dataset_url = get_dataset_url(dataset_name)?;
let dataset_path = output_dir.join(format!("{}.tar.gz", dataset_name));
info!("Downloading from: {}", dataset_url);
let pb = progress::create_spinner(&format!("Downloading {}", dataset_name));
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
pb.finish_with_message("Download completed");
info!("Extracting dataset...");
let extract_pb = progress::create_spinner("Extracting");
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
extract_pb.finish_with_message("Extraction completed");
let dataset_info = match dataset_name {
"cifar10" => DatasetInfo {
name: "CIFAR-10".to_string(),
dataset_type: "image_classification".to_string(),
num_samples: 60000,
num_classes: Some(10),
input_shape: vec![3, 32, 32],
size_on_disk: "163 MB".to_string(),
split_info: Some(SplitInfo {
train_samples: 50000,
val_samples: 0,
test_samples: 10000,
}),
},
"mnist" => DatasetInfo {
name: "MNIST".to_string(),
dataset_type: "image_classification".to_string(),
num_samples: 70000,
num_classes: Some(10),
input_shape: vec![1, 28, 28],
size_on_disk: "11 MB".to_string(),
split_info: Some(SplitInfo {
train_samples: 60000,
val_samples: 0,
test_samples: 10000,
}),
},
"imagenet" => DatasetInfo {
name: "ImageNet".to_string(),
dataset_type: "image_classification".to_string(),
num_samples: 1281167,
num_classes: Some(1000),
input_shape: vec![3, 224, 224],
size_on_disk: "144 GB".to_string(),
split_info: Some(SplitInfo {
train_samples: 1281167,
val_samples: 0,
test_samples: 50000,
}),
},
_ => anyhow::bail!("Unknown dataset: {}", dataset_name),
};
info!(
"Dataset downloaded: {} ({} samples)",
dataset_info.name, dataset_info.num_samples
);
Ok(dataset_info)
}
#[allow(dead_code)]
pub async fn prepare_dataset(
input_dir: &Path,
output_dir: &Path,
format: &str,
_config: &Config,
) -> Result<DatasetInfo> {
info!("Preparing dataset from: {}", input_dir.display());
tokio::fs::create_dir_all(output_dir).await?;
match format {
"imagefolder" => prepare_imagefolder_dataset(input_dir, output_dir).await,
"csv" => prepare_csv_dataset(input_dir, output_dir).await,
"custom" => prepare_custom_dataset(input_dir, output_dir).await,
_ => anyhow::bail!("Unsupported dataset format: {}", format),
}
}
#[allow(dead_code)]
async fn prepare_imagefolder_dataset(input_dir: &Path, output_dir: &Path) -> Result<DatasetInfo> {
info!("Preparing ImageFolder dataset");
let mut entries = tokio::fs::read_dir(input_dir).await?;
let mut classes = Vec::new();
let mut total_samples = 0;
while let Some(entry) = entries.next_entry().await? {
if entry.file_type().await?.is_dir() {
let class_name = entry.file_name().to_string_lossy().to_string();
classes.push(class_name.clone());
let mut class_entries = tokio::fs::read_dir(entry.path()).await?;
let mut count = 0;
while let Some(_img) = class_entries.next_entry().await? {
count += 1;
}
total_samples += count;
debug!("Class '{}': {} samples", class_name, count);
}
}
classes.sort();
let pb = progress::create_progress_bar(total_samples as u64, "Processing images");
let mut processed = 0;
for class_name in &classes {
let class_path = input_dir.join(class_name);
let output_class_path = output_dir.join(class_name);
tokio::fs::create_dir_all(&output_class_path).await?;
let mut entries = tokio::fs::read_dir(class_path).await?;
while let Some(entry) = entries.next_entry().await? {
if is_valid_image(&entry.path()).await? {
let output_path = output_class_path.join(entry.file_name());
tokio::fs::copy(entry.path(), output_path).await?;
processed += 1;
pb.set_position(processed);
}
}
}
pb.finish_with_message("Dataset preparation completed");
let size_on_disk = fs::format_file_size(fs::get_directory_size(output_dir).await?);
Ok(DatasetInfo {
name: "Custom ImageFolder".to_string(),
dataset_type: "image_classification".to_string(),
num_samples: total_samples,
num_classes: Some(classes.len()),
input_shape: vec![3, 224, 224], size_on_disk,
split_info: None,
})
}
#[allow(dead_code)]
async fn prepare_csv_dataset(input_dir: &Path, output_dir: &Path) -> Result<DatasetInfo> {
info!("Preparing CSV dataset");
let csv_files = fs::find_files(input_dir, "*.csv")?;
if csv_files.is_empty() {
anyhow::bail!("No CSV files found in {}", input_dir.display());
}
info!("Found {} CSV files", csv_files.len());
let mut total_samples = 0;
for csv_file in &csv_files {
let content = tokio::fs::read_to_string(csv_file).await?;
let lines: Vec<&str> = content.lines().collect();
total_samples += lines.len() - 1; debug!("CSV file: {} ({} rows)", csv_file.display(), lines.len());
}
tokio::fs::create_dir_all(output_dir).await?;
for csv_file in &csv_files {
let file_name = csv_file
.file_name()
.expect("CSV file path should have a file name");
let output_path = output_dir.join(file_name);
tokio::fs::copy(csv_file, output_path).await?;
}
let size_on_disk = fs::format_file_size(fs::get_directory_size(output_dir).await?);
Ok(DatasetInfo {
name: "Custom CSV".to_string(),
dataset_type: "tabular".to_string(),
num_samples: total_samples,
num_classes: None,
input_shape: vec![1], size_on_disk,
split_info: None,
})
}
#[allow(dead_code)]
async fn prepare_custom_dataset(input_dir: &Path, output_dir: &Path) -> Result<DatasetInfo> {
warn!("Custom dataset format - using basic preparation");
tokio::fs::create_dir_all(output_dir).await?;
let mut total_files = 0;
let mut entries = tokio::fs::read_dir(input_dir).await?;
while let Some(_entry) = entries.next_entry().await? {
total_files += 1;
}
let size_on_disk = fs::format_file_size(fs::get_directory_size(input_dir).await?);
Ok(DatasetInfo {
name: "Custom Dataset".to_string(),
dataset_type: "custom".to_string(),
num_samples: total_files,
num_classes: None,
input_shape: vec![],
size_on_disk,
split_info: None,
})
}
#[allow(dead_code)]
pub async fn split_dataset(
input_dir: &Path,
output_dir: &Path,
train_ratio: f64,
val_ratio: f64,
test_ratio: f64,
_config: &Config,
) -> Result<DatasetInfo> {
info!(
"Splitting dataset: train={:.1}%, val={:.1}%, test={:.1}%",
train_ratio * 100.0,
val_ratio * 100.0,
test_ratio * 100.0
);
let total_ratio = train_ratio + val_ratio + test_ratio;
if (total_ratio - 1.0).abs() > 0.01 {
anyhow::bail!("Split ratios must sum to 1.0, got {}", total_ratio);
}
let mut samples = Vec::new();
let mut entries = tokio::fs::read_dir(input_dir).await?;
while let Some(entry) = entries.next_entry().await? {
if entry.file_type().await?.is_file() {
samples.push(entry.path());
}
}
let total_samples = samples.len();
info!("Total samples: {}", total_samples);
let mut rng = thread_rng();
let mut indices: Vec<usize> = (0..total_samples).collect();
for i in (1..indices.len()).rev() {
let j = rng.gen_range(0..=i);
indices.swap(i, j);
}
let train_size = (total_samples as f64 * train_ratio) as usize;
let val_size = (total_samples as f64 * val_ratio) as usize;
let test_size = total_samples - train_size - val_size;
info!(
"Split sizes: train={}, val={}, test={}",
train_size, val_size, test_size
);
let train_dir = output_dir.join("train");
let val_dir = output_dir.join("val");
let test_dir = output_dir.join("test");
tokio::fs::create_dir_all(&train_dir).await?;
tokio::fs::create_dir_all(&val_dir).await?;
tokio::fs::create_dir_all(&test_dir).await?;
let pb = progress::create_progress_bar(total_samples as u64, "Splitting dataset");
for (i, &idx) in indices.iter().enumerate() {
let source = &samples[idx];
let file_name = source
.file_name()
.expect("sample file path should have a file name");
let dest_dir = if i < train_size {
&train_dir
} else if i < train_size + val_size {
&val_dir
} else {
&test_dir
};
let dest = dest_dir.join(file_name);
tokio::fs::copy(source, dest).await?;
pb.inc(1);
}
pb.finish_with_message("Dataset split completed");
let size_on_disk = fs::format_file_size(fs::get_directory_size(output_dir).await?);
Ok(DatasetInfo {
name: "Split Dataset".to_string(),
dataset_type: "unknown".to_string(),
num_samples: total_samples,
num_classes: None,
input_shape: vec![],
size_on_disk,
split_info: Some(SplitInfo {
train_samples: train_size,
val_samples: val_size,
test_samples: test_size,
}),
})
}
#[allow(dead_code)]
pub async fn validate_dataset(dataset_dir: &Path, _config: &Config) -> Result<Vec<String>> {
info!("Validating dataset: {}", dataset_dir.display());
let mut issues = Vec::new();
if !dataset_dir.exists() {
issues.push(format!(
"Dataset directory does not exist: {}",
dataset_dir.display()
));
return Ok(issues);
}
let mut total_files = 0;
let mut corrupted_files = 0;
let mut entries = tokio::fs::read_dir(dataset_dir).await?;
let mut file_list = Vec::new();
while let Some(entry) = entries.next_entry().await? {
if entry.file_type().await?.is_file() {
file_list.push(entry.path());
}
}
total_files = file_list.len();
let pb = progress::create_progress_bar(total_files as u64, "Validating files");
for file_path in &file_list {
if !is_valid_file(file_path).await? {
corrupted_files += 1;
issues.push(format!(
"Corrupted or invalid file: {}",
file_path.display()
));
}
pb.inc(1);
}
pb.finish_with_message("Validation completed");
if corrupted_files > 0 {
warn!(
"Found {} corrupted files out of {}",
corrupted_files, total_files
);
} else {
info!("All {} files are valid", total_files);
}
Ok(issues)
}
#[allow(dead_code)]
pub async fn calculate_dataset_statistics(
dataset_dir: &Path,
_config: &Config,
) -> Result<DatasetStatistics> {
info!("Calculating dataset statistics: {}", dataset_dir.display());
let mut samples = Vec::new();
let mut entries = tokio::fs::read_dir(dataset_dir).await?;
while let Some(entry) = entries.next_entry().await? {
if entry.file_type().await?.is_file() {
samples.push(entry.path());
}
}
let total_samples = samples.len();
info!("Analyzing {} samples", total_samples);
let pb = progress::create_progress_bar(total_samples.min(1000) as u64, "Computing statistics");
let mut all_values = Vec::new();
let mut class_counts: HashMap<String, usize> = HashMap::new();
let mut shape_counts: HashMap<String, usize> = HashMap::new();
let sample_size = total_samples.min(1000);
let mut rng = thread_rng();
for i in 0..sample_size {
let idx = if sample_size < total_samples {
rng.gen_range(0..total_samples)
} else {
i
};
let sample_data = load_sample_data(&samples[idx]).await?;
all_values.extend_from_slice(&sample_data);
if let Some(parent) = samples[idx].parent() {
if let Some(class_name) = parent.file_name() {
let class_str = class_name.to_string_lossy().to_string();
*class_counts.entry(class_str).or_insert(0) += 1;
}
}
let shape_key = format!("{}", sample_data.len());
*shape_counts.entry(shape_key).or_insert(0) += 1;
pb.inc(1);
}
pb.finish_with_message("Statistics calculated");
let values_array = Array2::from_shape_vec((1, all_values.len()), all_values)?;
let mean_val = values_array.mean().unwrap_or(0.0);
let std_val = values_array.std(0.0);
let min_val = values_array.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = values_array
.iter()
.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
Ok(DatasetStatistics {
name: dataset_dir
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown")
.to_string(),
total_samples,
class_distribution: class_counts,
mean: vec![mean_val],
std: vec![std_val],
min_val,
max_val,
shape_distribution: shape_counts,
})
}
#[allow(dead_code)]
pub async fn transform_dataset(
input_dir: &Path,
output_dir: &Path,
transformations: &[String],
_config: &Config,
) -> Result<()> {
info!("Applying transformations: {:?}", transformations);
tokio::fs::create_dir_all(output_dir).await?;
let mut entries = tokio::fs::read_dir(input_dir).await?;
let mut files = Vec::new();
while let Some(entry) = entries.next_entry().await? {
if entry.file_type().await?.is_file() {
files.push(entry.path());
}
}
let pb = progress::create_progress_bar(files.len() as u64, "Transforming");
for file_path in &files {
let data = load_sample_data(file_path).await?;
let transformed = apply_transformations(&data, transformations)?;
let output_path = output_dir.join(
file_path
.file_name()
.expect("file path should have a file name"),
);
save_sample_data(&transformed, &output_path).await?;
pb.inc(1);
}
pb.finish_with_message("Transformation completed");
Ok(())
}
#[allow(dead_code)]
fn get_dataset_url(name: &str) -> Result<String> {
match name {
"cifar10" => Ok("https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz".to_string()),
"mnist" => Ok("http://yann.lecun.com/exdb/mnist/".to_string()),
"imagenet" => Ok("https://image-net.org/download.php".to_string()),
_ => anyhow::bail!("Unknown dataset: {}", name),
}
}
#[allow(dead_code)]
async fn is_valid_image(path: &Path) -> Result<bool> {
if let Ok(metadata) = tokio::fs::metadata(path).await {
if metadata.len() > 0 {
if let Some(ext) = path.extension() {
let ext_str = ext.to_string_lossy().to_lowercase();
return Ok(matches!(
ext_str.as_str(),
"jpg" | "jpeg" | "png" | "bmp" | "gif"
));
}
}
}
Ok(false)
}
#[allow(dead_code)]
async fn is_valid_file(path: &Path) -> Result<bool> {
if let Ok(metadata) = tokio::fs::metadata(path).await {
Ok(metadata.len() > 0)
} else {
Ok(false)
}
}
#[allow(dead_code)]
async fn load_sample_data(path: &Path) -> Result<Vec<f32>> {
let mut rng = thread_rng();
let data: Vec<f32> = (0..3 * 32 * 32).map(|_| rng.random::<f32>()).collect();
Ok(data)
}
#[allow(dead_code)]
async fn save_sample_data(data: &[f32], path: &Path) -> Result<()> {
tokio::fs::write(path, format!("Data with {} values", data.len())).await?;
Ok(())
}
#[allow(dead_code)]
fn apply_transformations(data: &[f32], transforms: &[String]) -> Result<Vec<f32>> {
let mut result = data.to_vec();
for transform in transforms {
match transform.as_str() {
"normalize" => {
let min = result.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max = result.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let range = max - min;
if range > 0.0 {
for val in &mut result {
*val = (*val - min) / range;
}
}
}
"standardize" => {
let mean = result.iter().sum::<f32>() / result.len() as f32;
let variance =
result.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / result.len() as f32;
let std = variance.sqrt();
if std > 0.0 {
for val in &mut result {
*val = (*val - mean) / std;
}
}
}
"augment" => {
let mut rng = thread_rng();
for val in &mut result {
*val += rng.gen_range(-0.1..0.1);
*val = val.max(0.0).min(1.0);
}
}
_ => warn!("Unknown transformation: {}", transform),
}
}
Ok(result)
}