use super::OptimizationResult;
use crate::primitives::Vector;
pub mod prox {
use crate::primitives::Vector;
#[must_use]
pub fn soft_threshold(v: &Vector<f32>, lambda: f32) -> Vector<f32> {
let mut result = Vector::zeros(v.len());
for i in 0..v.len() {
let val = v[i];
result[i] = if val > lambda {
val - lambda
} else if val < -lambda {
val + lambda
} else {
0.0
};
}
result
}
#[must_use]
pub fn nonnegative(x: &Vector<f32>) -> Vector<f32> {
let mut result = Vector::zeros(x.len());
for i in 0..x.len() {
result[i] = x[i].max(0.0);
}
result
}
#[must_use]
pub fn project_l2_ball(x: &Vector<f32>, radius: f32) -> Vector<f32> {
let mut norm_sq = 0.0;
for i in 0..x.len() {
norm_sq += x[i] * x[i];
}
let norm = norm_sq.sqrt();
if norm <= radius {
x.clone()
} else {
let scale = radius / norm;
let mut result = Vector::zeros(x.len());
for i in 0..x.len() {
result[i] = scale * x[i];
}
result
}
}
#[must_use]
pub fn project_box(x: &Vector<f32>, lower: &Vector<f32>, upper: &Vector<f32>) -> Vector<f32> {
let mut result = Vector::zeros(x.len());
for i in 0..x.len() {
result[i] = x[i].max(lower[i]).min(upper[i]);
}
result
}
}
pub trait Optimizer {
fn step(&mut self, params: &mut Vector<f32>, gradients: &Vector<f32>);
fn minimize<F, G>(
&mut self,
_objective: F,
_gradient: G,
_x0: Vector<f32>,
) -> OptimizationResult
where
F: Fn(&Vector<f32>) -> f32,
G: Fn(&Vector<f32>) -> Vector<f32>,
{
panic!(
"{} does not support batch optimization (minimize). Use step() for stochastic updates.",
std::any::type_name::<Self>()
)
}
fn reset(&mut self);
}