#![allow(dead_code)]
use crate::error::{SpecialError, SpecialResult};
use scirs2_core::ndarray::{Array1, ArrayView1};
use std::collections::HashMap;
use std::sync::LazyLock;
#[derive(Debug, Clone)]
pub struct PerformanceConfig {
pub use_lookup_tables: bool,
pub use_simd: bool,
pub use_parallel: bool,
pub chunksize: usize,
pub lookup_resolution: usize,
}
impl Default for PerformanceConfig {
fn default() -> Self {
Self {
use_lookup_tables: true,
use_simd: cfg!(feature = "simd"),
use_parallel: cfg!(feature = "parallel"),
chunksize: 8192,
lookup_resolution: 10000,
}
}
}
static GAMMA_LOOKUP_TABLE: LazyLock<Vec<f64>> = LazyLock::new(|| {
const TABLE_SIZE: usize = 2000;
const MIN_X: f64 = 0.5;
const MAX_X: f64 = 2.5;
const STEP: f64 = (MAX_X - MIN_X) / (TABLE_SIZE as f64);
let mut table = Vec::with_capacity(TABLE_SIZE + 1);
for i in 0..=TABLE_SIZE {
let x = MIN_X + i as f64 * STEP;
table.push(crate::gamma::gamma(x));
}
table
});
static J0_LOOKUP_TABLE: LazyLock<Vec<f64>> = LazyLock::new(|| {
const TABLE_SIZE: usize = 20000;
const MAX_X: f64 = 20.0;
const STEP: f64 = MAX_X / (TABLE_SIZE as f64);
let mut table = Vec::with_capacity(TABLE_SIZE + 1);
for i in 0..=TABLE_SIZE {
let x = i as f64 * STEP;
table.push(crate::bessel::j0(x));
}
table
});
#[allow(dead_code)]
fn gamma_stirling_precise(x: f64) -> f64 {
if x < 0.5 {
return std::f64::consts::PI
/ ((std::f64::consts::PI * x).sin() * gamma_stirling_precise(1.0 - x));
}
if x < 2.0 {
return gamma_stirling_precise(x + 1.0) / x;
}
let ln_sqrt_2pi = 0.9189385332046727; let z = x - 1.0;
let ln_gamma = ln_sqrt_2pi + (z + 0.5) * z.ln() - z + 1.0 / (12.0 * z)
- 1.0 / (360.0 * z * z * z)
+ 1.0 / (1260.0 * z * z * z * z * z);
ln_gamma.exp()
}
#[allow(dead_code)]
fn j0_series_precise(x: f64) -> f64 {
if x == 0.0 {
return 1.0;
}
if x.abs() > 8.0 {
let sqrt_2_over_pi_x = (2.0 / (std::f64::consts::PI * x)).sqrt();
let phase = x - std::f64::consts::PI / 4.0;
return sqrt_2_over_pi_x * phase.cos();
}
let mut result = 1.0;
let mut term = 1.0;
let x_squared_quarter = x * x / 4.0;
for k in 1..50 {
term *= -x_squared_quarter / (k * k) as f64;
result += term;
if term.abs() < 1e-15 * result.abs() {
break;
}
}
result
}
#[allow(dead_code)]
pub fn gamma_advancedfast(x: f64) -> SpecialResult<f64> {
const MIN_X: f64 = 0.5;
const MAX_X: f64 = 2.5;
const TABLE_SIZE: usize = 2000;
if (MIN_X..=MAX_X).contains(&x) {
let index_f = (x - MIN_X) / (MAX_X - MIN_X) * TABLE_SIZE as f64;
let index = index_f.floor() as usize;
if index >= TABLE_SIZE {
return Ok(GAMMA_LOOKUP_TABLE[TABLE_SIZE]);
}
let frac = index_f - index as f64;
let result =
GAMMA_LOOKUP_TABLE[index] * (1.0 - frac) + GAMMA_LOOKUP_TABLE[index + 1] * frac;
Ok(result)
} else if x < MIN_X {
if x <= 0.0 && x.fract() == 0.0 {
return Err(SpecialError::DomainError(
"Gamma undefined for non-positive integers".to_string(),
));
}
let reflected = gamma_advancedfast(1.0 - x)?;
Ok(std::f64::consts::PI / ((std::f64::consts::PI * x).sin() * reflected))
} else {
let mut result = gamma_advancedfast(x - 1.0)?;
result *= x - 1.0;
Ok(result)
}
}
#[allow(dead_code)]
pub fn j0_advancedfast(x: f64) -> SpecialResult<f64> {
let x_abs = x.abs();
const MAX_X: f64 = 20.0;
const TABLE_SIZE: usize = 20000;
if x_abs <= MAX_X {
let index_f = x_abs / MAX_X * TABLE_SIZE as f64;
let index = index_f.floor() as usize;
if index >= TABLE_SIZE {
return Ok(J0_LOOKUP_TABLE[TABLE_SIZE]);
}
let frac = index_f - index as f64;
let result = J0_LOOKUP_TABLE[index] * (1.0 - frac) + J0_LOOKUP_TABLE[index + 1] * frac;
Ok(result)
} else {
let sqrt_2_over_pi_x = (2.0 / (std::f64::consts::PI * x_abs)).sqrt();
let phase = x_abs - std::f64::consts::PI / 4.0;
Ok(sqrt_2_over_pi_x * phase.cos())
}
}
#[allow(dead_code)]
pub fn gamma_array_advancedfast(
input: &ArrayView1<f64>,
config: &PerformanceConfig,
) -> SpecialResult<Array1<f64>> {
let len = input.len();
let mut output = Array1::zeros(len);
if len > config.chunksize && config.use_parallel {
#[cfg(feature = "parallel")]
{
use scirs2_core::parallel_ops::*;
let input_slice = input.as_slice().ok_or_else(|| {
SpecialError::ComputationError("Input array is not contiguous".to_string())
})?;
let output_slice = output.as_slice_mut().ok_or_else(|| {
SpecialError::ComputationError("Output array is not contiguous".to_string())
})?;
input_slice
.par_chunks(config.chunksize)
.zip(output_slice.par_chunks_mut(config.chunksize))
.try_for_each(|(input_chunk, output_chunk)| -> SpecialResult<()> {
for (i, &x) in input_chunk.iter().enumerate() {
output_chunk[i] = gamma_advancedfast(x)?;
}
Ok(())
})?;
return Ok(output);
}
}
if config.use_simd && len >= 4 {
#[cfg(feature = "simd")]
{
let simd_chunks = len / 4;
for i in 0..simd_chunks {
let start = i * 4;
for j in 0..4 {
output[start + j] = gamma_advancedfast(input[start + j])?;
}
}
for i in (simd_chunks * 4)..len {
output[i] = gamma_advancedfast(input[i])?;
}
return Ok(output);
}
}
for (i, &x) in input.iter().enumerate() {
output[i] = gamma_advancedfast(x)?;
}
Ok(output)
}
#[derive(Debug, Clone)]
pub struct PerformanceMetrics {
pub function_name: String,
pub arraysize: usize,
pub time_ns: u64,
pub throughput_ops_per_sec: f64,
pub memory_bandwidth_gb_per_sec: f64,
pub cache_efficiency: Option<f64>,
}
#[allow(dead_code)]
pub fn benchmark_function<F>(
function_name: &str,
test_function: F,
inputsize: usize,
iterations: usize,
) -> PerformanceMetrics
where
F: Fn() -> SpecialResult<()>,
{
for _ in 0..3 {
let _ = test_function();
}
let start = std::time::Instant::now();
for _ in 0..iterations {
let _ = test_function();
}
let duration = start.elapsed();
let time_ns = duration.as_nanos() as u64 / iterations as u64;
let ops_per_sec = (inputsize as f64 * 1e9) / time_ns as f64;
let memory_bandwidth =
(inputsize * 16) as f64 * 1e9 / (1024.0 * 1024.0 * 1024.0) / (time_ns as f64 / 1e9);
PerformanceMetrics {
function_name: function_name.to_string(),
arraysize: inputsize,
time_ns,
throughput_ops_per_sec: ops_per_sec,
memory_bandwidth_gb_per_sec: memory_bandwidth,
cache_efficiency: None, }
}
#[allow(dead_code)]
pub fn erf_advancedfast(x: f64) -> SpecialResult<f64> {
let abs_x = x.abs();
if abs_x < 0.5 {
let x2 = x * x;
let result = x * (1.0 - x2 * (1.0 / 3.0 - x2 * (1.0 / 10.0 - x2 * (1.0 / 42.0))));
Ok(result * 2.0 / std::f64::consts::PI.sqrt())
} else if abs_x < 4.0 {
let t = 1.0 / (1.0 + 0.3275911 * abs_x);
let poly = t
* (0.254829592
+ t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
let result = 1.0 - poly * (-abs_x * abs_x).exp();
Ok(if x >= 0.0 { result } else { -result })
} else {
let result = if x >= 0.0 { 1.0 } else { -1.0 };
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_gamma_advancedfast_accuracy() {
let test_values = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0];
let expected = [
1.7724538509055159, 1.0, 0.8862269254527580, 1.0, 1.3293403881791370, 2.0, ];
for (i, &x) in test_values.iter().enumerate() {
let result = gamma_advancedfast(x).expect("Operation failed");
assert_relative_eq!(result, expected[i], epsilon = 1e-10);
}
}
#[test]
fn test_j0_advancedfast_accuracy() {
let test_values = [0.0, 1.0, 2.0, 5.0, 10.0];
let expected = [
1.0,
0.7651976865579666,
0.2238907791412357,
-0.17759677131433830,
-0.24593576445134833,
];
for (i, &x) in test_values.iter().enumerate() {
let result = j0_advancedfast(x).expect("Operation failed");
assert_relative_eq!(result, expected[i], epsilon = 1e-8);
}
}
#[test]
fn test_performance_metrics() {
let input = scirs2_core::ndarray::Array1::linspace(0.1, 10.0, 1000);
let config = PerformanceConfig::default();
let metrics = benchmark_function(
"gamma_array_advancedfast",
|| gamma_array_advancedfast(&input.view(), &config).map(|_| ()),
1000,
10,
);
assert!(metrics.throughput_ops_per_sec > 0.0);
assert!(metrics.memory_bandwidth_gb_per_sec > 0.0);
}
}