use super::scalar::{LapackScalar, NeedsRwork};
use crate::SVDConfig;
use mdarray::{DSlice, DTensor, Layout};
use mdarray_linalg::svd::SVDError;
use mdarray_linalg::transpose_in_place;
use mdarray_linalg::{get_dims, into_i32};
use num_complex::ComplexFloat;
use std::ptr::null_mut;
pub fn gsvd<
La: Layout,
Ls: Layout,
Lu: Layout,
Lvt: Layout,
T: ComplexFloat + Default + LapackScalar + NeedsRwork,
>(
a: &mut DSlice<T, 2, La>,
s: &mut DSlice<T, 2, Ls>,
mut u: Option<&mut DSlice<T, 2, Lu>>,
mut vt: Option<&mut DSlice<T, 2, Lvt>>,
config: SVDConfig,
) -> Result<(), SVDError>
where
T::Real: Into<T>,
{
let (m, n) = get_dims!(a);
let min_mn = m.min(n);
let use_divide_conquer = match config {
SVDConfig::Auto => {
min_mn > 100
}
SVDConfig::DivideConquer => true,
SVDConfig::Jacobi => false,
};
let job = match (&u, &vt) {
(Some(x), Some(y)) => {
let ((mu, nu), (ms, ns), (mvt, nvt)) = get_dims!(x, s, y);
assert_eq!(mu, nu, "U must be square (m × m)");
assert_eq!(mvt, nvt, "VT must be square (n × n)");
assert_eq!(ns, ms, "s must be square (min(m,n),min(m,n))");
assert_eq!(
ms, min_mn,
"s must have min(m, n) rows (number of singular values)"
);
assert_eq!(mu, m, "U must have the same number of rows as A: U(m, m)");
assert_eq!(
nvt, n,
"VT must have the same number of columns as A: VT(n, n)"
);
'A'
}
(None, None) => 'N',
_ => return Err(SVDError::InconsistentUV),
};
let u_ptr: *mut T = u.as_mut().map_or(null_mut(), |x| x.as_mut_ptr());
let vt_ptr: *mut T = vt.as_mut().map_or(null_mut(), |x| x.as_mut_ptr());
let a_backup = if use_divide_conquer && matches!(config, SVDConfig::Auto) {
let mut ab = DTensor::<T, 2>::from_elem([m as usize, n as usize], T::default());
for i in 0..(m as usize) {
for j in 0..(n as usize) {
ab[[i, j]] = a[[i, j]];
}
}
Some(ab)
} else {
None
};
let info = if use_divide_conquer {
call_gesdd(a, m, n, s.as_mut_ptr(), Some(u_ptr), Some(vt_ptr), job)
} else {
call_gesvd(a, m, n, s.as_mut_ptr(), Some(u_ptr), Some(vt_ptr), job)
};
if info < 0 {
panic!(
"Invalid argument to SVD: the {}-th parameter had an illegal value.",
-info
);
} else if info > 0 && use_divide_conquer && (config == SVDConfig::Auto) {
let mut backup = a_backup.unwrap();
let info = call_gesvd(
&mut backup,
m,
n,
s.as_mut_ptr(),
Some(u_ptr),
Some(vt_ptr),
job,
);
if info < 0 {
panic!(
"Invalid argument to fallback SVD: the {}-th parameter had an illegal value.",
-info
);
} else if info > 0 {
Err(SVDError::BackendDidNotConverge {
superdiagonals: (info),
})
} else {
if job == 'A' {
transpose_in_place(u.unwrap());
transpose_in_place(vt.unwrap());
}
Ok(())
}
} else if info > 0 {
Err(SVDError::BackendDidNotConverge {
superdiagonals: (info),
})
} else {
if job == 'A' {
transpose_in_place(u.unwrap());
transpose_in_place(vt.unwrap());
}
Ok(())
}
}
fn call_gesdd<T: ComplexFloat + Default + LapackScalar + NeedsRwork, La: Layout>(
a: &mut DSlice<T, 2, La>,
m: i32,
n: i32,
s_ptr: *mut T,
u_ptr: Option<*mut T>,
vt_ptr: Option<*mut T>,
job: char,
) -> i32
where
T::Real: Into<T>,
{
let mut work = T::allocate(1);
let lwork = -1i32;
let mut iwork = vec![0i32; 8 * m.min(n) as usize];
let mut info = 0;
let row_major = a.stride(1) == 1;
assert!(
row_major || a.stride(0) == 1,
"a must be contiguous in one dimension"
);
if row_major {
transpose_in_place(a)
};
let mut rwork = vec![0.0; T::rwork_len(m, n)];
unsafe {
T::lapack_gesdd(
job as i8,
m,
n,
a.as_mut_ptr() as *mut _,
m,
s_ptr as *mut _,
u_ptr.unwrap() as *mut _,
m,
vt_ptr.unwrap() as *mut _,
n,
work.as_mut_ptr() as *mut _,
lwork,
rwork.as_mut_ptr() as *mut _,
iwork.as_mut_ptr() as *mut _,
&mut info,
);
}
let lwork = T::lwork_from_query(work.first().expect("Query buffer is empty"));
let mut work = T::allocate(lwork);
let lwork = lwork as usize;
unsafe {
T::lapack_gesdd(
job as i8,
m,
n,
a.as_mut_ptr() as *mut _,
m,
s_ptr as *mut _,
u_ptr.unwrap() as *mut _,
m,
vt_ptr.unwrap() as *mut _,
n,
work.as_mut_ptr() as *mut _,
lwork as i32,
rwork.as_mut_ptr() as *mut _,
iwork.as_mut_ptr() as *mut _,
&mut info,
);
}
info
}
fn call_gesvd<T: ComplexFloat + Default + LapackScalar + NeedsRwork, La: Layout>(
a: &mut DSlice<T, 2, La>,
m: i32,
n: i32,
s_ptr: *mut T,
u_ptr: Option<*mut T>,
vt_ptr: Option<*mut T>,
job: char,
) -> i32
where
T::Real: Into<T>,
{
let mut work = T::allocate(1);
let lwork = -1i32;
let mut info = 0;
let row_major = a.stride(1) == 1;
assert!(
row_major || a.stride(0) == 1,
"a must be contiguous in one dimension"
);
if row_major {
transpose_in_place(a)
};
let mut rwork = vec![0.0; T::rwork_len(m, n)];
unsafe {
T::lapack_gesvd(
job as i8,
job as i8,
m,
n,
a.as_mut_ptr() as *mut _,
m,
s_ptr as *mut _,
u_ptr.unwrap_or(null_mut()) as *mut _,
m,
vt_ptr.unwrap_or(null_mut()) as *mut _,
n,
work.as_mut_ptr() as *mut _,
lwork,
rwork.as_mut_ptr() as *mut _,
&mut info,
);
}
let lwork = T::lwork_from_query(&work[0]);
let mut work = T::allocate(lwork);
unsafe {
T::lapack_gesvd(
job as i8,
job as i8,
m,
n,
a.as_mut_ptr() as *mut _,
m,
s_ptr as *mut _,
u_ptr.unwrap_or(null_mut()) as *mut _,
m,
vt_ptr.unwrap_or(null_mut()) as *mut _,
n,
work.as_mut_ptr() as *mut _,
lwork,
rwork.as_mut_ptr() as *mut _,
&mut info,
);
}
info
}