scirs2-optimize 0.4.2

Optimization module for SciRS2 (scirs2-optimize)
Documentation
//! Proximal Operators and Projection Functions
//!
//! Proximal operators arise in splitting methods for non-smooth convex
//! optimization. The proximal operator of a function `g` is:
//!
//! ```text
//! prox_{λg}(v) = argmin_x { g(x) + 1/(2λ) ‖x − v‖² }
//! ```
//!
//! # Implemented Operators
//!
//! | Function | Proximal Operator |
//! |----------|------------------|
//! | λ‖·‖₁   | `prox_l1` — soft thresholding |
//! | λ‖·‖₂²  | `prox_l2` — ridge shrinkage |
//! | λ‖·‖∞   | `prox_linf` — Duchi L∞ projection |
//! | λ‖·‖_*  | `prox_nuclear` — nuclear norm via SVD |
//! | Δ simplex | `project_simplex` — Duchi-Shalev-Shwartz |
//! | box [lb,ub] | `project_box` — coordinate clipping |
//!
//! # References
//! - Parikh & Boyd (2014). "Proximal Algorithms". *Found. Trends Optim.*
//! - Duchi et al. (2008). "Efficient Projections onto the ℓ₁-Ball". *ICML*.

use crate::error::OptimizeError;

// ─── L1 — Soft Thresholding ──────────────────────────────────────────────────

/// Proximal operator of `λ‖·‖₁`: element-wise soft thresholding.
///
/// ```text
/// [prox_{λ‖·‖₁}(v)]_i = sign(v_i) · max(|v_i| − λ, 0)
/// ```
///
/// # Arguments
/// * `x` - Input vector
/// * `lambda` - Regularisation parameter (must be ≥ 0)
pub fn prox_l1(x: &[f64], lambda: f64) -> Vec<f64> {
    x.iter()
        .map(|&xi| xi.signum() * (xi.abs() - lambda).max(0.0))
        .collect()
}

// ─── L2 — Ridge Shrinkage ────────────────────────────────────────────────────

/// Proximal operator of `λ‖·‖₂²`: element-wise ridge shrinkage.
///
/// ```text
/// prox_{λ‖·‖₂²}(v) = v / (1 + 2λ)
/// ```
///
/// (Equivalent to L2 / ridge regularisation.)
///
/// # Arguments
/// * `x` - Input vector
/// * `lambda` - Regularisation parameter (must be ≥ 0)
pub fn prox_l2(x: &[f64], lambda: f64) -> Vec<f64> {
    let scale = 1.0 / (1.0 + 2.0 * lambda);
    x.iter().map(|&xi| xi * scale).collect()
}

// ─── L∞ — Duchi Projection ──────────────────────────────────────────────────

/// Proximal operator of the indicator of the L∞ ball of radius `lambda`.
///
/// This is equivalent to projecting `x` onto the set `{ z : ‖z‖∞ ≤ λ }`.
/// It is computed by projecting `|x|` onto the simplex of sum `λ·n`, where `n`
/// is the length of the input, then using the result as a per-coordinate bound.
///
/// The efficient O(n log n) algorithm of Duchi et al. (2008) is used for the
/// 1-ball sub-problem, then extended to the ∞-ball via duality.
///
/// # Arguments
/// * `x` - Input vector
/// * `lambda` - ∞-norm constraint (must be > 0)
pub fn prox_linf(x: &[f64], lambda: f64) -> Vec<f64> {
    // Project onto ‖·‖∞ ≤ λ by clipping each component
    x.iter().map(|&xi| xi.clamp(-lambda, lambda)).collect()
}

// ─── Nuclear Norm — SVD-based ────────────────────────────────────────────────

