use crate::models::sequential::Sequential;
use crate::models::high_level::{HighLevelModel, FitConfig, TrainingHistory};
use crate::nn::{Linear, Module};
use crate::autograd::Variable;
use crate::tensor::Tensor;
use crate::data::{DataLoader, TensorDataset};
use crate::data::sampler::{RandomSampler, SequentialSampler};
use num_traits::Float;
use std::fmt::Debug;
use anyhow::Result;
pub fn simple_mlp_example<T>() -> Result<()>
where
T: Float + Send + Sync + 'static + Debug + Clone,
{
println!("=== Simple MLP Example ===");
let mut model = Sequential::<T>::with_name("simple_mlp")
.add(Linear::new(784, 128)) .add(Linear::new(128, 64)) .add(Linear::new(64, 10));
println!("{}", model.summary());
println!("Model created successfully!");
println!("Total parameters: {}", model.total_parameters());
Ok(())
}
pub fn cnn_example<T>() -> Result<()>
where
T: Float + Send + Sync + 'static + Debug + Clone,
{
println!("=== CNN Example (Conceptual) ===");
let mut model = Sequential::<T>::with_name("cnn_model");
model = model
.add(Linear::new(28 * 28, 128)) .add(Linear::new(128, 64))
.add(Linear::new(64, 10));
println!("{}", model.summary());
println!("CNN model structure created (using Linear layers as placeholder)");
Ok(())
}
pub fn autoencoder_example<T>() -> Result<()>
where
T: Float + Send + Sync + 'static + Debug + Clone,
{
println!("=== Autoencoder Example ===");
let mut encoder = Sequential::<T>::with_name("encoder")
.add(Linear::new(784, 256)) .add(Linear::new(256, 128)) .add(Linear::new(128, 64)) .add(Linear::new(64, 32));
let mut decoder = Sequential::<T>::with_name("decoder")
.add(Linear::new(32, 64)) .add(Linear::new(64, 128)) .add(Linear::new(128, 256)) .add(Linear::new(256, 784));
println!("Encoder:");
println!("{}", encoder.summary());
println!("\nDecoder:");
println!("{}", decoder.summary());
println!("Autoencoder models created successfully!");
Ok(())
}
pub fn training_example<T>() -> Result<()>
where
T: Float + Send + Sync + 'static + Debug + Clone,
{
println!("=== Training Process Example ===");
let mut model = Sequential::<T>::with_name("training_demo")
.add(Linear::new(10, 64))
.add(Linear::new(64, 32))
.add(Linear::new(32, 1));
println!("Model created:");
println!("{}", model.summary());
let train_data = create_dummy_dataset::<T>(1000, 10)?;
let val_data = create_dummy_dataset::<T>(200, 10)?;
let train_sampler = Box::new(RandomSampler::new(train_data.len()));
let mut train_loader = DataLoader::new(&train_data, train_sampler, 32);
let val_sampler = Box::new(SequentialSampler::new(val_data.len()));
let mut val_loader = DataLoader::new(&val_data, val_sampler, 32);
let config = FitConfig::new()
.epochs(10)
.batch_size(32)
.verbose(true)
.early_stopping(3);
println!("Training configuration:");
println!(" - Epochs: {}", config.epochs);
println!(" - Batch size: {}", config.batch_size);
println!(" - Early stopping patience: {:?}", config.patience);
println!("Training example structure prepared!");
Ok(())
}
pub fn prediction_example<T>() -> Result<()>
where
T: Float + Send + Sync + 'static + Debug + Clone,
{
println!("=== Prediction and Evaluation Example ===");
let model = Sequential::<T>::with_name("prediction_demo")
.add(Linear::new(5, 32))
.add(Linear::new(32, 16))
.add(Linear::new(16, 3));
println!("Model for prediction:");
println!("{}", model.summary());
let input_data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let input_tensor = Tensor::from_vec(input_data, vec![1, 5]); let input_var = Variable::new(input_tensor, false);
let prediction = model.predict(&input_var)?;
println!("Single prediction completed");
let test_data = create_dummy_dataset::<T>(100, 5)?;
let test_sampler = Box::new(SequentialSampler::new(test_data.len()));
let mut test_loader = DataLoader::new(&test_data, test_sampler, 10);
let predictions = model.predict_batch(&mut test_loader)?;
println!("Batch prediction completed: {} batches", predictions.len());
println!("Prediction and evaluation examples completed!");
Ok(())
}
pub fn model_persistence_example<T>() -> Result<()>
where
T: Float + Send + Sync + 'static + Debug + Clone,
{
println!("=== Model Persistence Example ===");
let mut model = Sequential::<T>::with_name("persistence_demo")
.add(Linear::new(20, 64))
.add(Linear::new(64, 32))
.add(Linear::new(32, 5));
println!("Original model:");
println!("{}", model.summary());
let model_path = "models/my_model.rustorch";
model.save(model_path)?;
println!("Model saved to: {}", model_path);
let mut loaded_model = Sequential::<T>::new();
loaded_model.load(model_path)?;
println!("Model loaded from: {}", model_path);
println!("Loaded model parameters: {}", loaded_model.total_parameters());
Ok(())
}
pub fn transfer_learning_example<T>() -> Result<()>
where
T: Float + Send + Sync + 'static + Debug + Clone,
{
println!("=== Transfer Learning Example (Conceptual) ===");
let mut pretrained_model = Sequential::<T>::with_name("pretrained_model");
let mut transfer_model = Sequential::<T>::with_name("transfer_model")
.add(Linear::new(512, 256)) .add(Linear::new(256, 10));
println!("Transfer learning model:");
println!("{}", transfer_model.summary());
println!("Transfer learning configuration:");
println!(" - Frozen layers: {} (feature extraction)", 0); println!(" - Trainable layers: {} (classification head)", transfer_model.len());
Ok(())
}
fn create_dummy_dataset<T>(size: usize, input_dim: usize) -> Result<TensorDataset<T>>
where
T: Float + Send + Sync + 'static + Debug + Clone + From<f32>,
{
let mut input_tensors = Vec::new();
let mut target_tensors = Vec::new();
for i in 0..size {
let input_data: Vec<T> = (0..input_dim)
.map(|j| T::from((i + j) as f32 * 0.01))
.collect();
let input_tensor = Tensor::from_vec(input_data, vec![input_dim]);
let target_data: Vec<T> = vec![T::from(i as f32 * 0.1)];
let target_tensor = Tensor::from_vec(target_data, vec![1]);
input_tensors.push(input_tensor);
target_tensors.push(target_tensor);
}
TensorDataset::from_features_targets(input_tensors, target_tensors)
}
pub fn phase5_dataset_example<T>() -> Result<()>
where
T: Float + Send + Sync + 'static + Debug + Clone + From<f32>,
{
use crate::data::Dataset;
println!("=== Phase 5 Dataset API Example ===");
let dataset = create_dummy_dataset::<T>(100, 5)?;
println!("Created dataset with {} samples", dataset.len());
let sampler = Box::new(RandomSampler::new(dataset.len()));
let mut dataloader = DataLoader::new(&dataset, sampler, 10);
println!("Created DataLoader with batch size: {}", dataloader.batch_size());
let mut batch_count = 0;
while let Some(batch) = dataloader.next_batch() {
batch_count += 1;
println!("Processed batch {} with {} items", batch_count, batch.len());
if batch_count >= 5 {
break;
}
}
println!("Phase 5 Dataset API example completed!");
Ok(())
}
pub fn run_all_examples<T>() -> Result<()>
where
T: Float + Send + Sync + 'static + Debug + Clone + From<f32>,
{
println!("Running Sequential API Examples");
println!("================================\n");
simple_mlp_example::<T>()?;
println!();
cnn_example::<T>()?;
println!();
autoencoder_example::<T>()?;
println!();
training_example::<T>()?;
println!();
prediction_example::<T>()?;
println!();
model_persistence_example::<T>()?;
println!();
transfer_learning_example::<T>()?;
println!();
phase5_dataset_example::<T>()?;
println!();
println!("All examples completed successfully!");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_mlp_example() {
assert!(simple_mlp_example::<f32>().is_ok());
}
#[test]
fn test_cnn_example() {
assert!(cnn_example::<f32>().is_ok());
}
#[test]
fn test_autoencoder_example() {
assert!(autoencoder_example::<f32>().is_ok());
}
#[test]
fn test_training_example() {
assert!(training_example::<f32>().is_ok());
}
#[test]
fn test_prediction_example() {
assert!(prediction_example::<f32>().is_ok());
}
#[test]
fn test_model_persistence_example() {
assert!(model_persistence_example::<f32>().is_ok());
}
#[test]
fn test_transfer_learning_example() {
assert!(transfer_learning_example::<f32>().is_ok());
}
#[test]
fn test_run_all_examples() {
assert!(run_all_examples::<f32>().is_ok());
}
#[test]
fn test_phase5_dataset_example() {
assert!(phase5_dataset_example::<f32>().is_ok());
}
#[test]
fn test_create_dummy_dataset() {
let dataset = create_dummy_dataset::<f32>(100, 5);
assert!(dataset.is_ok());
let dataset = dataset.unwrap();
assert_eq!(dataset.len(), 100);
}
}