use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::prelude::*;
use scirs2_core::random::{Distribution, Normal};
use std::fmt::Debug;
use std::iter::Sum;
use crate::decomposition::svd;
use crate::error::{LinalgError, LinalgResult};
#[derive(Debug, Clone)]
pub struct ObservationMask {
pub observed: Vec<(usize, usize)>,
pub nrows: usize,
pub ncols: usize,
}
impl ObservationMask {
pub fn from_bool_matrix(mask: &ArrayView2<bool>) -> Self {
let (nrows, ncols) = mask.dim();
let mut observed = Vec::new();
for i in 0..nrows {
for j in 0..ncols {
if mask[[i, j]] {
observed.push((i, j));
}
}
}
ObservationMask {
observed,
nrows,
ncols,
}
}
pub fn from_indices(observed: Vec<(usize, usize)>, nrows: usize, ncols: usize) -> Self {
ObservationMask {
observed,
nrows,
ncols,
}
}
pub fn from_nan_matrix<F: Float>(matrix: &ArrayView2<F>) -> Self {
let (nrows, ncols) = matrix.dim();
let mut observed = Vec::new();
for i in 0..nrows {
for j in 0..ncols {
if !matrix[[i, j]].is_nan() {
observed.push((i, j));
}
}
}
ObservationMask {
observed,
nrows,
ncols,
}
}
pub fn observation_ratio(&self) -> f64 {
let total = self.nrows * self.ncols;
if total == 0 {
return 0.0;
}
self.observed.len() as f64 / total as f64
}
pub fn is_observed(&self, row: usize, col: usize) -> bool {
self.observed.contains(&(row, col))
}
}
#[derive(Debug, Clone)]
pub struct CompletionResult<F> {
pub matrix: Array2<F>,
pub iterations: usize,
pub residual: F,
pub converged: bool,
}
#[derive(Debug, Clone)]
pub struct CompletionConfig<F> {
pub max_iter: usize,
pub tolerance: F,
pub rank: Option<usize>,
pub lambda: F,
pub step_size: Option<F>,
}
impl<F: Float> CompletionConfig<F> {
pub fn new(lambda: F) -> Self {
Self {
max_iter: 200,
tolerance: F::from(1e-6).unwrap_or(F::epsilon()),
rank: None,
lambda,
step_size: None,
}
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn with_tolerance(mut self, tol: F) -> Self {
self.tolerance = tol;
self
}
pub fn with_rank(mut self, rank: usize) -> Self {
self.rank = Some(rank);
self
}
pub fn with_step_size(mut self, step_size: F) -> Self {
self.step_size = Some(step_size);
self
}
}
pub fn singular_value_threshold<F>(x: &ArrayView2<F>, tau: F) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let (u, s, vt) = svd(x, false, None)?;
let k = s.len();
let mut s_thresh = Array1::zeros(k);
let mut effective_rank = 0;
for i in 0..k {
let val = s[i] - tau;
if val > F::zero() {
s_thresh[i] = val;
effective_rank += 1;
}
}
if effective_rank == 0 {
return Ok(Array2::zeros(x.dim()));
}
let r = effective_rank;
let u_r = u.slice(s![.., ..r]).to_owned();
let vt_r = vt.slice(s![..r, ..]).to_owned();
let (m, n) = x.dim();
let mut result = Array2::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut val = F::zero();
for kk in 0..r {
val += u_r[[i, kk]] * s_thresh[kk] * vt_r[[kk, j]];
}
result[[i, j]] = val;
}
}
Ok(result)
}
pub fn svt_completion<F>(
observed_values: &ArrayView2<F>,
mask: &ObservationMask,
config: &CompletionConfig<F>,
) -> LinalgResult<CompletionResult<F>>
where
F: Float
+ NumAssign
+ Sum
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync
+ 'static,
{
let (m, n) = observed_values.dim();
if mask.nrows != m || mask.ncols != n {
return Err(LinalgError::DimensionError(
"Mask dimensions do not match matrix dimensions".to_string(),
));
}
let tau = config.lambda;
let delta = config.step_size.unwrap_or_else(|| {
F::from(1.2 * (m * n) as f64 / mask.observed.len().max(1) as f64).unwrap_or(F::one())
});
let mut y = Array2::zeros((m, n));
for &(i, j) in &mask.observed {
y[[i, j]] = delta * observed_values[[i, j]];
}
let mut x = Array2::zeros((m, n));
let mut converged = false;
let mut last_residual = F::infinity();
let mut iterations = 0;
for iter in 0..config.max_iter {
iterations = iter + 1;
x = singular_value_threshold(&y.view(), tau)?;
let mut residual = F::zero();
let mut obs_count = F::zero();
for &(i, j) in &mask.observed {
let diff = observed_values[[i, j]] - x[[i, j]];
residual += diff * diff;
obs_count += F::one();
}
residual = if obs_count > F::zero() {
(residual / obs_count).sqrt()
} else {
F::zero()
};
let rel_change = if last_residual > F::epsilon() {
(last_residual - residual).abs() / last_residual
} else {
F::zero()
};
if rel_change < config.tolerance && iter > 0 {
converged = true;
last_residual = residual;
break;
}
last_residual = residual;
for &(i, j) in &mask.observed {
y[[i, j]] += delta * (observed_values[[i, j]] - x[[i, j]]);
}
}
Ok(CompletionResult {
matrix: x,
iterations,
residual: last_residual,
converged,
})
}
pub fn nuclear_norm_completion<F>(
observed_values: &ArrayView2<F>,
mask: &ObservationMask,
config: &CompletionConfig<F>,
) -> LinalgResult<CompletionResult<F>>
where
F: Float
+ NumAssign
+ Sum
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync
+ 'static,
{
let (m, n) = observed_values.dim();
if mask.nrows != m || mask.ncols != n {
return Err(LinalgError::DimensionError(
"Mask dimensions do not match matrix dimensions".to_string(),
));
}
let lambda = config.lambda;
let step = config.step_size.unwrap_or(F::one());
let mut x = Array2::zeros((m, n));
for &(i, j) in &mask.observed {
x[[i, j]] = observed_values[[i, j]];
}
let mut converged = false;
let mut last_residual = F::infinity();
let mut iterations = 0;
for iter in 0..config.max_iter {
iterations = iter + 1;
let mut g = x.clone();
for &(i, j) in &mask.observed {
g[[i, j]] -= step * (x[[i, j]] - observed_values[[i, j]]);
}
let x_new = singular_value_threshold(&g.view(), step * lambda)?;
let mut change = F::zero();
let mut norm_x = F::zero();
for i in 0..m {
for j in 0..n {
let diff = x_new[[i, j]] - x[[i, j]];
change += diff * diff;
norm_x += x_new[[i, j]] * x_new[[i, j]];
}
}
let rel_change = if norm_x > F::epsilon() {
change.sqrt() / norm_x.sqrt()
} else {
change.sqrt()
};
x = x_new;
let mut residual = F::zero();
let mut obs_count = F::zero();
for &(i, j) in &mask.observed {
let diff = observed_values[[i, j]] - x[[i, j]];
residual += diff * diff;
obs_count += F::one();
}
residual = if obs_count > F::zero() {
(residual / obs_count).sqrt()
} else {
F::zero()
};
if rel_change < config.tolerance {
converged = true;
last_residual = residual;
break;
}
last_residual = residual;
}
Ok(CompletionResult {
matrix: x,
iterations,
residual: last_residual,
converged,
})
}
pub fn als_completion<F>(
observed_values: &ArrayView2<F>,
mask: &ObservationMask,
config: &CompletionConfig<F>,
) -> LinalgResult<CompletionResult<F>>
where
F: Float
+ NumAssign
+ Sum
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync
+ 'static,
{
let (m, n) = observed_values.dim();
if mask.nrows != m || mask.ncols != n {
return Err(LinalgError::DimensionError(
"Mask dimensions do not match matrix dimensions".to_string(),
));
}
let rank = config
.rank
.ok_or_else(|| LinalgError::InvalidInput("ALS requires a target rank".to_string()))?;
if rank == 0 || rank > m.min(n) {
return Err(LinalgError::InvalidInput(format!(
"Rank ({rank}) must be in [1, {}]",
m.min(n)
)));
}
let lambda = config.lambda;
let mut rng = scirs2_core::random::rng();
let normal = Normal::new(0.0, 0.1).map_err(|e| {
LinalgError::ComputationError(format!("Failed to create distribution: {e}"))
})?;
let mut u_factor = Array2::zeros((m, rank));
let mut v_factor = Array2::zeros((n, rank));
for i in 0..m {
for j in 0..rank {
u_factor[[i, j]] = F::from(normal.sample(&mut rng)).unwrap_or(F::zero());
}
}
for i in 0..n {
for j in 0..rank {
v_factor[[i, j]] = F::from(normal.sample(&mut rng)).unwrap_or(F::zero());
}
}
let mut row_obs: Vec<Vec<(usize, F)>> = vec![Vec::new(); m]; let mut col_obs: Vec<Vec<(usize, F)>> = vec![Vec::new(); n]; for &(i, j) in &mask.observed {
row_obs[i].push((j, observed_values[[i, j]]));
col_obs[j].push((i, observed_values[[i, j]]));
}
let mut converged = false;
let mut last_residual = F::infinity();
let mut iterations = 0;
for iter in 0..config.max_iter {
iterations = iter + 1;
for i in 0..m {
if row_obs[i].is_empty() {
continue;
}
let n_obs = row_obs[i].len();
let mut v_sub = Array2::zeros((n_obs, rank));
let mut b_vec = Array1::zeros(n_obs);
for (idx, &(j, val)) in row_obs[i].iter().enumerate() {
for kk in 0..rank {
v_sub[[idx, kk]] = v_factor[[j, kk]];
}
b_vec[idx] = val;
}
let vt_v = v_sub.t().dot(&v_sub);
let vt_b = v_sub.t().dot(&b_vec);
let mut gram = vt_v;
for kk in 0..rank {
gram[[kk, kk]] += lambda;
}
if let Ok(sol) = solve_small_system(&gram.view(), &vt_b) {
for kk in 0..rank {
u_factor[[i, kk]] = sol[kk];
}
}
}
for j in 0..n {
if col_obs[j].is_empty() {
continue;
}
let n_obs = col_obs[j].len();
let mut u_sub = Array2::zeros((n_obs, rank));
let mut b_vec = Array1::zeros(n_obs);
for (idx, &(i, val)) in col_obs[j].iter().enumerate() {
for kk in 0..rank {
u_sub[[idx, kk]] = u_factor[[i, kk]];
}
b_vec[idx] = val;
}
let ut_u = u_sub.t().dot(&u_sub);
let ut_b = u_sub.t().dot(&b_vec);
let mut gram = ut_u;
for kk in 0..rank {
gram[[kk, kk]] += lambda;
}
if let Ok(sol) = solve_small_system(&gram.view(), &ut_b) {
for kk in 0..rank {
v_factor[[j, kk]] = sol[kk];
}
}
}
let mut residual = F::zero();
let mut count = F::zero();
for &(i, j) in &mask.observed {
let mut pred = F::zero();
for kk in 0..rank {
pred += u_factor[[i, kk]] * v_factor[[j, kk]];
}
let diff = observed_values[[i, j]] - pred;
residual += diff * diff;
count += F::one();
}
residual = if count > F::zero() {
(residual / count).sqrt()
} else {
F::zero()
};
let rel_change = if last_residual > F::epsilon() {
(last_residual - residual).abs() / last_residual
} else {
F::zero()
};
if rel_change < config.tolerance && iter > 0 {
converged = true;
last_residual = residual;
break;
}
last_residual = residual;
}
let matrix = u_factor.dot(&v_factor.t());
Ok(CompletionResult {
matrix,
iterations,
residual: last_residual,
converged,
})
}
fn solve_small_system<F>(a: &ArrayView2<F>, b: &Array1<F>) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let n = a.nrows();
if a.ncols() != n || b.len() != n {
return Err(LinalgError::DimensionError(
"System dimensions mismatch".to_string(),
));
}
let mut aug = Array2::zeros((n, n + 1));
for i in 0..n {
for j in 0..n {
aug[[i, j]] = a[[i, j]];
}
aug[[i, n]] = b[i];
}
for col in 0..n {
let mut max_val = F::zero();
let mut max_row = col;
for row in col..n {
let abs_val = aug[[row, col]].abs();
if abs_val > max_val {
max_val = abs_val;
max_row = row;
}
}
if max_val < F::epsilon() * F::from(100.0).unwrap_or(F::one()) {
aug[[col, col]] += F::epsilon() * F::from(1000.0).unwrap_or(F::one());
}
if max_row != col {
for j in 0..=n {
let tmp = aug[[col, j]];
aug[[col, j]] = aug[[max_row, j]];
aug[[max_row, j]] = tmp;
}
}
let pivot = aug[[col, col]];
if pivot.abs() < F::epsilon() {
continue;
}
for row in (col + 1)..n {
let factor = aug[[row, col]] / pivot;
for j in col..=n {
let val = aug[[col, j]];
aug[[row, j]] -= factor * val;
}
}
}
let mut x = Array1::zeros(n);
for i in (0..n).rev() {
let mut sum = aug[[i, n]];
for j in (i + 1)..n {
sum -= aug[[i, j]] * x[j];
}
let diag = aug[[i, i]];
x[i] = if diag.abs() > F::epsilon() {
sum / diag
} else {
F::zero()
};
}
Ok(x)
}
pub fn soft_impute<F>(
observed_values: &ArrayView2<F>,
mask: &ObservationMask,
config: &CompletionConfig<F>,
) -> LinalgResult<CompletionResult<F>>
where
F: Float
+ NumAssign
+ Sum
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync
+ 'static,
{
let (m, n) = observed_values.dim();
if mask.nrows != m || mask.ncols != n {
return Err(LinalgError::DimensionError(
"Mask dimensions do not match matrix dimensions".to_string(),
));
}
let lambda = config.lambda;
let mut x = Array2::zeros((m, n));
let mut converged = false;
let mut last_residual = F::infinity();
let mut iterations = 0;
for iter in 0..config.max_iter {
iterations = iter + 1;
let mut z = x.clone();
for &(i, j) in &mask.observed {
z[[i, j]] = observed_values[[i, j]];
}
let x_new = singular_value_threshold(&z.view(), lambda)?;
let mut change = F::zero();
let mut norm_x = F::zero();
for i in 0..m {
for j in 0..n {
let diff = x_new[[i, j]] - x[[i, j]];
change += diff * diff;
norm_x += x_new[[i, j]] * x_new[[i, j]];
}
}
let rel_change = if norm_x > F::epsilon() {
change.sqrt() / norm_x.sqrt()
} else {
change.sqrt()
};
x = x_new;
let mut residual = F::zero();
let mut count = F::zero();
for &(i, j) in &mask.observed {
let diff = observed_values[[i, j]] - x[[i, j]];
residual += diff * diff;
count += F::one();
}
residual = if count > F::zero() {
(residual / count).sqrt()
} else {
F::zero()
};
if rel_change < config.tolerance && iter > 0 {
converged = true;
last_residual = residual;
break;
}
last_residual = residual;
}
Ok(CompletionResult {
matrix: x,
iterations,
residual: last_residual,
converged,
})
}
pub fn fill_missing<F>(matrix: &ArrayView2<F>, strategy: &str) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + Debug + scirs2_core::ndarray::ScalarOperand + 'static,
{
let (m, n) = matrix.dim();
let mut result = matrix.to_owned();
match strategy {
"zero" => {
for i in 0..m {
for j in 0..n {
if result[[i, j]].is_nan() {
result[[i, j]] = F::zero();
}
}
}
}
"mean" => {
let mut sum = F::zero();
let mut count = F::zero();
for &val in matrix.iter() {
if !val.is_nan() {
sum += val;
count += F::one();
}
}
let mean = if count > F::zero() {
sum / count
} else {
F::zero()
};
for i in 0..m {
for j in 0..n {
if result[[i, j]].is_nan() {
result[[i, j]] = mean;
}
}
}
}
"row_mean" => {
for i in 0..m {
let mut sum = F::zero();
let mut count = F::zero();
for j in 0..n {
if !matrix[[i, j]].is_nan() {
sum += matrix[[i, j]];
count += F::one();
}
}
let row_mean = if count > F::zero() {
sum / count
} else {
F::zero()
};
for j in 0..n {
if result[[i, j]].is_nan() {
result[[i, j]] = row_mean;
}
}
}
}
"col_mean" => {
for j in 0..n {
let mut sum = F::zero();
let mut count = F::zero();
for i in 0..m {
if !matrix[[i, j]].is_nan() {
sum += matrix[[i, j]];
count += F::one();
}
}
let col_mean = if count > F::zero() {
sum / count
} else {
F::zero()
};
for i in 0..m {
if result[[i, j]].is_nan() {
result[[i, j]] = col_mean;
}
}
}
}
_ => {
return Err(LinalgError::InvalidInput(format!(
"Unknown fill strategy: '{strategy}'. Use 'zero', 'mean', 'row_mean', or 'col_mean'"
)));
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn make_low_rank_observed(
m: usize,
n: usize,
rank: usize,
obs_fraction: f64,
) -> (Array2<f64>, ObservationMask) {
let mut rng = scirs2_core::random::rng();
let normal =
Normal::new(0.0, 1.0).unwrap_or_else(|_| panic!("Failed to create distribution"));
let mut u_gen = Array2::zeros((m, rank));
let mut v_gen = Array2::zeros((n, rank));
for i in 0..m {
for j in 0..rank {
u_gen[[i, j]] = normal.sample(&mut rng);
}
}
for i in 0..n {
for j in 0..rank {
v_gen[[i, j]] = normal.sample(&mut rng);
}
}
let full_matrix = u_gen.dot(&v_gen.t());
let mut observed = Vec::new();
for i in 0..m {
for j in 0..n {
let r: f64 = rng.random();
if r < obs_fraction {
observed.push((i, j));
}
}
}
for i in 0..m {
let has_obs = observed.iter().any(|&(r, _)| r == i);
if !has_obs {
let j: usize = rng.random_range(0..n);
observed.push((i, j));
}
}
for j in 0..n {
let has_obs = observed.iter().any(|&(_, c)| c == j);
if !has_obs {
let i: usize = rng.random_range(0..m);
observed.push((i, j));
}
}
let mask = ObservationMask::from_indices(observed, m, n);
(full_matrix, mask)
}
#[test]
fn test_observation_mask_from_bool() {
let mask_arr = array![[true, false, true], [false, true, false]];
let mask = ObservationMask::from_bool_matrix(&mask_arr.view());
assert_eq!(mask.nrows, 2);
assert_eq!(mask.ncols, 3);
assert_eq!(mask.observed.len(), 3);
assert!(mask.is_observed(0, 0));
assert!(!mask.is_observed(0, 1));
}
#[test]
fn test_observation_mask_from_nan() {
let mat = array![[1.0, f64::NAN, 3.0], [f64::NAN, 5.0, f64::NAN]];
let mask = ObservationMask::from_nan_matrix(&mat.view());
assert_eq!(mask.observed.len(), 3);
assert!(mask.is_observed(0, 0));
assert!(!mask.is_observed(0, 1));
}
#[test]
fn test_observation_ratio() {
let mask = ObservationMask::from_indices(vec![(0, 0), (1, 1)], 3, 3);
let ratio = mask.observation_ratio();
assert!((ratio - 2.0 / 9.0).abs() < 1e-10);
}
#[test]
fn test_svt_basic() {
let a = array![[3.0, 0.0], [0.0, 2.0], [0.0, 0.0]];
let result = singular_value_threshold(&a.view(), 1.0);
assert!(result.is_ok());
let thresholded = result.expect("SVT failed");
let frob_sq: f64 = thresholded.iter().map(|&x| x * x).sum();
assert!(
(frob_sq - 5.0).abs() < 0.5,
"Frobenius norm squared should be ~5, got {frob_sq}"
);
}
#[test]
fn test_svt_full_threshold() {
let a = array![[1.0, 0.0], [0.0, 0.5]];
let result = singular_value_threshold(&a.view(), 2.0);
assert!(result.is_ok());
let thresholded = result.expect("SVT failed");
for &val in thresholded.iter() {
assert!(val.abs() < 1e-10, "Should be zero after full threshold");
}
}
#[test]
fn test_svt_completion_simple() {
let (full_mat, mask) = make_low_rank_observed(8, 6, 2, 0.8);
let config = CompletionConfig::new(0.1)
.with_max_iter(100)
.with_tolerance(1e-4);
let result = svt_completion(&full_mat.view(), &mask, &config);
assert!(result.is_ok());
let comp = result.expect("SVT completion failed");
assert_eq!(comp.matrix.nrows(), 8);
assert_eq!(comp.matrix.ncols(), 6);
assert!(comp.iterations > 0);
}
#[test]
fn test_nuclear_norm_completion() {
let (full_mat, mask) = make_low_rank_observed(8, 6, 2, 0.8);
let config = CompletionConfig::new(0.01)
.with_max_iter(50)
.with_tolerance(1e-4);
let result = nuclear_norm_completion(&full_mat.view(), &mask, &config);
assert!(result.is_ok());
let comp = result.expect("Nuclear norm completion failed");
assert_eq!(comp.matrix.nrows(), 8);
assert_eq!(comp.matrix.ncols(), 6);
}
#[test]
fn test_als_completion_basic() {
let (full_mat, mask) = make_low_rank_observed(10, 8, 2, 0.7);
let config = CompletionConfig::new(0.01)
.with_max_iter(100)
.with_tolerance(1e-4)
.with_rank(2);
let result = als_completion(&full_mat.view(), &mask, &config);
assert!(result.is_ok());
let comp = result.expect("ALS completion failed");
assert_eq!(comp.matrix.nrows(), 10);
assert_eq!(comp.matrix.ncols(), 8);
}
#[test]
fn test_als_requires_rank() {
let (full_mat, mask) = make_low_rank_observed(5, 5, 2, 0.8);
let config = CompletionConfig::new(0.01); assert!(als_completion(&full_mat.view(), &mask, &config).is_err());
}
#[test]
fn test_als_invalid_rank() {
let (full_mat, mask) = make_low_rank_observed(5, 5, 2, 0.8);
let config = CompletionConfig::new(0.01).with_rank(0);
assert!(als_completion(&full_mat.view(), &mask, &config).is_err());
let config2 = CompletionConfig::new(0.01).with_rank(100);
assert!(als_completion(&full_mat.view(), &mask, &config2).is_err());
}
#[test]
fn test_soft_impute_basic() {
let (full_mat, mask) = make_low_rank_observed(8, 6, 2, 0.8);
let config = CompletionConfig::new(0.05)
.with_max_iter(50)
.with_tolerance(1e-4);
let result = soft_impute(&full_mat.view(), &mask, &config);
assert!(result.is_ok());
let comp = result.expect("Soft-Impute failed");
assert_eq!(comp.matrix.nrows(), 8);
assert_eq!(comp.matrix.ncols(), 6);
}
#[test]
#[ignore = "SVD-based soft-impute with 200 iterations exceeds CI time budget"]
fn test_soft_impute_observed_entries_fit() {
let (full_mat, mask) = make_low_rank_observed(6, 5, 1, 0.9);
let config = CompletionConfig::new(0.001)
.with_max_iter(200)
.with_tolerance(1e-6);
let comp = soft_impute(&full_mat.view(), &mask, &config).expect("Soft-Impute failed");
let mut max_obs_err = 0.0_f64;
for &(i, j) in &mask.observed {
let err = (full_mat[[i, j]] - comp.matrix[[i, j]]).abs();
if err > max_obs_err {
max_obs_err = err;
}
}
assert!(
max_obs_err < 5.0,
"Max observed error too large: {max_obs_err}"
);
}
#[test]
fn test_dimension_mismatch_errors() {
let mat = array![[1.0, 2.0], [3.0, 4.0]];
let bad_mask = ObservationMask::from_indices(vec![(0, 0)], 3, 3);
let config = CompletionConfig::new(0.1);
assert!(svt_completion(&mat.view(), &bad_mask, &config).is_err());
assert!(nuclear_norm_completion(&mat.view(), &bad_mask, &config).is_err());
assert!(soft_impute(&mat.view(), &bad_mask, &config).is_err());
assert!(als_completion(&mat.view(), &bad_mask, &config.clone().with_rank(1)).is_err());
}
#[test]
fn test_fill_missing_zero() {
let mat = array![[1.0, f64::NAN], [f64::NAN, 4.0]];
let filled = fill_missing(&mat.view(), "zero").expect("fill zero failed");
assert_eq!(filled[[0, 0]], 1.0);
assert_eq!(filled[[0, 1]], 0.0);
assert_eq!(filled[[1, 0]], 0.0);
assert_eq!(filled[[1, 1]], 4.0);
}
#[test]
fn test_fill_missing_mean() {
let mat = array![[1.0, f64::NAN], [f64::NAN, 3.0]];
let filled = fill_missing(&mat.view(), "mean").expect("fill mean failed");
assert!((filled[[0, 1]] - 2.0).abs() < 1e-10);
assert!((filled[[1, 0]] - 2.0).abs() < 1e-10);
}
#[test]
fn test_fill_missing_row_mean() {
let mat = array![[1.0, f64::NAN, 3.0], [4.0, 5.0, f64::NAN]];
let filled = fill_missing(&mat.view(), "row_mean").expect("fill row_mean failed");
assert!((filled[[0, 1]] - 2.0).abs() < 1e-10);
assert!((filled[[1, 2]] - 4.5).abs() < 1e-10);
}
#[test]
fn test_fill_missing_col_mean() {
let mat = array![[1.0, f64::NAN], [3.0, 4.0], [f64::NAN, 6.0]];
let filled = fill_missing(&mat.view(), "col_mean").expect("fill col_mean failed");
assert!((filled[[2, 0]] - 2.0).abs() < 1e-10);
assert!((filled[[0, 1]] - 5.0).abs() < 1e-10);
}
#[test]
fn test_fill_missing_invalid_strategy() {
let mat = array![[1.0, f64::NAN]];
assert!(fill_missing(&mat.view(), "invalid").is_err());
}
#[test]
fn test_config_builder() {
let config = CompletionConfig::new(0.5_f64)
.with_max_iter(500)
.with_tolerance(1e-8)
.with_rank(3)
.with_step_size(0.1);
assert_eq!(config.max_iter, 500);
assert!((config.tolerance - 1e-8).abs() < 1e-15);
assert_eq!(config.rank, Some(3));
assert!((config.step_size.expect("step") - 0.1).abs() < 1e-15);
}
#[test]
fn test_solve_small_system() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let b = array![5.0, 4.0];
let x = solve_small_system(&a.view(), &b).expect("solve failed");
let ax = a.dot(&x);
assert!((ax[0] - 5.0).abs() < 1e-6);
assert!((ax[1] - 4.0).abs() < 1e-6);
}
}