use crate::error::IntegrateError;
use scirs2_core::parallel_ops::*;
use std::sync::Arc;
pub struct CachedOdeProblem<F>
where
F: Fn(f64, &[f64], &mut [f64]) + Send + Sync,
{
rhs: Arc<F>,
t0: f64,
dt: f64,
n_steps: usize,
state_dim: usize,
}
impl<F> CachedOdeProblem<F>
where
F: Fn(f64, &[f64], &mut [f64]) + Send + Sync + 'static,
{
pub fn new(rhs: F, t0: f64, t1: f64, dt: f64, state_dim: usize) -> Self {
let span = t1 - t0;
let n_steps = ((span / dt).ceil() as usize).max(1);
CachedOdeProblem {
rhs: Arc::new(rhs),
t0,
dt,
n_steps,
state_dim,
}
}
pub fn integrate(&self, y0: &[f64]) -> Result<Vec<f64>, IntegrateError> {
if y0.len() != self.state_dim {
return Err(IntegrateError::DimensionMismatch(format!(
"y0.len()={} != state_dim={}",
y0.len(),
self.state_dim
)));
}
let dim = self.state_dim;
let mut y = y0.to_vec();
let mut k1 = vec![0.0_f64; dim];
let mut k2 = vec![0.0_f64; dim];
let mut k3 = vec![0.0_f64; dim];
let mut k4 = vec![0.0_f64; dim];
let mut ytmp = vec![0.0_f64; dim];
let rhs = &*self.rhs;
let mut t = self.t0;
let h = self.dt;
for _ in 0..self.n_steps {
rhs(t, &y, &mut k1);
for i in 0..dim {
ytmp[i] = y[i] + 0.5 * h * k1[i];
}
rhs(t + 0.5 * h, &ytmp, &mut k2);
for i in 0..dim {
ytmp[i] = y[i] + 0.5 * h * k2[i];
}
rhs(t + 0.5 * h, &ytmp, &mut k3);
for i in 0..dim {
ytmp[i] = y[i] + h * k3[i];
}
rhs(t + h, &ytmp, &mut k4);
for i in 0..dim {
y[i] += (h / 6.0) * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]);
}
t += h;
}
Ok(y)
}
pub fn integrate_batch(&self, batch_y0: &[Vec<f64>]) -> Result<Vec<Vec<f64>>, IntegrateError> {
parallel_map_result(batch_y0, |y0| self.integrate(y0))
}
pub fn dt(&self) -> f64 {
self.dt
}
pub fn n_steps(&self) -> usize {
self.n_steps
}
pub fn state_dim(&self) -> usize {
self.state_dim
}
}
pub async fn integrate_batch_async<F>(
problem: Arc<CachedOdeProblem<F>>,
batch_y0: Vec<Vec<f64>>,
) -> Result<Vec<Vec<f64>>, IntegrateError>
where
F: Fn(f64, &[f64], &mut [f64]) + Send + Sync + 'static,
{
tokio::task::spawn_blocking(move || problem.integrate_batch(&batch_y0))
.await
.map_err(|e| IntegrateError::ComputationError(format!("spawn_blocking panicked: {e}")))?
}
#[cfg(test)]
mod tests {
use super::*;
fn exponential_decay() -> impl Fn(f64, &[f64], &mut [f64]) + Send + Sync + 'static {
|_t, y, dydt| {
dydt[0] = -y[0];
}
}
#[test]
fn test_cached_ode_exponential_decay() {
let problem = CachedOdeProblem::new(exponential_decay(), 0.0, 1.0, 0.001, 1);
let result = problem.integrate(&[1.0]).expect("integration failed");
let expected = std::f64::consts::E.recip(); assert!(
(result[0] - expected).abs() < 1e-5,
"Expected ≈{expected:.6}, got {:.6}",
result[0]
);
}
#[test]
fn test_batch_integration_matches_serial() {
let problem = Arc::new(CachedOdeProblem::new(
exponential_decay(),
0.0,
0.5,
0.001,
1,
));
let batch_y0 = vec![vec![1.0], vec![2.0], vec![0.5]];
let batch_result = problem.integrate_batch(&batch_y0).expect("batch failed");
for (y0, yr) in batch_y0.iter().zip(batch_result.iter()) {
let serial = problem.integrate(y0).expect("serial failed");
assert!(
(serial[0] - yr[0]).abs() < 1e-14,
"Batch/serial mismatch: serial={:.10} batch={:.10}",
serial[0],
yr[0]
);
}
}
#[test]
fn test_neural_ode_repeated_forward_same_result() {
let problem = Arc::new(CachedOdeProblem::new(
|_t, y, dydt| {
dydt[0] = -y[0];
dydt[1] = -2.0 * y[1];
},
0.0,
1.0,
0.01,
2,
));
let y0 = vec![1.0, 1.0];
let r1 = problem.integrate(&y0).expect("first forward failed");
let r2 = problem.integrate(&y0).expect("second forward failed");
let r3 = problem.integrate(&y0).expect("third forward failed");
assert_eq!(r1, r2, "Results differ between calls 1 and 2");
assert_eq!(r1, r3, "Results differ between calls 1 and 3");
}
#[test]
fn test_dimension_mismatch_returns_error() {
let problem = CachedOdeProblem::new(exponential_decay(), 0.0, 1.0, 0.01, 1);
assert!(problem.integrate(&[1.0, 2.0]).is_err());
}
#[tokio::test]
async fn test_async_batch_returns_correct_shape() {
let problem = Arc::new(CachedOdeProblem::new(
exponential_decay(),
0.0,
0.5,
0.01,
1,
));
let batch_y0 = vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]];
let expected_len = batch_y0.len();
let results = integrate_batch_async(problem, batch_y0)
.await
.expect("async batch failed");
assert_eq!(results.len(), expected_len);
for r in &results {
assert_eq!(r.len(), 1, "Each result must have state_dim=1 entries");
}
}
#[tokio::test]
async fn test_async_matches_sync() {
let problem_async = Arc::new(CachedOdeProblem::new(
exponential_decay(),
0.0,
1.0,
0.001,
1,
));
let problem_sync = Arc::new(CachedOdeProblem::new(
exponential_decay(),
0.0,
1.0,
0.001,
1,
));
let batch_y0 = vec![vec![1.0], vec![0.5], vec![2.0]];
let async_results = integrate_batch_async(problem_async, batch_y0.clone())
.await
.expect("async failed");
let sync_results = problem_sync
.integrate_batch(&batch_y0)
.expect("sync failed");
for (a, s) in async_results.iter().zip(sync_results.iter()) {
assert!(
(a[0] - s[0]).abs() < 1e-14,
"Async/sync mismatch: {:.10} vs {:.10}",
a[0],
s[0]
);
}
}
}