1use crate::error::SciRS2Error;
24use pyo3::prelude::*;
25use pyo3_async_runtimes;
26use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyUntypedArrayMethods};
27
28#[pyfunction]
33pub fn fft_async<'py>(
34 py: Python<'py>,
35 data: &Bound<'_, PyArray1<f64>>,
36) -> PyResult<Bound<'py, PyAny>> {
37 let data_vec: Vec<f64> = {
38 let binding = data.readonly();
39 let arr = binding.as_array();
40 arr.iter().cloned().collect()
41 };
42
43 pyo3_async_runtimes::tokio::future_into_py(py, async move {
44 let (real_part, imag_part): (Vec<f64>, Vec<f64>) = tokio::task::spawn_blocking(move || {
47 use scirs2_core::Complex64;
48 use scirs2_fft::fft;
49
50 let result: Vec<Complex64> = fft(data_vec.as_slice(), None)
51 .map_err(|e| SciRS2Error::ComputationError(format!("FFT failed: {}", e)))?;
52
53 let real: Vec<f64> = result.iter().map(|c| c.re).collect();
54 let imag: Vec<f64> = result.iter().map(|c| c.im).collect();
55 Ok::<(Vec<f64>, Vec<f64>), SciRS2Error>((real, imag))
56 })
57 .await
58 .map_err(|e| SciRS2Error::RuntimeError(format!("Task join error: {}", e)))??;
59
60 let py_result: Py<PyAny> = Python::attach(|py| {
62 use pyo3::types::PyDict;
63 use scirs2_core::Array1;
64 let dict = PyDict::new(py);
66 let real_arr: Array1<f64> = Array1::from_vec(real_part);
67 let imag_arr: Array1<f64> = Array1::from_vec(imag_part);
68 dict.set_item("real", real_arr.into_pyarray(py))?;
69 dict.set_item("imag", imag_arr.into_pyarray(py))?;
70 Ok::<Py<PyAny>, PyErr>(dict.into_any().unbind())
71 })?;
72
73 Ok(py_result)
74 })
75}
76
77#[pyfunction]
82pub fn svd_async<'py>(
83 py: Python<'py>,
84 matrix: &Bound<'_, PyArray2<f64>>,
85 full_matrices: Option<bool>,
86) -> PyResult<Bound<'py, PyAny>> {
87 let matrix_shape = matrix.shape().to_vec();
88 let matrix_vec: Vec<f64> = {
89 let binding = matrix.readonly();
90 let arr = binding.as_array();
91 arr.iter().cloned().collect()
92 };
93 let full_matrices = full_matrices.unwrap_or(true);
94
95 pyo3_async_runtimes::tokio::future_into_py(py, async move {
96 let result = tokio::task::spawn_blocking(move || {
98 use scirs2_core::Array2;
99 use scirs2_linalg::svd_f64_lapack;
100
101 let arr = Array2::from_shape_vec((matrix_shape[0], matrix_shape[1]), matrix_vec)
102 .map_err(|e| SciRS2Error::ArrayError(format!("Array reshape failed: {}", e)))?;
103
104 svd_f64_lapack(&arr.view(), full_matrices)
105 .map_err(|e| SciRS2Error::ComputationError(format!("SVD failed: {}", e)))
106 })
107 .await
108 .map_err(|e| SciRS2Error::RuntimeError(format!("Task join error: {}", e)))??;
109
110 let py_result: Py<PyAny> = Python::attach(|py| {
112 use pyo3::types::PyDict;
113 let dict = PyDict::new(py);
114 dict.set_item("U", result.0.into_pyarray(py))?;
115 dict.set_item("S", result.1.into_pyarray(py))?;
116 dict.set_item("Vt", result.2.into_pyarray(py))?;
117 Ok::<Py<PyAny>, PyErr>(dict.into_any().unbind())
118 })?;
119
120 Ok(py_result)
121 })
122}
123
124#[pyfunction]
126pub fn qr_async<'py>(
127 py: Python<'py>,
128 matrix: &Bound<'_, PyArray2<f64>>,
129) -> PyResult<Bound<'py, PyAny>> {
130 let matrix_shape = matrix.shape().to_vec();
131 let matrix_vec: Vec<f64> = {
132 let binding = matrix.readonly();
133 let arr = binding.as_array();
134 arr.iter().cloned().collect()
135 };
136
137 pyo3_async_runtimes::tokio::future_into_py(py, async move {
138 let result = tokio::task::spawn_blocking(move || {
139 use scirs2_core::Array2;
140 use scirs2_linalg::qr_f64_lapack;
141
142 let arr = Array2::from_shape_vec((matrix_shape[0], matrix_shape[1]), matrix_vec)
143 .map_err(|e| SciRS2Error::ArrayError(format!("Array reshape failed: {}", e)))?;
144
145 qr_f64_lapack(&arr.view())
146 .map_err(|e| SciRS2Error::ComputationError(format!("QR failed: {}", e)))
147 })
148 .await
149 .map_err(|e| SciRS2Error::RuntimeError(format!("Task join error: {}", e)))??;
150
151 let py_result: Py<PyAny> = Python::attach(|py| {
152 use pyo3::types::PyDict;
153 let dict = PyDict::new(py);
154 dict.set_item("Q", result.0.into_pyarray(py))?;
155 dict.set_item("R", result.1.into_pyarray(py))?;
156 Ok::<Py<PyAny>, PyErr>(dict.into_any().unbind())
157 })?;
158
159 Ok(py_result)
160 })
161}
162
163#[pyfunction]
165pub fn quad_async<'py>(
166 py: Python<'py>,
167 func: Py<PyAny>,
168 a: f64,
169 b: f64,
170 epsabs: Option<f64>,
171 epsrel: Option<f64>,
172) -> PyResult<Bound<'py, PyAny>> {
173 pyo3_async_runtimes::tokio::future_into_py(py, async move {
174 let result: (f64, f64) = tokio::task::spawn_blocking(move || {
175 Python::attach(|py| {
176 use scirs2_integrate::quad::{quad, QuadOptions};
177
178 let abs_tol = epsabs.unwrap_or(1e-8);
179 let rel_tol = epsrel.unwrap_or(1e-8);
180
181 let integrand = |x: f64| -> f64 {
183 func.call1(py, (x,))
184 .and_then(|result| result.extract::<f64>(py))
185 .unwrap_or(f64::NAN)
186 };
187
188 let options = QuadOptions {
189 abs_tol,
190 rel_tol,
191 ..Default::default()
192 };
193
194 let result = quad(integrand, a, b, Some(options)).map_err(|e| {
195 PyErr::from(SciRS2Error::ComputationError(format!(
196 "Integration failed: {}",
197 e
198 )))
199 })?;
200
201 Ok::<(f64, f64), PyErr>((result.value, result.abs_error))
202 })
203 })
204 .await
205 .map_err(|e| SciRS2Error::RuntimeError(format!("Task join error: {}", e)))??;
206
207 let py_result: Py<PyAny> = Python::attach(|py| {
208 use pyo3::types::PyDict;
209 let dict = PyDict::new(py);
210 dict.set_item("value", result.0)?;
211 dict.set_item("error", result.1)?;
212 Ok::<Py<PyAny>, PyErr>(dict.into_any().unbind())
213 })?;
214
215 Ok(py_result)
216 })
217}
218
219#[pyfunction]
221pub fn minimize_async<'py>(
222 py: Python<'py>,
223 func: Py<PyAny>,
224 x0: &Bound<'_, PyArray1<f64>>,
225 method: Option<String>,
226 maxiter: Option<usize>,
227) -> PyResult<Bound<'py, PyAny>> {
228 let x0_vec: Vec<f64> = {
229 let binding = x0.readonly();
230 let arr = binding.as_array();
231 arr.iter().cloned().collect()
232 };
233
234 pyo3_async_runtimes::tokio::future_into_py(py, async move {
235 let result: (Vec<f64>, f64, usize) = tokio::task::spawn_blocking(move || {
236 Python::attach(|py| {
237 use scirs2_core::ndarray::ArrayView1;
238 use scirs2_optimize::unconstrained::{minimize, Method};
239
240 let objective = |x: &ArrayView1<f64>| -> f64 {
242 let x_slice = x.as_slice().unwrap_or(&[]);
243 let x_py = match pyo3::types::PyList::new(py, x_slice) {
244 Ok(list) => list,
245 Err(_) => return f64::NAN,
246 };
247 func.call1(py, (x_py,))
248 .and_then(|r| r.extract::<f64>(py))
249 .unwrap_or(f64::NAN)
250 };
251
252 let opt_method = match method.as_deref() {
253 Some("BFGS") => Method::BFGS,
254 Some("Newton") | Some("NewtonCG") => Method::NewtonCG,
255 Some("GradientDescent") | Some("CG") => Method::CG,
256 Some("NelderMead") => Method::NelderMead,
257 Some("LBFGS") => Method::LBFGS,
258 _ => Method::BFGS,
259 };
260
261 use scirs2_optimize::unconstrained::Options;
262 let options = Options {
263 max_iter: maxiter.unwrap_or(1000),
264 ..Default::default()
265 };
266
267 let result =
268 minimize(objective, &x0_vec, opt_method, Some(options)).map_err(|e| {
269 PyErr::from(SciRS2Error::ComputationError(format!(
270 "Optimization failed: {}",
271 e
272 )))
273 })?;
274
275 let x_vec = result.x.to_vec();
276 let fun_val: f64 = result.fun;
277 let nit = result.nit;
278 Ok::<(Vec<f64>, f64, usize), PyErr>((x_vec, fun_val, nit))
279 })
280 })
281 .await
282 .map_err(|e| SciRS2Error::RuntimeError(format!("Task join error: {}", e)))??;
283
284 let py_result: Py<PyAny> = Python::attach(|py| {
285 use pyo3::types::PyDict;
286 use scirs2_core::Array1;
287
288 let dict = PyDict::new(py);
289 let x = Array1::from_vec(result.0);
290 dict.set_item("x", x.into_pyarray(py))?;
291 dict.set_item("fun", result.1)?;
292 dict.set_item("nit", result.2)?;
293 Ok::<Py<PyAny>, PyErr>(dict.into_any().unbind())
294 })?;
295
296 Ok(py_result)
297 })
298}
299
300pub fn register_async_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
302 m.add_function(wrap_pyfunction!(fft_async, m)?)?;
303 m.add_function(wrap_pyfunction!(svd_async, m)?)?;
304 m.add_function(wrap_pyfunction!(qr_async, m)?)?;
305 m.add_function(wrap_pyfunction!(quad_async, m)?)?;
306 m.add_function(wrap_pyfunction!(minimize_async, m)?)?;
307 Ok(())
308}