use crate::{Result, Tensor, TensorError};
use scirs2_core::ndarray::{Array1, Array2};
pub fn inverse_f32(input: &Tensor<f32>) -> Result<Tensor<f32>> {
let data = input.as_slice().ok_or_else(|| {
TensorError::invalid_shape_simple(
"Matrix inverse requires contiguous tensor data".to_string(),
)
})?;
let shape = input.shape().dims();
if shape.len() != 2 {
return Err(TensorError::invalid_shape_simple(
"Matrix inverse requires 2D tensor".to_string(),
));
}
if shape[0] != shape[1] {
return Err(TensorError::invalid_shape_simple(format!(
"Matrix inverse requires square matrix, got {}x{}",
shape[0], shape[1]
)));
}
let matrix = Array2::from_shape_vec((shape[0], shape[1]), data.to_vec()).map_err(|e| {
TensorError::invalid_shape_simple(format!(
"Failed to create Array2 from tensor data: {}",
e
))
})?;
let result = scirs2_linalg::inv(&matrix.view(), None).map_err(|e| TensorError::BlasError {
operation: "inv".to_string(),
details: format!("scirs2-linalg inverse failed: {}", e),
context: None,
})?;
Ok(Tensor::from_array(result.into_dyn()))
}
pub fn determinant_f32(input: &Tensor<f32>) -> Result<f32> {
let data = input.as_slice().ok_or_else(|| {
TensorError::invalid_shape_simple(
"Matrix determinant requires contiguous tensor data".to_string(),
)
})?;
let shape = input.shape().dims();
if shape.len() != 2 {
return Err(TensorError::invalid_shape_simple(
"Matrix determinant requires 2D tensor".to_string(),
));
}
if shape[0] != shape[1] {
return Err(TensorError::invalid_shape_simple(format!(
"Matrix determinant requires square matrix, got {}x{}",
shape[0], shape[1]
)));
}
let matrix = Array2::from_shape_vec((shape[0], shape[1]), data.to_vec()).map_err(|e| {
TensorError::invalid_shape_simple(format!(
"Failed to create Array2 from tensor data: {}",
e
))
})?;
scirs2_linalg::det(&matrix.view(), None).map_err(|e| TensorError::BlasError {
operation: "det".to_string(),
details: format!("scirs2-linalg determinant failed: {}", e),
context: None,
})
}
pub fn svd_f32(input: &Tensor<f32>) -> Result<(Tensor<f32>, Tensor<f32>, Tensor<f32>)> {
let data = input.as_slice().ok_or_else(|| {
TensorError::invalid_shape_simple("SVD requires contiguous tensor data".to_string())
})?;
let shape = input.shape().dims();
if shape.len() != 2 {
return Err(TensorError::invalid_shape_simple(
"SVD requires 2D tensor".to_string(),
));
}
let matrix = Array2::from_shape_vec((shape[0], shape[1]), data.to_vec()).map_err(|e| {
TensorError::invalid_shape_simple(format!(
"Failed to create Array2 from tensor data: {}",
e
))
})?;
let (u, s, vt) =
scirs2_linalg::svd(&matrix.view(), true, None).map_err(|e| TensorError::BlasError {
operation: "svd".to_string(),
details: format!("scirs2-linalg SVD failed: {}", e),
context: None,
})?;
Ok((
Tensor::from_array(u.into_dyn()),
Tensor::from_array(s.into_dyn()),
Tensor::from_array(vt.into_dyn()),
))
}
pub fn solve_f32(a: &Tensor<f32>, b: &Tensor<f32>) -> Result<Tensor<f32>> {
let a_data = a.as_slice().ok_or_else(|| {
TensorError::invalid_shape_simple(
"Linear system solver requires contiguous tensor data for A".to_string(),
)
})?;
let b_data = b.as_slice().ok_or_else(|| {
TensorError::invalid_shape_simple(
"Linear system solver requires contiguous tensor data for b".to_string(),
)
})?;
let a_shape = a.shape().dims();
let b_shape = b.shape().dims();
if a_shape.len() != 2 {
return Err(TensorError::invalid_shape_simple(
"Linear system solver requires 2D matrix A".to_string(),
));
}
if b_shape.len() != 1 {
return Err(TensorError::invalid_shape_simple(
"Linear system solver requires 1D vector b".to_string(),
));
}
if a_shape[0] != a_shape[1] {
return Err(TensorError::invalid_shape_simple(format!(
"Linear system solver requires square matrix A, got {}x{}",
a_shape[0], a_shape[1]
)));
}
if a_shape[0] != b_shape[0] {
return Err(TensorError::invalid_shape_simple(format!(
"Matrix A and vector b dimensions don't match: A is {}x{}, b is {}",
a_shape[0], a_shape[1], b_shape[0]
)));
}
let a_matrix =
Array2::from_shape_vec((a_shape[0], a_shape[1]), a_data.to_vec()).map_err(|e| {
TensorError::invalid_shape_simple(format!(
"Failed to create Array2 from tensor A: {}",
e
))
})?;
let b_vector = Array1::from_shape_vec(b_shape[0], b_data.to_vec()).map_err(|e| {
TensorError::invalid_shape_simple(format!("Failed to create Array1 from tensor b: {}", e))
})?;
let result = scirs2_linalg::solve(&a_matrix.view(), &b_vector.view(), None).map_err(|e| {
TensorError::BlasError {
operation: "solve".to_string(),
details: format!("scirs2-linalg solve failed: {}", e),
context: None,
}
})?;
Ok(Tensor::from_array(result.into_dyn()))
}