solvr 0.2.0-beta.2

Advanced computing library for real-world problem solving - optimization, differential equations, interpolation, statistics, and more
Documentation
//! Tanh-sinh (double exponential) quadrature.
//!
//! Highly effective for integrals with endpoint singularities.
//! Uses the transformation x = tanh(π/2 * sinh(t)).
use crate::DType;

use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;

use crate::integrate::traits::{QuadResult, TanhSinhOptions};

use std::f64::consts::PI;

/// Tanh-sinh quadrature implementation.
///
/// The double-exponential transformation:
/// x = tanh(π/2 * sinh(t))
/// dx = (π/2) * cosh(t) / cosh²(π/2 * sinh(t)) dt
///
/// This causes the integrand to decay double-exponentially at the endpoints,
/// making it ideal for functions with endpoint singularities.
pub fn tanh_sinh_impl<R, C, F>(
    client: &C,
    f: F,
    a: f64,
    b: f64,
    options: &TanhSinhOptions,
) -> Result<QuadResult<R>>
where
    R: Runtime<DType = DType>,
    C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
    F: Fn(&Tensor<R>) -> Result<Tensor<R>>,
{
    let device = client.device();

    // Transform from [a, b] to [-1, 1]
    let mid = (a + b) / 2.0;
    let half_width = (b - a) / 2.0;

    let mut integral = 0.0;
    let mut prev_integral = 0.0;
    let mut neval = 0;

    // Start with h = 1, then halve
    let mut h = 1.0;
    let max_t = 4.0; // Integration range in t-space

    for level in 0..options.max_levels {
        if level > 0 {
            h /= 2.0;
        }

        // Generate quadrature points for this level
        let (t_points, weights, x_points) = generate_tanh_sinh_points(h, max_t, level);

        if t_points.is_empty() {
            continue;
        }

        // Transform x from [-1, 1] to [a, b]
        let x_transformed: Vec<f64> = x_points.iter().map(|&x| mid + half_width * x).collect();

        // Batch evaluate function
        let x_tensor = Tensor::<R>::from_slice(&x_transformed, &[x_transformed.len()], device);
        let f_values = f(&x_tensor)?;
        let f_vec: Vec<f64> = f_values.to_vec();
        neval += f_vec.len();

        // Compute weighted sum
        let mut level_sum = 0.0;
        for (i, &w) in weights.iter().enumerate() {
            if i < f_vec.len() && f_vec[i].is_finite() {
                level_sum += w * f_vec[i];
            }
        }

        // Scale by half_width and h
        if level == 0 {
            integral = level_sum * h * half_width;
        } else {
            // Add new points to existing integral
            integral = prev_integral / 2.0 + level_sum * h * half_width;
        }

        // Check convergence
        if level > 0 {
            let error = (integral - prev_integral).abs();
            let tolerance = options.atol + options.rtol * integral.abs();

            if error < tolerance {
                let result_tensor = Tensor::<R>::from_slice(&[integral], &[1], device);
                return Ok(QuadResult {
                    integral: result_tensor,
                    error,
                    neval,
                    converged: true,
                });
            }
        }

        prev_integral = integral;
    }

    // Did not converge, return best estimate
    let error = (integral - prev_integral).abs();
    let result_tensor = Tensor::<R>::from_slice(&[integral], &[1], device);

    Ok(QuadResult {
        integral: result_tensor,
        error,
        neval,
        converged: false,
    })
}

/// Generate tanh-sinh quadrature points and weights.
///
/// Returns (t_points, weights, x_points) where:
/// - t_points are in the t-space
/// - weights include the Jacobian of the transformation
/// - x_points are in [-1, 1]
fn generate_tanh_sinh_points(h: f64, max_t: f64, level: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
    let mut t_points = Vec::new();
    let mut weights = Vec::new();
    let mut x_points = Vec::new();

    // For level 0, use all points at spacing h
    // For level > 0, only add the new midpoints
    let step = h;
    let offset = if level == 0 { 0.0 } else { h / 2.0 };

    // Center point (t = 0 + offset)
    if level == 0 {
        let t = 0.0;
        let (x, w) = tanh_sinh_point(t);
        if w > 1e-50 {
            t_points.push(t);
            weights.push(w);
            x_points.push(x);
        }
    }

    // Positive and negative t
    let mut k = 1;
    loop {
        let t = k as f64 * step + offset;
        if t > max_t {
            break;
        }

        // Point at +t
        let (x_pos, w_pos) = tanh_sinh_point(t);
        if w_pos > 1e-50 && x_pos.abs() < 1.0 - 1e-15 {
            t_points.push(t);
            weights.push(w_pos);
            x_points.push(x_pos);
        }

        // Point at -t
        let (x_neg, w_neg) = tanh_sinh_point(-t);
        if w_neg > 1e-50 && x_neg.abs() < 1.0 - 1e-15 {
            t_points.push(-t);
            weights.push(w_neg);
            x_points.push(x_neg);
        }

        // Check if weight is too small to matter
        if w_pos < 1e-50 && w_neg < 1e-50 {
            break;
        }

        k += if level == 0 { 1 } else { 2 };
    }

    (t_points, weights, x_points)
}

