use crate::distributed::{DistributedMatrix, DistributedVector, solvers, decomposition};
use crate::error::{LinalgError, LinalgResult};
pub struct DistributedLinalgOps;
impl DistributedLinalgOps {
pub fn distributed_matmul<T>(
a: &DistributedMatrix<T>,
b: &DistributedMatrix<T>,
) -> LinalgResult<DistributedMatrix<T>>
where
T: scirs2_core::numeric::Float + Send + Sync + 'static,
{
let (m, k) = a.global_shape();
let (k2, n) = b.global_shape();
if k != k2 {
return Err(LinalgError::DimensionError(format!(
"Matrix dimensions don't match for multiplication: ({}, {}) x ({}, {})",
m, k, k2, n
)));
}
a.multiply(b)
}
pub fn distributed_add<T>(
a: &DistributedMatrix<T>,
b: &DistributedMatrix<T>,
) -> LinalgResult<DistributedMatrix<T>>
where
T: scirs2_core::numeric::Float + Send + Sync + 'static,
{
if a.global_shape() != b.global_shape() {
return Err(LinalgError::DimensionError(format!(
"Matrix dimensions don't match for addition: {:?} vs {:?}",
a.global_shape(),
b.global_shape()
)));
}
a.add(b)
}
pub fn distributed_transpose<T>(
matrix: &DistributedMatrix<T>,
) -> LinalgResult<DistributedMatrix<T>>
where
T: scirs2_core::numeric::Float + Send + Sync + 'static,
{
matrix.transpose()
}
pub fn distributed_solve<T>(
a: &DistributedMatrix<T>,
b: &DistributedVector<T>,
) -> LinalgResult<DistributedVector<T>>
where
T: scirs2_core::numeric::Float + Send + Sync + 'static,
{
solvers::solve_linear_system(a, b)
}
pub fn distributed_lu<T>(
matrix: &DistributedMatrix<T>,
) -> LinalgResult<(DistributedMatrix<T>, DistributedMatrix<T>)>
where
T: scirs2_core::numeric::Float + Send + Sync + 'static,
{
decomposition::lu_decomposition(matrix)
}
pub fn distributed_qr<T>(
matrix: &DistributedMatrix<T>,
) -> LinalgResult<(DistributedMatrix<T>, DistributedMatrix<T>)>
where
T: scirs2_core::numeric::Float + Send + Sync + 'static,
{
decomposition::qr_decomposition(matrix)
}
}