use crate::numeric::cast;
use super::Constraint;
use crate::{FunctionCallResult, SolverError};
use num::Float;
#[derive(Copy, Clone)]
pub struct BallP<'a, T = f64> {
center: Option<&'a [T]>,
radius: T,
p: T,
tolerance: T,
max_iter: usize,
}
impl<'a, T: Float> BallP<'a, T> {
pub fn new(center: Option<&'a [T]>, radius: T, p: T, tolerance: T, max_iter: usize) -> Self {
assert!(radius > T::zero());
assert!(p > T::one() && p.is_finite());
assert!(tolerance > T::zero());
assert!(max_iter > 0);
BallP {
center,
radius,
p,
tolerance,
max_iter,
}
}
#[inline]
fn lp_norm(&self, x: &[T]) -> T {
x.iter()
.map(|xi| xi.abs().powf(self.p))
.fold(T::zero(), |sum, xi| sum + xi)
.powf(T::one() / self.p)
}
fn project_lp_ball(&self, x: &mut [T]) -> FunctionCallResult {
let p = self.p;
let r = self.radius;
let tol = self.tolerance;
let max_iter = self.max_iter;
let current_norm = self.lp_norm(x);
if current_norm <= r {
return Ok(());
}
let abs_x: Vec<T> = x.iter().map(|xi| xi.abs()).collect();
let target = r.powf(p);
let radius_error = |lambda: T| -> T {
abs_x
.iter()
.map(|&a| {
let u = Self::solve_coordinate_newton(a, lambda, p, tol, max_iter);
u.powf(p)
})
.fold(T::zero(), |sum, ui| sum + ui)
- target
};
let mut lambda_lo = T::zero();
let mut lambda_hi = T::one();
while radius_error(lambda_hi) > T::zero() {
lambda_hi = lambda_hi * cast::<T>(2.0);
if lambda_hi > cast::<T>(1e20) {
return Err(SolverError::ProjectionFailed(
"failed to bracket the Lagrange multiplier",
));
}
}
for _ in 0..max_iter {
let lambda_mid = cast::<T>(0.5) * (lambda_lo + lambda_hi);
let err = radius_error(lambda_mid);
if err.abs() <= tol {
lambda_lo = lambda_mid;
lambda_hi = lambda_mid;
break;
}
if err > T::zero() {
lambda_lo = lambda_mid;
} else {
lambda_hi = lambda_mid;
}
}
let lambda_star = cast::<T>(0.5) * (lambda_lo + lambda_hi);
x.iter_mut().zip(abs_x.iter()).for_each(|(xi, &a)| {
let u = Self::solve_coordinate_newton(a, lambda_star, p, tol, max_iter);
*xi = xi.signum() * u;
});
Ok(())
}
fn solve_coordinate_newton(a: T, lambda: T, p: T, tol: T, max_iter: usize) -> T {
if a == T::zero() {
return T::zero();
}
if lambda == T::zero() {
return a;
}
let mut lo = T::zero();
let mut hi = a;
let mut u = (a / (T::one() + lambda * p)).clamp(lo, hi);
for _ in 0..max_iter {
let upm1 = u.powf(p - T::one());
let f = u + lambda * p * upm1 - a;
if f.abs() <= tol {
return u;
}
if f > T::zero() {
hi = u;
} else {
lo = u;
}
let df = T::one() + lambda * p * (p - T::one()) * u.powf(p - cast::<T>(2.0));
let mut candidate = u - f / df;
if !candidate.is_finite() || candidate <= lo || candidate >= hi {
candidate = cast::<T>(0.5) * (lo + hi);
}
if (candidate - u).abs() <= tol * (T::one() + u.abs()) {
return candidate;
}
u = candidate;
}
cast::<T>(0.5) * (lo + hi)
}
}
impl<'a, T: Float> Constraint<T> for BallP<'a, T> {
fn project(&self, x: &mut [T]) -> FunctionCallResult {
if let Some(center) = &self.center {
assert_eq!(
x.len(),
center.len(),
"x and xc have incompatible dimensions"
);
let mut shifted = vec![T::zero(); x.len()];
shifted
.iter_mut()
.zip(x.iter().zip(center.iter()))
.for_each(|(s, (xi, ci))| *s = *xi - *ci);
self.project_lp_ball(&mut shifted)?;
x.iter_mut()
.zip(shifted.iter().zip(center.iter()))
.for_each(|(xi, (si, ci))| *xi = *ci + *si);
} else {
self.project_lp_ball(x)?;
}
Ok(())
}
fn is_convex(&self) -> bool {
true
}
}