use burn_core as burn;
use burn::module::Module;
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use burn_autodiff::Autodiff;
use burn_ndarray::NdArray;
use burn_nn::{Linear, LinearConfig, Relu};
use burn_optim::{GradientsParams, LearningRate, Optimizer, SgdConfig};
pub type BurnBackend = Autodiff<NdArray<f32>>;
#[derive(Module, Debug)]
pub struct BurnEuclideanCondMlp<B: Backend> {
l1: Linear<B>,
l2: Linear<B>,
}
impl<B: Backend> BurnEuclideanCondMlp<B> {
pub fn new(device: &B::Device, d: usize, hidden: usize) -> Self {
let in_dim = 2 * d + 2;
let l1 = LinearConfig::new(in_dim, hidden).init(device);
let l2 = LinearConfig::new(hidden, d).init(device);
Self { l1, l2 }
}
pub fn forward(&self, x_t: Tensor<B, 2>, x1: Tensor<B, 2>, t: Tensor<B, 2>) -> Tensor<B, 2> {
let ones = Tensor::<B, 2>::ones([t.dims()[0], 1], &t.device());
let feats = Tensor::cat(vec![x_t, x1, t, ones], 1);
let h = Relu.forward(self.l1.forward(feats));
self.l2.forward(h)
}
}
pub fn euclidean_path_targets<B: Backend>(
x0: Tensor<B, 2>,
x1: Tensor<B, 2>,
t: Tensor<B, 2>,
) -> (Tensor<B, 2>, Tensor<B, 2>) {
let xt =
x0.clone() * (Tensor::ones([t.dims()[0], 1], &t.device()) - t.clone()) + x1.clone() * t;
let ut = x1 - x0;
(xt, ut)
}
pub fn train_euclidean_fm_sgd(
device: &<BurnBackend as Backend>::Device,
d: usize,
hidden: usize,
steps: usize,
batch_size: usize,
lr: LearningRate,
) -> crate::Result<BurnEuclideanCondMlp<BurnBackend>> {
if d == 0 {
return Err(crate::Error::Domain("d must be > 0"));
}
if hidden == 0 {
return Err(crate::Error::Domain("hidden must be > 0"));
}
if steps == 0 {
return Err(crate::Error::Domain("steps must be > 0"));
}
if batch_size == 0 {
return Err(crate::Error::Domain("batch_size must be > 0"));
}
if !lr.is_finite() || lr <= 0.0 {
return Err(crate::Error::Domain("lr must be finite and positive"));
}
use burn::tensor::Distribution;
let mut model = BurnEuclideanCondMlp::<BurnBackend>::new(device, d, hidden);
let config = SgdConfig::new();
let mut optim = config.init::<BurnBackend, BurnEuclideanCondMlp<BurnBackend>>();
for _ in 0..steps {
let x0 = Tensor::<BurnBackend, 2>::random(
[batch_size, d],
Distribution::Normal(0.0, 1.0),
device,
);
let x1 = Tensor::<BurnBackend, 2>::random(
[batch_size, d],
Distribution::Normal(0.0, 1.0),
device,
);
let t = Tensor::<BurnBackend, 2>::random(
[batch_size, 1],
Distribution::Uniform(0.0, 1.0),
device,
);
let (xt, ut) = euclidean_path_targets(x0, x1.clone(), t.clone());
let pred = model.forward(xt, x1, t);
let loss = (pred - ut).powf_scalar(2.0).mean();
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(lr, model, grads);
}
Ok(model)
}
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::Distribution;
#[test]
fn burn_euclidean_shapes_smoke() {
let device = <BurnBackend as Backend>::Device::default();
let batch = 4usize;
let d = 3usize;
let x0 =
Tensor::<BurnBackend, 2>::random([batch, d], Distribution::Normal(0.0, 1.0), &device);
let x1 =
Tensor::<BurnBackend, 2>::random([batch, d], Distribution::Normal(0.0, 1.0), &device);
let t =
Tensor::<BurnBackend, 2>::random([batch, 1], Distribution::Uniform(0.0, 1.0), &device);
let (xt, ut) = euclidean_path_targets(x0, x1.clone(), t.clone());
assert_eq!(xt.dims(), [batch, d]);
assert_eq!(ut.dims(), [batch, d]);
let model = BurnEuclideanCondMlp::<BurnBackend>::new(&device, d, 8);
let pred = model.forward(xt, x1, t);
assert_eq!(pred.dims(), [batch, d]);
}
#[test]
fn burn_euclidean_train_smoke() {
let device = <BurnBackend as Backend>::Device::default();
let _model = train_euclidean_fm_sgd(&device, 4, 16, 2, 8, 1e-2).unwrap();
}
}