use std::{cmp::min, time::Instant};
use dfdx::{
data::{ExactSizeDataset, IteratorBatchExt, IteratorCollateExt, IteratorStackExt},
nn::builders::*,
optim::Adam,
prelude::{mse_loss, Optimizer},
shapes::{Const, HasShape, Rank1, Rank2},
tensor::{
AsArray, AutoDevice, Gradients, OnesTensor, SampleTensor, Tensor, TensorFrom, Trace,
ZerosTensor,
},
tensor_ops::{Backward, BroadcastTo, ChooseFrom, RealizeTo, SelectTo, SumTo, TryGt},
};
use indicatif::ProgressIterator;
use rand::{rngs::StdRng, Rng, SeedableRng};
type MlpStructure = (
(Linear<2, 32>, ReLU),
(Linear<32, 32>, ReLU),
(Linear<32, 1>, Sigmoid),
);
type Mlp = <MlpStructure as BuildOnDevice<AutoDevice, f32>>::Built;
pub struct Predictor {
device: AutoDevice,
model: Mlp,
gradients: Gradients<f32, AutoDevice>,
optimizer: Adam<Mlp, f32, AutoDevice>,
}
impl Predictor {
pub fn new(seed: u64) -> Self {
let device = AutoDevice::seed_from_u64(seed);
let model = device.build_module::<MlpStructure, f32>();
let gradients = model.alloc_grads();
let optimizer: Adam<Mlp, f32, AutoDevice> = Adam::new(&model, Default::default());
Self {
device,
model,
gradients,
optimizer,
}
}
pub fn predict_single(
&self,
input: Tensor<Rank1<2>, f32, AutoDevice>,
) -> Tensor<Rank1<1>, f32, AutoDevice> {
let batched: Tensor<Rank2<1, 2>, _, _> = input.clone().broadcast();
let batched_realized: Tensor<(usize, Const<2>), _, _> = batched.try_realize().unwrap();
assert_eq!(batched_realized.shape(), &(1 as usize, Const::<2>));
let batched_prediction = self.predict_batch(batched_realized);
assert_eq!(batched_prediction.shape(), &(1 as usize, Const::<1>));
batched_prediction.select(self.device.tensor(0))
}
pub fn predict_batch(
&self,
input: Tensor<(usize, Const<2>), f32, AutoDevice>,
) -> Tensor<(usize, Const<1>), f32, AutoDevice> {
self.model.forward(input)
}
pub fn learn_batch(
&mut self,
input: Tensor<(usize, Const<2>), f32, AutoDevice>,
expected_output: Tensor<(usize, Const<1>), f32, AutoDevice>,
) -> f32 {
assert_eq!(input.shape().0, expected_output.shape().0);
let predictions = self
.model
.forward_mut(input.traced(self.gradients.to_owned()));
let loss = mse_loss(predictions, expected_output);
let batch_loss = loss.array();
self.gradients = loss.backward();
self.optimizer
.update(&mut self.model, &self.gradients)
.unwrap();
self.model.zero_grads(&mut self.gradients);
batch_loss
}
}
fn function_we_would_like_the_nn_to_mimic(
input: Tensor<(Const<2>,), f32, AutoDevice>,
) -> Tensor<(Const<1>,), f32, AutoDevice> {
let dev = input.device().clone();
let distance_from_center: Tensor<(Const<1>,), f32, AutoDevice> =
input.powi(2).sum().sqrt().broadcast();
distance_from_center.gt(1.0).choose(dev.ones(), dev.zeros())
}
struct XYPointsDataSet {
points_and_predictions: Vec<(
Tensor<Rank1<2>, f32, AutoDevice>,
Tensor<Rank1<1>, f32, AutoDevice>,
)>,
}
impl XYPointsDataSet {
fn new(size: usize, seed: u64) -> Self {
let device = AutoDevice::seed_from_u64(seed);
let points_and_predictions = (0..size)
.into_iter()
.map(|_| {
let point: Tensor<Rank1<2>, f32, AutoDevice> =
device.sample(rand_distr::Uniform::new(-2.0, 2.0));
let class = function_we_would_like_the_nn_to_mimic(point.to_owned());
(point, class)
})
.collect();
Self {
points_and_predictions,
}
}
fn get_loss_of_predictor(&self, predictor: &Predictor) -> f32 {
let mut total_epoch_loss = 0.0;
for (points, classifications) in self
.iter()
.batch_exact(self.len())
.collate()
.stack()
.progress()
{
let predictions = predictor.predict_batch(points);
let loss = mse_loss(predictions, classifications);
total_epoch_loss += loss.array();
}
total_epoch_loss
}
}
impl ExactSizeDataset for XYPointsDataSet {
type Item<'a> = (Tensor<Rank1<2>, f32, AutoDevice>, Tensor<Rank1<1>, f32, AutoDevice>) where Self: 'a;
fn get(&self, index: usize) -> Self::Item<'_> {
self.points_and_predictions[index].to_owned()
}
fn len(&self) -> usize {
self.points_and_predictions.len()
}
}
fn main() {
let mut rng = StdRng::seed_from_u64(43);
let train_set = XYPointsDataSet::new(800, rng.gen());
let test_set = XYPointsDataSet::new(200, rng.gen());
let mut predictor = Predictor::new(rng.gen());
for epoch_number in 1..100 {
let mut num_batches = 0;
let batch_size = min(epoch_number, 128);
let mut total_epoch_loss = 0.0;
let start = Instant::now();
for (points, classifications) in train_set
.shuffled(&mut rng)
.batch_exact(batch_size)
.collate()
.stack()
.progress()
{
num_batches += 1;
total_epoch_loss += predictor.learn_batch(points, classifications);
}
let duration = start.elapsed().as_secs_f32();
println!(
"Epoch {epoch_number} in {} seconds ({:.3} batches/s): avg sample loss {:.5}, test loss = {:.3}",
duration,
((num_batches as f32) / duration),
batch_size as f32 * total_epoch_loss / num_batches as f32,
test_set.get_loss_of_predictor(&predictor)
);
}
}