use super::utils::extend_to_full_domain;
use super::{TworkType, compute_sve};
use crate::kernel::{
CentrosymmKernel, KernelProperties, LogisticKernel, RegularizedBoseKernel, SymmetryType,
};
use crate::poly::PiecewiseLegendrePoly;
use mdarray::DTensor;
fn create_simple_poly_on_positive_domain() -> PiecewiseLegendrePoly {
let data = DTensor::<f64, 2>::from_fn([2, 1], |idx| if idx[0] == 0 { 1.0 } else { 2.0 });
let knots = vec![0.0, 1.0];
let delta_x = vec![1.0];
PiecewiseLegendrePoly::new(data, knots, 0, Some(delta_x), 0)
}
fn create_poly_with_segments() -> PiecewiseLegendrePoly {
let data_vec = [1.0, 1.5, 0.5, 1.0];
let data = DTensor::<f64, 2>::from_fn([2, 2], |idx| data_vec[idx[0] * 2 + idx[1]]);
let knots = vec![0.0, 0.5, 1.0];
let delta_x = vec![0.5, 0.5];
PiecewiseLegendrePoly::new(data, knots, 0, Some(delta_x), 0)
}
#[test]
fn test_extend_even_symmetry() {
let poly_positive = create_simple_poly_on_positive_domain();
let polys_full = extend_to_full_domain(vec![poly_positive], SymmetryType::Even, 1.0);
let poly = &polys_full[0];
for x in [0.1, 0.3, 0.5, 0.7, 0.9] {
let val_pos = poly.evaluate(x);
let val_neg = poly.evaluate(-x);
assert!(
(val_pos - val_neg).abs() < 1e-14,
"Even symmetry violated: f({}) = {}, f({}) = {}",
x,
val_pos,
-x,
val_neg
);
}
}
#[test]
fn test_extend_odd_symmetry() {
let poly_positive = create_simple_poly_on_positive_domain();
let polys_full = extend_to_full_domain(vec![poly_positive], SymmetryType::Odd, 1.0);
let poly = &polys_full[0];
for x in [0.1, 0.3, 0.5, 0.7, 0.9] {
let val_pos = poly.evaluate(x);
let val_neg = poly.evaluate(-x);
assert!(
(val_pos + val_neg).abs() < 1e-14,
"Odd symmetry violated: f({}) = {}, f({}) = {}",
x,
val_pos,
-x,
val_neg
);
}
}
#[test]
fn test_positive_domain_preserved() {
let poly_positive = create_simple_poly_on_positive_domain();
let original_values: Vec<f64> = (0..10)
.map(|i| poly_positive.evaluate(i as f64 * 0.1))
.collect();
let polys_full = extend_to_full_domain(vec![poly_positive], SymmetryType::Even, 1.0);
let poly = &polys_full[0];
let norm_factor = 1.0 / 2.0_f64.sqrt();
for (i, &expected) in original_values.iter().enumerate() {
let x = i as f64 * 0.1;
let actual = poly.evaluate(x);
let expected_normalized = expected * norm_factor;
assert!(
(actual - expected_normalized).abs() < 1e-14,
"Positive domain not preserved: f({}) = {} (expected {})",
x,
actual,
expected_normalized
);
}
}
#[test]
fn test_segment_structure() {
let poly = create_poly_with_segments();
let polys_full = extend_to_full_domain(vec![poly], SymmetryType::Even, 1.0);
let expected_knots = [-1.0, -0.5, 0.0, 0.5, 1.0];
for (i, &expected) in expected_knots.iter().enumerate() {
assert!(
(polys_full[0].knots[i] - expected).abs() < 1e-14,
"Segment {} mismatch: got {}, expected {}",
i,
polys_full[0].knots[i],
expected
);
}
}
#[test]
fn test_multiple_polynomials() {
let poly1 = create_simple_poly_on_positive_domain();
let poly2 = create_poly_with_segments();
let polys_full = extend_to_full_domain(vec![poly1, poly2], SymmetryType::Even, 1.0);
assert_eq!(polys_full.len(), 2);
for poly in &polys_full {
let val_pos = poly.evaluate(0.3);
let val_neg = poly.evaluate(-0.3);
assert!(
(val_pos - val_neg).abs() < 1e-14,
"Even symmetry violated for one of the polynomials"
);
}
}
fn test_sve_decomposition_kernel_impl<K>(kernel: K, _lambda: f64, epsilon: f64)
where
K: CentrosymmKernel + KernelProperties + Clone + 'static,
{
let sve_result = compute_sve(kernel.clone(), epsilon, None, None, TworkType::Auto);
let u_polys = sve_result.u.get_polys();
let v_polys = sve_result.v.get_polys();
let x_knots = &u_polys[0].knots;
let y_knots = &v_polys[0].knots;
let test_x: Vec<f64> = x_knots.windows(2).map(|w| 0.5 * (w[0] + w[1])).collect();
let test_y: Vec<f64> = y_knots.windows(2).map(|w| 0.5 * (w[0] + w[1])).collect();
let tolerance = epsilon * 200.0;
let mut max_error = 0.0f64;
let mut worst_x = 0.0f64;
let mut worst_y = 0.0f64;
let mut worst_direct = 0.0f64;
let mut worst_sve = 0.0f64;
let mut max_abs_direct = 0.0f64;
for &x in &test_x {
for &y in &test_y {
let k_direct = kernel.compute(x, y);
let mut k_sve = 0.0;
for l in 0..sve_result.s.len() {
let u_l_x = sve_result.u[l].evaluate(x);
let v_l_y = sve_result.v[l].evaluate(y);
k_sve += sve_result.s[l] * u_l_x * v_l_y;
}
let error = (k_direct - k_sve).abs();
if error > max_error {
max_error = error;
worst_x = x;
worst_y = y;
worst_direct = k_direct;
worst_sve = k_sve;
}
let abs_direct = k_direct.abs();
if abs_direct > max_abs_direct {
max_abs_direct = abs_direct;
}
}
}
eprintln!(
"Max SVE abs error: error={:.15e}, x={:.15e}, y={:.15e}, direct={:.15e}, sve={:.15e}, abs_tol={:.15e}, max|K|={:.15e}",
max_error, worst_x, worst_y, worst_direct, worst_sve, tolerance, max_abs_direct
);
assert!(
max_error < tolerance,
"SVE decomposition failed: max_error={:.15e} at x={}, y={}, direct={:.15e}, sve={:.15e}, abs_tol={:.15e}",
max_error,
worst_x,
worst_y,
worst_direct,
worst_sve,
tolerance
);
}
fn test_sve_decomposition_kernel_impl_with_tolerance<K>(
kernel: K,
_lambda: f64,
epsilon: f64,
tolerance: f64,
) where
K: CentrosymmKernel + KernelProperties + Clone + 'static,
{
let sve_result = compute_sve(kernel.clone(), epsilon, None, None, TworkType::Auto);
let u_polys = sve_result.u.get_polys();
let v_polys = sve_result.v.get_polys();
let x_knots = &u_polys[0].knots;
let y_knots = &v_polys[0].knots;
let test_x: Vec<f64> = x_knots.windows(2).map(|w| 0.5 * (w[0] + w[1])).collect();
let test_y: Vec<f64> = y_knots.windows(2).map(|w| 0.5 * (w[0] + w[1])).collect();
let mut max_error = 0.0f64;
let mut worst_x = 0.0f64;
let mut worst_y = 0.0f64;
let mut worst_direct = 0.0f64;
let mut worst_sve = 0.0f64;
let mut max_abs_direct = 0.0f64;
for &x in &test_x {
for &y in &test_y {
let k_direct = kernel.compute(x, y);
let mut k_sve = 0.0;
for l in 0..sve_result.s.len() {
let u_l_x = sve_result.u[l].evaluate(x);
let v_l_y = sve_result.v[l].evaluate(y);
k_sve += sve_result.s[l] * u_l_x * v_l_y;
}
let error = (k_direct - k_sve).abs();
if error > max_error {
max_error = error;
worst_x = x;
worst_y = y;
worst_direct = k_direct;
worst_sve = k_sve;
}
let abs_direct = k_direct.abs();
if abs_direct > max_abs_direct {
max_abs_direct = abs_direct;
}
}
}
eprintln!(
"Max SVE abs error: error={:.15e}, x={:.15e}, y={:.15e}, direct={:.15e}, sve={:.15e}, abs_tol={:.15e}, max|K|={:.15e}",
max_error, worst_x, worst_y, worst_direct, worst_sve, tolerance, max_abs_direct
);
assert!(
max_error < tolerance,
"SVE decomposition failed: max_error={:.15e} at x={}, y={}, direct={:.15e}, sve={:.15e}, abs_tol={:.15e}",
max_error,
worst_x,
worst_y,
worst_direct,
worst_sve,
tolerance
);
}
#[test]
fn test_sve_decomposition_logistic_kernel() {
test_sve_decomposition_kernel_impl(LogisticKernel::new(100.0), 100.0, 1e-6);
test_sve_decomposition_kernel_impl(LogisticKernel::new(1e5), 1e5, 1e-12);
}
#[test]
fn test_sve_decomposition_regularized_bose_kernel() {
test_sve_decomposition_kernel_impl_with_tolerance(
RegularizedBoseKernel::new(100.0),
100.0,
1e-6,
1e-6 * 1.0,
);
}