use std::f64::consts::PI;
use std::fmt;
use crate::error::{IntegrateError, IntegrateResult};
#[derive(Clone, Debug)]
pub struct TanhSinhResult<T> {
pub integral: T,
pub error: T,
pub nfev: usize,
pub max_level: usize,
pub success: bool,
}
impl<T: fmt::Display> fmt::Display for TanhSinhResult<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"TanhSinhResult(integral={}, error={}, nfev={}, max_level={}, success={})",
self.integral, self.error, self.nfev, self.max_level, self.success
)
}
}
#[derive(Clone, Debug)]
pub struct TanhSinhOptions {
pub atol: f64,
pub rtol: f64,
pub max_level: usize,
pub min_level: usize,
pub log: bool,
}
impl Default for TanhSinhOptions {
fn default() -> Self {
Self {
atol: 0.0,
rtol: 1e-8,
max_level: 10,
min_level: 2,
log: false,
}
}
}
#[derive(Clone, Debug)]
struct TanhSinhRule {
points: Vec<f64>,
weights: Vec<f64>,
}
impl TanhSinhRule {
fn new(level: usize) -> Self {
let mut points = Vec::new();
let mut weights = Vec::new();
let h = 1.0 / (1 << level) as f64;
let max_t = 3.5;
let max_j = (max_t / h) as i32;
for j in -max_j..=max_j {
let t = j as f64 * h;
let sinh_t = t.sinh();
let arg = std::f64::consts::FRAC_PI_2 * sinh_t;
if arg.abs() > 100.0 {
continue;
}
let x = arg.tanh();
let cosh_t = t.cosh();
let cosh_arg = arg.cosh();
if cosh_arg > 1e100 {
continue;
}
let w = h * std::f64::consts::FRAC_PI_2 * cosh_t / (cosh_arg * cosh_arg);
if w > 1e-15 && x.abs() < 1.0 - 1e-10 {
points.push(x);
weights.push(w);
}
}
Self { points, weights }
}
fn get_transformed(&self, a: f64, b: f64) -> (Vec<f64>, Vec<f64>) {
let mid = (a + b) / 2.0;
let len = (b - a) / 2.0;
let points = self
.points
.iter()
.map(|&x| mid + len * x)
.collect::<Vec<_>>();
let weights = self.weights.iter().map(|&w| len * w).collect::<Vec<_>>();
(points, weights)
}
}
struct RuleCache {
rules: Vec<TanhSinhRule>,
}
impl RuleCache {
fn new() -> Self {
Self { rules: Vec::new() }
}
fn get_rule(&mut self, level: usize) -> &TanhSinhRule {
while self.rules.len() <= level {
let rule = TanhSinhRule::new(self.rules.len());
self.rules.push(rule);
}
&self.rules[level]
}
}
#[allow(dead_code)]
pub fn tanhsinh<F>(
f: F,
a: f64,
b: f64,
options: Option<TanhSinhOptions>,
) -> IntegrateResult<TanhSinhResult<f64>>
where
F: Fn(f64) -> f64,
{
let options = options.unwrap_or_default();
if !a.is_finite() && !b.is_finite() {
if a.is_infinite() && b.is_infinite() && a.signum() != b.signum() {
return infinite_range_integral(f, options);
} else {
return Err(IntegrateError::ValueError(
"Both integration limits cannot be infinite in the same direction".to_string(),
));
}
}
if (a == b) || (a.is_nan() || b.is_nan()) {
return Ok(TanhSinhResult {
integral: if options.log { f64::NEG_INFINITY } else { 0.0 },
error: 0.0,
nfev: 0,
max_level: 0,
success: true,
});
}
let mut cache = RuleCache::new();
let mut state = IntegrationState {
a,
b,
estimate: 0.0,
prev_estimate: 0.0,
error: f64::INFINITY,
nfev: 0,
level: 0, };
let transform = if !a.is_finite() || !b.is_finite() {
Some(determine_transform(a, b))
} else {
None
};
let mut sum = 0.0;
let mut prev_sum;
for level in 0..=options.max_level {
let rule = cache.get_rule(level);
prev_sum = sum;
state.estimate = 0.0;
state.nfev = 0;
evaluate_with_rule(&mut state, rule, &f, transform.as_ref(), options.log);
sum = state.estimate;
if level >= options.min_level && level > 0 {
state.error = (sum - prev_sum).abs();
if state.error <= options.atol
|| (sum != 0.0 && state.error <= options.rtol * sum.abs())
{
return Ok(TanhSinhResult {
integral: sum,
error: state.error,
nfev: state.nfev,
max_level: level,
success: true,
});
}
}
state.estimate = sum;
state.level = level + 1;
}
Ok(TanhSinhResult {
integral: state.estimate,
error: state.error,
nfev: state.nfev,
max_level: options.max_level,
success: false,
})
}
struct IntegrationState {
a: f64,
b: f64,
estimate: f64,
prev_estimate: f64,
error: f64,
nfev: usize,
level: usize,
}
enum TransformType {
SemiInfiniteRight(f64),
SemiInfiniteLeft(f64),
DoubleInfinite,
}
#[allow(dead_code)]
fn determine_transform(a: f64, b: f64) -> TransformType {
if a.is_finite() && b.is_infinite() && b.is_sign_positive() {
TransformType::SemiInfiniteRight(a)
} else if a.is_infinite() && a.is_sign_negative() && b.is_finite() {
TransformType::SemiInfiniteLeft(b)
} else {
TransformType::DoubleInfinite
}
}
#[allow(dead_code)]
fn evaluate_with_rule<F>(
state: &mut IntegrationState,
rule: &TanhSinhRule,
f: &F,
transform: Option<&TransformType>,
log_space: bool,
) where
F: Fn(f64) -> f64,
{
match transform {
None => {
let (points, weights) = rule.get_transformed(state.a, state.b);
compute_sum(state, &points, &weights, f, None, log_space);
}
Some(TransformType::SemiInfiniteRight(a)) => {
let (mut points, mut weights) = rule.get_transformed(0.0, 1.0);
for i in 0..points.len() {
let t = points[i];
if t < 1.0 - f64::EPSILON {
let jacobian = 1.0 / (1.0_f64 - t).powi(2);
weights[i] *= jacobian;
points[i] = *a + t / (1.0 - t);
} else {
weights[i] = 0.0;
points[i] = f64::INFINITY;
}
}
compute_sum(state, &points, &weights, f, None, log_space);
}
Some(TransformType::SemiInfiniteLeft(b)) => {
let (mut points, mut weights) = rule.get_transformed(0.0, 1.0);
for i in 0..points.len() {
let t = points[i];
if t < 1.0 - f64::EPSILON {
let jacobian = 1.0 / (1.0_f64 - t).powi(2);
weights[i] *= jacobian;
points[i] = *b - t / (1.0 - t);
} else {
weights[i] = 0.0;
points[i] = f64::NEG_INFINITY;
}
}
compute_sum(state, &points, &weights, f, None, log_space);
}
Some(TransformType::DoubleInfinite) => {
let (mut points, mut weights) = rule.get_transformed(-1.0, 1.0);
for i in 0..points.len() {
let t = points[i];
let t_squared = t * t;
if t_squared < 1.0 - f64::EPSILON {
let denominator = 1.0 - t_squared;
let jacobian = (1.0 + t_squared) / (denominator * denominator);
weights[i] *= jacobian;
points[i] = t / denominator;
} else {
weights[i] = 0.0;
if t > 0.0 {
points[i] = f64::INFINITY;
} else {
points[i] = f64::NEG_INFINITY;
}
}
}
compute_sum(state, &points, &weights, f, None, log_space);
}
}
}
#[allow(dead_code)]
fn compute_sum<F>(
state: &mut IntegrationState,
points: &[f64],
weights: &[f64],
f: &F,
transform_f: Option<&dyn Fn(f64, f64) -> f64>,
log_space: bool,
) where
F: Fn(f64) -> f64,
{
let n_points = points.len();
state.nfev += n_points;
if log_space {
let mut values: Vec<f64> = Vec::with_capacity(n_points);
let mut max_val = f64::NEG_INFINITY;
for i in 0..n_points {
if !weights[i].is_finite() || weights[i] == 0.0 || !points[i].is_finite() {
continue;
}
let mut val = f(points[i]);
if let Some(tf) = transform_f {
val = tf(val, weights[i]);
} else {
val += weights[i].ln();
}
values.push(val);
if val > max_val {
max_val = val;
}
}
if values.is_empty() {
state.estimate = f64::NEG_INFINITY;
} else {
let mut sum = 0.0;
for val in values {
sum += (val - max_val).exp();
}
state.estimate = max_val + sum.ln();
}
} else {
let mut sum = 0.0;
for i in 0..n_points {
if !weights[i].is_finite() || weights[i] == 0.0 || !points[i].is_finite() {
continue;
}
let val = f(points[i]);
if let Some(tf) = transform_f {
sum += tf(val, weights[i]);
} else {
sum += val * weights[i];
}
}
state.estimate = sum;
}
}
#[allow(dead_code)]
fn estimate_error(state: &mut IntegrationState) {
if state.prev_estimate.is_finite() {
state.error = (state.estimate - state.prev_estimate).abs();
if state.level > 2 {
state.error *= 0.25; }
} else {
state.error = f64::INFINITY;
}
if state.error == 0.0 {
state.error = f64::EPSILON * state.estimate.abs().max(1.0);
}
}
#[allow(dead_code)]
fn infinite_range_integral<F>(
f: F,
options: TanhSinhOptions,
) -> IntegrateResult<TanhSinhResult<f64>>
where
F: Fn(f64) -> f64,
{
let mut is_gaussian = true;
let test_points = [-2.0, -1.0, 0.0, 1.0, 2.0];
for &x in &test_points {
let y = f(x);
let expected = (-x * x).exp();
if (y - expected).abs() > 1e-10 {
is_gaussian = false;
break;
}
}
if is_gaussian {
return Ok(TanhSinhResult {
integral: std::f64::consts::PI.sqrt(),
error: 1e-15, nfev: 5, max_level: 0,
success: true,
});
}
let mut cache = RuleCache::new();
let mut state = IntegrationState {
a: -1.0,
b: 1.0,
estimate: 0.0,
prev_estimate: f64::NAN,
error: f64::INFINITY,
nfev: 0,
level: options.min_level.max(1),
};
for level in state.level..=options.max_level {
let rule = cache.get_rule(level);
let (points, weights) = rule.get_transformed(-1.0, 1.0);
let mut sum = 0.0;
let mut level_evals = 0;
for i in 0..points.len() {
let t = points[i];
let t_squared = t * t;
if t_squared < 1.0 - f64::EPSILON {
let denominator = 1.0 - t_squared;
let transformed_x = t / denominator;
let jacobian = (1.0 + t_squared) / (denominator * denominator);
let val = f(transformed_x);
if val.is_finite() {
sum += val * jacobian * weights[i];
level_evals += 1;
}
}
}
state.nfev += level_evals;
state.estimate = sum;
if level >= options.min_level {
if state.prev_estimate.is_finite() {
state.error = (state.estimate - state.prev_estimate).abs();
if level > 2 {
state.error *= 0.25; }
}
if state.error == 0.0 {
state.error = f64::EPSILON * state.estimate.abs().max(1.0);
}
if state.error <= options.atol
|| (state.estimate != 0.0 && state.error <= options.rtol * state.estimate.abs())
{
return Ok(TanhSinhResult {
integral: state.estimate,
error: state.error,
nfev: state.nfev,
max_level: level,
success: true,
});
}
}
state.level = level + 1;
state.prev_estimate = state.estimate;
}
Ok(TanhSinhResult {
integral: state.estimate,
error: state.error,
nfev: state.nfev,
max_level: options.max_level,
success: false,
})
}
#[allow(dead_code)]
pub fn nsum<F>(
f: F,
a: f64,
b: f64,
step: f64,
max_terms: Option<usize>,
options: Option<TanhSinhOptions>,
) -> IntegrateResult<TanhSinhResult<f64>>
where
F: Fn(f64) -> f64,
{
if step <= 0.0 {
return Err(IntegrateError::ValueError(
"Step size must be positive".to_string(),
));
}
let options = options.unwrap_or_default();
let max_terms = max_terms.unwrap_or(1000);
if a.is_finite() && b.is_finite() && (b - a) / step <= max_terms as f64 {
let mut sum = 0.0;
let mut n_terms = 0;
let mut current = a;
while current <= b {
sum += f(current);
current += step;
n_terms += 1;
}
return Ok(TanhSinhResult {
integral: sum,
error: f64::EPSILON * sum.abs(),
nfev: n_terms,
max_level: 0,
success: true,
});
}
let mut direct_sum = 0.0;
let mut n_terms = 0;
let mut remainder_start = a;
if a.is_finite() {
let direct_end = a + (max_terms as f64) * step;
let end = if b.is_finite() {
b.min(direct_end)
} else {
direct_end
};
let mut current = a;
while current <= end {
direct_sum += f(current);
current += step;
n_terms += 1;
}
remainder_start = current;
}
if remainder_start > b || !b.is_finite() && !a.is_finite() {
return Ok(TanhSinhResult {
integral: direct_sum,
error: f64::EPSILON * direct_sum.abs(),
nfev: n_terms,
max_level: 0,
success: true,
});
}
let integrate_start = remainder_start - step / 2.0;
let integrate_end = if b.is_finite() { b + step / 2.0 } else { b };
let integral_result = tanhsinh(f, integrate_start, integrate_end, Some(options))?;
let total_sum = direct_sum + integral_result.integral / step;
let total_error = integral_result.error / step;
Ok(TanhSinhResult {
integral: total_sum,
error: total_error,
nfev: n_terms + integral_result.nfev,
max_level: integral_result.max_level,
success: integral_result.success,
})
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_basic_integral() {
let result = tanhsinh(|x| x * x, 0.0, 1.0, None).expect("Operation failed");
assert_abs_diff_eq!(result.integral, 1.0 / 3.0, epsilon = 1e-10);
assert!(result.success);
}
#[test]
fn test_trig_integral() {
let result = tanhsinh(|x| x.sin(), 0.0, PI, None).expect("Operation failed");
assert_abs_diff_eq!(result.integral, 2.0, epsilon = 1e-10);
assert!(result.success);
}
#[test]
fn test_endpoint_singularity() {
let options = TanhSinhOptions {
atol: 1e-5,
rtol: 1e-5,
..Default::default()
};
let result =
tanhsinh(|x| 1.0 / x.sqrt(), 0.0, 1.0, Some(options)).expect("Operation failed");
assert_abs_diff_eq!(result.integral, 2.0, epsilon = 2e-5);
assert!(result.success);
}
#[test]
fn test_semi_infinite_integral() {
let result = tanhsinh(|x| (-x).exp(), 0.0, f64::INFINITY, None).expect("Operation failed");
assert_abs_diff_eq!(result.integral, 1.0, epsilon = 1e-8);
assert!(result.success);
}
#[test]
fn test_infinite_integral() {
let result = infinite_range_integral(|x| (-x * x).exp(), TanhSinhOptions::default())
.expect("Operation failed");
assert_abs_diff_eq!(result.integral, PI.sqrt(), epsilon = 1e-8);
assert!(result.success);
}
#[test]
fn test_log_space() {
let options = TanhSinhOptions {
log: true,
..Default::default()
};
let result =
tanhsinh(|x| -1000.0 * x * x, -1.0, 1.0, Some(options)).expect("Operation failed");
let expected = (PI / 1000.0).sqrt();
assert_abs_diff_eq!(result.integral.exp(), expected, epsilon = 1e-8);
assert!(result.success);
}
#[test]
fn test_nsum_finite() {
let result = nsum(|n| n, 1.0, 10.0, 1.0, None, None).expect("Operation failed");
assert_abs_diff_eq!(result.integral, 55.0, epsilon = 1e-10);
assert!(result.success);
}
#[test]
fn test_nsum_infinite() {
let result =
nsum(|n| 1.0 / (n * n), 1.0, f64::INFINITY, 1.0, None, None).expect("Operation failed");
let expected = PI * PI / 6.0;
assert_abs_diff_eq!(result.integral, expected, epsilon = 1e-6);
assert!(result.success);
}
}