Skip to main content

oxiphysics_core/
convex_analysis.rs

1#![allow(clippy::needless_range_loop, clippy::type_complexity)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! Convex analysis, proximal operators, and constrained optimisation.
6//!
7//! Implements convex sets (hyperplane, halfspace, ball, polytope), projection
8//! operators, proximal maps, dual decomposition with subgradient updates,
9//! ADMM, Frank-Wolfe / conditional gradient, and the ellipsoid method.
10
11#![allow(dead_code)]
12#![allow(clippy::too_many_arguments)]
13
14// ─── Convexity Check ─────────────────────────────────────────────────────────
15
16/// Check whether a discretely sampled 1-D function is (approximately) convex.
17///
18/// Convexity is verified by checking that all finite second differences are
19/// ≥ −`tol`.
20///
21/// Returns `true` for empty or single-element slices.
22pub fn is_convex_1d(f: &[f64], tol: f64) -> bool {
23    if f.len() < 3 {
24        return true;
25    }
26    for i in 1..f.len() - 1 {
27        let second_diff = f[i + 1] - 2.0 * f[i] + f[i - 1];
28        if second_diff < -tol {
29            return false;
30        }
31    }
32    true
33}
34
35// ─── ConvexSet ───────────────────────────────────────────────────────────────
36
37/// A representation of a convex set in ℝⁿ.
38///
39/// Supports hyperplanes, halfspaces, Euclidean balls, and polytopes in
40/// H-representation (intersection of halfspaces).
41#[derive(Debug, Clone)]
42pub enum ConvexSet {
43    /// Hyperplane `{x : aᵀx = b}`.
44    Hyperplane {
45        /// Normal vector `a`.
46        a: Vec<f64>,
47        /// Right-hand side `b`.
48        b: f64,
49    },
50    /// Halfspace `{x : aᵀx ≤ b}`.
51    Halfspace {
52        /// Normal vector `a`.
53        a: Vec<f64>,
54        /// Right-hand side `b`.
55        b: f64,
56    },
57    /// Euclidean ball `{x : ||x - c|| ≤ r}`.
58    Ball {
59        /// Centre `c`.
60        center: Vec<f64>,
61        /// Radius `r`.
62        radius: f64,
63    },
64    /// Polytope in H-form: `{x : A x ≤ b}`.
65    ///
66    /// `a_rows[i]` is the `i`-th row of A and `b_vec[i]` is `b_i`.
67    Polytope {
68        /// Rows of the constraint matrix.
69        a_rows: Vec<Vec<f64>>,
70        /// Right-hand side vector.
71        b_vec: Vec<f64>,
72    },
73}
74
75impl ConvexSet {
76    /// Check whether a point `x` belongs to this convex set.
77    pub fn contains(&self, x: &[f64]) -> bool {
78        match self {
79            ConvexSet::Hyperplane { a, b } => {
80                let v: f64 = dot(a, x);
81                (v - b).abs() < 1e-9
82            }
83            ConvexSet::Halfspace { a, b } => dot(a, x) <= b + 1e-9,
84            ConvexSet::Ball { center, radius } => {
85                let d: f64 = x
86                    .iter()
87                    .zip(center.iter())
88                    .map(|(xi, ci)| (xi - ci).powi(2))
89                    .sum::<f64>()
90                    .sqrt();
91                d <= radius + 1e-9
92            }
93            ConvexSet::Polytope { a_rows, b_vec } => a_rows
94                .iter()
95                .zip(b_vec.iter())
96                .all(|(row, &bi)| dot(row, x) <= bi + 1e-9),
97        }
98    }
99}
100
101/// Compute the inner product of two equal-length slices.
102fn dot(a: &[f64], b: &[f64]) -> f64 {
103    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
104}
105
106// ─── ProjectionOperator ───────────────────────────────────────────────────────
107
108/// Collection of projection operators onto standard convex sets.
109///
110/// Each method returns the projection of a given point onto a specified set.
111pub struct ProjectionOperator;
112
113impl ProjectionOperator {
114    /// Project `x` onto the hyperplane `{z : aᵀz = b}`.
115    ///
116    /// `proj(x) = x - (aᵀx - b) / ||a||² · a`
117    pub fn onto_hyperplane(x: &[f64], a: &[f64], b: f64) -> Vec<f64> {
118        let a_norm_sq: f64 = dot(a, a);
119        if a_norm_sq < 1e-14 {
120            return x.to_vec();
121        }
122        let alpha = (dot(a, x) - b) / a_norm_sq;
123        x.iter()
124            .zip(a.iter())
125            .map(|(xi, ai)| xi - alpha * ai)
126            .collect()
127    }
128
129    /// Project `x` onto the halfspace `{z : aᵀz ≤ b}`.
130    ///
131    /// If `x` is already in the halfspace returns `x` unchanged.
132    pub fn onto_halfspace(x: &[f64], a: &[f64], b: f64) -> Vec<f64> {
133        let ax = dot(a, x);
134        if ax <= b {
135            return x.to_vec();
136        }
137        let a_norm_sq: f64 = dot(a, a);
138        if a_norm_sq < 1e-14 {
139            return x.to_vec();
140        }
141        let alpha = (ax - b) / a_norm_sq;
142        x.iter()
143            .zip(a.iter())
144            .map(|(xi, ai)| xi - alpha * ai)
145            .collect()
146    }
147
148    /// Project `x` onto the Euclidean ball `B(center, radius)`.
149    pub fn onto_ball(x: &[f64], center: &[f64], radius: f64) -> Vec<f64> {
150        let diff: Vec<f64> = x
151            .iter()
152            .zip(center.iter())
153            .map(|(xi, ci)| xi - ci)
154            .collect();
155        let dist: f64 = diff.iter().map(|d| d * d).sum::<f64>().sqrt();
156        if dist <= radius {
157            return x.to_vec();
158        }
159        let scale = radius / dist;
160        center
161            .iter()
162            .zip(diff.iter())
163            .map(|(ci, di)| ci + scale * di)
164            .collect()
165    }
166
167    /// Project `x` onto the axis-aligned box `[lo, hi]` component-wise.
168    pub fn onto_box(x: &[f64], lo: &[f64], hi: &[f64]) -> Vec<f64> {
169        x.iter()
170            .zip(lo.iter().zip(hi.iter()))
171            .map(|(&xi, (&li, &hi_i))| xi.max(li).min(hi_i))
172            .collect()
173    }
174
175    /// Project `x` onto the probability simplex Δ = {z ≥ 0 : sum z = 1}.
176    pub fn onto_simplex(x: &[f64]) -> Vec<f64> {
177        proj_simplex(x)
178    }
179}
180
181// ─── Proximal Operators ───────────────────────────────────────────────────────
182
183/// A proximal operator parameterised by a regularisation strength `lambda`.
184#[derive(Debug, Clone)]
185pub struct ProximalOperator {
186    /// Regularisation strength λ > 0.
187    pub lambda: f64,
188}
189
190impl ProximalOperator {
191    /// Create a new proximal operator with the given `lambda`.
192    pub fn new(lambda: f64) -> Self {
193        Self { lambda }
194    }
195
196    /// Apply the L1 (soft-thresholding) proximal operator: `prox_{λ|·|}(x)`.
197    pub fn prox_l1(&self, x: f64) -> f64 {
198        prox_l1(x, self.lambda)
199    }
200
201    /// Apply the L2-squared proximal operator: `prox_{λ||·||²/2}(x)`.
202    pub fn prox_l2_sq(&self, x: f64) -> f64 {
203        prox_l2_sq(x, self.lambda)
204    }
205
206    /// Apply the Huber proximal operator with `delta` parameter.
207    pub fn prox_huber_val(&self, x: f64, delta: f64) -> f64 {
208        prox_huber(x, self.lambda, delta)
209    }
210
211    /// Apply the indicator proximal operator for the box `[lo, hi]`.
212    ///
213    /// The proximal map of the indicator function of a convex set is just the
214    /// projection onto that set.
215    pub fn prox_box_indicator(&self, x: f64, lo: f64, hi: f64) -> f64 {
216        x.max(lo).min(hi)
217    }
218
219    /// Apply the L1 proximal operator elementwise to a vector.
220    pub fn prox_l1_vec(&self, x: &[f64]) -> Vec<f64> {
221        x.iter().map(|&xi| prox_l1(xi, self.lambda)).collect()
222    }
223
224    /// Apply the L2-squared proximal operator elementwise to a vector.
225    pub fn prox_l2_sq_vec(&self, x: &[f64]) -> Vec<f64> {
226        x.iter().map(|&xi| prox_l2_sq(xi, self.lambda)).collect()
227    }
228}
229
230/// Soft-thresholding (proximal operator of the L1 norm): `prox_{λ|·|}(x)`.
231///
232/// Also known as the shrinkage operator: `sign(x) * max(|x| - λ, 0)`.
233pub fn prox_l1(x: f64, lambda: f64) -> f64 {
234    x.signum() * (x.abs() - lambda).max(0.0)
235}
236
237/// Proximal operator of the squared L2 norm (ridge): `x / (1 + λ)`.
238pub fn prox_l2_sq(x: f64, lambda: f64) -> f64 {
239    x / (1.0 + lambda)
240}
241
242/// Proximal operator of the Huber loss with parameter `delta`.
243///
244/// The Huber loss is:
245/// - `x²/2` for `|x| ≤ delta`
246/// - `delta * (|x| - delta/2)` for `|x| > delta`
247pub fn prox_huber(x: f64, lambda: f64, delta: f64) -> f64 {
248    let threshold = delta * (1.0 + lambda);
249    if x.abs() <= threshold {
250        x / (1.0 + lambda)
251    } else {
252        x - lambda * delta * x.signum()
253    }
254}
255
256// ─── Subgradient Descent ─────────────────────────────────────────────────────
257
258/// Configuration for subgradient descent.
259#[derive(Debug, Clone)]
260pub struct SubgradientMethod {
261    /// Step size (learning rate).
262    pub step_size: f64,
263    /// Maximum number of iterations.
264    pub max_iter: usize,
265}
266
267impl SubgradientMethod {
268    /// Create a new subgradient method configuration.
269    pub fn new(step_size: f64, max_iter: usize) -> Self {
270        Self {
271            step_size,
272            max_iter,
273        }
274    }
275}
276
277/// Run `iters` steps of subgradient descent on `grad_f` from `x0` with step
278/// size `lr`.
279///
280/// Returns the final iterate.
281pub fn subgradient_descent(
282    grad_f: impl Fn(&[f64]) -> Vec<f64>,
283    mut x0: Vec<f64>,
284    lr: f64,
285    iters: usize,
286) -> Vec<f64> {
287    for _ in 0..iters {
288        let g = grad_f(&x0);
289        for (xi, gi) in x0.iter_mut().zip(g.iter()) {
290            *xi -= lr * gi;
291        }
292    }
293    x0
294}
295
296// ─── DualDecomposition ───────────────────────────────────────────────────────
297
298/// Lagrangian relaxation with subgradient updates for dual decomposition.
299///
300/// Minimises `f(x) + g(z)` subject to `Ax + Bz = c` by relaxing the
301/// coupling constraint.  The dual variables `y` are updated using the
302/// subgradient of the dual function.
303pub struct DualDecomposition {
304    /// Step size for dual variable updates.
305    pub step_size: f64,
306    /// Maximum number of dual iterations.
307    pub max_iter: usize,
308    /// Convergence tolerance on the primal feasibility gap ‖Ax + Bz − c‖.
309    pub tol: f64,
310}
311
312impl DualDecomposition {
313    /// Create a new `DualDecomposition` solver.
314    pub fn new(step_size: f64, max_iter: usize, tol: f64) -> Self {
315        Self {
316            step_size,
317            max_iter,
318            tol,
319        }
320    }
321
322    /// Run dual decomposition.
323    ///
324    /// - `x_solve`: minimiser of the Lagrangian w.r.t. `x` given dual `y` and `z`.
325    /// - `z_solve`: minimiser of the Lagrangian w.r.t. `z` given dual `y` and `x`.
326    /// - `mat_a`, `mat_b`, `c`: constraint `Ax + Bz = c`.
327    ///
328    /// Returns `(x_best, z_best, y_final, iterations)`.
329    pub fn solve(
330        &self,
331        x_init: Vec<f64>,
332        z_init: Vec<f64>,
333        y_init: Vec<f64>,
334        x_solve: &dyn Fn(&[f64], &[f64]) -> Vec<f64>,
335        z_solve: &dyn Fn(&[f64], &[f64]) -> Vec<f64>,
336        mat_a: &[Vec<f64>],
337        mat_b: &[Vec<f64>],
338        c: &[f64],
339    ) -> (Vec<f64>, Vec<f64>, Vec<f64>, usize) {
340        let m = c.len();
341        let mut x = x_init;
342        let mut z = z_init;
343        let mut y = y_init;
344        let mut best_x = x.clone();
345        let mut best_z = z.clone();
346        let mut best_gap = f64::MAX;
347
348        for iter in 0..self.max_iter {
349            // x-update
350            x = x_solve(&x, &y);
351            // z-update
352            z = z_solve(&z, &y);
353            // Dual update: y += step * (Ax + Bz - c)
354            let ax = matvec(mat_a, &x);
355            let bz = matvec(mat_b, &z);
356            let mut gap = 0.0_f64;
357            for i in 0..m {
358                let subgrad = ax[i] + bz[i] - c[i];
359                gap += subgrad * subgrad;
360                y[i] += self.step_size * subgrad;
361            }
362            gap = gap.sqrt();
363            if gap < best_gap {
364                best_gap = gap;
365                best_x = x.clone();
366                best_z = z.clone();
367            }
368            if gap < self.tol {
369                return (best_x, best_z, y, iter + 1);
370            }
371        }
372        (best_x, best_z, y, self.max_iter)
373    }
374}
375
376/// Multiply matrix `a` (rows × cols) by vector `x`.
377fn matvec(a: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
378    a.iter().map(|row| dot(row, x)).collect()
379}
380
381// ─── ADMMSolver ──────────────────────────────────────────────────────────────
382
383/// Result returned by the ADMM solver.
384#[derive(Debug, Clone)]
385pub struct AdmmResult {
386    /// Primal variable `x`.
387    pub x: Vec<f64>,
388    /// Split variable `z`.
389    pub z: Vec<f64>,
390    /// Dual variable `u` (scaled form).
391    pub u: Vec<f64>,
392    /// Number of iterations performed.
393    pub iterations: usize,
394    /// Whether the algorithm converged within tolerance.
395    pub converged: bool,
396    /// Final primal residual ‖x − z‖.
397    pub primal_residual: f64,
398}
399
400/// Alternating Direction Method of Multipliers (ADMM) for problems of the
401/// form:
402///
403/// ```text
404///   minimise   f(x) + g(z)
405///   subject to x = z
406/// ```
407///
408/// Uses the augmented Lagrangian with penalty parameter `rho`.
409pub struct AdmmSolver {
410    /// Augmented Lagrangian penalty ρ > 0.
411    pub rho: f64,
412    /// Maximum number of ADMM iterations.
413    pub max_iter: usize,
414    /// Convergence tolerance on primal and dual residuals.
415    pub tol: f64,
416}
417
418impl AdmmSolver {
419    /// Create a new ADMM solver.
420    pub fn new(rho: f64, max_iter: usize, tol: f64) -> Self {
421        Self { rho, max_iter, tol }
422    }
423
424    /// Solve the consensus ADMM problem.
425    ///
426    /// - `x_update`: solves `argmin_x f(x) + (ρ/2)||x − z + u||²`.
427    /// - `z_update`: solves `argmin_z g(z) + (ρ/2)||x − z + u||²`.
428    pub fn solve(
429        &self,
430        x_init: Vec<f64>,
431        x_update: &dyn Fn(&[f64], &[f64], f64) -> Vec<f64>,
432        z_update: &dyn Fn(&[f64], &[f64], f64) -> Vec<f64>,
433    ) -> AdmmResult {
434        let n = x_init.len();
435        let mut x = x_init;
436        let mut z = vec![0.0_f64; n];
437        let mut u = vec![0.0_f64; n];
438
439        for iter in 0..self.max_iter {
440            // x-update
441            let x_arg: Vec<f64> = z.iter().zip(u.iter()).map(|(zi, ui)| zi - ui).collect();
442            x = x_update(&x, &x_arg, self.rho);
443
444            // z-update
445            let z_arg: Vec<f64> = x.iter().zip(u.iter()).map(|(xi, ui)| xi + ui).collect();
446            let z_new = z_update(&z, &z_arg, self.rho);
447
448            // Dual update
449            let mut primal_res = 0.0_f64;
450            let mut dual_res = 0.0_f64;
451            for i in 0..n {
452                u[i] += x[i] - z_new[i];
453                primal_res += (x[i] - z_new[i]).powi(2);
454                dual_res += self.rho * (z_new[i] - z[i]).powi(2);
455            }
456            primal_res = primal_res.sqrt();
457            dual_res = dual_res.sqrt();
458            z = z_new;
459
460            if primal_res < self.tol && dual_res < self.tol {
461                return AdmmResult {
462                    x,
463                    z,
464                    u,
465                    iterations: iter + 1,
466                    converged: true,
467                    primal_residual: primal_res,
468                };
469            }
470        }
471
472        let pr: f64 = x
473            .iter()
474            .zip(z.iter())
475            .map(|(xi, zi)| (xi - zi).powi(2))
476            .sum::<f64>()
477            .sqrt();
478        AdmmResult {
479            x,
480            z,
481            u,
482            iterations: self.max_iter,
483            converged: false,
484            primal_residual: pr,
485        }
486    }
487}
488
489// ─── FrankWolfeOptimizer ──────────────────────────────────────────────────────
490
491/// Result from the Frank-Wolfe algorithm.
492#[derive(Debug, Clone)]
493pub struct FrankWolfeResult {
494    /// Approximate minimiser.
495    pub x: Vec<f64>,
496    /// Final objective value.
497    pub objective: f64,
498    /// Number of iterations.
499    pub iterations: usize,
500    /// Whether the duality gap dropped below tolerance.
501    pub converged: bool,
502}
503
504/// Frank-Wolfe (conditional gradient) method for constrained convex problems.
505///
506/// Minimises a smooth convex function `f` over a compact convex feasible set
507/// by solving a linear minimisation oracle at each step.
508pub struct FrankWolfeOptimizer {
509    /// Maximum number of iterations.
510    pub max_iter: usize,
511    /// Tolerance on the Frank-Wolfe duality gap.
512    pub tol: f64,
513}
514
515impl FrankWolfeOptimizer {
516    /// Create a new Frank-Wolfe optimiser.
517    pub fn new(max_iter: usize, tol: f64) -> Self {
518        Self { max_iter, tol }
519    }
520
521    /// Run Frank-Wolfe minimisation.
522    ///
523    /// - `grad_f`: gradient oracle `∇f(x)`.
524    /// - `lmo`: linear minimisation oracle: given `d`, return `argmin_{s ∈ C} d·s`.
525    /// - `f_val`: function value oracle `f(x)`.
526    /// - `x_init`: feasible starting point.
527    pub fn minimize(
528        &self,
529        x_init: Vec<f64>,
530        grad_f: &dyn Fn(&[f64]) -> Vec<f64>,
531        lmo: &dyn Fn(&[f64]) -> Vec<f64>,
532        f_val: &dyn Fn(&[f64]) -> f64,
533    ) -> FrankWolfeResult {
534        let n = x_init.len();
535        let mut x = x_init;
536
537        for iter in 0..self.max_iter {
538            let g = grad_f(&x);
539            let s = lmo(&g);
540
541            // Duality gap: g · (x − s)
542            let gap: f64 = g
543                .iter()
544                .zip(x.iter().zip(s.iter()))
545                .map(|(gi, (xi, si))| gi * (xi - si))
546                .sum();
547
548            if gap < self.tol {
549                return FrankWolfeResult {
550                    objective: f_val(&x),
551                    x,
552                    iterations: iter + 1,
553                    converged: true,
554                };
555            }
556
557            // Step size: 2 / (iter + 2)
558            let step = 2.0_f64 / (iter as f64 + 2.0);
559            for i in 0..n {
560                x[i] = (1.0 - step) * x[i] + step * s[i];
561            }
562        }
563
564        FrankWolfeResult {
565            objective: f_val(&x),
566            x,
567            iterations: self.max_iter,
568            converged: false,
569        }
570    }
571}
572
573// ─── EllipsoidMethod ─────────────────────────────────────────────────────────
574
575/// Result from the ellipsoid method.
576#[derive(Debug, Clone)]
577pub struct EllipsoidResult {
578    /// Approximate centre of feasibility.
579    pub x: Vec<f64>,
580    /// Whether a feasible point was found (all constraints satisfied).
581    pub feasible: bool,
582    /// Number of iterations performed.
583    pub iterations: usize,
584}
585
586/// Ellipsoid method for convex feasibility problems.
587///
588/// Given a list of constraint functions `g_i(x) ≤ 0` and a starting ellipsoid
589/// `E(x0, R²I)`, iteratively cuts the ellipsoid using the most-violated
590/// constraint until a feasible point is found or `max_iter` is exhausted.
591pub struct EllipsoidMethod {
592    /// Maximum number of iterations.
593    pub max_iter: usize,
594    /// Initial radius of the starting ball.
595    pub initial_radius: f64,
596    /// Tolerance for constraint satisfaction.
597    pub tol: f64,
598}
599
600impl EllipsoidMethod {
601    /// Create a new ellipsoid method solver.
602    pub fn new(max_iter: usize, initial_radius: f64, tol: f64) -> Self {
603        Self {
604            max_iter,
605            initial_radius,
606            tol,
607        }
608    }
609
610    /// Find a feasible point satisfying all `constraints`.
611    ///
612    /// Each element of `constraints` is a pair `(g, subgrad)` where `g(x)`
613    /// returns the constraint value and `subgrad(x)` returns a subgradient.
614    pub fn find_feasible(
615        &self,
616        x0: Vec<f64>,
617        constraints: &[(Box<dyn Fn(&[f64]) -> f64>, Box<dyn Fn(&[f64]) -> Vec<f64>>)],
618    ) -> EllipsoidResult {
619        let n = x0.len();
620        let mut xc = x0;
621        // Represent ellipsoid as E = { x : (x-xc)^T P^{-1} (x-xc) ≤ 1 }
622        // Store P (n×n symmetric positive definite), initially P = R²·I.
623        let r2 = self.initial_radius * self.initial_radius;
624        let mut p: Vec<Vec<f64>> = (0..n)
625            .map(|i| {
626                let mut row = vec![0.0_f64; n];
627                row[i] = r2;
628                row
629            })
630            .collect();
631
632        for iter in 0..self.max_iter {
633            // Check feasibility
634            let all_feasible = constraints.iter().all(|(g, _)| g(&xc) <= self.tol);
635            if all_feasible {
636                return EllipsoidResult {
637                    x: xc,
638                    feasible: true,
639                    iterations: iter,
640                };
641            }
642
643            // Find the most-violated constraint
644            let (most_viol_idx, _max_viol) = constraints
645                .iter()
646                .enumerate()
647                .map(|(i, (g, _))| (i, g(&xc)))
648                .filter(|(_, v)| *v > self.tol)
649                .fold((0usize, f64::NEG_INFINITY), |(bi, bv), (i, v)| {
650                    if v > bv { (i, v) } else { (bi, bv) }
651                });
652
653            let sg = constraints[most_viol_idx].1(&xc);
654            // Compute g_hat = P * sg / sqrt(sg^T P sg)
655            let p_sg = matvec(&p, &sg);
656            let sg_p_sg: f64 = dot(&sg, &p_sg);
657            if sg_p_sg < 1e-20 {
658                break;
659            }
660            let denom = sg_p_sg.sqrt();
661            let g_hat: Vec<f64> = p_sg.iter().map(|x| x / denom).collect();
662
663            let nf = n as f64;
664            // Update centre: xc_new = xc - (1/(n+1)) * g_hat
665            for i in 0..n {
666                xc[i] -= g_hat[i] / (nf + 1.0);
667            }
668            // Update P: P_new = (n²/(n²-1)) * (P - (2/(n+1)) * g_hat * g_hat^T)
669            let scale = nf * nf / (nf * nf - 1.0);
670            let rank1_scale = 2.0 / (nf + 1.0);
671            for i in 0..n {
672                for j in 0..n {
673                    p[i][j] = scale * (p[i][j] - rank1_scale * g_hat[i] * g_hat[j]);
674                }
675            }
676        }
677
678        let all_feasible = constraints.iter().all(|(g, _)| g(&xc) <= self.tol);
679        EllipsoidResult {
680            x: xc,
681            feasible: all_feasible,
682            iterations: self.max_iter,
683        }
684    }
685}
686
687// ─── Legendre–Fenchel Conjugate ───────────────────────────────────────────────
688
689/// Compute the Legendre–Fenchel conjugate `f*(y) = sup_{x ∈ domain} (y·x - f(x))`
690/// via a grid search over `n` equally-spaced points on `domain = (lo, hi)`.
691///
692/// # Panics
693/// Panics if `domain.0 >= domain.1` or `n == 0`.
694pub fn conjugate_function_1d(f: impl Fn(f64) -> f64, y: f64, domain: (f64, f64), n: usize) -> f64 {
695    assert!(domain.0 < domain.1, "domain must be non-empty");
696    assert!(n > 0, "n must be at least 1");
697    let (lo, hi) = domain;
698    let step = (hi - lo) / (n as f64 - 1.0).max(1.0);
699    (0..n)
700        .map(|i| {
701            let x = lo + i as f64 * step;
702            y * x - f(x)
703        })
704        .fold(f64::NEG_INFINITY, f64::max)
705}
706
707// ─── Support Function ─────────────────────────────────────────────────────────
708
709/// Support function of a convex polygon given by its vertices in ℝ².
710///
711/// `h_C(d) = max_{v ∈ C} d · v`.
712#[derive(Debug, Clone)]
713pub struct SupportFunction {
714    /// Vertices of the convex set.
715    pub vertices: Vec<[f64; 2]>,
716}
717
718impl SupportFunction {
719    /// Construct from a list of vertices.
720    pub fn new(vertices: Vec<[f64; 2]>) -> Self {
721        Self { vertices }
722    }
723
724    /// Evaluate the support function in the given `direction`.
725    ///
726    /// Returns `f64::NEG_INFINITY` for an empty vertex set.
727    pub fn evaluate(&self, direction: [f64; 2]) -> f64 {
728        self.vertices
729            .iter()
730            .map(|v| v[0] * direction[0] + v[1] * direction[1])
731            .fold(f64::NEG_INFINITY, f64::max)
732    }
733}
734
735// ─── Dykstra's Alternating Projections ───────────────────────────────────────
736
737/// Project `x` onto the intersection of convex sets using Dykstra's algorithm.
738///
739/// Each element of `sets` is a projection operator `π_i: ℝⁿ → ℝⁿ`.
740/// Runs `iters` passes over all sets.
741pub fn dykstra_projection(
742    x: &[f64],
743    sets: &[&dyn Fn(&[f64]) -> Vec<f64>],
744    iters: usize,
745) -> Vec<f64> {
746    let n = x.len();
747    if sets.is_empty() {
748        return x.to_vec();
749    }
750    let m = sets.len();
751    let mut y = x.to_vec();
752    let mut increments: Vec<Vec<f64>> = vec![vec![0.0; n]; m];
753
754    for _ in 0..iters {
755        for (i, proj) in sets.iter().enumerate() {
756            let z: Vec<f64> = y
757                .iter()
758                .zip(increments[i].iter())
759                .map(|(&yi, &ii)| yi + ii)
760                .collect();
761            let p = proj(&z);
762            for k in 0..n {
763                increments[i][k] = z[k] - p[k];
764            }
765            y = p;
766        }
767    }
768    y
769}
770
771// ─── Box Projection ───────────────────────────────────────────────────────────
772
773/// Project `x` onto the axis-aligned box `[lo, hi]` component-wise.
774///
775/// # Panics
776/// Panics if `lo`, `hi`, and `x` have different lengths.
777pub fn proj_box(x: &[f64], lo: &[f64], hi: &[f64]) -> Vec<f64> {
778    assert_eq!(x.len(), lo.len());
779    assert_eq!(x.len(), hi.len());
780    x.iter()
781        .zip(lo.iter().zip(hi.iter()))
782        .map(|(&xi, (&li, &hi_i))| xi.max(li).min(hi_i))
783        .collect()
784}
785
786// ─── Simplex Projection ───────────────────────────────────────────────────────
787
788/// Project `x` onto the probability simplex Δ = {z ≥ 0 : sum z = 1}.
789///
790/// Uses the O(n log n) algorithm by Duchi et al. (2008).
791pub fn proj_simplex(x: &[f64]) -> Vec<f64> {
792    let n = x.len();
793    if n == 0 {
794        return Vec::new();
795    }
796    let mut sorted = x.to_vec();
797    sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
798
799    let mut cumsum = 0.0_f64;
800    let mut rho = 0usize;
801    for (i, &si) in sorted.iter().enumerate() {
802        cumsum += si;
803        if si - (cumsum - 1.0) / (i as f64 + 1.0) > 0.0 {
804            rho = i;
805        }
806    }
807
808    let sum_rho: f64 = sorted[..=rho].iter().sum();
809    let theta = (sum_rho - 1.0) / (rho as f64 + 1.0);
810
811    x.iter().map(|&xi| (xi - theta).max(0.0)).collect()
812}
813
814// ─── Moreau Envelope ─────────────────────────────────────────────────────────
815
816/// Moreau envelope (inf-convolution) smoothing of a convex function.
817///
818/// The Moreau envelope `M_μf(x) = inf_z [f(z) + ||x - z||²/(2μ)]` provides a
819/// smooth approximation to `f`.
820#[derive(Debug, Clone)]
821pub struct MorseEnvelope {
822    /// Smoothing parameter μ > 0.
823    pub mu: f64,
824}
825
826impl MorseEnvelope {
827    /// Create a Moreau-envelope smoother with parameter `mu`.
828    pub fn new(mu: f64) -> Self {
829        Self { mu }
830    }
831
832    /// Evaluate the Moreau envelope of |·| at `x`.
833    pub fn eval_l1(&self, x: f64) -> f64 {
834        moreau_envelope_l1(x, self.mu)
835    }
836}
837
838/// Moreau envelope of the L1 norm: smooth approximation to `|x|`.
839///
840/// `M_μ|·|(x) = |x| - μ/2` if `|x| ≥ μ`, else `x²/(2μ)`.
841pub fn moreau_envelope_l1(x: f64, mu: f64) -> f64 {
842    if x.abs() >= mu {
843        x.abs() - mu / 2.0
844    } else {
845        x * x / (2.0 * mu)
846    }
847}
848
849// ─── Tests ───────────────────────────────────────────────────────────────────
850
851#[cfg(test)]
852mod tests {
853    use super::*;
854
855    // --- is_convex_1d ---
856
857    #[test]
858    fn test_convex_quadratic() {
859        let f: Vec<f64> = (-10..=10).map(|i| (i as f64).powi(2)).collect();
860        assert!(is_convex_1d(&f, 1e-9));
861    }
862
863    #[test]
864    fn test_non_convex_neg_quadratic() {
865        let f: Vec<f64> = (-10..=10).map(|i| -(i as f64).powi(2)).collect();
866        assert!(!is_convex_1d(&f, 1e-9));
867    }
868
869    #[test]
870    fn test_convex_empty() {
871        assert!(is_convex_1d(&[], 1e-9));
872    }
873
874    #[test]
875    fn test_convex_single() {
876        assert!(is_convex_1d(&[3.0], 1e-9));
877    }
878
879    #[test]
880    fn test_convex_linear() {
881        let f: Vec<f64> = (0..20).map(|i| 3.0 * i as f64 + 1.0).collect();
882        assert!(is_convex_1d(&f, 1e-9));
883    }
884
885    // --- ConvexSet::contains ---
886
887    #[test]
888    fn test_convex_set_hyperplane_contains() {
889        let hs = ConvexSet::Hyperplane {
890            a: vec![1.0, 0.0],
891            b: 3.0,
892        };
893        assert!(hs.contains(&[3.0, 5.0]));
894        assert!(!hs.contains(&[2.0, 5.0]));
895    }
896
897    #[test]
898    fn test_convex_set_halfspace_contains() {
899        let hs = ConvexSet::Halfspace {
900            a: vec![1.0, 1.0],
901            b: 2.0,
902        };
903        assert!(hs.contains(&[0.5, 0.5]));
904        assert!(!hs.contains(&[2.0, 2.0]));
905    }
906
907    #[test]
908    fn test_convex_set_ball_contains() {
909        let ball = ConvexSet::Ball {
910            center: vec![0.0, 0.0],
911            radius: 1.0,
912        };
913        assert!(ball.contains(&[0.5, 0.5]));
914        assert!(!ball.contains(&[1.0, 1.0]));
915    }
916
917    #[test]
918    fn test_convex_set_polytope_contains() {
919        // Unit cube [0,1]²
920        let poly = ConvexSet::Polytope {
921            a_rows: vec![
922                vec![1.0, 0.0],
923                vec![-1.0, 0.0],
924                vec![0.0, 1.0],
925                vec![0.0, -1.0],
926            ],
927            b_vec: vec![1.0, 0.0, 1.0, 0.0],
928        };
929        assert!(poly.contains(&[0.5, 0.5]));
930        assert!(!poly.contains(&[1.5, 0.5]));
931    }
932
933    // --- ProjectionOperator ---
934
935    #[test]
936    fn test_proj_onto_hyperplane() {
937        // Hyperplane: x + y = 1; project (0, 0)
938        let a = vec![1.0_f64, 1.0];
939        let p = ProjectionOperator::onto_hyperplane(&[0.0, 0.0], &a, 1.0);
940        let val: f64 = p[0] + p[1];
941        assert!(
942            (val - 1.0).abs() < 1e-10,
943            "projected point must satisfy a·x = b"
944        );
945    }
946
947    #[test]
948    fn test_proj_onto_halfspace_inside() {
949        let a = vec![1.0_f64, 0.0];
950        let x = vec![0.5_f64, 1.0];
951        let p = ProjectionOperator::onto_halfspace(&x, &a, 1.0);
952        assert!((p[0] - 0.5).abs() < 1e-12); // already inside
953    }
954
955    #[test]
956    fn test_proj_onto_halfspace_outside() {
957        let a = vec![1.0_f64, 0.0];
958        let x = vec![2.0_f64, 1.0];
959        let p = ProjectionOperator::onto_halfspace(&x, &a, 1.0);
960        assert!((p[0] - 1.0).abs() < 1e-10); // projected to boundary
961    }
962
963    #[test]
964    fn test_proj_onto_ball_inside() {
965        let x = [0.1_f64, 0.1];
966        let p = ProjectionOperator::onto_ball(&x, &[0.0, 0.0], 1.0);
967        assert!((p[0] - 0.1).abs() < 1e-12);
968        assert!((p[1] - 0.1).abs() < 1e-12);
969    }
970
971    #[test]
972    fn test_proj_onto_ball_outside() {
973        let x = [3.0_f64, 4.0];
974        let p = ProjectionOperator::onto_ball(&x, &[0.0, 0.0], 1.0);
975        let dist: f64 = (p[0].powi(2) + p[1].powi(2)).sqrt();
976        assert!(
977            (dist - 1.0).abs() < 1e-10,
978            "projected point must lie on the ball boundary"
979        );
980    }
981
982    #[test]
983    fn test_proj_onto_box() {
984        let x = vec![-1.0_f64, 0.5, 3.0];
985        let lo = vec![0.0_f64, 0.0, 0.0];
986        let hi = vec![1.0_f64, 1.0, 1.0];
987        let p = ProjectionOperator::onto_box(&x, &lo, &hi);
988        assert_eq!(p, vec![0.0, 0.5, 1.0]);
989    }
990
991    #[test]
992    fn test_proj_onto_simplex_via_operator() {
993        let x = vec![3.0_f64, 1.0, -1.0, 0.5];
994        let p = ProjectionOperator::onto_simplex(&x);
995        let sum: f64 = p.iter().sum();
996        assert!((sum - 1.0).abs() < 1e-10);
997    }
998
999    // --- prox_l1 ---
1000
1001    #[test]
1002    fn test_prox_l1_positive_above_threshold() {
1003        assert!((prox_l1(3.0, 1.0) - 2.0).abs() < 1e-12);
1004    }
1005
1006    #[test]
1007    fn test_prox_l1_negative_above_threshold() {
1008        assert!((prox_l1(-3.0, 1.0) + 2.0).abs() < 1e-12);
1009    }
1010
1011    #[test]
1012    fn test_prox_l1_below_threshold() {
1013        assert_eq!(prox_l1(0.5, 1.0), 0.0);
1014    }
1015
1016    #[test]
1017    fn test_prox_l1_at_threshold() {
1018        assert_eq!(prox_l1(1.0, 1.0), 0.0);
1019    }
1020
1021    #[test]
1022    fn test_prox_l1_zero() {
1023        assert_eq!(prox_l1(0.0, 0.5), 0.0);
1024    }
1025
1026    // --- prox_l2_sq ---
1027
1028    #[test]
1029    fn test_prox_l2_sq_basic() {
1030        assert!((prox_l2_sq(3.0, 2.0) - 1.0).abs() < 1e-12);
1031    }
1032
1033    #[test]
1034    fn test_prox_l2_sq_lambda_zero() {
1035        assert!((prox_l2_sq(5.0, 0.0) - 5.0).abs() < 1e-12);
1036    }
1037
1038    #[test]
1039    fn test_prox_l2_sq_negative() {
1040        assert!((prox_l2_sq(-4.0, 1.0) + 2.0).abs() < 1e-12);
1041    }
1042
1043    // --- prox_huber ---
1044
1045    #[test]
1046    fn test_prox_huber_small_x() {
1047        let val = prox_huber(0.5, 1.0, 1.0);
1048        assert!((val - 0.25).abs() < 1e-12, "val = {}", val);
1049    }
1050
1051    #[test]
1052    fn test_prox_huber_large_x() {
1053        let val = prox_huber(5.0, 1.0, 1.0);
1054        assert!((val - 4.0).abs() < 1e-12, "val = {}", val);
1055    }
1056
1057    // --- ProximalOperator struct ---
1058
1059    #[test]
1060    fn test_proximal_operator_l1() {
1061        let po = ProximalOperator::new(2.0);
1062        assert!((po.prox_l1(5.0) - 3.0).abs() < 1e-12);
1063    }
1064
1065    #[test]
1066    fn test_proximal_operator_l2_sq() {
1067        let po = ProximalOperator::new(1.0);
1068        assert!((po.prox_l2_sq(4.0) - 2.0).abs() < 1e-12);
1069    }
1070
1071    #[test]
1072    fn test_proximal_operator_box_indicator() {
1073        let po = ProximalOperator::new(1.0);
1074        assert!((po.prox_box_indicator(3.0, 0.0, 2.0) - 2.0).abs() < 1e-12);
1075        assert!((po.prox_box_indicator(-1.0, 0.0, 2.0) - 0.0).abs() < 1e-12);
1076    }
1077
1078    #[test]
1079    fn test_proximal_operator_l1_vec() {
1080        let po = ProximalOperator::new(1.0);
1081        let x = vec![3.0, -3.0, 0.5, 0.0];
1082        let p = po.prox_l1_vec(&x);
1083        assert!((p[0] - 2.0).abs() < 1e-12);
1084        assert!((p[1] + 2.0).abs() < 1e-12);
1085        assert_eq!(p[2], 0.0);
1086        assert_eq!(p[3], 0.0);
1087    }
1088
1089    // --- subgradient_descent ---
1090
1091    #[test]
1092    fn test_subgradient_descent_quadratic() {
1093        let result = subgradient_descent(|x| vec![2.0 * (x[0] - 3.0)], vec![0.0], 0.1, 500);
1094        assert!((result[0] - 3.0).abs() < 0.2, "result = {}", result[0]);
1095    }
1096
1097    #[test]
1098    fn test_subgradient_descent_zero_iters() {
1099        let result = subgradient_descent(|_| vec![1.0], vec![5.0], 0.1, 0);
1100        assert!((result[0] - 5.0).abs() < 1e-12);
1101    }
1102
1103    // --- ADMM solver ---
1104
1105    #[test]
1106    fn test_admm_lasso_converges() {
1107        // Minimise (1/2)||x - v||² + lambda||z||_1  s.t. x = z
1108        // x-update: argmin_x (1/2)||x-v||² + (rho/2)||x - (z-u)||²
1109        //         = (v + rho*(z-u)) / (1 + rho)
1110        // z-update: prox_{lambda/rho}(x + u) = soft-threshold
1111        // Closed-form solution: x* = z* = soft_threshold(v, lambda)
1112        let v = vec![3.0_f64, -2.0];
1113        let rho = 1.0_f64;
1114        let lambda = 0.5_f64;
1115        let v_clone1 = v.clone();
1116        let v_clone2 = v.clone();
1117        let admm = AdmmSolver::new(rho, 500, 1e-6);
1118        let x_upd = move |_x: &[f64], z_minus_u: &[f64], rho_val: f64| -> Vec<f64> {
1119            v_clone1
1120                .iter()
1121                .zip(z_minus_u.iter())
1122                .map(|(&vi, &zui)| (vi + rho_val * zui) / (1.0 + rho_val))
1123                .collect()
1124        };
1125        let z_upd = move |_z: &[f64], x_plus_u: &[f64], rho_val: f64| -> Vec<f64> {
1126            let _ = v_clone2;
1127            x_plus_u
1128                .iter()
1129                .map(|&xi| prox_l1(xi, lambda / rho_val))
1130                .collect()
1131        };
1132        let x_init = vec![0.0_f64, 0.0];
1133        let result = admm.solve(x_init, &x_upd, &z_upd);
1134        // Closed-form solution: soft_threshold(v, lambda)
1135        let expected: Vec<f64> = v.iter().map(|&vi| prox_l1(vi, lambda)).collect();
1136        for i in 0..2 {
1137            assert!(
1138                (result.x[i] - expected[i]).abs() < 0.1,
1139                "x[{}] = {}, expected {}",
1140                i,
1141                result.x[i],
1142                expected[i]
1143            );
1144        }
1145    }
1146
1147    #[test]
1148    fn test_admm_result_fields() {
1149        let admm = AdmmSolver::new(1.0, 10, 1e-8);
1150        let x_upd = |_x: &[f64], z_u: &[f64], _rho: f64| z_u.to_vec();
1151        let z_upd = |_z: &[f64], x_u: &[f64], _rho: f64| x_u.to_vec();
1152        let res = admm.solve(vec![0.0], &x_upd, &z_upd);
1153        assert!(res.iterations > 0);
1154    }
1155
1156    // --- Frank-Wolfe on simplex ---
1157
1158    #[test]
1159    fn test_frank_wolfe_simplex_quadratic() {
1160        // Minimise ||x - c||² over the simplex, c = [0.3, 0.7]
1161        let c = [0.3_f64, 0.7];
1162        let grad_f = |x: &[f64]| -> Vec<f64> {
1163            x.iter()
1164                .zip(c.iter())
1165                .map(|(xi, ci)| 2.0 * (xi - ci))
1166                .collect()
1167        };
1168        let lmo = |d: &[f64]| -> Vec<f64> {
1169            // argmin d·s over simplex = standard basis vector for argmin d_i
1170            let n = d.len();
1171            let min_idx = d
1172                .iter()
1173                .enumerate()
1174                .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
1175                .map(|(i, _)| i)
1176                .unwrap_or(0);
1177            let mut s = vec![0.0_f64; n];
1178            s[min_idx] = 1.0;
1179            s
1180        };
1181        let f_val = |x: &[f64]| -> f64 {
1182            x.iter()
1183                .zip(c.iter())
1184                .map(|(xi, ci)| (xi - ci).powi(2))
1185                .sum()
1186        };
1187        let fw = FrankWolfeOptimizer::new(500, 1e-6);
1188        let x0 = vec![0.5_f64, 0.5];
1189        let result = fw.minimize(x0, &grad_f, &lmo, &f_val);
1190        // Optimal point on simplex closest to c = (0.3, 0.7) is exactly c (it's already on simplex)
1191        let sum: f64 = result.x.iter().sum();
1192        assert!((sum - 1.0).abs() < 0.01, "sum = {}", sum);
1193        for xi in &result.x {
1194            assert!(*xi >= -1e-10, "simplex component negative");
1195        }
1196    }
1197
1198    #[test]
1199    fn test_frank_wolfe_result_converged_flag() {
1200        let grad_f = |x: &[f64]| vec![2.0 * x[0]];
1201        let lmo = |_d: &[f64]| vec![0.0_f64];
1202        let f_val = |x: &[f64]| x[0].powi(2);
1203        let fw = FrankWolfeOptimizer::new(1000, 1e-6);
1204        let result = fw.minimize(vec![1.0], &grad_f, &lmo, &f_val);
1205        let _ = result.converged; // field exists and is accessible
1206        assert!(result.objective >= 0.0);
1207    }
1208
1209    // --- EllipsoidMethod ---
1210
1211    #[test]
1212    fn test_ellipsoid_ball_feasibility() {
1213        // Find x ∈ ℝ² with ||x - (1,1)||² ≤ 0.5 and x[0] ≥ 0
1214        let constraint1: Box<dyn Fn(&[f64]) -> f64> =
1215            Box::new(|x: &[f64]| (x[0] - 1.0).powi(2) + (x[1] - 1.0).powi(2) - 0.5);
1216        let sg1: Box<dyn Fn(&[f64]) -> Vec<f64>> =
1217            Box::new(|x: &[f64]| vec![2.0 * (x[0] - 1.0), 2.0 * (x[1] - 1.0)]);
1218        let constraint2: Box<dyn Fn(&[f64]) -> f64> = Box::new(|x: &[f64]| -x[0]);
1219        let sg2: Box<dyn Fn(&[f64]) -> Vec<f64>> = Box::new(|_x: &[f64]| vec![-1.0, 0.0]);
1220        let ellipsoid = EllipsoidMethod::new(200, 5.0, 1e-4);
1221        let result =
1222            ellipsoid.find_feasible(vec![0.0, 0.0], &[(constraint1, sg1), (constraint2, sg2)]);
1223        assert!(
1224            result.feasible,
1225            "ellipsoid method should find feasible point"
1226        );
1227    }
1228
1229    #[test]
1230    fn test_ellipsoid_result_fields() {
1231        let c: Box<dyn Fn(&[f64]) -> f64> = Box::new(|x: &[f64]| x[0] - 1.0);
1232        let sg: Box<dyn Fn(&[f64]) -> Vec<f64>> = Box::new(|_x: &[f64]| vec![1.0]);
1233        let em = EllipsoidMethod::new(50, 10.0, 1e-4);
1234        let result = em.find_feasible(vec![0.5], &[(c, sg)]);
1235        assert!(result.iterations > 0 || result.feasible);
1236    }
1237
1238    // --- conjugate_function_1d ---
1239
1240    #[test]
1241    fn test_conjugate_x_squared_half() {
1242        let y = 2.0_f64;
1243        let conj = conjugate_function_1d(|x| x * x / 2.0, y, (-10.0, 10.0), 10_000);
1244        assert!((conj - y * y / 2.0).abs() < 0.01, "conj = {}", conj);
1245    }
1246
1247    #[test]
1248    fn test_conjugate_abs() {
1249        let conj = conjugate_function_1d(|x| x.abs(), 0.5, (-5.0, 5.0), 1000);
1250        assert!(conj >= -0.01, "f*(0.5) should be ≥ 0, got {}", conj);
1251    }
1252
1253    // --- SupportFunction ---
1254
1255    #[test]
1256    fn test_support_function_unit_square() {
1257        let verts = vec![[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0]];
1258        let sf = SupportFunction::new(verts);
1259        assert!((sf.evaluate([1.0, 0.0]) - 1.0).abs() < 1e-12);
1260        assert!((sf.evaluate([1.0, 1.0]) - 2.0).abs() < 1e-12);
1261    }
1262
1263    #[test]
1264    fn test_support_function_empty() {
1265        let sf = SupportFunction::new(vec![]);
1266        assert!(sf.evaluate([1.0, 0.0]).is_infinite());
1267    }
1268
1269    // --- proj_box ---
1270
1271    #[test]
1272    fn test_proj_box_clamps() {
1273        let x = vec![-1.0, 0.5, 3.0];
1274        let lo = vec![0.0, 0.0, 0.0];
1275        let hi = vec![1.0, 1.0, 1.0];
1276        let p = proj_box(&x, &lo, &hi);
1277        assert_eq!(p, vec![0.0, 0.5, 1.0]);
1278    }
1279
1280    #[test]
1281    fn test_proj_box_no_clamp() {
1282        let x = vec![0.3, 0.7];
1283        let lo = vec![0.0, 0.0];
1284        let hi = vec![1.0, 1.0];
1285        let p = proj_box(&x, &lo, &hi);
1286        assert!((p[0] - 0.3).abs() < 1e-12);
1287        assert!((p[1] - 0.7).abs() < 1e-12);
1288    }
1289
1290    // --- proj_simplex ---
1291
1292    #[test]
1293    fn test_proj_simplex_sums_to_one() {
1294        let x = vec![3.0, 1.0, -1.0, 0.5];
1295        let p = proj_simplex(&x);
1296        let sum: f64 = p.iter().sum();
1297        assert!((sum - 1.0).abs() < 1e-10, "sum = {}", sum);
1298    }
1299
1300    #[test]
1301    fn test_proj_simplex_non_negative() {
1302        let x = vec![-5.0, -3.0, -1.0];
1303        let p = proj_simplex(&x);
1304        for &pi in &p {
1305            assert!(pi >= -1e-12, "component {} is negative", pi);
1306        }
1307    }
1308
1309    #[test]
1310    fn test_proj_simplex_already_on_simplex() {
1311        let x = vec![0.25, 0.25, 0.25, 0.25];
1312        let p = proj_simplex(&x);
1313        let sum: f64 = p.iter().sum();
1314        assert!((sum - 1.0).abs() < 1e-10);
1315        for (&xi, &pi) in x.iter().zip(p.iter()) {
1316            assert!((xi - pi).abs() < 1e-10);
1317        }
1318    }
1319
1320    #[test]
1321    fn test_proj_simplex_empty() {
1322        let p = proj_simplex(&[]);
1323        assert!(p.is_empty());
1324    }
1325
1326    #[test]
1327    fn test_proj_simplex_single() {
1328        let p = proj_simplex(&[3.0]);
1329        assert!((p[0] - 1.0).abs() < 1e-10);
1330    }
1331
1332    // --- dykstra_projection ---
1333
1334    #[test]
1335    fn test_dykstra_no_sets() {
1336        let x = vec![1.0, 2.0, 3.0];
1337        let result = dykstra_projection(&x, &[], 10);
1338        assert_eq!(result, x);
1339    }
1340
1341    #[test]
1342    fn test_dykstra_box_projection() {
1343        let x = vec![-0.5, 1.5];
1344        let proj_box_fn =
1345            |z: &[f64]| -> Vec<f64> { z.iter().map(|&xi| xi.clamp(0.0, 1.0)).collect() };
1346        let sets: Vec<&dyn Fn(&[f64]) -> Vec<f64>> = vec![&proj_box_fn];
1347        let result = dykstra_projection(&x, &sets, 5);
1348        assert!((result[0] - 0.0).abs() < 1e-10);
1349        assert!((result[1] - 1.0).abs() < 1e-10);
1350    }
1351
1352    // --- moreau_envelope_l1 ---
1353
1354    #[test]
1355    fn test_moreau_envelope_l1_large_x() {
1356        let val = moreau_envelope_l1(2.0, 0.5);
1357        assert!((val - (2.0 - 0.25)).abs() < 1e-12, "val = {}", val);
1358    }
1359
1360    #[test]
1361    fn test_moreau_envelope_l1_small_x() {
1362        let val = moreau_envelope_l1(0.5, 2.0);
1363        assert!((val - (0.25 / 4.0)).abs() < 1e-12, "val = {}", val);
1364    }
1365
1366    #[test]
1367    fn test_moreau_envelope_l1_non_negative() {
1368        for x in [-3.0, -1.0, 0.0, 0.5, 2.0] {
1369            let val = moreau_envelope_l1(x, 1.0);
1370            assert!(
1371                val >= 0.0,
1372                "Moreau envelope of |·| should be ≥ 0, got {}",
1373                val
1374            );
1375        }
1376    }
1377
1378    #[test]
1379    fn test_moreau_envelope_l1_continuous_at_mu() {
1380        let mu = 1.5;
1381        let from_above = moreau_envelope_l1(mu + 1e-9, mu);
1382        let from_inside = moreau_envelope_l1(mu - 1e-9, mu);
1383        assert!(
1384            (from_above - from_inside).abs() < 1e-6,
1385            "discontinuity at mu"
1386        );
1387    }
1388
1389    #[test]
1390    fn test_morse_envelope_struct() {
1391        let me = MorseEnvelope::new(1.0);
1392        let v = me.eval_l1(2.0);
1393        assert!((v - 1.5).abs() < 1e-12);
1394    }
1395
1396    // --- SubgradientMethod struct ---
1397
1398    #[test]
1399    fn test_subgradient_method_struct() {
1400        let sm = SubgradientMethod::new(0.01, 100);
1401        assert!((sm.step_size - 0.01).abs() < 1e-12);
1402        assert_eq!(sm.max_iter, 100);
1403    }
1404
1405    // --- DualDecomposition ---
1406
1407    #[test]
1408    fn test_dual_decomposition_basic() {
1409        // Simple consensus: min x² + z², s.t. x = z (A=I, B=-I, c=0)
1410        // Optimal: x = z = 0
1411        let dd = DualDecomposition::new(0.1, 200, 1e-4);
1412        let x_solve = |_x: &[f64], y: &[f64]| -> Vec<f64> {
1413            // argmin_x x² + y*x = -y/2
1414            vec![-y[0] / 2.0]
1415        };
1416        let z_solve = |_z: &[f64], y: &[f64]| -> Vec<f64> {
1417            // argmin_z z² - y*z = y/2
1418            vec![y[0] / 2.0]
1419        };
1420        let mat_a = vec![vec![1.0_f64]];
1421        let mat_b = vec![vec![-1.0_f64]];
1422        let c = vec![0.0_f64];
1423        let (x, z, _y, _iters) = dd.solve(
1424            vec![1.0],
1425            vec![1.0],
1426            vec![0.0],
1427            &x_solve,
1428            &z_solve,
1429            &mat_a,
1430            &mat_b,
1431            &c,
1432        );
1433        // Best iterate should have small constraint violation |x - z|
1434        assert!(
1435            (x[0] - z[0]).abs() < 0.5,
1436            "|x - z| = {}",
1437            (x[0] - z[0]).abs()
1438        );
1439    }
1440}