use crate::error::OptimizeError;
pub fn prox_l1(x: &[f64], lambda: f64) -> Vec<f64> {
x.iter()
.map(|&xi| xi.signum() * (xi.abs() - lambda).max(0.0))
.collect()
}
pub fn prox_l2(x: &[f64], lambda: f64) -> Vec<f64> {
let scale = 1.0 / (1.0 + 2.0 * lambda);
x.iter().map(|&xi| xi * scale).collect()
}
pub fn prox_linf(x: &[f64], lambda: f64) -> Vec<f64> {
x.iter().map(|&xi| xi.clamp(-lambda, lambda)).collect()
}
pub fn prox_nuclear(
matrix: &[f64],
rows: usize,
cols: usize,
lambda: f64,
) -> Result<Vec<f64>, OptimizeError> {
if matrix.len() != rows * cols {
return Err(OptimizeError::ValueError(format!(
"matrix.len()={} != rows*cols={}",
matrix.len(),
rows * cols
)));
}
if rows == 0 || cols == 0 {
return Ok(Vec::new());
}
let mut a: Vec<Vec<f64>> = (0..rows)
.map(|i| matrix[i * cols..(i + 1) * cols].to_vec())
.collect();
let k = rows.min(cols);
let (u_mat, sigma, vt_mat) = thin_svd(&mut a, rows, cols, k)?;
let sigma_thresh: Vec<f64> = sigma.iter().map(|&s| (s - lambda).max(0.0)).collect();
let mut result = vec![0.0; rows * cols];
for i in 0..rows {
for j in 0..cols {
let mut val = 0.0;
for r in 0..k {
val += u_mat[i][r] * sigma_thresh[r] * vt_mat[r][j];
}
result[i * cols + j] = val;
}
}
Ok(result)
}
fn thin_svd(
a: &mut Vec<Vec<f64>>,
rows: usize,
cols: usize,
k: usize,
) -> Result<(Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>), OptimizeError> {
let n_iter = 100;
let mut u_vecs: Vec<Vec<f64>> = Vec::with_capacity(k);
let mut sigma_vals: Vec<f64> = Vec::with_capacity(k);
let mut v_vecs: Vec<Vec<f64>> = Vec::with_capacity(k);
let mut work: Vec<Vec<f64>> = a.clone();
for _r in 0..k {
let mut v_vec: Vec<f64> = (0..cols).map(|i| (i as f64 + 1.0).sin()).collect();
normalise_vec(&mut v_vec);
let mut u_vec = vec![0.0; rows];
for _ in 0..n_iter {
for i in 0..rows {
u_vec[i] = (0..cols).map(|j| work[i][j] * v_vec[j]).sum();
}
for j in 0..cols {
v_vec[j] = (0..rows).map(|i| work[i][j] * u_vec[i]).sum();
}
normalise_vec(&mut v_vec);
}
for i in 0..rows {
u_vec[i] = (0..cols).map(|j| work[i][j] * v_vec[j]).sum();
}
let sigma = norm_vec(&u_vec);
if sigma < 1e-14 {
break; }
for ui in &mut u_vec {
*ui /= sigma;
}
for i in 0..rows {
for j in 0..cols {
work[i][j] -= sigma * u_vec[i] * v_vec[j];
}
}
u_vecs.push(u_vec);
sigma_vals.push(sigma);
v_vecs.push(v_vec);
}
let vt = v_vecs;
Ok((u_vecs, sigma_vals, vt))
}
fn normalise_vec(v: &mut Vec<f64>) {
let n = norm_vec(v);
if n > 1e-14 {
for vi in v.iter_mut() {
*vi /= n;
}
}
}
fn norm_vec(v: &[f64]) -> f64 {
v.iter().map(|&x| x * x).sum::<f64>().sqrt()
}
pub fn project_simplex(x: &[f64]) -> Vec<f64> {
let n = x.len();
if n == 0 {
return Vec::new();
}
let mut sorted: Vec<f64> = x.to_vec();
sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0;
let mut rho = 0usize;
for (i, &si) in sorted.iter().enumerate() {
cumsum += si;
if si > (cumsum - 1.0) / (i as f64 + 1.0) {
rho = i;
}
}
let cumsum_rho: f64 = sorted[..=rho].iter().sum();
let theta = (cumsum_rho - 1.0) / (rho as f64 + 1.0);
x.iter().map(|&xi| (xi - theta).max(0.0)).collect()
}
pub fn project_box(x: &[f64], lb: &[f64], ub: &[f64]) -> Result<Vec<f64>, OptimizeError> {
let n = x.len();
if lb.len() != n || ub.len() != n {
return Err(OptimizeError::ValueError(format!(
"x.len()={}, lb.len()={}, ub.len()={}",
n,
lb.len(),
ub.len()
)));
}
Ok(x.iter()
.zip(lb.iter().zip(ub.iter()))
.map(|(&xi, (&lo, &hi))| xi.clamp(lo, hi))
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_prox_l1_soft_threshold() {
let x = vec![-3.0, -0.5, 0.0, 0.5, 3.0];
let result = prox_l1(&x, 1.0);
assert_abs_diff_eq!(result[0], -2.0, epsilon = 1e-12);
assert_abs_diff_eq!(result[1], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(result[2], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(result[3], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(result[4], 2.0, epsilon = 1e-12);
}
#[test]
fn test_prox_l1_zero_lambda() {
let x = vec![1.0, -2.0, 3.0];
let result = prox_l1(&x, 0.0);
for (r, orig) in result.iter().zip(x.iter()) {
assert_abs_diff_eq!(r, orig, epsilon = 1e-12);
}
}
#[test]
fn test_prox_l2_ridge() {
let x = vec![2.0, -4.0];
let result = prox_l2(&x, 0.5);
assert_abs_diff_eq!(result[0], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(result[1], -2.0, epsilon = 1e-12);
}
#[test]
fn test_prox_linf_clipping() {
let x = vec![-3.0, 1.0, 4.0];
let result = prox_linf(&x, 2.0);
assert_abs_diff_eq!(result[0], -2.0, epsilon = 1e-12);
assert_abs_diff_eq!(result[1], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(result[2], 2.0, epsilon = 1e-12);
}
#[test]
fn test_project_simplex_basic() {
let x = vec![0.5, 0.3, 0.2];
let proj = project_simplex(&x);
let sum: f64 = proj.iter().sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
assert!(proj.iter().all(|&v| v >= -1e-12));
}
#[test]
fn test_project_simplex_needs_projection() {
let x = vec![3.0, 3.0, 3.0];
let proj = project_simplex(&x);
let sum: f64 = proj.iter().sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
assert!(proj.iter().all(|&v| v >= -1e-12));
for p in &proj {
assert_abs_diff_eq!(p, &(1.0 / 3.0), epsilon = 1e-10);
}
}
#[test]
fn test_project_box() {
let x = vec![-2.0, 0.5, 3.0];
let lb = vec![-1.0, 0.0, 0.0];
let ub = vec![1.0, 1.0, 2.0];
let proj = project_box(&x, &lb, &ub).expect("box projection failed");
assert_abs_diff_eq!(proj[0], -1.0, epsilon = 1e-12);
assert_abs_diff_eq!(proj[1], 0.5, epsilon = 1e-12);
assert_abs_diff_eq!(proj[2], 2.0, epsilon = 1e-12);
}
#[test]
fn test_project_box_length_mismatch() {
let x = vec![1.0, 2.0];
let lb = vec![0.0];
let ub = vec![1.0, 2.0];
assert!(project_box(&x, &lb, &ub).is_err());
}
#[test]
fn test_prox_nuclear_identity() {
let m = vec![1.0, 2.0, 3.0, 4.0]; let result = prox_nuclear(&m, 2, 2, 0.0).expect("nuclear prox failed");
for (r, orig) in result.iter().zip(m.iter()) {
assert_abs_diff_eq!(r, orig, epsilon = 1e-6);
}
}
#[test]
fn test_prox_nuclear_shrinks_singular_values() {
let m = vec![5.0, 0.0, 0.0, 3.0];
let result = prox_nuclear(&m, 2, 2, 2.0).expect("nuclear prox failed");
assert!(result[0] < 5.0, "diagonal element should shrink");
assert!(result[3] < 3.0, "diagonal element should shrink");
}
#[test]
fn test_prox_nuclear_bad_size() {
let result = prox_nuclear(&[1.0, 2.0], 2, 2, 1.0);
assert!(result.is_err());
}
}