1use pyo3::prelude::*;
8use pyo3::types::PyDict;
9use scirs2_core::ndarray::{Array1 as Array1_17, ArrayView1};
10use scirs2_core::python::numpy_compat::{scirs_to_numpy_array1, Array1};
11use scirs2_numpy::{PyArray1, PyReadonlyArray1};
12
13use scirs2_integrate::ode::{solve_ivp, ODEMethod, ODEOptions};
14use scirs2_integrate::quad::{quad, QuadOptions};
15
16#[pyfunction]
24#[pyo3(signature = (y, x=None, dx=1.0))]
25fn trapezoid_array_py(
26 y: PyReadonlyArray1<f64>,
27 x: Option<PyReadonlyArray1<f64>>,
28 dx: f64,
29) -> PyResult<f64> {
30 let y_arr = y.as_array();
31
32 if y_arr.len() < 2 {
33 return Err(pyo3::exceptions::PyValueError::new_err(
34 "Need at least 2 points",
35 ));
36 }
37
38 let result = if let Some(x_py) = x {
39 let x_arr = x_py.as_array();
40 if x_arr.len() != y_arr.len() {
41 return Err(pyo3::exceptions::PyValueError::new_err(
42 "x and y must have same length",
43 ));
44 }
45 let mut total = 0.0;
47 for i in 0..y_arr.len() - 1 {
48 let dx = x_arr[i + 1] - x_arr[i];
49 total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
50 }
51 total
52 } else {
53 let mut total = 0.0;
55 for i in 0..y_arr.len() - 1 {
56 total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
57 }
58 total
59 };
60
61 Ok(result)
62}
63
64#[pyfunction]
68#[pyo3(signature = (y, x=None, dx=1.0))]
69fn simpson_array_py(
70 y: PyReadonlyArray1<f64>,
71 x: Option<PyReadonlyArray1<f64>>,
72 dx: f64,
73) -> PyResult<f64> {
74 let y_arr = y.as_array();
75 let n = y_arr.len();
76
77 if n < 3 {
78 return Err(pyo3::exceptions::PyValueError::new_err(
79 "Need at least 3 points",
80 ));
81 }
82
83 let result = if let Some(x_py) = x {
85 let x_arr = x_py.as_array();
86 if x_arr.len() != y_arr.len() {
87 return Err(pyo3::exceptions::PyValueError::new_err(
88 "x and y must have same length",
89 ));
90 }
91
92 let mut total = 0.0;
93 let mut i = 0;
94 while i + 2 < n {
95 let h = (x_arr[i + 2] - x_arr[i]) / 2.0;
96 total += h / 3.0 * (y_arr[i] + 4.0 * y_arr[i + 1] + y_arr[i + 2]);
97 i += 2;
98 }
99 if i + 1 < n {
101 let h = x_arr[i + 1] - x_arr[i];
102 total += 0.5 * (y_arr[i] + y_arr[i + 1]) * h;
103 }
104 total
105 } else {
106 let mut total = 0.0;
107 let mut i = 0;
108 while i + 2 < n {
109 total += dx / 3.0 * (y_arr[i] + 4.0 * y_arr[i + 1] + y_arr[i + 2]);
110 i += 2;
111 }
112 if i + 1 < n {
114 total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
115 }
116 total
117 };
118
119 Ok(result)
120}
121
122#[pyfunction]
126#[pyo3(signature = (y, x=None, dx=1.0, initial=None))]
127fn cumulative_trapezoid_py(
128 py: Python,
129 y: PyReadonlyArray1<f64>,
130 x: Option<PyReadonlyArray1<f64>>,
131 dx: f64,
132 initial: Option<f64>,
133) -> PyResult<Py<PyArray1<f64>>> {
134 let y_arr = y.as_array();
135
136 if y_arr.len() < 2 {
137 return Err(pyo3::exceptions::PyValueError::new_err(
138 "Need at least 2 points",
139 ));
140 }
141
142 let n = y_arr.len();
143 let has_initial = initial.is_some();
144 let mut result = Vec::with_capacity(if has_initial { n } else { n - 1 });
145
146 if let Some(init) = initial {
147 result.push(init);
148 }
149
150 let mut cumsum = initial.unwrap_or(0.0);
151
152 if let Some(x_py) = x {
153 let x_arr = x_py.as_array();
154 for i in 0..y_arr.len() - 1 {
155 let dx_i = x_arr[i + 1] - x_arr[i];
156 cumsum += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx_i;
157 result.push(cumsum);
158 }
159 } else {
160 for i in 0..y_arr.len() - 1 {
161 cumsum += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
162 result.push(cumsum);
163 }
164 }
165
166 scirs_to_numpy_array1(Array1::from_vec(result), py)
167}
168
169#[pyfunction]
171fn romberg_array_py(y: PyReadonlyArray1<f64>, dx: f64) -> PyResult<f64> {
172 let y_arr = y.as_array();
173 let n = y_arr.len();
174
175 if n < 3 {
176 return Err(pyo3::exceptions::PyValueError::new_err(
177 "Need at least 3 points",
178 ));
179 }
180
181 let mut total = 0.0;
184 let mut i = 0;
185 while i + 2 < n {
186 total += dx / 3.0 * (y_arr[i] + 4.0 * y_arr[i + 1] + y_arr[i + 2]);
187 i += 2;
188 }
189 if i + 1 < n {
190 total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
191 }
192
193 Ok(total)
194}
195
196#[pyfunction]
213#[pyo3(signature = (fun, a, b, epsabs=1.49e-8, epsrel=1.49e-8, maxiter=500))]
214fn quad_py(
215 py: Python,
216 fun: &Bound<'_, PyAny>,
217 a: f64,
218 b: f64,
219 epsabs: f64,
220 epsrel: f64,
221 maxiter: usize,
222) -> PyResult<Py<PyAny>> {
223 let fun_clone = fun.clone().unbind();
224 let f = |x: f64| -> f64 {
225 #[allow(deprecated)]
226 Python::with_gil(|py| {
227 let result = fun_clone
228 .bind(py)
229 .call1((x,))
230 .expect("Failed to call function");
231 result.extract().expect("Failed to extract result")
232 })
233 };
234
235 let options = QuadOptions {
236 abs_tol: epsabs,
237 rel_tol: epsrel,
238 max_evals: maxiter,
239 ..Default::default()
240 };
241
242 let result = quad(f, a, b, Some(options))
243 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
244
245 let dict = PyDict::new(py);
246 dict.set_item("value", result.value)?;
247 dict.set_item("error", result.abs_error)?;
248 dict.set_item("neval", result.n_evals)?;
249 dict.set_item("success", result.converged)?;
250
251 Ok(dict.into())
252}
253
254#[pyfunction]
272#[pyo3(signature = (fun, t_span, y0, method="RK45", rtol=1e-3, atol=1e-6, max_step=None))]
273fn solve_ivp_py(
274 py: Python,
275 fun: &Bound<'_, PyAny>,
276 t_span: (f64, f64),
277 y0: Vec<f64>,
278 method: &str,
279 rtol: f64,
280 atol: f64,
281 max_step: Option<f64>,
282) -> PyResult<Py<PyAny>> {
283 let fun_arc = std::sync::Arc::new(fun.clone().unbind());
284 let f = move |t: f64, y: ArrayView1<f64>| -> Array1_17<f64> {
285 let fun_clone = fun_arc.clone();
286 #[allow(deprecated)]
287 Python::with_gil(|py| {
288 let y_vec: Vec<f64> = y.to_vec();
289 let result = fun_clone
290 .bind(py)
291 .call1((t, y_vec))
292 .expect("Failed to call ODE function");
293 let result_vec: Vec<f64> = result.extract().expect("Failed to extract result");
294 Array1_17::from_vec(result_vec)
295 })
296 };
297
298 let ode_method = match method.to_uppercase().as_str() {
299 "EULER" => ODEMethod::Euler,
300 "RK4" => ODEMethod::RK4,
301 "RK23" => ODEMethod::RK23,
302 "RK45" => ODEMethod::RK45,
303 "DOP853" => ODEMethod::DOP853,
304 "BDF" => ODEMethod::Bdf,
305 "RADAU" => ODEMethod::Radau,
306 "LSODA" => ODEMethod::LSODA,
307 _ => ODEMethod::RK45,
308 };
309
310 let options = ODEOptions {
311 method: ode_method,
312 rtol,
313 atol,
314 max_step,
315 ..Default::default()
316 };
317
318 let y0_arr = Array1_17::from_vec(y0);
319 let result = solve_ivp(f, [t_span.0, t_span.1], y0_arr, Some(options))
320 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
321
322 let t_vec: Vec<f64> = result.t.to_vec();
324
325 let n_points = result.y.len();
327 let n_dim = if n_points > 0 { result.y[0].len() } else { 0 };
328 let mut y_flat = Vec::with_capacity(n_points * n_dim);
329 for arr in &result.y {
330 for &val in arr.iter() {
331 y_flat.push(val);
332 }
333 }
334
335 let dict = PyDict::new(py);
336 dict.set_item("t", scirs_to_numpy_array1(Array1::from_vec(t_vec), py)?)?;
337
338 let y_arr = scirs2_core::python::numpy_compat::Array2::from_shape_vec((n_dim, n_points), {
340 let mut transposed = Vec::with_capacity(n_points * n_dim);
341 for j in 0..n_dim {
342 for i in 0..n_points {
343 transposed.push(y_flat[i * n_dim + j]);
344 }
345 }
346 transposed
347 })
348 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
349 dict.set_item(
350 "y",
351 scirs2_core::python::numpy_compat::scirs_to_numpy_array2(y_arr, py)?,
352 )?;
353
354 dict.set_item("nfev", result.n_eval)?;
355 dict.set_item("success", result.success)?;
356 dict.set_item("message", result.message)?;
357
358 Ok(dict.into())
359}
360
361pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
363 m.add_function(wrap_pyfunction!(trapezoid_array_py, m)?)?;
364 m.add_function(wrap_pyfunction!(simpson_array_py, m)?)?;
365 m.add_function(wrap_pyfunction!(cumulative_trapezoid_py, m)?)?;
366 m.add_function(wrap_pyfunction!(romberg_array_py, m)?)?;
367 m.add_function(wrap_pyfunction!(quad_py, m)?)?;
368 m.add_function(wrap_pyfunction!(solve_ivp_py, m)?)?;
369
370 Ok(())
371}