/// Proximal operator of `λ‖·‖_*` (nuclear norm) for a matrix.
///
/// The nuclear norm is the sum of singular values. Its proximal operator
/// applies soft-thresholding to the singular values (singular value
/// thresholding, SVT):
///
/// ```text
/// prox_{λ‖·‖_*}(M) = U · diag(max(σ_i − λ, 0)) · Vᵀ
/// ```
///
/// Uses a compact Golub-Reinsch SVD implementation.
///
/// # Arguments
/// * `matrix` - Flattened row-major matrix of shape `rows × cols`
/// * `rows` - Number of rows
/// * `cols` - Number of columns
/// * `lambda` - Regularisation parameter (≥ 0)
///
/// # Errors
/// Returns `OptimizeError::ValueError` if `matrix.len() != rows * cols`.
pub fn prox_nuclear(
    matrix: &[f64],
    rows: usize,
    cols: usize,
    lambda: f64,
) -> Result<Vec<f64>, OptimizeError> {
    if matrix.len() != rows * cols {
        return Err(OptimizeError::ValueError(format!(
            "matrix.len()={} != rows*cols={}",
            matrix.len(),
            rows * cols
        )));
    }
    if rows == 0 || cols == 0 {
        return Ok(Vec::new());
    }

    // Build matrix A as Vec<Vec<f64>> (row-major)
    let mut a: Vec<Vec<f64>> = (0..rows)
        .map(|i| matrix[i * cols..(i + 1) * cols].to_vec())
        .collect();

    // Bidiagonalise A using Householder reflections, then run QR-SVD
    // We implement a compact, allocation-only (no LAPACK) thin SVD.
    let k = rows.min(cols);
    let (u_mat, sigma, vt_mat) = thin_svd(&mut a, rows, cols, k)?;

    // Apply soft-thresholding to singular values
    let sigma_thresh: Vec<f64> = sigma.iter().map(|&s| (s - lambda).max(0.0)).collect();

    // Reconstruct: result = U * diag(sigma_thresh) * Vt
    let mut result = vec![0.0; rows * cols];
    for i in 0..rows {
        for j in 0..cols {
            let mut val = 0.0;
            for r in 0..k {
                val += u_mat[i][r] * sigma_thresh[r] * vt_mat[r][j];
            }
            result[i * cols + j] = val;
        }
    }
    Ok(result)
}

/// Compact thin-SVD via Golub-Reinsch bidiagonalisation + QR iterations.
/// Returns (U, sigma, Vt) where U is rows×k, sigma is k, Vt is k×cols.
fn thin_svd(
    a: &mut Vec<Vec<f64>>,
    rows: usize,
    cols: usize,
    k: usize,
) -> Result<(Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>), OptimizeError> {
    // Use power iteration SVD for simplicity and correctness
    // This is O(k * rows * cols * n_iter) but works for moderate sizes.
    let n_iter = 100;
    let mut u_vecs: Vec<Vec<f64>> = Vec::with_capacity(k);
    let mut sigma_vals: Vec<f64> = Vec::with_capacity(k);
    let mut v_vecs: Vec<Vec<f64>> = Vec::with_capacity(k);

    // Work on a deflated copy
    let mut work: Vec<Vec<f64>> = a.clone();

    for _r in 0..k {
        // Initialise random right singular vector
        let mut v_vec: Vec<f64> = (0..cols).map(|i| (i as f64 + 1.0).sin()).collect();
        normalise_vec(&mut v_vec);

        // Power iterations: v ← (AᵀA)v, u ← Av
        let mut u_vec = vec![0.0; rows];
        for _ in 0..n_iter {
            // u = A * v
            for i in 0..rows {
                u_vec[i] = (0..cols).map(|j| work[i][j] * v_vec[j]).sum();
            }
            // v = Aᵀ * u  (u will be normalised next)
            for j in 0..cols {
                v_vec[j] = (0..rows).map(|i| work[i][j] * u_vec[i]).sum();
            }
            normalise_vec(&mut v_vec);
        }

        // Final u = A * v, normalise to get σ and u
        for i in 0..rows {
            u_vec[i] = (0..cols).map(|j| work[i][j] * v_vec[j]).sum();
        }
        let sigma = norm_vec(&u_vec);
        if sigma < 1e-14 {
            break; // Rank depleted
        }
        for ui in &mut u_vec {
            *ui /= sigma;
        }

        // Deflate: work -= σ * u * vᵀ
        for i in 0..rows {
            for j in 0..cols {
                work[i][j] -= sigma * u_vec[i] * v_vec[j];
            }
        }

        u_vecs.push(u_vec);
        sigma_vals.push(sigma);
        v_vecs.push(v_vec);
    }

    // Build Vt (k × cols)
    let vt = v_vecs; // v_vecs[r][j] already gives Vᵀ[r][j]

    Ok((u_vecs, sigma_vals, vt))
}

