use crate::gauss::Rule;
use crate::interpolation2d::Interpolate2D;
use crate::kernel::{AbstractKernel, CentrosymmKernel, KernelProperties, SymmetryType};
use crate::numeric::CustomNumeric;
use mdarray::DTensor;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct DiscretizedKernel<T> {
pub matrix: DTensor<T, 2>,
pub gauss_x: Rule<T>,
pub gauss_y: Rule<T>,
pub segments_x: Vec<T>,
pub segments_y: Vec<T>,
}
impl<T: CustomNumeric + Clone> DiscretizedKernel<T> {
pub fn new(
matrix: DTensor<T, 2>,
gauss_x: Rule<T>,
gauss_y: Rule<T>,
segments_x: Vec<T>,
segments_y: Vec<T>,
) -> Self {
Self {
matrix,
gauss_x,
gauss_y,
segments_x,
segments_y,
}
}
pub fn new_legacy(matrix: DTensor<T, 2>, gauss_x: Rule<T>, gauss_y: Rule<T>) -> Self {
Self {
matrix,
gauss_x: gauss_x.clone(),
gauss_y: gauss_y.clone(),
segments_x: vec![gauss_x.a, gauss_x.b],
segments_y: vec![gauss_y.a, gauss_y.b],
}
}
pub fn is_empty(&self) -> bool {
self.matrix.is_empty()
}
pub fn nrows(&self) -> usize {
self.matrix.shape().0
}
pub fn ncols(&self) -> usize {
self.matrix.shape().1
}
pub fn iter(&self) -> impl Iterator<Item = &T> {
self.matrix.iter()
}
pub fn apply_weights_for_sve(&self) -> DTensor<T, 2> {
let mut weighted_matrix = self.matrix.clone();
let shape = *weighted_matrix.shape();
for i in 0..self.gauss_x.x.len() {
let weight_sqrt = self.gauss_x.w[i].sqrt();
for j in 0..shape.1 {
weighted_matrix[[i, j]] = weighted_matrix[[i, j]] * weight_sqrt;
}
}
for j in 0..self.gauss_y.x.len() {
let weight_sqrt = self.gauss_y.w[j].sqrt();
for i in 0..shape.0 {
weighted_matrix[[i, j]] = weighted_matrix[[i, j]] * weight_sqrt;
}
}
weighted_matrix
}
pub fn remove_weights_from_sve(&mut self) {
let shape = *self.matrix.shape();
for i in 0..self.gauss_x.x.len() {
let weight_sqrt = self.gauss_x.w[i].sqrt();
for j in 0..shape.1 {
self.matrix[[i, j]] = self.matrix[[i, j]] / weight_sqrt;
}
}
for j in 0..self.gauss_y.x.len() {
let weight_sqrt = self.gauss_y.w[j].sqrt();
for i in 0..shape.0 {
self.matrix[[i, j]] = self.matrix[[i, j]] / weight_sqrt;
}
}
}
pub fn n_gauss_x(&self) -> usize {
self.gauss_x.x.len()
}
pub fn n_gauss_y(&self) -> usize {
self.gauss_y.x.len()
}
}
pub fn matrix_from_gauss_with_segments<
T: CustomNumeric + Clone + Send + Sync,
K: CentrosymmKernel + KernelProperties,
H: crate::kernel::SVEHints<T>,
>(
kernel: &K,
gauss_x: &Rule<T>,
gauss_y: &Rule<T>,
symmetry: SymmetryType,
hints: &H,
) -> DiscretizedKernel<T> {
let segments_x = hints.segments_x();
let segments_y = hints.segments_y();
let n = gauss_x.x.len();
let m = gauss_y.x.len();
let mut result = DTensor::<T, 2>::from_elem([n, m], T::zero());
for i in 0..n {
for j in 0..m {
let x = gauss_x.x[i];
let y = gauss_y.x[j];
result[[i, j]] = kernel.compute_reduced(x, y, symmetry);
}
}
DiscretizedKernel::new(
result,
gauss_x.clone(),
gauss_y.clone(),
segments_x,
segments_y,
)
}
pub fn matrix_from_gauss<T: CustomNumeric + Clone, K: CentrosymmKernel + KernelProperties>(
kernel: &K,
gauss_x: &Rule<T>,
gauss_y: &Rule<T>,
symmetry: SymmetryType,
) -> DiscretizedKernel<T> {
let kernel_xmax = kernel.xmax();
let kernel_ymax = kernel.ymax();
let tolerance = 1e-12;
for &x in &gauss_x.x {
let x_f64 = x.to_f64();
assert!(
x_f64 >= -tolerance && x_f64 <= kernel_xmax + tolerance,
"Gauss x point {} is outside [0, {}]",
x_f64,
kernel_xmax
);
}
for &y in &gauss_y.x {
let y_f64 = y.to_f64();
assert!(
y_f64 >= -tolerance && y_f64 <= kernel_ymax + tolerance,
"Gauss y point {} is outside [0, {}]",
y_f64,
kernel_ymax
);
}
let n = gauss_x.x.len();
let m = gauss_y.x.len();
let mut result = DTensor::<T, 2>::from_elem([n, m], T::zero());
for i in 0..n {
for j in 0..m {
let x = gauss_x.x[i];
let y = gauss_y.x[j];
result[[i, j]] = kernel.compute_reduced(x, y, symmetry);
}
}
DiscretizedKernel::new_legacy(result, gauss_x.clone(), gauss_y.clone())
}
pub fn matrix_from_gauss_noncentrosymmetric<
T: CustomNumeric + Clone + Send + Sync,
K: AbstractKernel + KernelProperties,
H: crate::kernel::SVEHints<T>,
>(
kernel: &K,
gauss_x: &Rule<T>,
gauss_y: &Rule<T>,
hints: &H,
) -> DiscretizedKernel<T> {
let segments_x = hints.segments_x();
let segments_y = hints.segments_y();
let n = gauss_x.x.len();
let m = gauss_y.x.len();
let mut result = DTensor::<T, 2>::from_elem([n, m], T::zero());
for i in 0..n {
for j in 0..m {
let x = gauss_x.x[i];
let y = gauss_y.x[j];
result[[i, j]] = kernel.compute(x, y);
}
}
DiscretizedKernel::new(
result,
gauss_x.clone(),
gauss_y.clone(),
segments_x,
segments_y,
)
}
#[derive(Debug, Clone)]
pub struct InterpolatedKernel<T> {
pub segments_x: Vec<T>,
pub segments_y: Vec<T>,
pub domain_x: (T, T),
pub domain_y: (T, T),
pub interpolators: DTensor<Interpolate2D<T>, 2>,
pub n_cells_x: usize,
pub n_cells_y: usize,
}
impl<T: CustomNumeric + Debug + Clone + 'static> InterpolatedKernel<T> {
pub fn from_kernel_and_segments<K: CentrosymmKernel + KernelProperties>(
kernel: &K,
segments_x: Vec<T>,
segments_y: Vec<T>,
gauss_per_cell: usize,
symmetry: SymmetryType,
) -> Self {
let n_cells_x = segments_x.len() - 1;
let n_cells_y = segments_y.len() - 1;
let mut interpolators = Vec::new();
for i in 0..n_cells_x {
for j in 0..n_cells_y {
let cell_gauss_x = crate::gauss::legendre_generic::<T>(gauss_per_cell)
.reseat(segments_x[i], segments_x[i + 1]);
let cell_gauss_y = crate::gauss::legendre_generic::<T>(gauss_per_cell)
.reseat(segments_y[j], segments_y[j + 1]);
let mut cell_values =
DTensor::<T, 2>::from_elem([gauss_per_cell, gauss_per_cell], T::zero());
for k in 0..gauss_per_cell {
for l in 0..gauss_per_cell {
let x = cell_gauss_x.x[k];
let y = cell_gauss_y.x[l];
let kernel_val = kernel.compute_reduced(x, y, symmetry);
cell_values[[k, l]] = kernel_val;
}
}
interpolators.push(Interpolate2D::new(
&cell_values,
&cell_gauss_x,
&cell_gauss_y,
));
}
}
let interpolators_array =
DTensor::<Interpolate2D<T>, 2>::from_fn([n_cells_x, n_cells_y], |idx| {
interpolators[idx[0] * n_cells_y + idx[1]].clone()
});
Self {
segments_x: segments_x.clone(),
segments_y: segments_y.clone(),
domain_x: (segments_x[0], segments_x[segments_x.len() - 1]),
domain_y: (segments_y[0], segments_y[segments_y.len() - 1]),
interpolators: interpolators_array,
n_cells_x,
n_cells_y,
}
}
pub fn find_cell(&self, x: T, y: T) -> Option<(usize, usize)> {
let i = self.binary_search_segments(&self.segments_x, x)?;
let j = self.binary_search_segments(&self.segments_y, y)?;
Some((i, j))
}
fn binary_search_segments(&self, segments: &[T], value: T) -> Option<usize> {
if value < segments[0] || value > segments[segments.len() - 1] {
return None;
}
let mut left = 0;
let mut right = segments.len() - 1;
while left < right {
let mid = (left + right) / 2;
if segments[mid] <= value && value < segments[mid + 1] {
return Some(mid);
} else if value < segments[mid] {
right = mid;
} else {
left = mid + 1;
}
}
if value == segments[segments.len() - 1] {
Some(segments.len() - 2)
} else {
None
}
}
pub fn evaluate(&self, x: T, y: T) -> T {
let (i, j) = self
.find_cell(x, y)
.expect("Point is outside interpolation domain");
self.interpolators[[i, j]].evaluate(x, y)
}
pub fn domain(&self) -> ((T, T), (T, T)) {
(self.domain_x, self.domain_y)
}
pub fn n_cells_x(&self) -> usize {
self.n_cells_x
}
pub fn n_cells_y(&self) -> usize {
self.n_cells_y
}
}
#[cfg(test)]
#[path = "kernelmatrix_tests.rs"]
mod tests;