use scirs2_core::ndarray_ext::stats::mean;
use std::collections::HashMap;
use tempfile::tempdir;
use scirs2_core::array_protocol::{
self,
distributed_training::{
DistributedStrategy, DistributedTrainingConfig, DistributedTrainingFactory,
},
grad::Adam,
ml_ops::ActivationFunc,
neural::{Conv2D, Linear, MaxPool2D, Sequential},
serialization::{load_checkpoint, save_checkpoint, ModelSerializer, OnnxExporter},
training::Dataset,
training::{CrossEntropyLoss, DataLoader, InMemoryDataset, Trainer},
};
use scirs2_core::ndarray_ext::Array2;
#[allow(dead_code)]
fn main() {
array_protocol::init();
println!("Distributed Training and Model Serialization Example");
println!("==================================================");
println!("\nPart 1: Creating a Model and Dataset");
println!("----------------------------------");
let model = create_model();
println!("Created model with {} layers", model.layers().len());
let (train_dataset, val_dataset) = create_dataset();
println!(
"Created dataset with {} training samples and {} validation samples",
Dataset::len(&train_dataset),
Dataset::len(&val_dataset)
);
println!("\nPart 2: Distributed Training Setup");
println!("-------------------------------");
let dist_config = DistributedTrainingConfig {
strategy: DistributedStrategy::DataParallel,
numworkers: 2,
rank: 0, is_master: true, syncinterval: 1,
backend: "threaded".to_string(),
mixed_precision: false,
gradient_accumulation_steps: 1,
};
println!(
"Created distributed training config with {} workers using {} strategy",
dist_config.numworkers, dist_config.strategy
);
let dist_train_dataset =
DistributedTrainingFactory::create_dataset(Box::new(train_dataset), &dist_config);
let dist_val_dataset =
DistributedTrainingFactory::create_dataset(Box::new(val_dataset), &dist_config);
println!(
"Created distributed datasets with {} training samples and {} validation samples",
dist_train_dataset.len(),
dist_val_dataset.len()
);
let batch_size = 32;
let train_loader = DataLoader::new(dist_train_dataset, batch_size, true, Some(42));
let val_loader = DataLoader::new(dist_val_dataset, batch_size, false, None);
println!("Created data loaders with batch size {batch_size}");
println!("\nPart 3: Training Setup");
println!("--------------------");
let optimizer = Box::new(Adam::new(0.001, Some(0.9), Some(0.999), Some(1e-8)));
let lossfn = Box::new(CrossEntropyLoss::new(Some("mean")));
let new_model = create_model();
let trainer = Trainer::new(
new_model,
Box::new(Adam::new(0.001, Some(0.9), Some(0.999), Some(1e-8))),
lossfn,
);
println!("Created trainer with Adam optimizer and CrossEntropyLoss");
let dist_trainer = DistributedTrainingFactory::create_trainer(trainer, dist_config.clone());
println!("Created distributed trainer");
println!("\nPart 4: Model Serialization");
println!("-------------------------");
let temp_dir = tempdir().expect("Operation failed");
let modeldir = temp_dir.path().join("models");
println!("Created model directory at: {}", modeldir.display());
let serializer = ModelSerializer::new(&modeldir);
let model_path =
serializer.save_model(&model, "example_model", "v1.0", Some(optimizer.as_ref()));
match model_path {
Ok(path) => println!("Saved model to: {}", path.display()),
Err(e) => println!("Error saving model: {e}"),
}
let loadedmodel = serializer.loadmodel("example_model", "v1.0");
match loadedmodel {
Ok((model, optimizer)) => {
println!("Loaded model with {} layers", model.layers().len());
println!(
"Loaded optimizer: {}",
if optimizer.is_some() { "yes" } else { "no" }
);
}
Err(e) => println!("Error loading model: {e}"),
}
println!("\nPart 5: Checkpoint Management");
println!("---------------------------");
let mut metrics = HashMap::new();
metrics.insert("loss".to_string(), 0.5);
metrics.insert("accuracy".to_string(), 0.85);
let checkpoint_path = modeldir.join("checkpoint");
let result = save_checkpoint(
&model,
optimizer.as_ref(),
&checkpoint_path,
10,
metrics.clone(),
);
match result {
Ok(()) => println!("Saved checkpoint at epoch 10"),
Err(e) => println!("Error saving checkpoint: {e}"),
}
let result = load_checkpoint(&checkpoint_path);
match result {
Ok((model, optimizer, epoch, metrics)) => {
println!("Loaded checkpoint from epoch {epoch}");
println!("Loaded model with {} layers", model.layers().len());
println!(
"Metrics: loss = {}, accuracy = {}",
metrics.get("loss").unwrap_or(&0.0),
metrics.get("accuracy").unwrap_or(&0.0)
);
}
Err(e) => println!("Error loading checkpoint: {e}"),
}
println!("\nPart 6: ONNX Export");
println!("-----------------");
let onnx_path = modeldir.join("model.onnx");
let exporter = OnnxExporter;
let result = exporter.export(&model, &onnx_path, &[1, 3, 224, 224]);
match result {
Ok(()) => println!("Exported model to ONNX format at: {}", onnx_path.display()),
Err(e) => println!("Error exporting model to ONNX: {e}"),
}
println!("\nDistributed Training and Model Serialization Example completed successfully!");
}
#[allow(dead_code)]
fn create_model() -> Sequential {
let mut model = Sequential::new("SimpleModel", Vec::new());
model.add_layer(Box::new(Conv2D::withshape(
"conv1",
3,
3, 3,
16, (1, 1), (1, 1), true, Some(ActivationFunc::ReLU),
)));
model.add_layer(Box::new(MaxPool2D::new(
"pool1",
(2, 2), None, (0, 0), )));
model.add_layer(Box::new(Conv2D::withshape(
"conv2",
3,
3, 16,
32, (1, 1), (1, 1), true, Some(ActivationFunc::ReLU),
)));
model.add_layer(Box::new(MaxPool2D::new(
"pool2",
(2, 2), None, (0, 0), )));
model.add_layer(Box::new(Linear::new_random(
"fc1",
32 * 6 * 6, 128, true, Some(ActivationFunc::ReLU),
)));
model.add_layer(Box::new(Linear::new_random(
"fc_out", 128, 10, true, None, )));
model
}
#[allow(dead_code)]
fn create_dataset() -> (InMemoryDataset, InMemoryDataset) {
let num_samples = 1000;
let num_features = 3 * 28 * 28; let num_classes = 10;
let inputs = Array2::<f64>::from_shape_fn((num_samples, num_features), |_| {
scirs2_core::random::random::<f64>() * 2.0 - 1.0
});
let mut targets = Array2::<f64>::zeros((num_samples, num_classes));
for i in 0..num_samples {
let class = (scirs2_core::random::random::<f64>() * num_classes as f64).floor() as usize;
targets[[i, class]] = 1.0;
}
let train_size = (num_samples as f64 * 0.8).floor() as usize;
let train_inputs = inputs
.slice(scirs2_core::ndarray::s![0..train_size, ..])
.to_owned();
let train_targets = targets
.slice(scirs2_core::ndarray::s![0..train_size, ..])
.to_owned();
let val_inputs = inputs
.slice(scirs2_core::ndarray::s![train_size..num_samples, ..])
.to_owned();
let val_targets = targets
.slice(scirs2_core::ndarray::s![train_size..num_samples, ..])
.to_owned();
let train_dataset = InMemoryDataset::from_arrays(train_inputs, train_targets);
let val_dataset = InMemoryDataset::from_arrays(val_inputs, val_targets);
(train_dataset, val_dataset)
}