use super::{evaluate_htst_temperature_sweep, ElementaryRateResult, ElementaryStep};
use crate::gpu::context::GpuContext;
const GPU_HTST_SWEEP_THRESHOLD: usize = 256;
pub fn evaluate_htst_sweep_gpu(
ctx: &GpuContext,
step: &ElementaryStep,
temperatures_k: &[f64],
pressure_bar: f64,
) -> Result<Vec<ElementaryRateResult>, String> {
if !ctx.capabilities.gpu_available || temperatures_k.len() < GPU_HTST_SWEEP_THRESHOLD {
return evaluate_htst_temperature_sweep(step, temperatures_k, pressure_bar);
}
evaluate_htst_temperature_sweep(step, temperatures_k, pressure_bar)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gpu_sweep_falls_back_to_cpu() {
let ctx = GpuContext::cpu_fallback();
let step = ElementaryStep {
step_id: "test".into(),
activation_free_energy_ev: 0.5,
reaction_free_energy_ev: -0.2,
prefactor_s_inv: None,
};
let temps: Vec<f64> = (300..=800).step_by(50).map(|t| t as f64).collect();
let gpu_result = evaluate_htst_sweep_gpu(&ctx, &step, &temps, 1.0).unwrap();
let cpu_result = evaluate_htst_temperature_sweep(&step, &temps, 1.0).unwrap();
assert_eq!(gpu_result.len(), cpu_result.len());
for (g, c) in gpu_result.iter().zip(&cpu_result) {
assert!(
(g.forward_rate_s_inv - c.forward_rate_s_inv).abs() < 1e-10,
"GPU/CPU parity: {} vs {}",
g.forward_rate_s_inv,
c.forward_rate_s_inv
);
}
}
}