use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use sprs::{TriMat};
use crate::spatial_hash::SpatialHash;
pub trait Wendland {
fn msq_order(&self) -> i32;
fn read_ndim(&self) -> i32;
fn calc_msq_diff_int(ndim: i32, order: i32) -> i32 {
(ndim / 2) + order + 1
}
fn calc_msq_diff_power(order: i32, msq_diff_int: i32) -> i32 {
order + msq_diff_int}
fn eval_sgl(&self, r: f64, r2: f64) -> f64;
}
pub struct WendlandBasisMSQ0{
ndim: i32,
#[allow(unused)]
msq_diff_int: i32,
msq_diff_power: i32,
}
impl WendlandBasisMSQ0 {
pub fn new(ndim: usize, order: i32) -> Self {
let msq_diff_int = Self::calc_msq_diff_int(ndim as i32, order);
Self {
ndim: ndim as i32,
msq_diff_int,
msq_diff_power: Self::calc_msq_diff_power(order, msq_diff_int),
}
}
}
impl Wendland for WendlandBasisMSQ0 {
fn msq_order(&self) -> i32 {0}
fn read_ndim(&self) -> i32 {self.ndim.clone()}
fn eval_sgl(&self, r: f64, _r2: f64) -> f64 {
if r >= 1.0 {return 0.0;}
(1. - r).max(0.0).powi(self.msq_diff_power)
}
}
pub struct WendlandBasisMSQ1{
ndim: i32,
#[allow(unused)]
msq_diff_int: i32,
msq_diff_power: i32,
}
impl WendlandBasisMSQ1 {
pub fn new(ndim: usize, order: i32) -> Self {
let msq_diff_int = Self::calc_msq_diff_int(ndim as i32, order);
Self {
ndim: ndim as i32,
msq_diff_int,
msq_diff_power: Self::calc_msq_diff_power(order, msq_diff_int)
}
}
}
impl Wendland for WendlandBasisMSQ1 {
fn msq_order(&self) -> i32 {1}
fn read_ndim(&self) -> i32 {self.ndim.clone()}
fn eval_sgl(&self, r: f64, _r2: f64) -> f64 {
if r >= 1.0 {return 0.0;}
(1. - r).max(0.0).powi(self.msq_diff_power) *
((f64::from(self.msq_diff_power) * r) + 1.0)
}
}
pub struct WendlandBasisMSQ2{
ndim: i32,
#[allow(unused)]
msq_diff_int: i32,
msq_diff_power: i32,
j_poly_r1: f64,
j_poly_r2: f64,
}
impl WendlandBasisMSQ2 {
pub fn new(ndim: usize, order: i32) -> Self {
let msq_diff_int = Self::calc_msq_diff_int(ndim as i32, order);
Self {
ndim: ndim as i32,
msq_diff_int,
msq_diff_power: Self::calc_msq_diff_power(order, msq_diff_int),
j_poly_r1: Self::calc_j_poly_r1(msq_diff_int) as f64,
j_poly_r2: Self::calc_j_poly_r2(msq_diff_int) as f64,
}
}
fn calc_j_poly_r1(msq_diff_int: i32) -> f64 {
f64::from((3 * msq_diff_int) + 6)
}
fn calc_j_poly_r2(msq_diff_int: i32) -> f64 {
f64::from(msq_diff_int.pow(2) + 4*msq_diff_int + 3)
}
}
impl Wendland for WendlandBasisMSQ2 {
fn msq_order(&self) -> i32 {2}
fn read_ndim(&self) -> i32 {self.ndim.clone()}
fn eval_sgl(&self, r: f64, r2: f64) -> f64 {
if r >= 1.0 {return 0.0;}
(1. - r).max(0.0).powi(self.msq_diff_power) * (
(self.j_poly_r2 * r2) +
(self.j_poly_r1 * r) +
3.0
) / 3.0
}
}
pub struct WendlandBasisMSQ3{
ndim: i32,
#[allow(unused)]
msq_diff_int: i32,
msq_diff_power: i32,
j_poly_r1: f64,
j_poly_r2: f64,
j_poly_r3: f64,
}
impl WendlandBasisMSQ3 {
pub fn new(ndim: usize, order: i32) -> Self {
let msq_diff_int = Self::calc_msq_diff_int(ndim as i32, order);
Self {
ndim: ndim as i32,
msq_diff_int,
msq_diff_power: Self::calc_msq_diff_power(order, msq_diff_int),
j_poly_r1: Self::calc_j_poly_r1(msq_diff_int) as f64,
j_poly_r2: Self::calc_j_poly_r2(msq_diff_int) as f64,
j_poly_r3: Self::calc_j_poly_r3(msq_diff_int) as f64,
}
}
fn calc_j_poly_r1(msq_diff_int: i32) -> f64 {
f64::from((15 * msq_diff_int) + 45)
}
fn calc_j_poly_r2(msq_diff_int: i32) -> f64 {
f64::from(6 * msq_diff_int.pow(2) + 36*msq_diff_int + 45)
}
fn calc_j_poly_r3(msq_diff_int: i32) -> f64 {
f64::from(
msq_diff_int.pow(3) +
9 * msq_diff_int.pow(2) +
23*msq_diff_int + 15
)
}
}
impl Wendland for WendlandBasisMSQ3 {
fn msq_order(&self) -> i32 {3}
fn read_ndim(&self) -> i32 {self.ndim.clone()}
fn eval_sgl(&self, r: f64, r2: f64) -> f64 {
if r >= 1.0 {return 0.0;}
let r3 = r2 * r;
(1. - r).max(0.0).powi(self.msq_diff_power) * (
(self.j_poly_r3 * r3) +
(self.j_poly_r2 * r2) +
(self.j_poly_r1 * r) +
15.
) / 15.
}
}
pub enum WendlandKernel {
Q0(WendlandBasisMSQ0),
Q1(WendlandBasisMSQ1),
Q2(WendlandBasisMSQ2),
Q3(WendlandBasisMSQ3),
}
impl WendlandKernel {
pub fn new(ndim: usize, msq_order: i32) -> Self {
match msq_order {
0 => WendlandKernel::Q0(WendlandBasisMSQ0::new(ndim, msq_order)),
1 => WendlandKernel::Q1(WendlandBasisMSQ1::new(ndim, msq_order)),
2 => WendlandKernel::Q2(WendlandBasisMSQ2::new(ndim, msq_order)),
3 => WendlandKernel::Q3(WendlandBasisMSQ3::new(ndim, msq_order)),
_ => unreachable!(),
}
}
}
impl Wendland for WendlandKernel {
fn msq_order(&self) -> i32 {
match self {
WendlandKernel::Q0(k) => k.msq_order(),
WendlandKernel::Q1(k) => k.msq_order(),
WendlandKernel::Q2(k) => k.msq_order(),
WendlandKernel::Q3(k) => k.msq_order(),
}
}
fn read_ndim(&self) -> i32 {
match self {
WendlandKernel::Q0(k) => k.read_ndim(),
WendlandKernel::Q1(k) => k.read_ndim(),
WendlandKernel::Q2(k) => k.read_ndim(),
WendlandKernel::Q3(k) => k.read_ndim(),
}
}
fn eval_sgl(&self, r:f64, r2:f64) -> f64 {
match self {
WendlandKernel::Q0(k) => k.eval_sgl(r, r2),
WendlandKernel::Q1(k) => k.eval_sgl(r, r2),
WendlandKernel::Q2(k) => k.eval_sgl(r, r2),
WendlandKernel::Q3(k) => k.eval_sgl(r, r2),
}
}
}
pub struct Kernel {
basis: WendlandKernel,
scale: Array1<f64>,
whitenoise: f64
}
impl Kernel {
pub fn new(
ndim : usize,
msq_order : i32,
scale_view : ArrayView1<f64>,
whitenoise : f64,
) -> Self {
let basis = WendlandKernel::new(ndim, msq_order);
let scale = scale_view.as_standard_layout().into_owned();
Self {
basis,
scale,
whitenoise,
}
}
pub fn eval_sgl(
&self, r:f64, r2:f64
) -> f64 {
self.basis.eval_sgl(r, r2)
}
pub fn msq_order(
&self,
) -> i32 {
self.basis.msq_order()
}
pub fn whitenoise(
&self,
) -> f64 {
self.whitenoise
}
pub fn scale(
&self,
) -> Array1<f64> {
self.scale.clone()
}
pub fn scale_sgl(&self, sample_point: ArrayView1<f64>) -> Array1<f64> {
&sample_point / &self.scale
}
pub fn scale_arr(&self, sample_points: ArrayView2<f64>) -> Array2<f64> {
let mut scaled_points = sample_points.as_standard_layout().into_owned();
scaled_points /= &self.scale;
scaled_points
}
pub fn naive_kernel_construction(
&self,
scaled_training_points: ArrayView2<f64>,
training_error: ArrayView1<f64>,
) -> TriMat<f64> {
let nsample = scaled_training_points.nrows();
let mut training_kernel = TriMat::<f64>::new((nsample,nsample));
for i in 0..nsample {
let diag_cell_value = training_error[i].powi(2) + self.whitenoise + 1e-10;
training_kernel.add_triplet(i,i,1.0 + diag_cell_value);
let xi = scaled_training_points.row(i);
for j in 0..i {
let xj = scaled_training_points.row(j);
let r2: f64 = xi.iter()
.zip(xj.iter())
.map(|(ki,kj)| (ki - kj).powi(2))
.sum();
if r2 >= 1.0 {continue;}
let r = r2.sqrt();
let val: f64 = self.eval_sgl(r, r2);
training_kernel.add_triplet(i,j,val);
training_kernel.add_triplet(j,i,val);
}
}
training_kernel
}
pub fn hashed_kernel_construction(
&self,
scaled_training_points: ArrayView2<f64>,
training_error: ArrayView1<f64>,
training_hash: &SpatialHash,
) -> TriMat<f64> {
let nsample = scaled_training_points.nrows();
let mut training_kernel = TriMat::<f64>::new((nsample,nsample));
for i in 0..nsample {
let diag_cell_value = training_error[i].powi(2) + self.whitenoise + 1e-10;
training_kernel.add_triplet(i,i,1.0 + diag_cell_value);
let xi = scaled_training_points.row(i);
training_hash.for_each_neighbor(xi.as_slice().unwrap(), |j| {
if j <= i {return;}
let xj = scaled_training_points.row(j);
let r2: f64 =
xi.iter().zip(
xj.iter())
.map(|(ki,kj)| (ki-kj).powi(2))
.sum();
if r2 >= 1.0 {return;}
let r = r2.sqrt();
let val: f64 = self.eval_sgl(r, r2);
training_kernel.add_triplet(i,j,val);
training_kernel.add_triplet(j,i,val);
});
}
training_kernel
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Instant;
use crate::spatial_hash::SpatialHash;
use ndarray::{arr1, Array2};
#[test]
fn timing_kernel_build_naive_vs_hashed() {
let nx = 150;
let ny = 150;
let h = 0.2; let ndim = 2;
let scale = arr1(&vec![1.;ndim]);
let mut x_train_vec: Vec<Vec<f64>> = Vec::new();
for i in 0..nx {
for j in 0..ny {
x_train_vec.push(vec![i as f64 * h, j as f64 * h]);
}
}
let x_train_flat: Vec<f64> = x_train_vec.iter()
.flat_map(|row| row.iter().copied())
.collect();
let nsample = x_train_vec.len();
println!("points: {}", nsample);
let x_train = Array2::from_shape_vec((nsample, ndim), x_train_flat).unwrap();
let y_err = arr1(&vec![0.0; nsample]);
let whitenoise = 1e-10;
let msq_order = 2;
let mykernel = Kernel::new(
ndim,
msq_order,
scale.view(),
whitenoise
);
let hash = SpatialHash::build(&x_train_vec, 1.0);
let t0 = Instant::now();
let kn = mykernel.naive_kernel_construction(
x_train.view(),
y_err.view(),
);
let t_naive = t0.elapsed();
let t1 = Instant::now();
let kh = mykernel.hashed_kernel_construction(
x_train.view(),
y_err.view(),
&hash,
);
let t_hash = t1.elapsed();
let nnz_naive = kn.nnz();
let nnz_hash = kh.nnz();
println!("naive build : {:?}", t_naive);
println!("hashed build: {:?}", t_hash);
println!("speedup : {:.2}x", t_naive.as_secs_f64() / t_hash.as_secs_f64());
println!("nnz naive : {}", nnz_naive);
println!("nnz hash : {}", nnz_hash);
assert_eq!(
nnz_naive, nnz_hash,
"hashed kernel missed or added interactions"
);
}
}