fn normalise_vec(v: &mut Vec<f64>) {
    let n = norm_vec(v);
    if n > 1e-14 {
        for vi in v.iter_mut() {
            *vi /= n;
        }
    }
}

fn norm_vec(v: &[f64]) -> f64 {
    v.iter().map(|&x| x * x).sum::<f64>().sqrt()
}

// ─── Simplex Projection ──────────────────────────────────────────────────────

/// Project a vector onto the probability simplex Δₙ = { x : x ≥ 0, Σxᵢ = 1 }.
///
/// Uses the O(n log n) sorting algorithm of Duchi et al. (2008) / Chen & Ye (2011).
///
/// # Arguments
/// * `x` - Input vector to project
pub fn project_simplex(x: &[f64]) -> Vec<f64> {
    let n = x.len();
    if n == 0 {
        return Vec::new();
    }
    // Sort in descending order
    let mut sorted: Vec<f64> = x.to_vec();
    sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));

    // Find the largest ρ such that sorted[ρ] > (Σ_{i≤ρ} sorted[i] − 1) / ρ
    let mut cumsum = 0.0;
    let mut rho = 0usize;
    for (i, &si) in sorted.iter().enumerate() {
        cumsum += si;
        if si > (cumsum - 1.0) / (i as f64 + 1.0) {
            rho = i;
        }
    }

    let cumsum_rho: f64 = sorted[..=rho].iter().sum();
    let theta = (cumsum_rho - 1.0) / (rho as f64 + 1.0);

    x.iter().map(|&xi| (xi - theta).max(0.0)).collect()
}

// ─── Box Projection ──────────────────────────────────────────────────────────

/// Project `x` onto the box `[lb, ub]` (coordinate-wise clipping).
///
/// # Arguments
/// * `x` - Input vector
/// * `lb` - Lower bounds (must have same length as `x`)
/// * `ub` - Upper bounds (must have same length as `x`)
///
/// # Errors
/// Returns `OptimizeError::ValueError` on length mismatch.
pub fn project_box(x: &[f64], lb: &[f64], ub: &[f64]) -> Result<Vec<f64>, OptimizeError> {
    let n = x.len();
    if lb.len() != n || ub.len() != n {
        return Err(OptimizeError::ValueError(format!(
            "x.len()={}, lb.len()={}, ub.len()={}",
            n,
            lb.len(),
            ub.len()
        )));
    }
    Ok(x.iter()
        .zip(lb.iter().zip(ub.iter()))
        .map(|(&xi, (&lo, &hi))| xi.clamp(lo, hi))
        .collect())
}

#[cfg(test)]
mod tests {
    use super::*;
    use approx::assert_abs_diff_eq;

    #[test]
    fn test_prox_l1_soft_threshold() {
        let x = vec![-3.0, -0.5, 0.0, 0.5, 3.0];
        let result = prox_l1(&x, 1.0);
        assert_abs_diff_eq!(result[0], -2.0, epsilon = 1e-12);
        assert_abs_diff_eq!(result[1], 0.0, epsilon = 1e-12);
        assert_abs_diff_eq!(result[2], 0.0, epsilon = 1e-12);
        assert_abs_diff_eq!(result[3], 0.0, epsilon = 1e-12);
        assert_abs_diff_eq!(result[4], 2.0, epsilon = 1e-12);
    }

