use mdarray::{DSlice, DTensor, Layout};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SVDError {
#[error("Backend error code: {0}")]
BackendError(i32),
#[error("Inconsistent U and VT: must be both Some or both None")]
InconsistentUV,
#[error("Backend failed to converge: {superdiagonals} superdiagonals did not converge to zero")]
BackendDidNotConverge { superdiagonals: i32 },
}
pub struct SVDDecomp<T> {
pub s: DTensor<T, 2>,
pub u: DTensor<T, 2>,
pub vt: DTensor<T, 2>,
}
pub type SVDResult<T> = Result<SVDDecomp<T>, SVDError>;
pub trait SVD<T> {
fn svd<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> SVDResult<T>;
fn svd_s<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> Result<DTensor<T, 2>, SVDError>;
fn svd_overwrite<L: Layout, Ls: Layout, Lu: Layout, Lvt: Layout>(
&self,
a: &mut DSlice<T, 2, L>,
s: &mut DSlice<T, 2, Ls>,
u: &mut DSlice<T, 2, Lu>,
vt: &mut DSlice<T, 2, Lvt>,
) -> Result<(), SVDError>;
fn svd_overwrite_s<L: Layout, Ls: Layout>(
&self,
a: &mut DSlice<T, 2, L>,
s: &mut DSlice<T, 2, Ls>,
) -> Result<(), SVDError>;
}