Skip to main content

scirs2/
integrate.rs

1//! Python bindings for scirs2-integrate
2//!
3//! Provides numerical integration similar to scipy.integrate
4
5// Allow deprecated with_gil for callback patterns where GIL must be acquired from Rust
6
7use pyo3::prelude::*;
8use pyo3::types::PyDict;
9use scirs2_core::ndarray::{Array1 as Array1_17, ArrayView1};
10use scirs2_core::python::numpy_compat::{scirs_to_numpy_array1, Array1};
11use scirs2_numpy::{PyArray1, PyReadonlyArray1};
12
13use scirs2_integrate::ode::{solve_ivp, ODEMethod, ODEOptions};
14use scirs2_integrate::quad::{quad, QuadOptions};
15
16// =============================================================================
17// Array-based Integration (works without callbacks)
18// =============================================================================
19
20/// Integrate using array data (y values at x points) - trapezoidal rule
21///
22/// Similar to scipy.integrate.trapezoid
23#[pyfunction]
24#[pyo3(signature = (y, x=None, dx=1.0))]
25fn trapezoid_array_py(
26    y: PyReadonlyArray1<f64>,
27    x: Option<PyReadonlyArray1<f64>>,
28    dx: f64,
29) -> PyResult<f64> {
30    let y_arr = y.as_array();
31
32    if y_arr.len() < 2 {
33        return Err(pyo3::exceptions::PyValueError::new_err(
34            "Need at least 2 points",
35        ));
36    }
37
38    let result = if let Some(x_py) = x {
39        let x_arr = x_py.as_array();
40        if x_arr.len() != y_arr.len() {
41            return Err(pyo3::exceptions::PyValueError::new_err(
42                "x and y must have same length",
43            ));
44        }
45        // Non-uniform spacing
46        let mut total = 0.0;
47        for i in 0..y_arr.len() - 1 {
48            let dx = x_arr[i + 1] - x_arr[i];
49            total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
50        }
51        total
52    } else {
53        // Uniform spacing with provided dx
54        let mut total = 0.0;
55        for i in 0..y_arr.len() - 1 {
56            total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
57        }
58        total
59    };
60
61    Ok(result)
62}
63
64/// Integrate using array data - Simpson's rule
65///
66/// Similar to scipy.integrate.simpson
67#[pyfunction]
68#[pyo3(signature = (y, x=None, dx=1.0))]
69fn simpson_array_py(
70    y: PyReadonlyArray1<f64>,
71    x: Option<PyReadonlyArray1<f64>>,
72    dx: f64,
73) -> PyResult<f64> {
74    let y_arr = y.as_array();
75    let n = y_arr.len();
76
77    if n < 3 {
78        return Err(pyo3::exceptions::PyValueError::new_err(
79            "Need at least 3 points",
80        ));
81    }
82
83    // Use Simpson's rule for even number of intervals, fall back to trapezoid for odd
84    let result = if let Some(x_py) = x {
85        let x_arr = x_py.as_array();
86        if x_arr.len() != y_arr.len() {
87            return Err(pyo3::exceptions::PyValueError::new_err(
88                "x and y must have same length",
89            ));
90        }
91
92        let mut total = 0.0;
93        let mut i = 0;
94        while i + 2 < n {
95            let h = (x_arr[i + 2] - x_arr[i]) / 2.0;
96            total += h / 3.0 * (y_arr[i] + 4.0 * y_arr[i + 1] + y_arr[i + 2]);
97            i += 2;
98        }
99        // Handle remaining interval with trapezoid
100        if i + 1 < n {
101            let h = x_arr[i + 1] - x_arr[i];
102            total += 0.5 * (y_arr[i] + y_arr[i + 1]) * h;
103        }
104        total
105    } else {
106        let mut total = 0.0;
107        let mut i = 0;
108        while i + 2 < n {
109            total += dx / 3.0 * (y_arr[i] + 4.0 * y_arr[i + 1] + y_arr[i + 2]);
110            i += 2;
111        }
112        // Handle remaining interval with trapezoid
113        if i + 1 < n {
114            total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
115        }
116        total
117    };
118
119    Ok(result)
120}
121
122/// Cumulative trapezoidal integration
123///
124/// Similar to scipy.integrate.cumulative_trapezoid
125#[pyfunction]
126#[pyo3(signature = (y, x=None, dx=1.0, initial=None))]
127fn cumulative_trapezoid_py(
128    py: Python,
129    y: PyReadonlyArray1<f64>,
130    x: Option<PyReadonlyArray1<f64>>,
131    dx: f64,
132    initial: Option<f64>,
133) -> PyResult<Py<PyArray1<f64>>> {
134    let y_arr = y.as_array();
135
136    if y_arr.len() < 2 {
137        return Err(pyo3::exceptions::PyValueError::new_err(
138            "Need at least 2 points",
139        ));
140    }
141
142    let n = y_arr.len();
143    let has_initial = initial.is_some();
144    let mut result = Vec::with_capacity(if has_initial { n } else { n - 1 });
145
146    if let Some(init) = initial {
147        result.push(init);
148    }
149
150    let mut cumsum = initial.unwrap_or(0.0);
151
152    if let Some(x_py) = x {
153        let x_arr = x_py.as_array();
154        for i in 0..y_arr.len() - 1 {
155            let dx_i = x_arr[i + 1] - x_arr[i];
156            cumsum += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx_i;
157            result.push(cumsum);
158        }
159    } else {
160        for i in 0..y_arr.len() - 1 {
161            cumsum += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
162            result.push(cumsum);
163        }
164    }
165
166    scirs_to_numpy_array1(Array1::from_vec(result), py)
167}
168
169/// Romberg integration using array data
170#[pyfunction]
171fn romberg_array_py(y: PyReadonlyArray1<f64>, dx: f64) -> PyResult<f64> {
172    let y_arr = y.as_array();
173    let n = y_arr.len();
174
175    if n < 3 {
176        return Err(pyo3::exceptions::PyValueError::new_err(
177            "Need at least 3 points",
178        ));
179    }
180
181    // Simple implementation using available data points
182    // This is essentially Simpson's rule as a good approximation
183    let mut total = 0.0;
184    let mut i = 0;
185    while i + 2 < n {
186        total += dx / 3.0 * (y_arr[i] + 4.0 * y_arr[i + 1] + y_arr[i + 2]);
187        i += 2;
188    }
189    if i + 1 < n {
190        total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
191    }
192
193    Ok(total)
194}
195
196// =============================================================================
197// Adaptive Quadrature
198// =============================================================================
199
200/// Adaptive quadrature integration
201///
202/// Parameters:
203/// - fun: Function to integrate
204/// - a: Lower bound
205/// - b: Upper bound
206/// - epsabs: Absolute error tolerance (default 1.49e-8)
207/// - epsrel: Relative error tolerance (default 1.49e-8)
208/// - maxiter: Maximum function evaluations (default 500)
209///
210/// Returns:
211/// - Dict with 'value' (integral), 'error' (estimated error), 'neval', 'success'
212#[pyfunction]
213#[pyo3(signature = (fun, a, b, epsabs=1.49e-8, epsrel=1.49e-8, maxiter=500))]
214fn quad_py(
215    py: Python,
216    fun: &Bound<'_, PyAny>,
217    a: f64,
218    b: f64,
219    epsabs: f64,
220    epsrel: f64,
221    maxiter: usize,
222) -> PyResult<Py<PyAny>> {
223    let fun_clone = fun.clone().unbind();
224    let f = |x: f64| -> f64 {
225        #[allow(deprecated)]
226        Python::with_gil(|py| {
227            let result = fun_clone
228                .bind(py)
229                .call1((x,))
230                .expect("Failed to call function");
231            result.extract().expect("Failed to extract result")
232        })
233    };
234
235    let options = QuadOptions {
236        abs_tol: epsabs,
237        rel_tol: epsrel,
238        max_evals: maxiter,
239        ..Default::default()
240    };
241
242    let result = quad(f, a, b, Some(options))
243        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
244
245    let dict = PyDict::new(py);
246    dict.set_item("value", result.value)?;
247    dict.set_item("error", result.abs_error)?;
248    dict.set_item("neval", result.n_evals)?;
249    dict.set_item("success", result.converged)?;
250
251    Ok(dict.into())
252}
253
254// =============================================================================
255// ODE Solvers
256// =============================================================================
257
258/// Solve an initial value problem for a system of ODEs
259///
260/// Parameters:
261/// - fun: Function computing dy/dt = f(t, y)
262/// - t_span: Tuple (t0, tf) for integration interval
263/// - y0: Initial state
264/// - method: 'RK45' (default), 'RK23', 'DOP853', 'Radau', 'BDF', 'LSODA'
265/// - rtol: Relative tolerance (default 1e-3)
266/// - atol: Absolute tolerance (default 1e-6)
267/// - max_step: Maximum step size (optional)
268///
269/// Returns:
270/// - Dict with 't' (times), 'y' (solutions), 'nfev', 'success', 'message'
271#[pyfunction]
272#[pyo3(signature = (fun, t_span, y0, method="RK45", rtol=1e-3, atol=1e-6, max_step=None))]
273fn solve_ivp_py(
274    py: Python,
275    fun: &Bound<'_, PyAny>,
276    t_span: (f64, f64),
277    y0: Vec<f64>,
278    method: &str,
279    rtol: f64,
280    atol: f64,
281    max_step: Option<f64>,
282) -> PyResult<Py<PyAny>> {
283    let fun_arc = std::sync::Arc::new(fun.clone().unbind());
284    let f = move |t: f64, y: ArrayView1<f64>| -> Array1_17<f64> {
285        let fun_clone = fun_arc.clone();
286        #[allow(deprecated)]
287        Python::with_gil(|py| {
288            let y_vec: Vec<f64> = y.to_vec();
289            let result = fun_clone
290                .bind(py)
291                .call1((t, y_vec))
292                .expect("Failed to call ODE function");
293            let result_vec: Vec<f64> = result.extract().expect("Failed to extract result");
294            Array1_17::from_vec(result_vec)
295        })
296    };
297
298    let ode_method = match method.to_uppercase().as_str() {
299        "EULER" => ODEMethod::Euler,
300        "RK4" => ODEMethod::RK4,
301        "RK23" => ODEMethod::RK23,
302        "RK45" => ODEMethod::RK45,
303        "DOP853" => ODEMethod::DOP853,
304        "BDF" => ODEMethod::Bdf,
305        "RADAU" => ODEMethod::Radau,
306        "LSODA" => ODEMethod::LSODA,
307        _ => ODEMethod::RK45,
308    };
309
310    let options = ODEOptions {
311        method: ode_method,
312        rtol,
313        atol,
314        max_step,
315        ..Default::default()
316    };
317
318    let y0_arr = Array1_17::from_vec(y0);
319    let result = solve_ivp(f, [t_span.0, t_span.1], y0_arr, Some(options))
320        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
321
322    // Convert results to Python
323    let t_vec: Vec<f64> = result.t.to_vec();
324
325    // Convert y (Vec<Array1>) to 2D array
326    let n_points = result.y.len();
327    let n_dim = if n_points > 0 { result.y[0].len() } else { 0 };
328    let mut y_flat = Vec::with_capacity(n_points * n_dim);
329    for arr in &result.y {
330        for &val in arr.iter() {
331            y_flat.push(val);
332        }
333    }
334
335    let dict = PyDict::new(py);
336    dict.set_item("t", scirs_to_numpy_array1(Array1::from_vec(t_vec), py)?)?;
337
338    // Create 2D array for y
339    let y_arr = scirs2_core::python::numpy_compat::Array2::from_shape_vec((n_dim, n_points), {
340        let mut transposed = Vec::with_capacity(n_points * n_dim);
341        for j in 0..n_dim {
342            for i in 0..n_points {
343                transposed.push(y_flat[i * n_dim + j]);
344            }
345        }
346        transposed
347    })
348    .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
349    dict.set_item(
350        "y",
351        scirs2_core::python::numpy_compat::scirs_to_numpy_array2(y_arr, py)?,
352    )?;
353
354    dict.set_item("nfev", result.n_eval)?;
355    dict.set_item("success", result.success)?;
356    dict.set_item("message", result.message)?;
357
358    Ok(dict.into())
359}
360
361/// Python module registration
362pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
363    m.add_function(wrap_pyfunction!(trapezoid_array_py, m)?)?;
364    m.add_function(wrap_pyfunction!(simpson_array_py, m)?)?;
365    m.add_function(wrap_pyfunction!(cumulative_trapezoid_py, m)?)?;
366    m.add_function(wrap_pyfunction!(romberg_array_py, m)?)?;
367    m.add_function(wrap_pyfunction!(quad_py, m)?)?;
368    m.add_function(wrap_pyfunction!(solve_ivp_py, m)?)?;
369
370    Ok(())
371}