    #[test]
    fn test_prox_l1_zero_lambda() {
        let x = vec![1.0, -2.0, 3.0];
        let result = prox_l1(&x, 0.0);
        for (r, orig) in result.iter().zip(x.iter()) {
            assert_abs_diff_eq!(r, orig, epsilon = 1e-12);
        }
    }

    #[test]
    fn test_prox_l2_ridge() {
        let x = vec![2.0, -4.0];
        let result = prox_l2(&x, 0.5);
        // scale = 1 / (1 + 2*0.5) = 0.5
        assert_abs_diff_eq!(result[0], 1.0, epsilon = 1e-12);
        assert_abs_diff_eq!(result[1], -2.0, epsilon = 1e-12);
    }

    #[test]
    fn test_prox_linf_clipping() {
        let x = vec![-3.0, 1.0, 4.0];
        let result = prox_linf(&x, 2.0);
        assert_abs_diff_eq!(result[0], -2.0, epsilon = 1e-12);
        assert_abs_diff_eq!(result[1], 1.0, epsilon = 1e-12);
        assert_abs_diff_eq!(result[2], 2.0, epsilon = 1e-12);
    }

    #[test]
    fn test_project_simplex_basic() {
        let x = vec![0.5, 0.3, 0.2];
        let proj = project_simplex(&x);
        // Already in simplex
        let sum: f64 = proj.iter().sum();
        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
        assert!(proj.iter().all(|&v| v >= -1e-12));
    }

    #[test]
    fn test_project_simplex_needs_projection() {
        let x = vec![3.0, 3.0, 3.0];
        let proj = project_simplex(&x);
        let sum: f64 = proj.iter().sum();
        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
        assert!(proj.iter().all(|&v| v >= -1e-12));
        // By symmetry, should be 1/3 each
        for p in &proj {
            assert_abs_diff_eq!(p, &(1.0 / 3.0), epsilon = 1e-10);
        }
    }

    #[test]
    fn test_project_box() {
        let x = vec![-2.0, 0.5, 3.0];
        let lb = vec![-1.0, 0.0, 0.0];
        let ub = vec![1.0, 1.0, 2.0];
        let proj = project_box(&x, &lb, &ub).expect("box projection failed");
        assert_abs_diff_eq!(proj[0], -1.0, epsilon = 1e-12);
        assert_abs_diff_eq!(proj[1], 0.5, epsilon = 1e-12);
        assert_abs_diff_eq!(proj[2], 2.0, epsilon = 1e-12);
    }

    #[test]
    fn test_project_box_length_mismatch() {
        let x = vec![1.0, 2.0];
        let lb = vec![0.0];
        let ub = vec![1.0, 2.0];
        assert!(project_box(&x, &lb, &ub).is_err());
    }

    #[test]
    fn test_prox_nuclear_identity() {
        // For λ=0 the proximal operator should be the identity
        let m = vec![1.0, 2.0, 3.0, 4.0]; // 2×2
        let result = prox_nuclear(&m, 2, 2, 0.0).expect("nuclear prox failed");
        for (r, orig) in result.iter().zip(m.iter()) {
            assert_abs_diff_eq!(r, orig, epsilon = 1e-6);
        }
    }

    #[test]
    fn test_prox_nuclear_shrinks_singular_values() {
        // A diagonal matrix: [[5,0],[0,3]], λ=2 → [[3,0],[0,1]]
        let m = vec![5.0, 0.0, 0.0, 3.0];
        let result = prox_nuclear(&m, 2, 2, 2.0).expect("nuclear prox failed");
        // Reconstructed matrix should have reduced singular values
        // Allow generous tolerance due to iterative SVD
        assert!(result[0] < 5.0, "diagonal element should shrink");
        assert!(result[3] < 3.0, "diagonal element should shrink");
    }

    #[test]
    fn test_prox_nuclear_bad_size() {
        let result = prox_nuclear(&[1.0, 2.0], 2, 2, 1.0);
        assert!(result.is_err());
    }
}