use burn_core as burn;
use burn::module::Module;
use burn::tensor::{backend::Backend, Tensor};
use burn_nn::{Linear, LinearConfig};
use burn_optim::{GradientsParams, LearningRate, Optimizer, SgdConfig};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
use wass::semidiscrete::{
assign_hard_from_scores, fit_potentials_sgd_neg_dot, scores_neg_dot, SemidiscreteSgdConfig,
};
use crate::linear::LinearCondField;
use crate::sd_fm::{
sample_categorical_from_probs, SdFmTrainAssignment, SdFmTrainConfig, TrainedSdFm,
};
use crate::{Error, Result};
pub type BurnBackend = crate::burn_euclidean::BurnBackend;
#[derive(Module, Debug)]
struct BurnLinearCondField<B: Backend> {
linear: Linear<B>,
d: usize,
}
impl<B: Backend> BurnLinearCondField<B> {
fn new(device: &B::Device, d: usize) -> Self {
let in_dim = 2 * d + 1;
let linear = LinearConfig::new(in_dim, d).with_bias(true).init(device);
Self { linear, d }
}
fn forward(&self, x_t: Tensor<B, 2>, y: Tensor<B, 2>, t: Tensor<B, 2>) -> Tensor<B, 2> {
let feats = Tensor::cat(vec![x_t, y, t], 1);
self.linear.forward(feats)
}
fn export_to_ndarray(&self) -> crate::Result<LinearCondField> {
let w_data = self.linear.weight.to_data();
let w_shape = &w_data.shape;
debug_assert_eq!(w_shape.len(), 2);
let d_in = w_shape[0];
let d_out = w_shape[1];
debug_assert_eq!(d_out, self.d);
debug_assert_eq!(d_in, 2 * self.d + 1);
let b = match self.linear.bias.as_ref() {
Some(b) => b
.to_data()
.to_vec::<f32>()
.map_err(|_| crate::Error::Shape("bias to_vec failed"))?,
None => vec![0.0; d_out],
};
let w_flat: Vec<f32> = w_data
.to_vec::<f32>()
.map_err(|_| crate::Error::Shape("weight to_vec failed"))?;
let mut w = Array2::<f32>::zeros((d_out, 2 * self.d + 2));
for j in 0..d_out {
for i in 0..d_in {
w[[j, i]] = w_flat[i * d_out + j];
}
w[[j, 2 * self.d + 1]] = b[j];
}
Ok(LinearCondField { w })
}
}
fn ndarray_to_burn_2<B: Backend>(
device: &B::Device,
x: &Array2<f32>,
) -> crate::Result<Tensor<B, 2>> {
let (n, d) = x.dim();
let slice = x
.as_slice()
.ok_or(crate::Error::Shape("non-contiguous array"))?;
let data = burn::tensor::TensorData::new(slice.to_vec(), [n, d]);
Ok(Tensor::from_data(data, device))
}
fn ndarray_to_burn_2_keepdim<B: Backend>(
device: &B::Device,
x: &Array1<f32>,
) -> crate::Result<Tensor<B, 2>> {
let n = x.len();
let slice = x
.as_slice()
.ok_or(crate::Error::Shape("non-contiguous array"))?;
let data = burn::tensor::TensorData::new(slice.to_vec(), [n, 1]);
Ok(Tensor::from_data(data, device))
}
pub fn train_sd_fm_semidiscrete_linear_burn(
device: &<BurnBackend as Backend>::Device,
y: &ArrayView2<f32>,
b: &ArrayView1<f32>,
pot_cfg: &SemidiscreteSgdConfig,
fm_cfg: &SdFmTrainConfig,
assignment: SdFmTrainAssignment,
lr: LearningRate,
) -> Result<TrainedSdFm> {
let n = y.nrows();
let d = y.ncols();
if n == 0 || d == 0 {
return Err(Error::Domain("y must be non-empty"));
}
if b.len() != n {
return Err(Error::Shape("b length must match y.nrows()"));
}
if b.iter().any(|&x| x < 0.0) {
return Err(Error::Domain("b must be nonnegative"));
}
let bs = b.sum();
if bs <= 0.0 {
return Err(Error::Domain("b must have positive total mass"));
}
if fm_cfg.steps == 0 || fm_cfg.batch_size == 0 {
return Err(Error::Domain("steps and batch_size must be >= 1"));
}
let b_norm = b.to_owned() / bs;
let g = match assignment {
SdFmTrainAssignment::SemidiscretePotentials => {
fit_potentials_sgd_neg_dot(y, &b_norm.view(), pot_cfg)
.map_err(|_| Error::Domain("failed to fit semidiscrete potentials"))?
}
SdFmTrainAssignment::CategoricalFromB => Array1::<f32>::zeros(n),
};
let mut model = BurnLinearCondField::<BurnBackend>::new(device, d);
let mut optim = SgdConfig::new().init::<BurnBackend, BurnLinearCondField<BurnBackend>>();
let mut rng = ChaCha8Rng::seed_from_u64(fm_cfg.seed);
let bs = fm_cfg.batch_size;
let mut x0s = Array2::<f32>::zeros((bs, d));
let mut ys = Array2::<f32>::zeros((bs, d));
let mut ts = Array1::<f32>::zeros(bs);
let mut xts = Array2::<f32>::zeros((bs, d));
let mut us = Array2::<f32>::zeros((bs, d));
for _step in 0..fm_cfg.steps {
for i in 0..bs {
for k in 0..d {
x0s[[i, k]] = StandardNormal.sample(&mut rng);
}
let x0 = x0s.row(i);
let j = match assignment {
SdFmTrainAssignment::SemidiscretePotentials => {
let scores = scores_neg_dot(&x0, y, &g.view());
assign_hard_from_scores(&scores.view())
}
SdFmTrainAssignment::CategoricalFromB => {
sample_categorical_from_probs(&b_norm.view(), &mut rng)
}
};
let yj = y.row(j);
for k in 0..d {
ys[[i, k]] = yj[k];
}
let t = fm_cfg.t_schedule.sample_t(&mut rng);
ts[i] = t;
for k in 0..d {
let x0k = x0s[[i, k]];
let yk = ys[[i, k]];
xts[[i, k]] = (1.0 - t) * x0k + t * yk;
us[[i, k]] = yk - x0k;
}
}
let x_t = ndarray_to_burn_2::<BurnBackend>(device, &xts)?;
let y_b = ndarray_to_burn_2::<BurnBackend>(device, &ys)?;
let t_b = ndarray_to_burn_2_keepdim::<BurnBackend>(device, &ts)?;
let u_b = ndarray_to_burn_2::<BurnBackend>(device, &us)?;
let pred = model.forward(x_t, y_b, t_b);
let loss = (pred - u_b).powf_scalar(2.0).mean();
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(lr, model, grads);
}
let field = model.export_to_ndarray()?;
Ok(TrainedSdFm {
y: y.to_owned(),
b: b_norm,
g,
assignment,
field,
})
}
pub fn train_rfm_minibatch_ot_linear_burn(
device: &<BurnBackend as Backend>::Device,
y: &ArrayView2<f32>,
b: &ArrayView1<f32>,
rfm_cfg: &crate::sd_fm::RfmMinibatchOtConfig,
fm_cfg: &SdFmTrainConfig,
lr: LearningRate,
) -> Result<TrainedSdFm> {
let n = y.nrows();
let d = y.ncols();
if n == 0 || d == 0 {
return Err(Error::Domain("y must be non-empty"));
}
if b.len() != n {
return Err(Error::Shape("b length must match y.nrows()"));
}
if b.iter().any(|&x| x < 0.0) {
return Err(Error::Domain("b must be nonnegative"));
}
let bs = b.sum();
if bs <= 0.0 {
return Err(Error::Domain("b must have positive total mass"));
}
if fm_cfg.steps == 0 || fm_cfg.batch_size == 0 {
return Err(Error::Domain("steps and batch_size must be >= 1"));
}
if rfm_cfg.pairing_every == 0 {
return Err(Error::Domain("rfm_cfg.pairing_every must be >= 1"));
}
let b_norm = b.to_owned() / bs;
let g = Array1::<f32>::zeros(n);
let mut model = BurnLinearCondField::<BurnBackend>::new(device, d);
let mut optim = SgdConfig::new().init::<BurnBackend, BurnLinearCondField<BurnBackend>>();
let mut rng = ChaCha8Rng::seed_from_u64(fm_cfg.seed);
let bs = fm_cfg.batch_size;
let mut x0s = Array2::<f32>::zeros((bs, d));
let mut ys = Array2::<f32>::zeros((bs, d));
let mut perm: Vec<usize> = Vec::new();
let mut ts = Array1::<f32>::zeros(bs);
let mut xts = Array2::<f32>::zeros((bs, d));
let mut us = Array2::<f32>::zeros((bs, d));
for step in 0..fm_cfg.steps {
let recompute = step == 0 || (step % rfm_cfg.pairing_every == 0);
if recompute {
for i in 0..bs {
for k in 0..d {
x0s[[i, k]] = StandardNormal.sample(&mut rng);
}
}
for i in 0..bs {
let j = sample_categorical_from_probs(&b_norm.view(), &mut rng);
let yj = y.row(j);
for k in 0..d {
ys[[i, k]] = yj[k];
}
}
perm = crate::rfm::apply_pairing(&rfm_cfg.pairing, &x0s.view(), &ys.view(), rfm_cfg)?;
}
debug_assert!(perm.len() >= bs, "perm shorter than batch_size");
debug_assert!(
perm.iter().take(bs).all(|&p| p < bs),
"perm index out of range"
);
for i in 0..bs {
let p = perm[i];
let t = fm_cfg.t_schedule.sample_t(&mut rng);
ts[i] = t;
for k in 0..d {
let x0k = x0s[[i, k]];
let yk = ys[[p, k]];
xts[[i, k]] = (1.0 - t) * x0k + t * yk;
us[[i, k]] = yk - x0k;
}
}
let x_t = ndarray_to_burn_2::<BurnBackend>(device, &xts)?;
let y1 = ndarray_to_burn_2::<BurnBackend>(device, &ys.select(ndarray::Axis(0), &perm))?;
let t_b = ndarray_to_burn_2_keepdim::<BurnBackend>(device, &ts)?;
let u_b = ndarray_to_burn_2::<BurnBackend>(device, &us)?;
let pred = model.forward(x_t, y1, t_b);
let loss = (pred - u_b).powf_scalar(2.0).mean();
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(lr, model, grads);
}
let field = model.export_to_ndarray()?;
Ok(TrainedSdFm {
y: y.to_owned(),
b: b_norm,
g,
assignment: SdFmTrainAssignment::CategoricalFromB,
field,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn burn_sd_fm_trains_and_exports_linear_field() {
let device = <BurnBackend as Backend>::Device::default();
let y = Array2::<f32>::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
.unwrap();
let b = Array1::<f32>::from_elem(4, 1.0);
let pot_cfg = SemidiscreteSgdConfig::default();
let fm_cfg = SdFmTrainConfig {
steps: 3,
batch_size: 16,
..Default::default()
};
let m = train_sd_fm_semidiscrete_linear_burn(
&device,
&y.view(),
&b.view(),
&pot_cfg,
&fm_cfg,
SdFmTrainAssignment::CategoricalFromB,
1e-2,
)
.unwrap();
assert_eq!(m.field.w.nrows(), 2);
assert_eq!(m.field.w.ncols(), 2 * 2 + 2);
}
}