use crate::error::{OptimizeError, OptimizeResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::random::{rngs::StdRng, RngExt, SeedableRng};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SketchType {
Gaussian,
CountSketch,
}
#[derive(Debug, Clone)]
pub struct SketchedGdConfig {
pub sketch_dim: usize,
pub max_iter: usize,
pub tol: f64,
pub step_size: f64,
pub sketch_type: SketchType,
pub seed: u64,
pub resample_sketch: bool,
pub track_objective: bool,
}
impl Default for SketchedGdConfig {
fn default() -> Self {
Self {
sketch_dim: 10,
max_iter: 1000,
tol: 1e-6,
step_size: 0.01,
sketch_type: SketchType::Gaussian,
seed: 42,
resample_sketch: true,
track_objective: false,
}
}
}
#[derive(Debug, Clone)]
pub struct SketchedGdResult {
pub x: Array1<f64>,
pub fun: f64,
pub iterations: usize,
pub converged: bool,
pub objective_history: Vec<f64>,
pub grad_norm: f64,
}
fn generate_gaussian_sketch(k: usize, n: usize, rng: &mut StdRng) -> Array2<f64> {
let scale = 1.0 / (k as f64).sqrt();
let mut s = Array2::zeros((k, n));
for i in 0..k {
for j in 0..n {
let u1: f64 = rng.random::<f64>().max(1e-30);
let u2: f64 = rng.random::<f64>();
let z: f64 = (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos();
s[[i, j]] = z * scale;
}
}
s
}
fn generate_count_sketch(k: usize, n: usize, rng: &mut StdRng) -> Array2<f64> {
let mut s = Array2::zeros((k, n));
for j in 0..n {
let row = rng.random_range(0..k);
let sign: f64 = if rng.random_range(0..2_u32) == 0 {
1.0
} else {
-1.0
};
s[[row, j]] = sign;
}
s
}
pub struct SketchedGradientDescent {
config: SketchedGdConfig,
}
impl SketchedGradientDescent {
pub fn new(config: SketchedGdConfig) -> Self {
Self { config }
}
pub fn default_solver() -> Self {
Self::new(SketchedGdConfig::default())
}
pub fn minimize<F, G>(
&self,
objective: F,
gradient: G,
x0: &Array1<f64>,
) -> OptimizeResult<SketchedGdResult>
where
F: Fn(&ArrayView1<f64>) -> f64,
G: Fn(&ArrayView1<f64>) -> Array1<f64>,
{
let n = x0.len();
let k = self.config.sketch_dim.min(n);
if n == 0 {
return Err(OptimizeError::InvalidInput(
"Dimension must be at least 1".to_string(),
));
}
if k == 0 {
return Err(OptimizeError::InvalidInput(
"Sketch dimension must be at least 1".to_string(),
));
}
let mut rng = StdRng::seed_from_u64(self.config.seed);
let mut x = x0.clone();
let mut objective_history = Vec::new();
let mut sketch = match self.config.sketch_type {
SketchType::Gaussian => generate_gaussian_sketch(k, n, &mut rng),
SketchType::CountSketch => generate_count_sketch(k, n, &mut rng),
};
let mut prev_obj = objective(&x.view());
if self.config.track_objective {
objective_history.push(prev_obj);
}
let mut converged = false;
let mut iterations = 0;
for iter in 0..self.config.max_iter {
iterations = iter + 1;
if self.config.resample_sketch && iter > 0 {
sketch = match self.config.sketch_type {
SketchType::Gaussian => generate_gaussian_sketch(k, n, &mut rng),
SketchType::CountSketch => generate_count_sketch(k, n, &mut rng),
};
}
let grad = gradient(&x.view());
let s_grad = sketch.dot(&grad);
let direction = sketch.t().dot(&s_grad);
for j in 0..n {
x[j] -= self.config.step_size * direction[j];
}
let cur_obj = objective(&x.view());
if self.config.track_objective {
objective_history.push(cur_obj);
}
let change = (prev_obj - cur_obj).abs();
prev_obj = cur_obj;
if change < self.config.tol {
converged = true;
break;
}
}
let final_grad = gradient(&x.view());
let grad_norm = final_grad.dot(&final_grad).sqrt();
Ok(SketchedGdResult {
x,
fun: prev_obj,
iterations,
converged,
objective_history,
grad_norm,
})
}
pub fn sketch_and_project(
&self,
a: &Array2<f64>,
b: &Array1<f64>,
x0: &Array1<f64>,
) -> OptimizeResult<SketchedGdResult> {
let n = x0.len();
let k = self.config.sketch_dim.min(n);
if a.nrows() != n || a.ncols() != n {
return Err(OptimizeError::InvalidInput(format!(
"Matrix A has shape ({}, {}), expected ({}, {})",
a.nrows(),
a.ncols(),
n,
n
)));
}
if b.len() != n {
return Err(OptimizeError::InvalidInput(format!(
"Vector b has length {}, expected {}",
b.len(),
n
)));
}
let mut rng = StdRng::seed_from_u64(self.config.seed);
let mut x = x0.clone();
let mut objective_history = Vec::new();
let compute_obj = |x: &Array1<f64>| -> f64 {
let ax = a.dot(x);
0.5 * x.dot(&ax) - b.dot(x)
};
let mut prev_obj = compute_obj(&x);
if self.config.track_objective {
objective_history.push(prev_obj);
}
let mut converged = false;
let mut iterations = 0;
for iter in 0..self.config.max_iter {
iterations = iter + 1;
let sketch = match self.config.sketch_type {
SketchType::Gaussian => generate_gaussian_sketch(k, n, &mut rng),
SketchType::CountSketch => generate_count_sketch(k, n, &mut rng),
};
let grad = a.dot(&x) - b;
let s_grad = sketch.dot(&grad);
let sa = sketch.dot(a); let sa_st = sa.dot(&sketch.t());
let z = match solve_sketched_system(&sa_st, &s_grad) {
Some(z) => z,
None => {
let direction = sketch.t().dot(&s_grad);
for j in 0..n {
x[j] -= self.config.step_size * direction[j];
}
let cur_obj = compute_obj(&x);
if self.config.track_objective {
objective_history.push(cur_obj);
}
prev_obj = cur_obj;
continue;
}
};
let update = sketch.t().dot(&z);
for j in 0..n {
x[j] -= update[j];
}
let cur_obj = compute_obj(&x);
if self.config.track_objective {
objective_history.push(cur_obj);
}
let change = (prev_obj - cur_obj).abs();
prev_obj = cur_obj;
if change < self.config.tol {
converged = true;
break;
}
}
let final_grad = a.dot(&x) - b;
let grad_norm = final_grad.dot(&final_grad).sqrt();
Ok(SketchedGdResult {
x,
fun: prev_obj,
iterations,
converged,
objective_history,
grad_norm,
})
}
}
fn solve_sketched_system(a: &Array2<f64>, b: &Array1<f64>) -> Option<Array1<f64>> {
let n = a.nrows();
if n == 0 || a.ncols() != n || b.len() != n {
return None;
}
let mut aug = Array2::zeros((n, n + 1));
for i in 0..n {
for j in 0..n {
aug[[i, j]] = a[[i, j]];
}
aug[[i, n]] = b[i];
}
for col in 0..n {
let mut max_val = aug[[col, col]].abs();
let mut max_row = col;
for row in (col + 1)..n {
let val = aug[[row, col]].abs();
if val > max_val {
max_val = val;
max_row = row;
}
}
if max_val < 1e-14 {
return None;
}
if max_row != col {
for j in 0..=n {
let tmp = aug[[col, j]];
aug[[col, j]] = aug[[max_row, j]];
aug[[max_row, j]] = tmp;
}
}
let pivot = aug[[col, col]];
for row in (col + 1)..n {
let factor = aug[[row, col]] / pivot;
for j in col..=n {
let val = aug[[col, j]];
aug[[row, j]] -= factor * val;
}
}
}
let mut x = Array1::zeros(n);
for i in (0..n).rev() {
let mut sum = aug[[i, n]];
for j in (i + 1)..n {
sum -= aug[[i, j]] * x[j];
}
if aug[[i, i]].abs() < 1e-14 {
return None;
}
x[i] = sum / aug[[i, i]];
}
Some(x)
}
pub fn sketched_gradient_descent<F, G>(
objective: F,
gradient: G,
x0: &Array1<f64>,
config: Option<SketchedGdConfig>,
) -> OptimizeResult<SketchedGdResult>
where
F: Fn(&ArrayView1<f64>) -> f64,
G: Fn(&ArrayView1<f64>) -> Array1<f64>,
{
let config = config.unwrap_or_default();
let solver = SketchedGradientDescent::new(config);
solver.minimize(objective, gradient, x0)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{array, Array1, Array2};
#[test]
fn test_gaussian_sketch_quadratic() {
let objective = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[1] * x[1] };
let gradient = |x: &ArrayView1<f64>| -> Array1<f64> { array![2.0 * x[0], 2.0 * x[1]] };
let x0 = array![5.0, 3.0];
let config = SketchedGdConfig {
sketch_dim: 2,
max_iter: 5000,
tol: 1e-10,
step_size: 0.1,
sketch_type: SketchType::Gaussian,
seed: 42,
resample_sketch: false, ..Default::default()
};
let result = sketched_gradient_descent(objective, gradient, &x0, Some(config));
assert!(result.is_ok());
let result = result.expect("should succeed");
assert!(result.converged, "Did not converge, fun={}", result.fun);
assert!(result.fun < 1e-4, "fun={}", result.fun);
}
#[test]
fn test_count_sketch_convergence() {
let objective = |x: &ArrayView1<f64>| -> f64 {
(x[0] - 1.0).powi(2) + (x[1] - 2.0).powi(2) + (x[2] - 3.0).powi(2)
};
let gradient = |x: &ArrayView1<f64>| -> Array1<f64> {
array![2.0 * (x[0] - 1.0), 2.0 * (x[1] - 2.0), 2.0 * (x[2] - 3.0)]
};
let x0 = array![0.0, 0.0, 0.0];
let config = SketchedGdConfig {
sketch_dim: 3,
max_iter: 10000,
tol: 1e-8,
step_size: 0.1,
sketch_type: SketchType::CountSketch,
seed: 123,
resample_sketch: true,
..Default::default()
};
let result = sketched_gradient_descent(objective, gradient, &x0, Some(config));
assert!(result.is_ok());
let result = result.expect("should succeed");
assert!(result.fun < 1.0, "fun={}", result.fun);
}
#[test]
fn test_sketch_and_project() {
let a = array![[2.0, 0.5], [0.5, 3.0]];
let b = array![1.0, 2.0];
let x0 = array![0.0, 0.0];
let config = SketchedGdConfig {
sketch_dim: 2,
max_iter: 200,
tol: 1e-10,
sketch_type: SketchType::Gaussian,
seed: 42,
..Default::default()
};
let solver = SketchedGradientDescent::new(config);
let result = solver.sketch_and_project(&a, &b, &x0);
assert!(result.is_ok());
let result = result.expect("should succeed");
assert!(result.converged, "Did not converge, fun={}", result.fun);
let residual = a.dot(&result.x) - &b;
let res_norm = residual.dot(&residual).sqrt();
assert!(res_norm < 1e-4, "Residual norm={}", res_norm);
}
#[test]
fn test_high_dimensional() {
let n = 20;
let objective = |x: &ArrayView1<f64>| -> f64 { x.dot(x) };
let gradient = |x: &ArrayView1<f64>| -> Array1<f64> { x.mapv(|xi| 2.0 * xi) };
let x0 = Array1::from_vec(vec![1.0; n]);
let config = SketchedGdConfig {
sketch_dim: 5, max_iter: 10000,
tol: 1e-6,
step_size: 0.05,
sketch_type: SketchType::Gaussian,
seed: 77,
resample_sketch: true,
..Default::default()
};
let result = sketched_gradient_descent(objective, gradient, &x0, Some(config));
assert!(result.is_ok());
let result = result.expect("should succeed");
assert!(result.fun < 1.0, "fun={}, expected < 1.0", result.fun);
}
#[test]
fn test_objective_history() {
let objective = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[1] * x[1] };
let gradient = |x: &ArrayView1<f64>| -> Array1<f64> { array![2.0 * x[0], 2.0 * x[1]] };
let x0 = array![3.0, 4.0];
let config = SketchedGdConfig {
sketch_dim: 2,
max_iter: 50,
tol: 1e-20,
step_size: 0.05,
sketch_type: SketchType::Gaussian,
seed: 42,
resample_sketch: false, track_objective: true,
..Default::default()
};
let result = sketched_gradient_descent(objective, gradient, &x0, Some(config));
assert!(result.is_ok());
let result = result.expect("should succeed");
assert!(result.objective_history.len() > 1);
for i in 1..result.objective_history.len() {
assert!(
result.objective_history[i] <= result.objective_history[i - 1] + 1e-10,
"Objective increased at iter {}: {} -> {}",
i,
result.objective_history[i - 1],
result.objective_history[i]
);
}
}
#[test]
fn test_zero_dimension_error() {
let objective = |_x: &ArrayView1<f64>| -> f64 { 0.0 };
let gradient = |_x: &ArrayView1<f64>| -> Array1<f64> { Array1::zeros(0) };
let x0 = Array1::zeros(0);
let result = sketched_gradient_descent(objective, gradient, &x0, None);
assert!(result.is_err());
}
#[test]
fn test_sketch_project_mismatch() {
let a = Array2::eye(3);
let b = array![1.0, 2.0]; let x0 = array![0.0, 0.0, 0.0];
let solver = SketchedGradientDescent::default_solver();
let result = solver.sketch_and_project(&a, &b, &x0);
assert!(result.is_err());
}
}