use crate::error::QuantRS2Result;
use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use scirs2_core::Complex64;
#[derive(Debug, Clone)]
pub struct CsrMatrix<T> {
pub data: Vec<T>,
pub indices: Vec<usize>,
pub indptr: Vec<usize>,
pub shape: (usize, usize),
}
impl<T: Clone + Default> CsrMatrix<T> {
pub fn zeros(shape: (usize, usize)) -> Self {
Self {
data: Vec::new(),
indices: Vec::new(),
indptr: vec![0; shape.0 + 1],
shape,
}
}
pub const fn new(
data: Vec<T>,
indices: Vec<usize>,
indptr: Vec<usize>,
shape: (usize, usize),
) -> Result<Self, String> {
Ok(Self {
data,
indices,
indptr,
shape,
})
}
pub const fn shape(&self) -> (usize, usize) {
self.shape
}
pub fn nnz(&self) -> usize {
self.data.len()
}
}
impl CsrMatrix<Complex64> {
pub fn to_dense(&self) -> Array2<Complex64> {
let (rows, cols) = self.shape;
let mut dense = Array2::zeros((rows, cols));
for row in 0..rows {
let start = self.indptr[row];
let end = self.indptr[row + 1];
for idx in start..end {
let col = self.indices[idx];
let val = self.data[idx];
dense[(row, col)] = val;
}
}
dense
}
}
pub struct SvdResult {
pub u: Array2<f64>,
pub s: Array1<f64>,
pub vt: Array2<f64>,
}
pub fn svd(
matrix: &ArrayView2<f64>,
full_matrices: bool,
compute_uv: Option<bool>,
) -> QuantRS2Result<(Array2<f64>, Array1<f64>, Array2<f64>)> {
let (u, s, vt) = scirs2_linalg::svd(matrix, true, None)
.map_err(|e| crate::error::QuantRS2Error::ComputationError(format!("SVD failed: {e:?}")))?;
Ok((u, s, vt))
}
pub fn svd_simple(matrix: &Array2<f64>) -> QuantRS2Result<SvdResult> {
let (u, s, vt) = svd(&matrix.view(), true, Some(true))?;
Ok(SvdResult { u, s, vt })
}
pub fn randomized_svd(
matrix: &ArrayView2<f64>,
rank: usize,
oversampling: Option<usize>,
n_iter: Option<usize>,
random_state: Option<u64>,
) -> QuantRS2Result<(Array2<f64>, Array1<f64>, Array2<f64>)> {
let (u, s, vt) = svd(matrix, false, Some(true))?;
let k = rank.min(s.len());
Ok((
u.slice(scirs2_core::ndarray::s![.., ..k]).to_owned(),
s.slice(scirs2_core::ndarray::s![..k]).to_owned(),
vt.slice(scirs2_core::ndarray::s![..k, ..]).to_owned(),
))
}
pub fn truncated_svd(
matrix: &ArrayView2<f64>,
rank: usize,
random_state: Option<u64>,
) -> QuantRS2Result<(Array2<f64>, Array1<f64>, Array2<f64>)> {
randomized_svd(matrix, rank, Some(10), Some(2), random_state)
}