use crate::python::array::PyArray;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
#[pyfunction]
fn matmul(a: &PyArray, b: &PyArray) -> PyResult<PyArray> {
let result = a
.inner
.matmul(&b.inner)
.map_err(|e| PyValueError::new_err(format!("Matrix multiplication failed: {}", e)))?;
Ok(PyArray { inner: result })
}
#[pyfunction]
fn dot(a: &PyArray, b: &PyArray) -> PyResult<f64> {
let result = a
.inner
.dot(&b.inner)
.map_err(|e| PyValueError::new_err(format!("Dot product failed: {}", e)))?;
Ok(result)
}
#[pyfunction]
fn matvec(a: &PyArray, b: &PyArray) -> PyResult<PyArray> {
matmul(a, b)
}
#[pyfunction]
fn det(a: &PyArray) -> PyResult<f64> {
let shape = a.inner.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(PyValueError::new_err("Input must be a square matrix"));
}
#[cfg(feature = "lapack")]
{
use crate::linalg;
linalg::det(&a.inner)
.map_err(|e| PyValueError::new_err(format!("Determinant calculation failed: {}", e)))
}
#[cfg(not(feature = "lapack"))]
{
Err(PyValueError::new_err(
"Determinant calculation requires the 'lapack' feature. Rebuild with --features lapack",
))
}
}
#[pyfunction]
fn trace(a: &PyArray) -> PyResult<f64> {
let shape = a.inner.shape();
if shape.len() != 2 {
return Err(PyValueError::new_err("Input must be a 2D array"));
}
let data = a.inner.to_vec();
let n = shape[0].min(shape[1]);
let mut sum = 0.0;
for i in 0..n {
sum += data[i * shape[1] + i];
}
Ok(sum)
}
#[pyfunction]
fn inv(a: &PyArray) -> PyResult<PyArray> {
let shape = a.inner.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(PyValueError::new_err("Input must be a square matrix"));
}
#[cfg(feature = "lapack")]
{
use crate::linalg;
let result = linalg::inv(&a.inner)
.map_err(|e| PyValueError::new_err(format!("Matrix inversion failed: {}", e)))?;
Ok(PyArray { inner: result })
}
#[cfg(not(feature = "lapack"))]
{
Err(PyValueError::new_err(
"Matrix inversion requires the 'lapack' feature. Rebuild with --features lapack",
))
}
}
#[pyfunction]
fn solve(a: &PyArray, b: &PyArray) -> PyResult<PyArray> {
#[cfg(feature = "lapack")]
{
use crate::linalg::solve as linalg_solve;
let result = linalg_solve(&a.inner, &b.inner)
.map_err(|e| PyValueError::new_err(format!("Linear system solve failed: {}", e)))?;
Ok(PyArray { inner: result })
}
#[cfg(not(feature = "lapack"))]
{
Err(PyValueError::new_err(
"Linear system solving requires the 'lapack' feature. Rebuild with --features lapack",
))
}
}
#[pyfunction]
fn eigvals(a: &PyArray) -> PyResult<PyArray> {
let shape = a.inner.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(PyValueError::new_err("Input must be a square matrix"));
}
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
{
use crate::linalg_accelerated::eigvals as linalg_eigvals;
let complex_result = linalg_eigvals(&a.inner)
.map_err(|e| PyValueError::new_err(format!("Eigenvalue calculation failed: {}", e)))?;
let data = complex_result.to_vec();
let magnitudes: Vec<f64> = data.iter().map(|c| c.norm()).collect();
Ok(PyArray {
inner: crate::array::Array::from_vec(magnitudes),
})
}
#[cfg(not(all(feature = "matrix_decomp", feature = "lapack")))]
{
Err(PyValueError::new_err("Eigenvalue calculation requires the 'matrix_decomp' and 'lapack' features. Rebuild with --features matrix_decomp,lapack"))
}
}
#[pyfunction]
fn eig(a: &PyArray) -> PyResult<(PyArray, PyArray)> {
let shape = a.inner.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(PyValueError::new_err("Input must be a square matrix"));
}
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
{
use crate::linalg;
let (vals, vecs) = linalg::eig(&a.inner, None)
.map_err(|e| PyValueError::new_err(format!("Eigendecomposition failed: {}", e)))?;
Ok((PyArray { inner: vals }, PyArray { inner: vecs }))
}
#[cfg(not(all(feature = "matrix_decomp", feature = "lapack")))]
{
Err(PyValueError::new_err("Eigendecomposition requires the 'matrix_decomp' and 'lapack' features. Rebuild with --features matrix_decomp,lapack"))
}
}
#[pyfunction]
fn svd(a: &PyArray, full_matrices: Option<bool>) -> PyResult<(PyArray, PyArray, PyArray)> {
let _full_matrices = full_matrices.unwrap_or(true);
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
{
use crate::linalg;
let (u, s, vt) = linalg::svd(&a.inner)
.map_err(|e| PyValueError::new_err(format!("SVD failed: {}", e)))?;
Ok((
PyArray { inner: u },
PyArray { inner: s },
PyArray { inner: vt },
))
}
#[cfg(not(all(feature = "matrix_decomp", feature = "lapack")))]
{
Err(PyValueError::new_err("SVD requires the 'matrix_decomp' and 'lapack' features. Rebuild with --features matrix_decomp,lapack"))
}
}
#[pyfunction]
fn qr(a: &PyArray) -> PyResult<(PyArray, PyArray)> {
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
{
use crate::linalg;
let (q, r) = linalg::qr(&a.inner)
.map_err(|e| PyValueError::new_err(format!("QR decomposition failed: {}", e)))?;
Ok((PyArray { inner: q }, PyArray { inner: r }))
}
#[cfg(not(all(feature = "matrix_decomp", feature = "lapack")))]
{
Err(PyValueError::new_err("QR decomposition requires the 'matrix_decomp' and 'lapack' features. Rebuild with --features matrix_decomp,lapack"))
}
}
#[pyfunction]
fn cholesky(a: &PyArray) -> PyResult<PyArray> {
let shape = a.inner.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(PyValueError::new_err(
"Input must be a square matrix for Cholesky decomposition",
));
}
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
{
use crate::linalg;
let result = linalg::cholesky(&a.inner)
.map_err(|e| PyValueError::new_err(format!("Cholesky decomposition failed: {}", e)))?;
Ok(PyArray { inner: result })
}
#[cfg(not(all(feature = "matrix_decomp", feature = "lapack")))]
{
Err(PyValueError::new_err("Cholesky decomposition requires the 'matrix_decomp' and 'lapack' features. Rebuild with --features matrix_decomp,lapack"))
}
}
#[pyfunction]
fn lu(a: &PyArray) -> PyResult<(PyArray, PyArray, PyArray)> {
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
{
use crate::new_modules::matrix_decomp::lu as linalg_lu;
let (l, u, piv) = linalg_lu(&a.inner)
.map_err(|e| PyValueError::new_err(format!("LU decomposition failed: {}", e)))?;
let piv_data = piv.to_vec();
let piv_f64: Vec<f64> = piv_data.iter().map(|&x| x as f64).collect();
let piv_array = crate::array::Array::from_vec(piv_f64);
Ok((
PyArray { inner: l },
PyArray { inner: u },
PyArray { inner: piv_array },
))
}
#[cfg(not(all(feature = "matrix_decomp", feature = "lapack")))]
{
Err(PyValueError::new_err("LU decomposition requires the 'matrix_decomp' and 'lapack' features. Rebuild with --features matrix_decomp,lapack"))
}
}
#[pyfunction]
fn norm(a: &PyArray, ord: Option<String>) -> PyResult<f64> {
use crate::linalg;
let p = match ord.as_deref() {
None | Some("fro") => Some(2.0),
Some("2") => Some(2.0),
Some("1") => Some(1.0),
Some("inf") => Some(f64::INFINITY),
_ => Some(2.0),
};
linalg::norm(&a.inner, p)
.map_err(|e| PyValueError::new_err(format!("Norm calculation failed: {}", e)))
}
#[pyfunction]
fn cond(a: &PyArray) -> PyResult<f64> {
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
{
use crate::new_modules::matrix_decomp::condition_number;
condition_number(&a.inner).map_err(|e| {
PyValueError::new_err(format!("Condition number calculation failed: {}", e))
})
}
#[cfg(not(all(feature = "matrix_decomp", feature = "lapack")))]
{
Err(PyValueError::new_err("Condition number calculation requires the 'matrix_decomp' and 'lapack' features. Rebuild with --features matrix_decomp,lapack"))
}
}
#[pyfunction]
fn matrix_rank(a: &PyArray, tol: Option<f64>) -> PyResult<usize> {
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
{
use crate::linalg::matrix_rank as linalg_matrix_rank;
linalg_matrix_rank(&a.inner, tol)
.map_err(|e| PyValueError::new_err(format!("Rank calculation failed: {}", e)))
}
#[cfg(not(all(feature = "matrix_decomp", feature = "lapack")))]
{
Err(PyValueError::new_err("Matrix rank calculation requires the 'matrix_decomp' and 'lapack' features. Rebuild with --features matrix_decomp,lapack"))
}
}
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
let linalg_module = PyModule::new(m.py(), "linalg")?;
linalg_module.add_function(wrap_pyfunction!(matmul, m)?)?;
linalg_module.add_function(wrap_pyfunction!(dot, m)?)?;
linalg_module.add_function(wrap_pyfunction!(matvec, m)?)?;
linalg_module.add_function(wrap_pyfunction!(det, m)?)?;
linalg_module.add_function(wrap_pyfunction!(trace, m)?)?;
linalg_module.add_function(wrap_pyfunction!(inv, m)?)?;
linalg_module.add_function(wrap_pyfunction!(solve, m)?)?;
linalg_module.add_function(wrap_pyfunction!(eigvals, m)?)?;
linalg_module.add_function(wrap_pyfunction!(eig, m)?)?;
linalg_module.add_function(wrap_pyfunction!(svd, m)?)?;
linalg_module.add_function(wrap_pyfunction!(qr, m)?)?;
linalg_module.add_function(wrap_pyfunction!(cholesky, m)?)?;
linalg_module.add_function(wrap_pyfunction!(lu, m)?)?;
linalg_module.add_function(wrap_pyfunction!(norm, m)?)?;
linalg_module.add_function(wrap_pyfunction!(cond, m)?)?;
linalg_module.add_function(wrap_pyfunction!(matrix_rank, m)?)?;
m.add_submodule(&linalg_module)?;
m.add_function(wrap_pyfunction!(matmul, m)?)?;
m.add_function(wrap_pyfunction!(dot, m)?)?;
Ok(())
}