pub const GAMMA_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
const PI: f32 = 3.14159265358979323846;
// Lanczos g=7 coefficients (Spouge's form, 9 terms)
fn lanczos_gamma(x_in: f32) -> f32 {
var x = x_in;
var sign = 1.0f;
if x < 0.5 {
sign = PI / (sin(PI * x));
x = 1.0 - x;
}
let g: f32 = 7.0;
x = x - 1.0;
let c0: f32 = 0.99999999999980993;
let c1: f32 = 676.5203681218851;
let c2: f32 = -1259.1392167224028;
let c3: f32 = 771.32342877765313;
let c4: f32 = -176.61502916214059;
let c5: f32 = 12.507343278686905;
let c6: f32 = -0.13857109526572012;
let c7: f32 = 9.9843695780195716e-6;
let c8: f32 = 1.5056327351493116e-7;
let s = c0
+ c1 / (x + 1.0)
+ c2 / (x + 2.0)
+ c3 / (x + 3.0)
+ c4 / (x + 4.0)
+ c5 / (x + 5.0)
+ c6 / (x + 6.0)
+ c7 / (x + 7.0)
+ c8 / (x + 8.0);
let t = x + g + 0.5;
let result = sqrt(2.0 * PI) * pow(t, x + 0.5) * exp(-t) * s;
if sign != 1.0 { return sign / result; }
return result;
}
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if idx >= arrayLength(&input) { return; }
output[idx] = lanczos_gamma(input[idx]);
}
"#;
pub const ERF_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
fn approx_erf(x: f32) -> f32 {
let t = 1.0 / (1.0 + 0.3275911 * abs(x));
let y = 1.0 - (((((
1.061405429 * t
- 1.453152027) * t
+ 1.421413741) * t
- 0.284496736) * t
+ 0.254829592) * t * exp(-x * x));
return select(-y, y, x >= 0.0);
}
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if idx >= arrayLength(&input) { return; }
output[idx] = approx_erf(input[idx]);
}
"#;
pub const BESSEL_J0_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
const PI: f32 = 3.14159265358979323846;
fn bessel_j0(x_in: f32) -> f32 {
let x = abs(x_in);
if x < 8.0 {
let y = x * x;
let p1: f32 = 57568490574.0;
let p2: f32 = -13362590354.0;
let p3: f32 = 651619640.7;
let p4: f32 = -11214424.18;
let p5: f32 = 77392.33017;
let p6: f32 = -184.9052456;
let q1: f32 = 57568490411.0;
let q2: f32 = 1029532985.0;
let q3: f32 = 9494680.718;
let q4: f32 = 59272.64853;
let q5: f32 = 267.8532712;
let p = p1 + y * (p2 + y * (p3 + y * (p4 + y * (p5 + y * p6))));
let q = q1 + y * (q2 + y * (q3 + y * (q4 + y * (q5 + y))));
return p / q;
} else {
let z = 8.0 / x;
let y = z * z;
let xx = x - 0.785398164;
let pv = 1.0 + y * (-0.1098628627e-2 + y * (0.2734510407e-4
+ y * (-0.2073370639e-5 + y * 0.2093887211e-6)));
let qv = -0.1562499995e-1 + y * (0.1430488765e-3
+ y * (-0.6911147651e-5 + y * (0.7621095161e-6
- y * 0.934945152e-7)));
return sqrt(0.636619772 / x) * (cos(xx) * pv - z * sin(xx) * qv);
}
}
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if idx >= arrayLength(&input) { return; }
output[idx] = bessel_j0(input[idx]);
}
"#;
pub const ERFC_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
fn approx_erf_inner(x: f32) -> f32 {
let t = 1.0 / (1.0 + 0.3275911 * abs(x));
let y = 1.0 - (((((
1.061405429 * t
- 1.453152027) * t
+ 1.421413741) * t
- 0.284496736) * t
+ 0.254829592) * t * exp(-x * x));
return select(-y, y, x >= 0.0);
}
fn approx_erfc(x: f32) -> f32 {
// erfc saturates quickly: |erfc(x)| < f32_epsilon for |x| > ~6
if abs(x) > 6.0 {
return select(0.0, 2.0, x < 0.0);
}
return 1.0 - approx_erf_inner(x);
}
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if idx >= arrayLength(&input) { return; }
output[idx] = approx_erfc(input[idx]);
}
"#;
pub const ERFINV_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
const PI_F: f32 = 3.14159265358979323846;
const WINITZKI_A: f32 = 0.147;
const INV_WINITZKI_A: f32 = 6.802721088; // 1.0 / 0.147
fn approx_erfinv(p: f32) -> f32 {
let ap = abs(p);
if ap >= 1.0 {
// Return signed large value for |p| = 1 boundary
return select(1e10, -1e10, p < 0.0);
}
if p == 0.0 {
return 0.0;
}
let sign_p = select(-1.0f, 1.0f, p >= 0.0);
// Winitzki (2008): erfinv(p) ≈ sign(p) * sqrt(sqrt(c^2 - ln(1-p^2)/a) - c)
// where c = 2/(π·a) + ln(1-p^2)/2
let ln_term = log(1.0 - p * p);
let two_over_pia = 2.0 / (PI_F * WINITZKI_A);
let c = two_over_pia + ln_term * 0.5;
let discriminant = c * c - ln_term * INV_WINITZKI_A;
// discriminant is always non-negative for |p| < 1
let inner = sqrt(max(discriminant, 0.0)) - c;
return sign_p * sqrt(max(inner, 0.0));
}
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if idx >= arrayLength(&input) { return; }
output[idx] = approx_erfinv(input[idx]);
}
"#;
pub const LGAMMA_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
const PI: f32 = 3.14159265358979323846;
fn lanczos_lgamma(x_in: f32) -> f32 {
var x = x_in;
var log_sign: f32 = 0.0;
if x < 0.5 {
log_sign = log(PI / abs(sin(PI * x)));
x = 1.0 - x;
}
let g: f32 = 7.0;
x = x - 1.0;
let c0: f32 = 0.99999999999980993;
let c1: f32 = 676.5203681218851;
let c2: f32 = -1259.1392167224028;
let c3: f32 = 771.32342877765313;
let c4: f32 = -176.61502916214059;
let c5: f32 = 12.507343278686905;
let c6: f32 = -0.13857109526572012;
let c7: f32 = 9.9843695780195716e-6;
let c8: f32 = 1.5056327351493116e-7;
let s = c0 + c1/(x+1.0) + c2/(x+2.0) + c3/(x+3.0) + c4/(x+4.0)
+ c5/(x+5.0) + c6/(x+6.0) + c7/(x+7.0) + c8/(x+8.0);
let t = x + g + 0.5;
let lgamma = 0.5 * log(2.0 * PI) + (x + 0.5) * log(t) - t + log(s);
if log_sign != 0.0 {
return log_sign - lgamma;
}
return lgamma;
}
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if idx >= arrayLength(&input) { return; }
output[idx] = lanczos_lgamma(input[idx]);
}
"#;
#[derive(Debug, Clone)]
pub enum WgslDispatchError {
GpuNotAvailable,
RuntimeError(String),
}
impl std::fmt::Display for WgslDispatchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WgslDispatchError::GpuNotAvailable => {
write!(f, "wgpu GPU device not available")
}
WgslDispatchError::RuntimeError(msg) => {
write!(f, "wgpu runtime error: {msg}")
}
}
}
}
#[cfg(feature = "wgpu_kernels")]
fn dispatch_unary_f32(shader_src: &str, xs_f32: &[f32]) -> Result<Vec<f32>, WgslDispatchError> {
use wgpu::{
util::BufferInitDescriptor, util::DeviceExt as _, Backends, BindGroupDescriptor,
BindGroupEntry, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType,
BufferBindingType, BufferDescriptor, BufferUsages, CommandEncoderDescriptor,
ComputePassDescriptor, DeviceDescriptor, Features, Instance, InstanceDescriptor, Limits,
MapMode, PowerPreference, RequestAdapterOptions, ShaderModuleDescriptor, ShaderSource,
ShaderStages,
};
let n = xs_f32.len();
if n == 0 {
return Ok(Vec::new());
}
let instance = Instance::new(InstanceDescriptor {
backends: Backends::all(),
flags: wgpu::InstanceFlags::default(),
memory_budget_thresholds: Default::default(),
backend_options: Default::default(),
display: None,
});
let adapter = pollster::block_on(instance.request_adapter(&RequestAdapterOptions {
power_preference: PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
}))
.map_err(|_| WgslDispatchError::GpuNotAvailable)?;
let (device, queue) = pollster::block_on(adapter.request_device(&DeviceDescriptor {
label: Some("scirs2-special"),
required_features: Features::empty(),
required_limits: Limits::default(),
..Default::default()
}))
.map_err(|e| WgslDispatchError::RuntimeError(e.to_string()))?;
let shader_module = device.create_shader_module(ShaderModuleDescriptor {
label: Some("scirs2-special-shader"),
source: ShaderSource::Wgsl(shader_src.into()),
});
let bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("scirs2-special-bgl"),
entries: &[
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("scirs2-special-layout"),
bind_group_layouts: &[Some(&bgl)],
..Default::default()
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("scirs2-special-pipeline"),
layout: Some(&pipeline_layout),
module: &shader_module,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let input_bytes: Vec<u8> = xs_f32.iter().flat_map(|v| v.to_le_bytes()).collect();
let byte_len = (n * 4) as u64;
let buf_input = device.create_buffer_init(&BufferInitDescriptor {
label: Some("scirs2-special-input"),
contents: &input_bytes,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
});
let buf_output = device.create_buffer(&BufferDescriptor {
label: Some("scirs2-special-output"),
size: byte_len,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let buf_staging = device.create_buffer(&BufferDescriptor {
label: Some("scirs2-special-staging"),
size: byte_len,
usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let bind_group = device.create_bind_group(&BindGroupDescriptor {
label: Some("scirs2-special-bg"),
layout: &bgl,
entries: &[
BindGroupEntry {
binding: 0,
resource: buf_input.as_entire_binding(),
},
BindGroupEntry {
binding: 1,
resource: buf_output.as_entire_binding(),
},
],
});
let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor {
label: Some("scirs2-special-encoder"),
});
{
let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor {
label: Some("scirs2-special-pass"),
timestamp_writes: None,
});
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
let workgroups = (n as u32 + 63) / 64;
cpass.dispatch_workgroups(workgroups, 1, 1);
}
encoder.copy_buffer_to_buffer(&buf_output, 0, &buf_staging, 0, byte_len);
queue.submit(Some(encoder.finish()));
device
.poll(wgpu::PollType::wait_indefinitely())
.map_err(|e| WgslDispatchError::RuntimeError(format!("GPU poll error: {e:?}")))?;
let slice = buf_staging.slice(0..byte_len);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(MapMode::Read, move |r| {
let _ = tx.send(r);
});
device
.poll(wgpu::PollType::wait_indefinitely())
.map_err(|e| WgslDispatchError::RuntimeError(format!("GPU poll during map: {e:?}")))?;
rx.recv()
.map_err(|_| WgslDispatchError::RuntimeError("channel closed in map_async".into()))?
.map_err(|e| WgslDispatchError::RuntimeError(format!("map_async failed: {e:?}")))?;
let mapped = slice.get_mapped_range();
let result: Vec<f32> = mapped
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
drop(mapped);
buf_staging.unmap();
Ok(result)
}
#[cfg(feature = "wgpu_kernels")]
pub fn gamma_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
let result_f32 = dispatch_unary_f32(GAMMA_WGSL, &xs_f32)?;
Ok(result_f32.iter().map(|&v| v as f64).collect())
}
#[cfg(feature = "wgpu_kernels")]
pub fn erf_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
let result_f32 = dispatch_unary_f32(ERF_WGSL, &xs_f32)?;
Ok(result_f32.iter().map(|&v| v as f64).collect())
}
#[cfg(feature = "wgpu_kernels")]
pub fn bessel_j0_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
let result_f32 = dispatch_unary_f32(BESSEL_J0_WGSL, &xs_f32)?;
Ok(result_f32.iter().map(|&v| v as f64).collect())
}
#[cfg(feature = "wgpu_kernels")]
pub fn lgamma_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
let result_f32 = dispatch_unary_f32(LGAMMA_WGSL, &xs_f32)?;
Ok(result_f32.iter().map(|&v| v as f64).collect())
}
#[cfg(feature = "wgpu_kernels")]
pub fn erfc_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
let result_f32 = dispatch_unary_f32(ERFC_WGSL, &xs_f32)?;
Ok(result_f32.iter().map(|&v| v as f64).collect())
}
#[cfg(feature = "wgpu_kernels")]
pub fn erfinv_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
let result_f32 = dispatch_unary_f32(ERFINV_WGSL, &xs_f32)?;
Ok(result_f32.iter().map(|&v| v as f64).collect())
}
#[cfg(not(feature = "wgpu_kernels"))]
pub fn gamma_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
Err(WgslDispatchError::GpuNotAvailable)
}
#[cfg(not(feature = "wgpu_kernels"))]
pub fn erf_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
Err(WgslDispatchError::GpuNotAvailable)
}
#[cfg(not(feature = "wgpu_kernels"))]
pub fn bessel_j0_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
Err(WgslDispatchError::GpuNotAvailable)
}
#[cfg(not(feature = "wgpu_kernels"))]
pub fn lgamma_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
Err(WgslDispatchError::GpuNotAvailable)
}
#[cfg(not(feature = "wgpu_kernels"))]
pub fn erfc_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
Err(WgslDispatchError::GpuNotAvailable)
}
#[cfg(not(feature = "wgpu_kernels"))]
pub fn erfinv_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
Err(WgslDispatchError::GpuNotAvailable)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gamma_wgsl_source_is_non_empty() {
assert!(!GAMMA_WGSL.is_empty());
assert!(GAMMA_WGSL.contains("@compute"));
assert!(GAMMA_WGSL.contains("workgroup_size"));
assert!(GAMMA_WGSL.contains("lanczos_gamma"));
}
#[test]
fn test_erf_wgsl_source_is_non_empty() {
assert!(!ERF_WGSL.is_empty());
assert!(ERF_WGSL.contains("@compute"));
assert!(ERF_WGSL.contains("approx_erf"));
}
#[test]
fn test_bessel_j0_wgsl_source_is_non_empty() {
assert!(!BESSEL_J0_WGSL.is_empty());
assert!(BESSEL_J0_WGSL.contains("@compute"));
assert!(BESSEL_J0_WGSL.contains("bessel_j0"));
}
#[test]
fn test_lgamma_wgsl_source_is_non_empty() {
assert!(!LGAMMA_WGSL.is_empty());
assert!(LGAMMA_WGSL.contains("@compute"));
assert!(LGAMMA_WGSL.contains("lanczos_lgamma"));
}
#[test]
fn test_erfc_wgsl_source_is_non_empty() {
assert!(!ERFC_WGSL.is_empty());
assert!(ERFC_WGSL.contains("@compute"));
assert!(ERFC_WGSL.contains("approx_erfc"));
assert!(ERFC_WGSL.contains("workgroup_size"));
}
#[test]
fn test_erfinv_wgsl_source_is_non_empty() {
assert!(!ERFINV_WGSL.is_empty());
assert!(ERFINV_WGSL.contains("@compute"));
assert!(ERFINV_WGSL.contains("approx_erfinv"));
assert!(ERFINV_WGSL.contains("workgroup_size"));
}
#[test]
fn test_gamma_batch_wgpu_returns_not_available() {
let xs = vec![1.0_f64, 2.0, 3.0];
let result = gamma_batch_wgpu(&xs);
match result {
Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
Err(e) => panic!("unexpected error: {e}"),
}
}
#[test]
fn test_erf_batch_wgpu_returns_not_available() {
let xs = vec![0.0_f64, 1.0];
let result = erf_batch_wgpu(&xs);
match result {
Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
Err(e) => panic!("unexpected error: {e}"),
}
}
#[test]
fn test_bessel_j0_batch_wgpu_returns_not_available() {
let xs = vec![0.0_f64, 2.405];
let result = bessel_j0_batch_wgpu(&xs);
match result {
Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
Err(e) => panic!("unexpected error: {e}"),
}
}
#[test]
fn test_lgamma_batch_wgpu_returns_not_available() {
let xs = vec![1.0_f64, 2.0, 3.0];
let result = lgamma_batch_wgpu(&xs);
match result {
Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
Err(e) => panic!("unexpected error: {e}"),
}
}
#[test]
fn test_erfc_batch_wgpu_returns_not_available() {
let xs = vec![0.0_f64, 1.0, -1.0];
let result = erfc_batch_wgpu(&xs);
match result {
Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
Err(e) => panic!("unexpected error: {e}"),
}
}
#[test]
fn test_erfinv_batch_wgpu_returns_not_available() {
let xs = vec![0.0_f64, 0.5, -0.5];
let result = erfinv_batch_wgpu(&xs);
match result {
Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
Err(e) => panic!("unexpected error: {e}"),
}
}
#[test]
fn test_wgsl_dispatch_error_display() {
let e = WgslDispatchError::GpuNotAvailable;
assert!(e.to_string().contains("not available"));
let e2 = WgslDispatchError::RuntimeError("buffer overflow".into());
assert!(e2.to_string().contains("buffer overflow"));
}
}