use crate::DType;
use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::integrate::traits::{QuadResult, TanhSinhOptions};
use std::f64::consts::PI;
pub fn tanh_sinh_impl<R, C, F>(
client: &C,
f: F,
a: f64,
b: f64,
options: &TanhSinhOptions,
) -> Result<QuadResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
F: Fn(&Tensor<R>) -> Result<Tensor<R>>,
{
let device = client.device();
let mid = (a + b) / 2.0;
let half_width = (b - a) / 2.0;
let mut integral = 0.0;
let mut prev_integral = 0.0;
let mut neval = 0;
let mut h = 1.0;
let max_t = 4.0;
for level in 0..options.max_levels {
if level > 0 {
h /= 2.0;
}
let (t_points, weights, x_points) = generate_tanh_sinh_points(h, max_t, level);
if t_points.is_empty() {
continue;
}
let x_transformed: Vec<f64> = x_points.iter().map(|&x| mid + half_width * x).collect();
let x_tensor = Tensor::<R>::from_slice(&x_transformed, &[x_transformed.len()], device);
let f_values = f(&x_tensor)?;
let f_vec: Vec<f64> = f_values.to_vec();
neval += f_vec.len();
let mut level_sum = 0.0;
for (i, &w) in weights.iter().enumerate() {
if i < f_vec.len() && f_vec[i].is_finite() {
level_sum += w * f_vec[i];
}
}
if level == 0 {
integral = level_sum * h * half_width;
} else {
integral = prev_integral / 2.0 + level_sum * h * half_width;
}
if level > 0 {
let error = (integral - prev_integral).abs();
let tolerance = options.atol + options.rtol * integral.abs();
if error < tolerance {
let result_tensor = Tensor::<R>::from_slice(&[integral], &[1], device);
return Ok(QuadResult {
integral: result_tensor,
error,
neval,
converged: true,
});
}
}
prev_integral = integral;
}
let error = (integral - prev_integral).abs();
let result_tensor = Tensor::<R>::from_slice(&[integral], &[1], device);
Ok(QuadResult {
integral: result_tensor,
error,
neval,
converged: false,
})
}
fn generate_tanh_sinh_points(h: f64, max_t: f64, level: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
let mut t_points = Vec::new();
let mut weights = Vec::new();
let mut x_points = Vec::new();
let step = h;
let offset = if level == 0 { 0.0 } else { h / 2.0 };
if level == 0 {
let t = 0.0;
let (x, w) = tanh_sinh_point(t);
if w > 1e-50 {
t_points.push(t);
weights.push(w);
x_points.push(x);
}
}
let mut k = 1;
loop {
let t = k as f64 * step + offset;
if t > max_t {
break;
}
let (x_pos, w_pos) = tanh_sinh_point(t);
if w_pos > 1e-50 && x_pos.abs() < 1.0 - 1e-15 {
t_points.push(t);
weights.push(w_pos);
x_points.push(x_pos);
}
let (x_neg, w_neg) = tanh_sinh_point(-t);
if w_neg > 1e-50 && x_neg.abs() < 1.0 - 1e-15 {
t_points.push(-t);
weights.push(w_neg);
x_points.push(x_neg);
}
if w_pos < 1e-50 && w_neg < 1e-50 {
break;
}
k += if level == 0 { 1 } else { 2 };
}
(t_points, weights, x_points)
}
fn tanh_sinh_point(t: f64) -> (f64, f64) {
let pi_half = PI / 2.0;
let sinh_t = t.sinh();
let cosh_t = t.cosh();
let u = pi_half * sinh_t;
let x = if u.abs() > 20.0 {
if u > 0.0 {
1.0 - 2.0 * (-2.0 * u).exp()
} else {
-1.0 + 2.0 * (2.0 * u).exp()
}
} else {
u.tanh()
};
let cosh_u = if u.abs() > 20.0 {
(u.abs()).exp() / 2.0
} else {
u.cosh()
};
let w = pi_half * cosh_t / (cosh_u * cosh_u);
(x, w)
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
(device, client)
}
#[test]
fn test_tanh_sinh_smooth() {
let (device, client) = setup();
let result = tanh_sinh_impl(
&client,
|x| {
let data: Vec<f64> = x.to_vec();
let sin_data: Vec<f64> = data.iter().map(|&xi| xi.sin()).collect();
Ok(Tensor::<CpuRuntime>::from_slice(
&sin_data,
x.shape(),
&device,
))
},
0.0,
PI,
&TanhSinhOptions::default(),
)
.unwrap();
let integral: Vec<f64> = result.integral.to_vec();
assert!(
(integral[0] - 2.0).abs() < 0.05,
"integral = {}, expected 2.0",
integral[0]
);
}
#[test]
fn test_tanh_sinh_singularity() {
let (device, client) = setup();
let result = tanh_sinh_impl(
&client,
|x| {
let data: Vec<f64> = x.to_vec();
let inv_sqrt: Vec<f64> = data.iter().map(|&xi| 1.0 / xi.sqrt()).collect();
Ok(Tensor::<CpuRuntime>::from_slice(
&inv_sqrt,
x.shape(),
&device,
))
},
0.0,
1.0,
&TanhSinhOptions::default(),
)
.unwrap();
let integral: Vec<f64> = result.integral.to_vec();
assert!(
(integral[0] - 2.0).abs() < 0.01,
"integral = {}, expected 2.0",
integral[0]
);
}
#[test]
fn test_tanh_sinh_log_singularity() {
let (device, client) = setup();
let result = tanh_sinh_impl(
&client,
|x| {
let data: Vec<f64> = x.to_vec();
let neg_log: Vec<f64> = data.iter().map(|&xi| -xi.ln()).collect();
Ok(Tensor::<CpuRuntime>::from_slice(
&neg_log,
x.shape(),
&device,
))
},
0.0,
1.0,
&TanhSinhOptions::default(),
)
.unwrap();
let integral: Vec<f64> = result.integral.to_vec();
assert!(
(integral[0] - 1.0).abs() < 0.01,
"integral = {}, expected 1.0",
integral[0]
);
}
#[test]
fn test_tanh_sinh_both_endpoints() {
let (device, client) = setup();
let result = tanh_sinh_impl(
&client,
|x| {
let data: Vec<f64> = x.to_vec();
let vals: Vec<f64> = data
.iter()
.map(|&xi| 1.0 / (xi * (1.0 - xi)).sqrt())
.collect();
Ok(Tensor::<CpuRuntime>::from_slice(&vals, x.shape(), &device))
},
0.0,
1.0,
&TanhSinhOptions::with_tolerances(1e-6, 1e-6),
)
.unwrap();
let integral: Vec<f64> = result.integral.to_vec();
assert!(
(integral[0] - PI).abs() < 0.1,
"integral = {}, expected π ≈ {}",
integral[0],
PI
);
}
}