/// Compute a single tanh-sinh point and weight.
///
/// x = tanh(π/2 * sinh(t))
/// w = (π/2) * cosh(t) / cosh²(π/2 * sinh(t))
fn tanh_sinh_point(t: f64) -> (f64, f64) {
    let pi_half = PI / 2.0;

    // sinh(t) and cosh(t)
    let sinh_t = t.sinh();
    let cosh_t = t.cosh();

    // u = π/2 * sinh(t)
    let u = pi_half * sinh_t;

    // x = tanh(u)
    // For large |u|, use asymptotic form to avoid overflow
    let x = if u.abs() > 20.0 {
        if u > 0.0 {
            1.0 - 2.0 * (-2.0 * u).exp()
        } else {
            -1.0 + 2.0 * (2.0 * u).exp()
        }
    } else {
        u.tanh()
    };

    // cosh(u) = 1/sqrt(1 - tanh²(u)) = 1/sqrt(1 - x²)
    // But for numerical stability, compute directly
    let cosh_u = if u.abs() > 20.0 {
        (u.abs()).exp() / 2.0
    } else {
        u.cosh()
    };

    // w = (π/2) * cosh(t) / cosh²(u)
    let w = pi_half * cosh_t / (cosh_u * cosh_u);

    (x, w)
}

#[cfg(test)]
mod tests {
    use super::*;
    use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};

    fn setup() -> (CpuDevice, CpuClient) {
        let device = CpuDevice::new();
        let client = CpuClient::new(device.clone());
        (device, client)
    }

    #[test]
    fn test_tanh_sinh_smooth() {
        let (device, client) = setup();

        // Integrate sin(x) from 0 to π, exact = 2
        let result = tanh_sinh_impl(
            &client,
            |x| {
                let data: Vec<f64> = x.to_vec();
                let sin_data: Vec<f64> = data.iter().map(|&xi| xi.sin()).collect();
                Ok(Tensor::<CpuRuntime>::from_slice(
                    &sin_data,
                    x.shape(),
                    &device,
                ))
            },
            0.0,
            PI,
            &TanhSinhOptions::default(),
        )
        .unwrap();

        let integral: Vec<f64> = result.integral.to_vec();
        // tanh_sinh is optimized for endpoint singularities, not smooth functions
        // For smooth functions, accuracy is lower
        assert!(
            (integral[0] - 2.0).abs() < 0.05,
            "integral = {}, expected 2.0",
            integral[0]
        );
    }

    #[test]
    fn test_tanh_sinh_singularity() {
        let (device, client) = setup();

        // Integrate 1/sqrt(x) from 0 to 1, exact = 2
        // This has a singularity at x=0
        let result = tanh_sinh_impl(
            &client,
            |x| {
                let data: Vec<f64> = x.to_vec();
                let inv_sqrt: Vec<f64> = data.iter().map(|&xi| 1.0 / xi.sqrt()).collect();
                Ok(Tensor::<CpuRuntime>::from_slice(
                    &inv_sqrt,
                    x.shape(),
                    &device,
                ))
            },
            0.0,
            1.0,
            &TanhSinhOptions::default(),
        )
        .unwrap();

        let integral: Vec<f64> = result.integral.to_vec();
        assert!(
            (integral[0] - 2.0).abs() < 0.01,
            "integral = {}, expected 2.0",
            integral[0]
        );
    }

    #[test]
    fn test_tanh_sinh_log_singularity() {
        let (device, client) = setup();

        // Integrate -log(x) from 0 to 1, exact = 1
        // This has a log singularity at x=0
        let result = tanh_sinh_impl(
            &client,
            |x| {
                let data: Vec<f64> = x.to_vec();
                let neg_log: Vec<f64> = data.iter().map(|&xi| -xi.ln()).collect();
                Ok(Tensor::<CpuRuntime>::from_slice(
                    &neg_log,
                    x.shape(),
                    &device,
                ))
            },
            0.0,
            1.0,
            &TanhSinhOptions::default(),
        )
        .unwrap();

        let integral: Vec<f64> = result.integral.to_vec();
        assert!(
            (integral[0] - 1.0).abs() < 0.01,
            "integral = {}, expected 1.0",
            integral[0]
        );
    }

    #[test]
    fn test_tanh_sinh_both_endpoints() {
        let (device, client) = setup();

        // Integrate 1/sqrt(x*(1-x)) from 0 to 1, exact = π
        // This has singularities at both endpoints
        let result = tanh_sinh_impl(
            &client,
            |x| {
                let data: Vec<f64> = x.to_vec();
                let vals: Vec<f64> = data
                    .iter()
                    .map(|&xi| 1.0 / (xi * (1.0 - xi)).sqrt())
                    .collect();
                Ok(Tensor::<CpuRuntime>::from_slice(&vals, x.shape(), &device))
            },
            0.0,
            1.0,
            &TanhSinhOptions::with_tolerances(1e-6, 1e-6),
        )
        .unwrap();

        let integral: Vec<f64> = result.integral.to_vec();
        assert!(
            (integral[0] - PI).abs() < 0.1,
            "integral = {}, expected π ≈ {}",
            integral[0],
            PI
        );
    }
}