use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyDict};
use rayon::prelude::*;
use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
use scirs2_core::{Array1, Array2};
use scirs2_linalg::compat::pinv;
use scirs2_linalg::{
basic_trace, cond,
eig,
lstsq,
matrix_norm,
matrix_rank,
vector_norm,
};
type SvdResult = (Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>);
#[pyfunction]
fn det_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
let binding = a.readonly();
let data = binding.as_array();
scirs2_linalg::det_f64_lapack(&data)
.map_err(|e| PyRuntimeError::new_err(format!("Determinant failed: {}", e)))
}
#[pyfunction]
fn inv_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
let binding = a.readonly();
let data = binding.as_array();
let result = scirs2_linalg::inv_f64_lapack(&data)
.map_err(|e| PyRuntimeError::new_err(format!("Inverse failed: {}", e)))?;
Ok(result.into_pyarray(py).unbind())
}
#[pyfunction]
fn trace_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
let binding = a.readonly();
let data = binding.as_array();
basic_trace(&data).map_err(|e| PyRuntimeError::new_err(format!("Trace failed: {}", e)))
}
#[pyfunction]
fn lu_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
let binding = a.readonly();
let data = binding.as_array();
let (p, l, u) = scirs2_linalg::lu_f64_lapack(&data)
.map_err(|e| PyRuntimeError::new_err(format!("LU decomposition failed: {}", e)))?;
let dict = PyDict::new(py);
dict.set_item("p", p.into_pyarray(py).unbind())?;
dict.set_item("l", l.into_pyarray(py).unbind())?;
dict.set_item("u", u.into_pyarray(py).unbind())?;
Ok(dict.into())
}
#[pyfunction]
fn qr_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
let binding = a.readonly();
let data = binding.as_array();
let (q, r) = scirs2_linalg::qr_f64_lapack(&data)
.map_err(|e| PyRuntimeError::new_err(format!("QR decomposition failed: {}", e)))?;
let dict = PyDict::new(py);
dict.set_item("q", q.into_pyarray(py).unbind())?;
dict.set_item("r", r.into_pyarray(py).unbind())?;
Ok(dict.into())
}
#[pyfunction]
#[pyo3(signature = (a, full_matrices=false))]
fn svd_py(py: Python, a: &Bound<'_, PyArray2<f64>>, full_matrices: bool) -> PyResult<Py<PyAny>> {
let binding = a.readonly();
let data = binding.as_array();
let (u, s, vt) = scirs2_linalg::svd_f64_lapack(&data, full_matrices)
.map_err(|e| PyRuntimeError::new_err(format!("SVD decomposition failed: {}", e)))?;
let dict = PyDict::new(py);
dict.set_item("u", u.into_pyarray(py).unbind())?;
dict.set_item("s", s.into_pyarray(py).unbind())?;
dict.set_item("vt", vt.into_pyarray(py).unbind())?;
Ok(dict.into())
}
#[pyfunction]
fn cholesky_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
let binding = a.readonly();
let data = binding.as_array();
let result = scirs2_linalg::cholesky_f64_lapack(&data)
.map_err(|e| PyRuntimeError::new_err(format!("Cholesky decomposition failed: {}", e)))?;
Ok(result.into_pyarray(py).unbind())
}
#[pyfunction]
fn eig_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
let binding = a.readonly();
let data = binding.as_array();
let (eigenvalues, eigenvectors) = scirs2_linalg::eig_f64_lapack(&data)
.map_err(|e| PyRuntimeError::new_err(format!("Eigenvalue decomposition failed: {}", e)))?;
let eigenvalues_real: Vec<f64> = eigenvalues.iter().map(|c| c.re).collect();
let eigenvalues_imag: Vec<f64> = eigenvalues.iter().map(|c| c.im).collect();
let (nrows, ncols) = eigenvectors.dim();
let mut eigenvectors_real = Array2::zeros((nrows, ncols));
let mut eigenvectors_imag = Array2::zeros((nrows, ncols));
for ((i, j), val) in eigenvectors.indexed_iter() {
eigenvectors_real[[i, j]] = val.re;
eigenvectors_imag[[i, j]] = val.im;
}
let dict = PyDict::new(py);
dict.set_item(
"eigenvalues_real",
Array1::from_vec(eigenvalues_real).into_pyarray(py).unbind(),
)?;
dict.set_item(
"eigenvalues_imag",
Array1::from_vec(eigenvalues_imag).into_pyarray(py).unbind(),
)?;
dict.set_item(
"eigenvectors_real",
eigenvectors_real.into_pyarray(py).unbind(),
)?;
dict.set_item(
"eigenvectors_imag",
eigenvectors_imag.into_pyarray(py).unbind(),
)?;
Ok(dict.into())
}
#[pyfunction]
fn eigh_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
let binding = a.readonly();
let data = binding.as_array();
let (eigenvalues, eigenvectors) = scirs2_linalg::eigh_f64_lapack(&data).map_err(|e| {
PyRuntimeError::new_err(format!("Symmetric eigenvalue decomposition failed: {}", e))
})?;
let dict = PyDict::new(py);
dict.set_item("eigenvalues", eigenvalues.into_pyarray(py).unbind())?;
dict.set_item("eigenvectors", eigenvectors.into_pyarray(py).unbind())?;
Ok(dict.into())
}
#[pyfunction]
fn eigvals_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
let binding = a.readonly();
let data = binding.as_array();
let (eigenvalues, _eigenvectors) = eig(&data, None)
.map_err(|e| PyRuntimeError::new_err(format!("Eigenvalue computation failed: {}", e)))?;
let real: Vec<f64> = eigenvalues.iter().map(|c| c.re).collect();
let imag: Vec<f64> = eigenvalues.iter().map(|c| c.im).collect();
let dict = PyDict::new(py);
dict.set_item("real", Array1::from_vec(real).into_pyarray(py).unbind())?;
dict.set_item("imag", Array1::from_vec(imag).into_pyarray(py).unbind())?;
Ok(dict.into())
}
#[pyfunction]
fn solve_py(
py: Python,
a: &Bound<'_, PyArray2<f64>>,
b: &Bound<'_, PyArray1<f64>>,
) -> PyResult<Py<PyArray1<f64>>> {
let a_binding = a.readonly();
let a_data = a_binding.as_array();
let b_binding = b.readonly();
let b_data = b_binding.as_array();
let result = scirs2_linalg::solve_f64_lapack(&a_data, &b_data)
.map_err(|e| PyRuntimeError::new_err(format!("Linear solve failed: {}", e)))?;
Ok(result.into_pyarray(py).unbind())
}
#[pyfunction]
fn lstsq_py(
py: Python,
a: &Bound<'_, PyArray2<f64>>,
b: &Bound<'_, PyArray1<f64>>,
) -> PyResult<Py<PyAny>> {
let a_binding = a.readonly();
let a_data = a_binding.as_array();
let b_binding = b.readonly();
let b_data = b_binding.as_array();
let result = lstsq(&a_data, &b_data, None)
.map_err(|e| PyRuntimeError::new_err(format!("Least squares failed: {}", e)))?;
let dict = PyDict::new(py);
dict.set_item("solution", result.x.into_pyarray(py).unbind())?;
dict.set_item("residuals", result.residuals)?;
dict.set_item("rank", result.rank)?;
dict.set_item("singular_values", result.s.into_pyarray(py).unbind())?;
Ok(dict.into())
}
#[pyfunction]
#[pyo3(signature = (a, ord="fro"))]
fn matrix_norm_py(a: &Bound<'_, PyArray2<f64>>, ord: &str) -> PyResult<f64> {
let binding = a.readonly();
let data = binding.as_array();
matrix_norm(&data, ord, None)
.map_err(|e| PyRuntimeError::new_err(format!("Matrix norm failed: {}", e)))
}
#[pyfunction]
#[pyo3(signature = (x, ord=2))]
fn vector_norm_py(x: &Bound<'_, PyArray1<f64>>, ord: usize) -> PyResult<f64> {
let binding = x.readonly();
let data = binding.as_array();
vector_norm(&data, ord)
.map_err(|e| PyRuntimeError::new_err(format!("Vector norm failed: {}", e)))
}
#[pyfunction]
fn cond_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
let binding = a.readonly();
let data = binding.as_array();
cond(&data, None, None)
.map_err(|e| PyRuntimeError::new_err(format!("Condition number failed: {}", e)))
}
#[pyfunction]
#[pyo3(signature = (a, tol=None))]
fn matrix_rank_py(a: &Bound<'_, PyArray2<f64>>, tol: Option<f64>) -> PyResult<usize> {
let binding = a.readonly();
let data = binding.as_array();
matrix_rank(&data, tol, None)
.map_err(|e| PyRuntimeError::new_err(format!("Matrix rank failed: {}", e)))
}
#[pyfunction]
#[pyo3(signature = (a, rcond=None))]
fn pinv_py(
py: Python,
a: &Bound<'_, PyArray2<f64>>,
rcond: Option<f64>,
) -> PyResult<Py<PyArray2<f64>>> {
let binding = a.readonly();
let data = binding.as_array();
let _ = rcond; let result =
pinv(&data).map_err(|e| PyRuntimeError::new_err(format!("Pseudoinverse failed: {}", e)))?;
Ok(result.into_pyarray(py).unbind())
}
fn vec2d_to_array2(rows: &[Vec<f64>]) -> Result<Array2<f64>, String> {
if rows.is_empty() {
return Err("Matrix has no rows".to_string());
}
let nrows = rows.len();
let ncols = rows[0].len();
for (i, row) in rows.iter().enumerate() {
if row.len() != ncols {
return Err(format!(
"Row {} has {} columns, expected {}",
i,
row.len(),
ncols
));
}
}
let flat: Vec<f64> = rows.iter().flat_map(|r| r.iter().cloned()).collect();
Array2::from_shape_vec((nrows, ncols), flat).map_err(|e| format!("Shape error: {}", e))
}
fn array2_to_vec2d(a: &Array2<f64>) -> Vec<Vec<f64>> {
let (nrows, ncols) = a.dim();
(0..nrows)
.map(|i| (0..ncols).map(|j| a[[i, j]]).collect())
.collect()
}
#[pyfunction]
fn batch_matmul_py(
a_list: Vec<Vec<Vec<f64>>>,
b_list: Vec<Vec<Vec<f64>>>,
) -> PyResult<Vec<Vec<Vec<f64>>>> {
if a_list.len() != b_list.len() {
return Err(PyRuntimeError::new_err(format!(
"a_list length {} does not match b_list length {}",
a_list.len(),
b_list.len()
)));
}
if a_list.is_empty() {
return Ok(vec![]);
}
let results: Vec<Result<Vec<Vec<f64>>, String>> = a_list
.par_iter()
.zip(b_list.par_iter())
.map(|(a_rows, b_rows)| {
let a = vec2d_to_array2(a_rows)?;
let b = vec2d_to_array2(b_rows)?;
let (_, a_cols) = a.dim();
let (b_rows_n, _) = b.dim();
if a_cols != b_rows_n {
return Err(format!(
"Incompatible shapes for matmul: inner dims {} != {}",
a_cols, b_rows_n
));
}
let c = a.dot(&b);
Ok(array2_to_vec2d(&c))
})
.collect();
results
.into_iter()
.map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_matmul failed: {}", e))))
.collect()
}
#[pyfunction]
fn batch_svd_py(matrices: Vec<Vec<Vec<f64>>>) -> PyResult<Vec<SvdResult>> {
if matrices.is_empty() {
return Ok(vec![]);
}
let results: Vec<Result<SvdResult, String>> = matrices
.par_iter()
.map(|mat_rows| {
let mat = vec2d_to_array2(mat_rows)?;
let (u, s, vt) = scirs2_linalg::svd_f64_lapack(&mat.view(), false)
.map_err(|e| format!("SVD failed: {}", e))?;
Ok((array2_to_vec2d(&u), s.to_vec(), array2_to_vec2d(&vt)))
})
.collect();
results
.into_iter()
.map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_svd failed: {}", e))))
.collect()
}
#[pyfunction]
fn batch_solve_py(a_list: Vec<Vec<Vec<f64>>>, b_list: Vec<Vec<f64>>) -> PyResult<Vec<Vec<f64>>> {
if a_list.len() != b_list.len() {
return Err(PyRuntimeError::new_err(format!(
"a_list length {} does not match b_list length {}",
a_list.len(),
b_list.len()
)));
}
if a_list.is_empty() {
return Ok(vec![]);
}
let results: Vec<Result<Vec<f64>, String>> = a_list
.par_iter()
.zip(b_list.par_iter())
.map(|(a_rows, b_vec)| {
let a = vec2d_to_array2(a_rows)?;
let b = Array1::from_vec(b_vec.clone());
let x = scirs2_linalg::solve_f64_lapack(&a.view(), &b.view())
.map_err(|e| format!("Solve failed: {}", e))?;
Ok(x.to_vec())
})
.collect();
results
.into_iter()
.map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_solve failed: {}", e))))
.collect()
}
#[pyfunction]
#[pyo3(signature = (matrices, ord=None))]
fn batch_matrix_norm_py(matrices: Vec<Vec<Vec<f64>>>, ord: Option<&str>) -> PyResult<Vec<f64>> {
if matrices.is_empty() {
return Ok(vec![]);
}
let norm_type = ord.unwrap_or("fro");
match norm_type {
"fro" | "nuc" | "1" | "inf" => {}
other => {
return Err(PyRuntimeError::new_err(format!(
"Unknown norm type '{}'. Supported: fro, nuc, 1, inf",
other
)));
}
}
let results: Vec<Result<f64, String>> = matrices
.par_iter()
.map(|mat_rows| {
let mat = vec2d_to_array2(mat_rows)?;
scirs2_linalg::matrix_norm(&mat.view(), norm_type, None)
.map_err(|e| format!("Matrix norm failed: {}", e))
})
.collect();
results
.into_iter()
.map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_matrix_norm failed: {}", e))))
.collect()
}
pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(det_py, m)?)?;
m.add_function(wrap_pyfunction!(inv_py, m)?)?;
m.add_function(wrap_pyfunction!(trace_py, m)?)?;
m.add_function(wrap_pyfunction!(lu_py, m)?)?;
m.add_function(wrap_pyfunction!(qr_py, m)?)?;
m.add_function(wrap_pyfunction!(svd_py, m)?)?;
m.add_function(wrap_pyfunction!(cholesky_py, m)?)?;
m.add_function(wrap_pyfunction!(eig_py, m)?)?;
m.add_function(wrap_pyfunction!(eigh_py, m)?)?;
m.add_function(wrap_pyfunction!(eigvals_py, m)?)?;
m.add_function(wrap_pyfunction!(solve_py, m)?)?;
m.add_function(wrap_pyfunction!(lstsq_py, m)?)?;
m.add_function(wrap_pyfunction!(matrix_norm_py, m)?)?;
m.add_function(wrap_pyfunction!(vector_norm_py, m)?)?;
m.add_function(wrap_pyfunction!(cond_py, m)?)?;
m.add_function(wrap_pyfunction!(matrix_rank_py, m)?)?;
m.add_function(wrap_pyfunction!(pinv_py, m)?)?;
m.add_function(wrap_pyfunction!(batch_matmul_py, m)?)?;
m.add_function(wrap_pyfunction!(batch_svd_py, m)?)?;
m.add_function(wrap_pyfunction!(batch_solve_py, m)?)?;
m.add_function(wrap_pyfunction!(batch_matrix_norm_py, m)?)?;
Ok(())
}