use ndarray::LinalgScalar;
use num_traits::Float;
use rand::distr::Distribution as RandDistribution;
use rand_distr::uniform::SampleUniform;
use rand::Rng;
use rand_distr::StandardNormal;
pub trait EuclideanVector: Clone {
type Scalar: Float + LinalgScalar + SampleUniform + Copy;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn zeros_like(&self) -> Self;
fn fill_zero(&mut self);
fn assign(&mut self, other: &Self);
fn add_assign(&mut self, other: &Self);
fn sub_assign(&mut self, other: &Self);
fn add_scaled_assign(&mut self, other: &Self, alpha: Self::Scalar);
fn scale_assign(&mut self, alpha: Self::Scalar);
fn scale_diag_assign(&mut self, diag: &[Self::Scalar]);
fn add_diag_scaled_assign(&mut self, other: &Self, diag: &[Self::Scalar], alpha: Self::Scalar);
fn dot(&self, other: &Self) -> Self::Scalar;
fn quad_form_diag(&self, diag: &[Self::Scalar]) -> Self::Scalar;
fn fill_standard_normal(&mut self, rng: &mut impl Rng)
where
StandardNormal: RandDistribution<Self::Scalar>;
fn write_to_slice(&self, out: &mut [Self::Scalar]);
fn read_from_slice(&mut self, input: &[Self::Scalar]);
}
impl<T> EuclideanVector for ndarray::Array1<T>
where
T: Float + LinalgScalar + SampleUniform + Copy,
StandardNormal: RandDistribution<T>,
{
type Scalar = T;
fn len(&self) -> usize {
self.len()
}
fn zeros_like(&self) -> Self {
ndarray::Array1::zeros(self.len())
}
fn fill_zero(&mut self) {
self.fill(T::zero());
}
fn assign(&mut self, other: &Self) {
self.clone_from(other);
}
fn add_assign(&mut self, other: &Self) {
ndarray::Zip::from(self).and(other).for_each(|a, b| {
*a = *a + *b;
});
}
fn sub_assign(&mut self, other: &Self) {
ndarray::Zip::from(self).and(other).for_each(|a, b| {
*a = *a - *b;
});
}
fn add_scaled_assign(&mut self, other: &Self, alpha: Self::Scalar) {
ndarray::Zip::from(self).and(other).for_each(|a, b| {
*a = *a + *b * alpha;
});
}
fn scale_assign(&mut self, alpha: Self::Scalar) {
self.mapv_inplace(|x| x * alpha);
}
fn scale_diag_assign(&mut self, diag: &[Self::Scalar]) {
assert_eq!(
diag.len(),
self.len(),
"scale_diag_assign dimension mismatch"
);
ndarray::Zip::from(self).and(diag).for_each(|a, scale| {
*a = *a * *scale;
});
}
fn add_diag_scaled_assign(&mut self, other: &Self, diag: &[Self::Scalar], alpha: Self::Scalar) {
assert_eq!(
diag.len(),
self.len(),
"add_diag_scaled_assign dimension mismatch"
);
ndarray::Zip::from(self)
.and(other)
.and(diag)
.for_each(|a, b, scale| {
*a = *a + *b * *scale * alpha;
});
}
fn dot(&self, other: &Self) -> Self::Scalar {
self.view().dot(&other.view())
}
fn quad_form_diag(&self, diag: &[Self::Scalar]) -> Self::Scalar {
assert_eq!(diag.len(), self.len(), "quad_form_diag dimension mismatch");
self.iter()
.zip(diag.iter())
.fold(T::zero(), |acc, (&x, &d)| acc + x * x * d)
}
fn fill_standard_normal(&mut self, rng: &mut impl Rng)
where
StandardNormal: RandDistribution<Self::Scalar>,
{
self.iter_mut()
.for_each(|x| *x = rng.sample(StandardNormal));
}
fn write_to_slice(&self, out: &mut [Self::Scalar]) {
assert_eq!(
out.len(),
self.len(),
"write_to_slice called with mismatched buffer length"
);
let slice = self
.as_slice()
.expect("Array1 is expected to be contiguous when writing to slice");
out.copy_from_slice(slice);
}
fn read_from_slice(&mut self, input: &[Self::Scalar]) {
assert_eq!(
input.len(),
self.len(),
"read_from_slice called with mismatched buffer length"
);
let slice = self
.as_slice_mut()
.expect("Array1 is expected to be contiguous when reading from slice");
slice.copy_from_slice(input);
}
}
pub trait BatchVector: EuclideanVector {
type Energy: Clone;
type Mask;
fn n_chains(&self) -> usize;
fn dim_per_chain(&self) -> usize;
fn kinetic_energy(&self) -> Self::Energy;
fn kinetic_energy_diag(&self, inv_diag: &[Self::Scalar]) -> Self::Energy;
fn masked_assign(&mut self, other: &Self, mask: &Self::Mask);
fn fill_random_normal(&mut self, rng: &mut impl Rng)
where
StandardNormal: RandDistribution<Self::Scalar>;
fn sample_uniform(&self, rng: &mut impl Rng) -> Self::Energy
where
StandardNormal: RandDistribution<Self::Scalar>;
fn energy_sub(a: &Self::Energy, b: &Self::Energy) -> Self::Energy;
fn energy_add(a: &Self::Energy, b: &Self::Energy) -> Self::Energy;
fn energy_neg(a: &Self::Energy) -> Self::Energy;
fn energy_ln(a: &Self::Energy) -> Self::Energy;
fn mean_acceptance(log_accept: &Self::Energy) -> Self::Scalar;
fn accept_mask(log_accept: &Self::Energy, ln_u: &Self::Energy) -> Self::Mask;
}
impl<T> BatchVector for ndarray::Array1<T>
where
T: Float + LinalgScalar + SampleUniform + Copy,
StandardNormal: RandDistribution<T>,
{
type Energy = T;
type Mask = bool;
fn n_chains(&self) -> usize {
1
}
fn dim_per_chain(&self) -> usize {
self.len()
}
fn kinetic_energy(&self) -> T {
self.dot(self) * T::from(0.5).unwrap()
}
fn kinetic_energy_diag(&self, inv_diag: &[Self::Scalar]) -> T {
self.quad_form_diag(inv_diag) * T::from(0.5).unwrap()
}
fn masked_assign(&mut self, other: &Self, mask: &bool) {
if *mask {
self.assign(other);
}
}
fn fill_random_normal(&mut self, rng: &mut impl Rng)
where
StandardNormal: RandDistribution<Self::Scalar>,
{
self.iter_mut()
.for_each(|x| *x = rng.sample(StandardNormal));
}
fn sample_uniform(&self, rng: &mut impl Rng) -> T
where
StandardNormal: RandDistribution<Self::Scalar>,
{
use rand::distr::Uniform;
let dist = Uniform::new(T::zero(), T::one()).unwrap();
rng.sample(dist)
}
fn energy_sub(a: &T, b: &T) -> T {
*a - *b
}
fn energy_add(a: &T, b: &T) -> T {
*a + *b
}
fn energy_neg(a: &T) -> T {
T::zero() - *a
}
fn energy_ln(a: &T) -> T {
a.ln()
}
fn mean_acceptance(log_accept: &T) -> T {
if !log_accept.is_finite() {
T::zero()
} else if *log_accept < T::zero() {
log_accept.exp()
} else {
T::one()
}
}
fn accept_mask(log_accept: &T, ln_u: &T) -> bool {
*log_accept >= *ln_u
}
}
#[cfg(feature = "burn")]
mod burn_impl {
use super::EuclideanVector;
use burn::prelude::{Backend, Tensor};
use burn::tensor::Element;
use burn::tensor::ElementConversion;
use num_traits::{Float, FromPrimitive};
use rand::Rng;
use rand::distr::Distribution as RandDistribution;
use rand_distr::StandardNormal;
use rand_distr::uniform::SampleUniform;
fn expand_diag<T, B>(diag: &[T], n_rows: usize) -> Tensor<B, 2>
where
T: Float + Element + ElementConversion + SampleUniform + FromPrimitive + Copy,
B: Backend<FloatElem = T>,
StandardNormal: RandDistribution<T>,
{
let dim = diag.len();
let base: Tensor<B, 2> = Tensor::<B, 1>::from_data(
burn::tensor::TensorData::new(diag.to_vec(), [dim]),
&B::Device::default(),
)
.unsqueeze_dim(0);
base.expand([n_rows, dim])
}
impl<T, B> EuclideanVector for Tensor<B, 1>
where
T: Float + Element + ElementConversion + SampleUniform + FromPrimitive + Copy,
B: Backend<FloatElem = T>,
StandardNormal: RandDistribution<T>,
{
type Scalar = T;
fn len(&self) -> usize {
self.dims()[0]
}
fn zeros_like(&self) -> Self {
Tensor::<B, 1>::zeros_like(self)
}
fn fill_zero(&mut self) {
let zeros = Tensor::<B, 1>::zeros_like(self);
self.inplace(|_| zeros.clone());
}
fn assign(&mut self, other: &Self) {
self.inplace(|_| other.clone());
}
fn add_assign(&mut self, other: &Self) {
self.inplace(|x| x.add(other.clone()));
}
fn sub_assign(&mut self, other: &Self) {
self.inplace(|x| x.sub(other.clone()));
}
fn add_scaled_assign(&mut self, other: &Self, alpha: Self::Scalar) {
self.inplace(|x| x.add(other.clone().mul_scalar(alpha)));
}
fn scale_assign(&mut self, alpha: Self::Scalar) {
self.inplace(|x| x.mul_scalar(alpha));
}
fn scale_diag_assign(&mut self, diag: &[Self::Scalar]) {
assert_eq!(
diag.len(),
self.len(),
"scale_diag_assign dimension mismatch"
);
let scale: Tensor<B, 1> = Tensor::<B, 1>::from_data(
burn::tensor::TensorData::new(diag.to_vec(), [self.len()]),
&B::Device::default(),
);
self.inplace(|x| x.mul(scale));
}
fn add_diag_scaled_assign(
&mut self,
other: &Self,
diag: &[Self::Scalar],
alpha: Self::Scalar,
) {
assert_eq!(
diag.len(),
self.len(),
"add_diag_scaled_assign dimension mismatch"
);
let scale: Tensor<B, 1> = Tensor::<B, 1>::from_data(
burn::tensor::TensorData::new(diag.to_vec(), [self.len()]),
&B::Device::default(),
)
.mul_scalar(alpha);
self.inplace(|x| x.add(other.clone().mul(scale)));
}
fn dot(&self, other: &Self) -> Self::Scalar {
self.clone().mul(other.clone()).sum().into_scalar()
}
fn quad_form_diag(&self, diag: &[Self::Scalar]) -> Self::Scalar {
assert_eq!(diag.len(), self.len(), "quad_form_diag dimension mismatch");
let scale: Tensor<B, 1> = Tensor::<B, 1>::from_data(
burn::tensor::TensorData::new(diag.to_vec(), [self.len()]),
&B::Device::default(),
);
self.clone()
.mul(self.clone())
.mul(scale)
.sum()
.into_scalar()
}
fn fill_standard_normal(&mut self, _rng: &mut impl Rng)
where
StandardNormal: RandDistribution<Self::Scalar>,
{
let shape = burn::tensor::Shape::new([self.len()]);
let noise = Tensor::<B, 1>::random(
shape,
burn::tensor::Distribution::Normal(0.0, 1.0),
&B::Device::default(),
);
self.inplace(|_| noise);
}
fn write_to_slice(&self, out: &mut [Self::Scalar]) {
let data = self.to_data();
let slice = data.as_slice().expect("Tensor data expected to be dense");
assert_eq!(
out.len(),
slice.len(),
"write_to_slice called with mismatched buffer length"
);
out.copy_from_slice(slice);
}
fn read_from_slice(&mut self, input: &[Self::Scalar]) {
assert_eq!(
input.len(),
self.len(),
"read_from_slice called with mismatched buffer length"
);
let td = burn::tensor::TensorData::new(input.to_vec(), [self.len()]);
let updated = Tensor::<B, 1>::from_data(td, &B::Device::default());
self.inplace(|_| updated);
}
}
impl<T, B> EuclideanVector for Tensor<B, 2>
where
T: Float + Element + ElementConversion + SampleUniform + FromPrimitive + Copy,
B: Backend<FloatElem = T>,
StandardNormal: RandDistribution<T>,
{
type Scalar = T;
fn len(&self) -> usize {
self.dims()[0] * self.dims()[1]
}
fn zeros_like(&self) -> Self {
Tensor::<B, 2>::zeros_like(self)
}
fn fill_zero(&mut self) {
let zeros = Tensor::<B, 2>::zeros_like(self);
self.inplace(|_| zeros.clone());
}
fn assign(&mut self, other: &Self) {
self.inplace(|_| other.clone());
}
fn add_assign(&mut self, other: &Self) {
self.inplace(|x| x.add(other.clone()));
}
fn sub_assign(&mut self, other: &Self) {
self.inplace(|x| x.sub(other.clone()));
}
fn add_scaled_assign(&mut self, other: &Self, alpha: Self::Scalar) {
self.inplace(|x| x.add(other.clone().mul_scalar(alpha)));
}
fn scale_assign(&mut self, alpha: Self::Scalar) {
self.inplace(|x| x.mul_scalar(alpha));
}
fn scale_diag_assign(&mut self, diag: &[Self::Scalar]) {
let dims = self.dims();
let dim = dims[1];
assert_eq!(diag.len(), dim, "scale_diag_assign dimension mismatch");
let scale = expand_diag::<T, B>(diag, dims[0]);
self.inplace(|x| x.mul(scale));
}
fn add_diag_scaled_assign(
&mut self,
other: &Self,
diag: &[Self::Scalar],
alpha: Self::Scalar,
) {
let dims = self.dims();
let dim = dims[1];
assert_eq!(diag.len(), dim, "add_diag_scaled_assign dimension mismatch");
let scale = expand_diag::<T, B>(diag, dims[0]).mul_scalar(alpha);
self.inplace(|x| x.add(other.clone().mul(scale)));
}
fn dot(&self, other: &Self) -> Self::Scalar {
self.clone().mul(other.clone()).sum().into_scalar()
}
fn quad_form_diag(&self, diag: &[Self::Scalar]) -> Self::Scalar {
let dims = self.dims();
let dim = dims[1];
assert_eq!(diag.len(), dim, "quad_form_diag dimension mismatch");
let scale = expand_diag::<T, B>(diag, dims[0]);
self.clone()
.mul(self.clone())
.mul(scale)
.sum()
.into_scalar()
}
fn fill_standard_normal(&mut self, _rng: &mut impl Rng)
where
StandardNormal: RandDistribution<Self::Scalar>,
{
let shape = burn::tensor::Shape::new(self.dims());
let noise = Tensor::<B, 2>::random(
shape,
burn::tensor::Distribution::Normal(0.0, 1.0),
&B::Device::default(),
);
self.inplace(|_| noise);
}
fn write_to_slice(&self, out: &mut [Self::Scalar]) {
let data = self.to_data();
let slice = data.as_slice().expect("Tensor data expected to be dense");
assert_eq!(
out.len(),
slice.len(),
"write_to_slice called with mismatched buffer length"
);
out.copy_from_slice(slice);
}
fn read_from_slice(&mut self, input: &[Self::Scalar]) {
let dims = self.dims();
let expected = dims[0] * dims[1];
assert_eq!(
input.len(),
expected,
"read_from_slice called with mismatched buffer length"
);
let td = burn::tensor::TensorData::new(input.to_vec(), dims);
let updated = Tensor::<B, 2>::from_data(td, &B::Device::default());
self.inplace(|_| updated);
}
}
use super::BatchVector;
impl<T, B> BatchVector for Tensor<B, 2>
where
T: Float + Element + ElementConversion + SampleUniform + FromPrimitive + Copy,
B: Backend<FloatElem = T>,
StandardNormal: RandDistribution<T>,
{
type Energy = Tensor<B, 1>; type Mask = Tensor<B, 1, burn::tensor::Bool>;
fn n_chains(&self) -> usize {
self.dims()[0]
}
fn dim_per_chain(&self) -> usize {
self.dims()[1]
}
fn kinetic_energy(&self) -> Tensor<B, 1> {
self.clone()
.mul(self.clone())
.sum_dim(1)
.squeeze(1)
.mul_scalar(T::from(0.5).unwrap())
}
fn kinetic_energy_diag(&self, inv_diag: &[Self::Scalar]) -> Tensor<B, 1> {
assert_eq!(
inv_diag.len(),
self.dims()[1],
"kinetic_energy_diag dimension mismatch"
);
let n_chains = self.dims()[0];
let inv = expand_diag::<T, B>(inv_diag, n_chains);
self.clone()
.mul(self.clone())
.mul(inv)
.sum_dim(1)
.squeeze(1)
.mul_scalar(T::from(0.5).unwrap())
}
fn masked_assign(&mut self, other: &Self, mask: &Tensor<B, 1, burn::tensor::Bool>) {
let n_chains = self.dims()[0];
let dim = self.dims()[1];
let mask_2d: Tensor<B, 2, burn::tensor::Bool> = mask.clone().unsqueeze_dim(1);
let mask_expanded = mask_2d.expand([n_chains, dim]);
self.inplace(|x| x.clone().mask_where(mask_expanded, other.clone()));
}
fn fill_random_normal(&mut self, _rng: &mut impl Rng)
where
StandardNormal: RandDistribution<Self::Scalar>,
{
let shape = burn::tensor::Shape::new(self.dims());
let noise = Tensor::<B, 2>::random(
shape,
burn::tensor::Distribution::Normal(0.0, 1.0),
&B::Device::default(),
);
self.inplace(|_| noise);
}
fn sample_uniform(&self, _rng: &mut impl Rng) -> Tensor<B, 1>
where
StandardNormal: RandDistribution<Self::Scalar>,
{
let n_chains = self.dims()[0];
Tensor::<B, 1>::random(
burn::tensor::Shape::new([n_chains]),
burn::tensor::Distribution::Uniform(0.0, 1.0),
&B::Device::default(),
)
}
fn energy_sub(a: &Tensor<B, 1>, b: &Tensor<B, 1>) -> Tensor<B, 1> {
a.clone().sub(b.clone())
}
fn energy_add(a: &Tensor<B, 1>, b: &Tensor<B, 1>) -> Tensor<B, 1> {
a.clone().add(b.clone())
}
fn energy_neg(a: &Tensor<B, 1>) -> Tensor<B, 1> {
a.clone().neg()
}
fn energy_ln(a: &Tensor<B, 1>) -> Tensor<B, 1> {
a.clone().log()
}
fn mean_acceptance(log_accept: &Tensor<B, 1>) -> T {
let data = log_accept.to_data();
let values = data
.as_slice::<T>()
.expect("Expected batched log acceptance tensor to be dense");
let total = values.iter().copied().fold(T::zero(), |sum, value| {
if !value.is_finite() {
sum
} else if value < T::zero() {
sum + value.exp()
} else {
sum + T::one()
}
});
total / T::from_usize(values.len()).unwrap()
}
fn accept_mask(
log_accept: &Tensor<B, 1>,
ln_u: &Tensor<B, 1>,
) -> Tensor<B, 1, burn::tensor::Bool> {
log_accept.clone().greater_equal(ln_u.clone())
}
}
}