pub mod common;
pub mod download;
pub mod nlp;
pub mod vision;
pub use common::{
DatasetConfig, ProgressTracker, AG_NEWS_TEST_URL, AG_NEWS_TRAIN_URL, CIFAR10_URL,
IMAGENET_LABELS_URL, IMAGENET_VAL_URL, IMDB_URL, MNIST_TEST_IMAGES_URL, MNIST_TEST_LABELS_URL,
MNIST_TRAIN_IMAGES_URL, MNIST_TRAIN_LABELS_URL,
};
pub use download::{get_file_size, verify_checksum, Downloader};
pub use vision::{
Cifar10Config, ImageNetConfig, MnistConfig, RealCifar10Builder, RealCifar10Dataset,
RealImageNetBuilder, RealImageNetDataset, RealMnistBuilder, RealMnistDataset,
};
pub use nlp::{
AgNewsConfig, ImdbConfig, RealAgNewsBuilder, RealAgNewsDataset, RealImdbBuilder,
RealImdbDataset,
};
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_mnist_builder_integration() {
let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
let builder = RealMnistBuilder::new()
.root(temp_dir.path())
.train(true)
.download(false);
let result = builder.build::<f32>();
assert!(result.is_err());
}
#[test]
fn test_cifar10_builder_integration() {
let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
let builder = RealCifar10Builder::new()
.root(temp_dir.path())
.train(true)
.download(false);
let result = builder.build::<f32>();
assert!(result.is_err());
}
#[test]
fn test_imagenet_builder_integration() {
let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
let builder = RealImageNetBuilder::new()
.root(temp_dir.path())
.train(false) .download(false);
let result = builder.build::<f32>();
assert!(result.is_err());
}
#[test]
fn test_imdb_builder_integration() {
let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
let builder = RealImdbBuilder::new()
.root(temp_dir.path())
.train(true)
.download(false);
let result = builder.build();
assert!(result.is_err());
}
#[test]
fn test_ag_news_builder_integration() {
let temp_dir = TempDir::new().expect("test: temp dir creation should succeed");
let builder = RealAgNewsBuilder::new()
.root(temp_dir.path())
.train(true)
.download(false);
let result = builder.build();
assert!(result.is_err());
}
#[test]
fn test_downloader_integration() {
let downloader = Downloader::new();
drop(downloader);
}
#[test]
fn test_constants_available() {
assert!(!MNIST_TRAIN_IMAGES_URL.is_empty());
assert!(!MNIST_TRAIN_LABELS_URL.is_empty());
assert!(!MNIST_TEST_IMAGES_URL.is_empty());
assert!(!MNIST_TEST_LABELS_URL.is_empty());
assert!(!CIFAR10_URL.is_empty());
assert!(!IMAGENET_VAL_URL.is_empty());
assert!(!IMAGENET_LABELS_URL.is_empty());
assert!(!IMDB_URL.is_empty());
assert!(!AG_NEWS_TRAIN_URL.is_empty());
assert!(!AG_NEWS_TEST_URL.is_empty());
}
#[test]
fn test_config_defaults() {
let mnist_config = MnistConfig::default();
assert!(mnist_config.train);
assert!(mnist_config.download);
let cifar10_config = Cifar10Config::default();
assert!(cifar10_config.train);
assert!(cifar10_config.download);
let imagenet_config = ImageNetConfig::default();
assert!(!imagenet_config.train); assert!(!imagenet_config.download);
let imdb_config = ImdbConfig::default();
assert!(imdb_config.train);
assert!(imdb_config.download);
let ag_news_config = AgNewsConfig::default();
assert!(ag_news_config.train);
assert!(ag_news_config.download);
}
#[test]
fn test_class_names() {
let mnist_classes = 10;
let cifar10_classes = 10;
let imagenet_classes = 1000;
let imdb_classes = 2;
let ag_news_classes = 4;
assert_eq!(mnist_classes, 10);
assert_eq!(cifar10_classes, 10);
assert_eq!(imagenet_classes, 1000);
assert_eq!(imdb_classes, 2);
assert_eq!(ag_news_classes, 4);
let cifar10_names = vision::RealCifar10Dataset::<f32>::class_names();
assert_eq!(cifar10_names.len(), 10);
assert_eq!(cifar10_names[0], "airplane");
let imdb_names = nlp::RealImdbDataset::class_names();
assert_eq!(imdb_names.len(), 2);
assert_eq!(imdb_names[0], "negative");
assert_eq!(imdb_names[1], "positive");
let ag_news_names = nlp::RealAgNewsDataset::class_names();
assert_eq!(ag_news_names.len(), 4);
assert_eq!(ag_news_names[0], "World");
assert_eq!(ag_news_names[3], "Science/Tech");
}
}