use anyhow::Result;
use tch::nn::{Adam, ModuleT, OptimizerConfig, VarStore};
use tch::vision::dataset::Dataset;
use tch::TrainableCModule;
use tch::{CModule, Device};
fn train_and_save_model(dataset: &Dataset, device: Device) -> Result<()> {
let vs = VarStore::new(device);
let mut trainable = TrainableCModule::load("model.pt", vs.root())?;
trainable.set_train();
let initial_acc = trainable.batch_accuracy_for_logits(
&dataset.test_images,
&dataset.test_labels,
vs.device(),
1024,
);
println!("Initial accuracy: {:5.2}%", 100. * initial_acc);
let mut opt = Adam::default().build(&vs, 1e-4)?;
for epoch in 1..20 {
for (images, labels) in dataset.train_iter(128).shuffle().to_device(vs.device()).take(50) {
let loss = trainable.forward_t(&images, true).cross_entropy_for_logits(&labels);
opt.backward_step(&loss);
}
let test_accuracy = trainable.batch_accuracy_for_logits(
&dataset.test_images,
&dataset.test_labels,
vs.device(),
1024,
);
println!("epoch: {:4} test acc: {:5.2}%", epoch, 100. * test_accuracy,);
}
trainable.save("trained_model.pt")?;
Ok(())
}
fn load_trained_and_test_acc(dataset: &Dataset, device: Device) -> Result<()> {
let mut module = CModule::load_on_device("trained_model.pt", device)?;
module.set_eval();
let accuracy =
module.batch_accuracy_for_logits(&dataset.test_images, &dataset.test_labels, device, 1024);
println!("Updated accuracy: {:5.2}%", 100. * accuracy);
Ok(())
}
fn main() -> Result<()> {
let dataset = tch::vision::mnist::load_dir("data")?;
let dataset = Dataset {
train_images: dataset.train_images.view([-1, 1, 28, 28]),
test_images: dataset.test_images.view([-1, 1, 28, 28]),
..dataset
};
let device = Device::Cpu;
train_and_save_model(&dataset, device)?;
load_trained_and_test_acc(&dataset, device)?;
Ok(())
}