pub mod gemm;
pub mod qr;
pub mod svd;
pub use gemm::{distributed_gemm_simulate, BlockCyclicMatrix, CommCost};
pub use qr::{caqr_simulate, HouseholderReflector};
pub use svd::{distributed_svd_simulate, thick_restart_lanczos, LanczosSvdConfig};
use crate::error::LinalgResult;
use scirs2_core::ndarray::Array2;
#[derive(Debug, Clone)]
pub struct DistribConfig {
pub block_size: usize,
pub n_proc_rows: usize,
pub n_proc_cols: usize,
}
impl Default for DistribConfig {
fn default() -> Self {
Self {
block_size: 64,
n_proc_rows: 2,
n_proc_cols: 2,
}
}
}
pub trait DistributedLinearAlgebra {
fn distributed_gemm(
a: &Array2<f64>,
b: &Array2<f64>,
config: &DistribConfig,
) -> LinalgResult<Array2<f64>>;
fn distributed_qr(
a: &Array2<f64>,
config: &DistribConfig,
) -> LinalgResult<(Array2<f64>, Array2<f64>)>;
fn distributed_svd(
a: &Array2<f64>,
k: usize,
) -> LinalgResult<(Array2<f64>, Vec<f64>, Array2<f64>)>;
}
pub struct SimulatedDistributed;
impl DistributedLinearAlgebra for SimulatedDistributed {
fn distributed_gemm(
a: &Array2<f64>,
b: &Array2<f64>,
config: &DistribConfig,
) -> LinalgResult<Array2<f64>> {
distributed_gemm_simulate(a, b, config)
}
fn distributed_qr(
a: &Array2<f64>,
config: &DistribConfig,
) -> LinalgResult<(Array2<f64>, Array2<f64>)> {
caqr_simulate(a, config)
}
fn distributed_svd(
a: &Array2<f64>,
k: usize,
) -> LinalgResult<(Array2<f64>, Vec<f64>, Array2<f64>)> {
distributed_svd_simulate(a, k)
}
}