Skip to main content

cjc_runtime/
optimize.rs

1//! Optimization & Root Finding — deterministic numerical solvers.
2//!
3//! # Determinism Contract
4//!
5//! All algorithms in this module are fully deterministic: given the same inputs
6//! and the same objective/gradient functions, they produce **bit-identical** results
7//! across runs. Floating-point reductions use `binned_sum_f64` from the accumulator
8//! module to avoid ordering-dependent rounding.
9//!
10//! # Scalar Root Finding
11//!
12//! - [`bisect`] — bisection method (guaranteed convergence for bracketed roots)
13//! - [`brentq`] — Brent's method (IQI + bisection fallback, superlinear convergence)
14//! - [`newton_scalar`] — Newton-Raphson (quadratic convergence near root)
15//! - [`secant`] — secant method (superlinear convergence, no derivative needed)
16//!
17//! # Unconstrained Optimization
18//!
19//! - [`minimize_gd`] — gradient descent with fixed learning rate
20//! - [`minimize_bfgs`] — BFGS quasi-Newton with Armijo line search
21//! - [`minimize_lbfgs`] — limited-memory BFGS with m history vectors
22//! - [`minimize_nelder_mead`] — Nelder-Mead simplex (derivative-free)
23
24use crate::accumulator::binned_sum_f64;
25
26// ---------------------------------------------------------------------------
27// Result type for vector optimizers
28// ---------------------------------------------------------------------------
29
30/// Result of an unconstrained optimization run.
31#[derive(Debug, Clone)]
32pub struct OptResult {
33    /// Optimal point found.
34    pub x: Vec<f64>,
35    /// Objective function value at `x`.
36    pub fun: f64,
37    /// Number of iterations performed.
38    pub niter: usize,
39    /// Whether the solver met the requested tolerance.
40    pub converged: bool,
41}
42
43// ===========================================================================
44// Scalar Root Finding
45// ===========================================================================
46
47/// Bisection method for scalar root finding.
48///
49/// Finds `x` in `[a, b]` such that `|f(x)| < tol`, given that `f(a)` and
50/// `f(b)` have opposite signs. Returns `Err` if the bracket is invalid.
51pub fn bisect(
52    f: &dyn Fn(f64) -> f64,
53    a: f64,
54    b: f64,
55    tol: f64,
56    max_iter: usize,
57) -> Result<f64, String> {
58    let mut lo = a;
59    let mut hi = b;
60    let mut f_lo = f(lo);
61    let f_hi = f(hi);
62
63    if f_lo * f_hi > 0.0 {
64        return Err(format!(
65            "bisect: f(a) and f(b) must have opposite signs, got f({})={}, f({})={}",
66            a, f_lo, b, f_hi
67        ));
68    }
69
70    // If one endpoint is already a root, return it immediately.
71    if f_lo == 0.0 {
72        return Ok(lo);
73    }
74    if f_hi == 0.0 {
75        return Ok(hi);
76    }
77
78    for _ in 0..max_iter {
79        let mid = lo + (hi - lo) * 0.5; // avoids overflow vs (lo+hi)/2
80        let f_mid = f(mid);
81
82        if f_mid.abs() < tol || (hi - lo) * 0.5 < tol {
83            return Ok(mid);
84        }
85
86        if f_lo * f_mid < 0.0 {
87            hi = mid;
88        } else {
89            lo = mid;
90            f_lo = f_mid;
91        }
92    }
93
94    // Return best estimate even if not fully converged.
95    Ok(lo + (hi - lo) * 0.5)
96}
97
98/// Brent's method for scalar root finding.
99///
100/// Combines inverse quadratic interpolation (IQI) with bisection fallback for
101/// robust, superlinear convergence. Requires `f(a)` and `f(b)` to have opposite
102/// signs.
103pub fn brentq(
104    f: &dyn Fn(f64) -> f64,
105    a: f64,
106    b: f64,
107    tol: f64,
108    max_iter: usize,
109) -> Result<f64, String> {
110    let mut a = a;
111    let mut b = b;
112    let mut fa = f(a);
113    let mut fb = f(b);
114
115    if fa * fb > 0.0 {
116        return Err(format!(
117            "brentq: f(a) and f(b) must have opposite signs, got f({})={}, f({})={}",
118            a, fa, b, fb
119        ));
120    }
121
122    if fa.abs() < fb.abs() {
123        core::mem::swap(&mut a, &mut b);
124        core::mem::swap(&mut fa, &mut fb);
125    }
126
127    let mut c = a;
128    let mut fc = fa;
129    let mut mflag = true;
130    let mut d = 0.0_f64; // previous step size (only used when mflag == false)
131
132    for _ in 0..max_iter {
133        if fb.abs() < tol {
134            return Ok(b);
135        }
136        if fa.abs() < tol {
137            return Ok(a);
138        }
139        if (b - a).abs() < tol {
140            return Ok(b);
141        }
142
143        // Attempt inverse quadratic interpolation or secant.
144        let s = if (fa - fc).abs() > f64::EPSILON && (fb - fc).abs() > f64::EPSILON {
145            // Inverse quadratic interpolation.
146            let t1 = a * fb * fc / ((fa - fb) * (fa - fc));
147            let t2 = b * fa * fc / ((fb - fa) * (fb - fc));
148            let t3 = c * fa * fb / ((fc - fa) * (fc - fb));
149            binned_sum_f64(&[t1, t2, t3])
150        } else {
151            // Secant method.
152            b - fb * (b - a) / (fb - fa)
153        };
154
155        // Conditions for rejecting `s` in favour of bisection.
156        let mid = (a + b) * 0.5;
157        let cond1 = {
158            // s not between (3a+b)/4 and b
159            let lo = if (3.0 * a + b) / 4.0 < b {
160                (3.0 * a + b) / 4.0
161            } else {
162                b
163            };
164            let hi = if (3.0 * a + b) / 4.0 > b {
165                (3.0 * a + b) / 4.0
166            } else {
167                b
168            };
169            s < lo || s > hi
170        };
171        let cond2 = mflag && (s - b).abs() >= (b - c).abs() * 0.5;
172        let cond3 = !mflag && (s - b).abs() >= (c - d).abs() * 0.5;
173        let cond4 = mflag && (b - c).abs() < tol;
174        let cond5 = !mflag && (c - d).abs() < tol;
175
176        let s = if cond1 || cond2 || cond3 || cond4 || cond5 {
177            mflag = true;
178            mid
179        } else {
180            mflag = false;
181            s
182        };
183
184        let fs = f(s);
185        d = c;
186        c = b;
187        fc = fb;
188
189        if fa * fs < 0.0 {
190            b = s;
191            fb = fs;
192        } else {
193            a = s;
194            fa = fs;
195        }
196
197        // Keep |f(a)| >= |f(b)| so b is the better approximation.
198        if fa.abs() < fb.abs() {
199            core::mem::swap(&mut a, &mut b);
200            core::mem::swap(&mut fa, &mut fb);
201        }
202    }
203
204    Ok(b)
205}
206
207/// Newton-Raphson method for scalar root finding.
208///
209/// Uses `f` and its derivative `df` to iterate `x_{k+1} = x_k - f(x_k)/df(x_k)`.
210/// Returns `Err` if the derivative is zero at any iterate.
211pub fn newton_scalar(
212    f: &dyn Fn(f64) -> f64,
213    df: &dyn Fn(f64) -> f64,
214    x0: f64,
215    tol: f64,
216    max_iter: usize,
217) -> Result<f64, String> {
218    let mut x = x0;
219
220    for _ in 0..max_iter {
221        let fx = f(x);
222        if fx.abs() < tol {
223            return Ok(x);
224        }
225
226        let dfx = df(x);
227        if dfx.abs() < f64::EPSILON {
228            return Err(format!(
229                "newton_scalar: derivative is zero at x={}, cannot continue",
230                x
231            ));
232        }
233
234        x = x - fx / dfx;
235    }
236
237    if f(x).abs() < tol {
238        Ok(x)
239    } else {
240        Err(format!(
241            "newton_scalar: did not converge after {} iterations, x={}, f(x)={}",
242            max_iter,
243            x,
244            f(x)
245        ))
246    }
247}
248
249/// Secant method for scalar root finding.
250///
251/// A derivative-free variant of Newton's method using finite difference
252/// approximation of the derivative from the two most recent iterates.
253pub fn secant(
254    f: &dyn Fn(f64) -> f64,
255    x0: f64,
256    x1: f64,
257    tol: f64,
258    max_iter: usize,
259) -> Result<f64, String> {
260    let mut xp = x0; // x_{k-1}
261    let mut xc = x1; // x_k
262    let mut fp = f(xp);
263    let mut fc = f(xc);
264
265    for _ in 0..max_iter {
266        if fc.abs() < tol {
267            return Ok(xc);
268        }
269
270        let denom = fc - fp;
271        if denom.abs() < f64::EPSILON {
272            return Err(format!(
273                "secant: division by zero (f(x0)={}, f(x1)={} are too close)",
274                fp, fc
275            ));
276        }
277
278        let xn = xc - fc * (xc - xp) / denom;
279        xp = xc;
280        fp = fc;
281        xc = xn;
282        fc = f(xc);
283    }
284
285    if fc.abs() < tol {
286        Ok(xc)
287    } else {
288        Err(format!(
289            "secant: did not converge after {} iterations, x={}, f(x)={}",
290            max_iter, xc, fc
291        ))
292    }
293}
294
295// ===========================================================================
296// Internal helpers
297// ===========================================================================
298
299/// Deterministic L2 norm of a vector using binned summation.
300fn norm_l2(v: &[f64]) -> f64 {
301    let sq: Vec<f64> = v.iter().map(|&x| x * x).collect();
302    binned_sum_f64(&sq).sqrt()
303}
304
305/// Deterministic dot product using binned summation.
306fn dot(a: &[f64], b: &[f64]) -> f64 {
307    let prods: Vec<f64> = a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).collect();
308    binned_sum_f64(&prods)
309}
310
311/// Armijo (backtracking) line search.
312///
313/// Finds a step size `alpha` satisfying the Armijo sufficient decrease condition:
314///   `f(x + alpha * d) <= f(x) + c * alpha * grad^T d`
315///
316/// Parameters:
317/// - `f`: objective function
318/// - `x`: current point
319/// - `d`: search direction
320/// - `grad`: gradient at `x`
321/// - `alpha0`: initial step size
322/// - `c`: sufficient decrease parameter (typically 1e-4)
323/// - `rho`: backtracking factor (typically 0.5)
324fn armijo_line_search(
325    f: &dyn Fn(&[f64]) -> f64,
326    x: &[f64],
327    d: &[f64],
328    grad: &[f64],
329    alpha0: f64,
330    c: f64,
331    rho: f64,
332) -> f64 {
333    let n = x.len();
334    let f0 = f(x);
335    let slope = dot(grad, d); // directional derivative
336
337    let mut alpha = alpha0;
338    let mut x_new = vec![0.0; n];
339
340    // Cap iterations to prevent infinite loops on pathological functions.
341    for _ in 0..60 {
342        for i in 0..n {
343            x_new[i] = x[i] + alpha * d[i];
344        }
345        let f_new = f(&x_new);
346        if f_new <= f0 + c * alpha * slope {
347            return alpha;
348        }
349        alpha *= rho;
350    }
351
352    alpha
353}
354
355// ===========================================================================
356// Unconstrained Optimization — Vector
357// ===========================================================================
358
359/// Gradient descent with fixed learning rate.
360///
361/// Minimizes `f` starting from `x0` by iterating `x_{k+1} = x_k - lr * grad(x_k)`.
362pub fn minimize_gd(
363    f: &dyn Fn(&[f64]) -> f64,
364    grad: &dyn Fn(&[f64]) -> Vec<f64>,
365    x0: &[f64],
366    lr: f64,
367    max_iter: usize,
368    tol: f64,
369) -> OptResult {
370    let n = x0.len();
371    let mut x = x0.to_vec();
372
373    for iter in 0..max_iter {
374        let g = grad(&x);
375        let gnorm = norm_l2(&g);
376        if gnorm < tol {
377            return OptResult {
378                fun: f(&x),
379                x,
380                niter: iter,
381                converged: true,
382            };
383        }
384        for i in 0..n {
385            x[i] -= lr * g[i];
386        }
387    }
388
389    OptResult {
390        fun: f(&x),
391        x,
392        niter: max_iter,
393        converged: false,
394    }
395}
396
397/// BFGS quasi-Newton method with Armijo line search.
398///
399/// Maintains an approximate inverse Hessian `H` (initialized to identity) and
400/// updates it with the BFGS rank-2 formula at each step.
401pub fn minimize_bfgs(
402    f: &dyn Fn(&[f64]) -> f64,
403    grad: &dyn Fn(&[f64]) -> Vec<f64>,
404    x0: &[f64],
405    tol: f64,
406    max_iter: usize,
407) -> OptResult {
408    let n = x0.len();
409    let mut x = x0.to_vec();
410    let mut g = grad(&x);
411
412    // H = inverse Hessian approximation, stored as dense n×n in row-major.
413    // Initialized to identity.
414    let mut h = vec![0.0; n * n];
415    for i in 0..n {
416        h[i * n + i] = 1.0;
417    }
418
419    for iter in 0..max_iter {
420        let gnorm = norm_l2(&g);
421        if gnorm < tol {
422            return OptResult {
423                fun: f(&x),
424                x,
425                niter: iter,
426                converged: true,
427            };
428        }
429
430        // Search direction: d = -H * g
431        let mut d = vec![0.0; n];
432        for i in 0..n {
433            let row: Vec<f64> = (0..n).map(|j| h[i * n + j] * g[j]).collect();
434            d[i] = -binned_sum_f64(&row);
435        }
436
437        // Armijo line search.
438        let alpha = armijo_line_search(f, &x, &d, &g, 1.0, 1e-4, 0.5);
439
440        // Step: s = alpha * d
441        let s: Vec<f64> = d.iter().map(|&di| alpha * di).collect();
442        let mut x_new = vec![0.0; n];
443        for i in 0..n {
444            x_new[i] = x[i] + s[i];
445        }
446
447        let g_new = grad(&x_new);
448
449        // y = g_new - g
450        let y: Vec<f64> = g_new.iter().zip(g.iter()).map(|(&gn, &go)| gn - go).collect();
451
452        let sy = dot(&s, &y);
453
454        // Skip update if curvature condition is not met (sy too small).
455        if sy > f64::EPSILON {
456            // BFGS update:
457            // H_new = (I - rho*s*y^T) * H * (I - rho*y*s^T) + rho*s*s^T
458            // where rho = 1/sy
459            let rho = 1.0 / sy;
460
461            // Compute H*y
462            let mut hy = vec![0.0; n];
463            for i in 0..n {
464                let row: Vec<f64> = (0..n).map(|j| h[i * n + j] * y[j]).collect();
465                hy[i] = binned_sum_f64(&row);
466            }
467
468            let yhy = dot(&y, &hy);
469
470            // Update H in-place using the Sherman-Morrison-Woodbury form:
471            // H_new = H - (H*y*s^T + s*y^T*H) * rho + (1 + rho*y^T*H*y) * rho * s*s^T
472            let factor = (1.0 + rho * yhy) * rho;
473            for i in 0..n {
474                for j in 0..n {
475                    h[i * n + j] = h[i * n + j]
476                        - rho * (hy[i] * s[j] + s[i] * hy[j])
477                        + factor * s[i] * s[j];
478                }
479            }
480        }
481
482        x = x_new;
483        g = g_new;
484    }
485
486    OptResult {
487        fun: f(&x),
488        x,
489        niter: max_iter,
490        converged: false,
491    }
492}
493
494/// L-BFGS (limited-memory BFGS) with Armijo line search.
495///
496/// Uses at most `m` recent (s, y) pairs to approximate the inverse Hessian-vector
497/// product via the two-loop recursion. Memory usage is O(m*n) instead of O(n^2).
498pub fn minimize_lbfgs(
499    f: &dyn Fn(&[f64]) -> f64,
500    grad: &dyn Fn(&[f64]) -> Vec<f64>,
501    x0: &[f64],
502    m: usize,
503    tol: f64,
504    max_iter: usize,
505) -> OptResult {
506    let n = x0.len();
507    let mut x = x0.to_vec();
508    let mut g = grad(&x);
509
510    // History ring buffer.
511    let mut s_hist: Vec<Vec<f64>> = Vec::new();
512    let mut y_hist: Vec<Vec<f64>> = Vec::new();
513    let mut rho_hist: Vec<f64> = Vec::new();
514
515    for iter in 0..max_iter {
516        let gnorm = norm_l2(&g);
517        if gnorm < tol {
518            return OptResult {
519                fun: f(&x),
520                x,
521                niter: iter,
522                converged: true,
523            };
524        }
525
526        // Two-loop recursion to compute d = -H_k * g.
527        let k = s_hist.len();
528        let mut q = g.clone();
529        let mut alpha_vals = vec![0.0; k];
530
531        // First loop: newest to oldest.
532        for i in (0..k).rev() {
533            alpha_vals[i] = rho_hist[i] * dot(&s_hist[i], &q);
534            for j in 0..n {
535                q[j] -= alpha_vals[i] * y_hist[i][j];
536            }
537        }
538
539        // Initial Hessian scaling: H0 = gamma * I
540        // gamma = s^T y / y^T y for the most recent pair.
541        let gamma = if k > 0 {
542            let sy = dot(&s_hist[k - 1], &y_hist[k - 1]);
543            let yy = dot(&y_hist[k - 1], &y_hist[k - 1]);
544            if yy > f64::EPSILON { sy / yy } else { 1.0 }
545        } else {
546            1.0
547        };
548
549        let mut r: Vec<f64> = q.iter().map(|&qi| gamma * qi).collect();
550
551        // Second loop: oldest to newest.
552        for i in 0..k {
553            let beta = rho_hist[i] * dot(&y_hist[i], &r);
554            for j in 0..n {
555                r[j] += s_hist[i][j] * (alpha_vals[i] - beta);
556            }
557        }
558
559        // d = -r (the search direction)
560        let d: Vec<f64> = r.iter().map(|&ri| -ri).collect();
561
562        // Armijo line search.
563        let alpha = armijo_line_search(f, &x, &d, &g, 1.0, 1e-4, 0.5);
564
565        let s: Vec<f64> = d.iter().map(|&di| alpha * di).collect();
566        let mut x_new = vec![0.0; n];
567        for i in 0..n {
568            x_new[i] = x[i] + s[i];
569        }
570
571        let g_new = grad(&x_new);
572        let y: Vec<f64> = g_new.iter().zip(g.iter()).map(|(&gn, &go)| gn - go).collect();
573        let sy = dot(&s, &y);
574
575        if sy > f64::EPSILON {
576            // Add to history, evicting oldest if at capacity.
577            if s_hist.len() == m {
578                s_hist.remove(0);
579                y_hist.remove(0);
580                rho_hist.remove(0);
581            }
582            s_hist.push(s);
583            y_hist.push(y);
584            rho_hist.push(1.0 / sy);
585        }
586
587        x = x_new;
588        g = g_new;
589    }
590
591    OptResult {
592        fun: f(&x),
593        x,
594        niter: max_iter,
595        converged: false,
596    }
597}
598
599/// Nelder-Mead simplex method (derivative-free).
600///
601/// Constructs an initial simplex around `x0` and iteratively transforms it using
602/// reflection, expansion, contraction, and shrinkage operations.
603pub fn minimize_nelder_mead(
604    f: &dyn Fn(&[f64]) -> f64,
605    x0: &[f64],
606    tol: f64,
607    max_iter: usize,
608) -> OptResult {
609    let n = x0.len();
610
611    // Standard Nelder-Mead parameters.
612    let alpha_reflect = 1.0;
613    let gamma_expand = 2.0;
614    let rho_contract = 0.5;
615    let sigma_shrink = 0.5;
616
617    // Build initial simplex: x0 plus n vertices offset along each axis.
618    let mut simplex: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
619    simplex.push(x0.to_vec());
620    for i in 0..n {
621        let mut v = x0.to_vec();
622        // Use a perturbation that works for both zero and non-zero components.
623        let delta = if v[i].abs() > f64::EPSILON {
624            v[i] * 0.05
625        } else {
626            0.00025
627        };
628        v[i] += delta;
629        simplex.push(v);
630    }
631
632    let mut fvals: Vec<f64> = simplex.iter().map(|v| f(v)).collect();
633
634    for iter in 0..max_iter {
635        // Sort simplex by function value (deterministic: uses total_cmp for ties).
636        let mut indices: Vec<usize> = (0..=n).collect();
637        indices.sort_by(|&a, &b| fvals[a].total_cmp(&fvals[b]));
638        simplex = indices.iter().map(|&i| simplex[i].clone()).collect();
639        fvals = indices.iter().map(|&i| fvals[i]).collect();
640
641        // Check convergence: spread of function values across simplex.
642        let f_best = fvals[0];
643        let f_worst = fvals[n];
644        if (f_worst - f_best).abs() < tol {
645            return OptResult {
646                x: simplex[0].clone(),
647                fun: f_best,
648                niter: iter,
649                converged: true,
650            };
651        }
652
653        // Compute centroid of all vertices except the worst.
654        let mut centroid = vec![0.0; n];
655        for i in 0..n {
656            for j in 0..n {
657                centroid[j] += simplex[i][j];
658            }
659        }
660        for j in 0..n {
661            centroid[j] /= n as f64;
662        }
663
664        // Reflection.
665        let xr: Vec<f64> = (0..n)
666            .map(|j| centroid[j] + alpha_reflect * (centroid[j] - simplex[n][j]))
667            .collect();
668        let fr = f(&xr);
669
670        if fr < fvals[0] {
671            // Try expansion.
672            let xe: Vec<f64> = (0..n)
673                .map(|j| centroid[j] + gamma_expand * (xr[j] - centroid[j]))
674                .collect();
675            let fe = f(&xe);
676            if fe < fr {
677                simplex[n] = xe;
678                fvals[n] = fe;
679            } else {
680                simplex[n] = xr;
681                fvals[n] = fr;
682            }
683        } else if fr < fvals[n - 1] {
684            // Accept reflection.
685            simplex[n] = xr;
686            fvals[n] = fr;
687        } else {
688            // Contraction.
689            let (xc, fc) = if fr < fvals[n] {
690                // Outside contraction.
691                let xc: Vec<f64> = (0..n)
692                    .map(|j| centroid[j] + rho_contract * (xr[j] - centroid[j]))
693                    .collect();
694                let fc = f(&xc);
695                (xc, fc)
696            } else {
697                // Inside contraction.
698                let xc: Vec<f64> = (0..n)
699                    .map(|j| centroid[j] + rho_contract * (simplex[n][j] - centroid[j]))
700                    .collect();
701                let fc = f(&xc);
702                (xc, fc)
703            };
704
705            if fc < fvals[n] {
706                simplex[n] = xc;
707                fvals[n] = fc;
708            } else {
709                // Shrink: move all vertices towards the best.
710                for i in 1..=n {
711                    for j in 0..n {
712                        simplex[i][j] = simplex[0][j] + sigma_shrink * (simplex[i][j] - simplex[0][j]);
713                    }
714                    fvals[i] = f(&simplex[i]);
715                }
716            }
717        }
718    }
719
720    // Sort one final time to return the best.
721    let mut best_idx = 0;
722    for i in 1..=n {
723        if fvals[i] < fvals[best_idx] {
724            best_idx = i;
725        }
726    }
727
728    OptResult {
729        x: simplex[best_idx].clone(),
730        fun: fvals[best_idx],
731        niter: max_iter,
732        converged: false,
733    }
734}
735
736// ===========================================================================
737// Constrained Optimization Utilities
738// ===========================================================================
739
740/// Penalty method for constrained optimization.
741///
742/// Given an objective function value `f_val` and an array of constraint
743/// violation values `constraint_violations` (where positive values indicate
744/// violation), returns the penalized objective:
745///
746///   f_val + penalty * sum(max(0, g_i)^2)
747///
748/// Uses Kahan summation for the penalty term accumulation.
749pub fn penalty_objective(f_val: f64, constraint_violations: &[f64], penalty: f64) -> f64 {
750    use cjc_repro::KahanAccumulatorF64;
751
752    let mut acc = KahanAccumulatorF64::new();
753    for &g in constraint_violations {
754        if g > 0.0 {
755            acc.add(g * g);
756        }
757    }
758    f_val + penalty * acc.finalize()
759}
760
761/// Project a point onto a box constraint [lower, upper].
762///
763/// For each component i, returns `x[i].clamp(lower[i], upper[i])`.
764/// All arrays must have the same length.
765pub fn project_box(x: &[f64], lower: &[f64], upper: &[f64]) -> Result<Vec<f64>, String> {
766    let n = x.len();
767    if lower.len() != n || upper.len() != n {
768        return Err(format!(
769            "project_box: all arrays must have same length, got x={}, lower={}, upper={}",
770            n,
771            lower.len(),
772            upper.len()
773        ));
774    }
775    let mut result = Vec::with_capacity(n);
776    for i in 0..n {
777        result.push(x[i].clamp(lower[i], upper[i]));
778    }
779    Ok(result)
780}
781
782/// Projected gradient descent step with box constraints.
783///
784/// Computes: project(x - lr * grad, lower, upper)
785///
786/// All arrays must have the same length.
787pub fn projected_gd_step(
788    x: &[f64],
789    grad: &[f64],
790    lr: f64,
791    lower: &[f64],
792    upper: &[f64],
793) -> Result<Vec<f64>, String> {
794    let n = x.len();
795    if grad.len() != n || lower.len() != n || upper.len() != n {
796        return Err(format!(
797            "projected_gd_step: all arrays must have same length, got x={}, grad={}, lower={}, upper={}",
798            n,
799            grad.len(),
800            lower.len(),
801            upper.len()
802        ));
803    }
804    let mut result = Vec::with_capacity(n);
805    for i in 0..n {
806        let step = x[i] - lr * grad[i];
807        result.push(step.clamp(lower[i], upper[i]));
808    }
809    Ok(result)
810}
811
812// ===========================================================================
813// Tests
814// ===========================================================================
815
816#[cfg(test)]
817mod optimize_tests {
818    use super::*;
819
820    // -----------------------------------------------------------------------
821    // Test helpers
822    // -----------------------------------------------------------------------
823
824    /// Rosenbrock function: f(x,y) = (1-x)^2 + 100*(y - x^2)^2
825    /// Minimum at (1, 1) with f = 0.
826    fn rosenbrock(x: &[f64]) -> f64 {
827        let a = 1.0 - x[0];
828        let b = x[1] - x[0] * x[0];
829        binned_sum_f64(&[a * a, 100.0 * b * b])
830    }
831
832    fn rosenbrock_grad(x: &[f64]) -> Vec<f64> {
833        let dx = -2.0 * (1.0 - x[0]) + 200.0 * (x[1] - x[0] * x[0]) * (-2.0 * x[0]);
834        let dy = 200.0 * (x[1] - x[0] * x[0]);
835        vec![dx, dy]
836    }
837
838    /// Simple quadratic: f(x) = sum(x_i^2).  Minimum at origin.
839    fn quadratic(x: &[f64]) -> f64 {
840        let sq: Vec<f64> = x.iter().map(|&xi| xi * xi).collect();
841        binned_sum_f64(&sq)
842    }
843
844    fn quadratic_grad(x: &[f64]) -> Vec<f64> {
845        x.iter().map(|&xi| 2.0 * xi).collect()
846    }
847
848    // -----------------------------------------------------------------------
849    // Scalar Root Finding
850    // -----------------------------------------------------------------------
851
852    #[test]
853    fn test_bisect_sqrt2() {
854        let f = |x: f64| x * x - 2.0;
855        let root = bisect(&f, 1.0, 2.0, 1e-12, 100).unwrap();
856        assert!((root - std::f64::consts::SQRT_2).abs() < 1e-10);
857    }
858
859    #[test]
860    fn test_bisect_invalid_bracket() {
861        let f = |x: f64| x * x + 1.0; // always positive
862        let result = bisect(&f, 0.0, 2.0, 1e-12, 100);
863        assert!(result.is_err());
864    }
865
866    #[test]
867    fn test_bisect_determinism() {
868        let f = |x: f64| x * x - 2.0;
869        let r1 = bisect(&f, 1.0, 2.0, 1e-12, 100).unwrap();
870        let r2 = bisect(&f, 1.0, 2.0, 1e-12, 100).unwrap();
871        assert_eq!(r1.to_bits(), r2.to_bits());
872    }
873
874    #[test]
875    fn test_brentq_sqrt2() {
876        let f = |x: f64| x * x - 2.0;
877        let root = brentq(&f, 1.0, 2.0, 1e-12, 100).unwrap();
878        assert!((root - std::f64::consts::SQRT_2).abs() < 1e-10);
879    }
880
881    #[test]
882    fn test_brentq_invalid_bracket() {
883        let f = |x: f64| x * x + 1.0;
884        let result = brentq(&f, 0.0, 2.0, 1e-12, 100);
885        assert!(result.is_err());
886    }
887
888    #[test]
889    fn test_brentq_determinism() {
890        let f = |x: f64| x * x - 2.0;
891        let r1 = brentq(&f, 1.0, 2.0, 1e-12, 100).unwrap();
892        let r2 = brentq(&f, 1.0, 2.0, 1e-12, 100).unwrap();
893        assert_eq!(r1.to_bits(), r2.to_bits());
894    }
895
896    #[test]
897    fn test_newton_scalar_sqrt4() {
898        let f = |x: f64| x * x - 4.0;
899        let df = |x: f64| 2.0 * x;
900        let root = newton_scalar(&f, &df, 3.0, 1e-12, 100).unwrap();
901        assert!((root - 2.0).abs() < 1e-10);
902    }
903
904    #[test]
905    fn test_newton_scalar_negative_root() {
906        let f = |x: f64| x * x - 4.0;
907        let df = |x: f64| 2.0 * x;
908        let root = newton_scalar(&f, &df, -3.0, 1e-12, 100).unwrap();
909        assert!((root - (-2.0)).abs() < 1e-10);
910    }
911
912    #[test]
913    fn test_newton_scalar_determinism() {
914        let f = |x: f64| x * x - 4.0;
915        let df = |x: f64| 2.0 * x;
916        let r1 = newton_scalar(&f, &df, 3.0, 1e-12, 100).unwrap();
917        let r2 = newton_scalar(&f, &df, 3.0, 1e-12, 100).unwrap();
918        assert_eq!(r1.to_bits(), r2.to_bits());
919    }
920
921    #[test]
922    fn test_secant_sqrt2() {
923        let f = |x: f64| x * x - 2.0;
924        let root = secant(&f, 1.0, 2.0, 1e-12, 100).unwrap();
925        assert!((root - std::f64::consts::SQRT_2).abs() < 1e-10);
926    }
927
928    #[test]
929    fn test_secant_determinism() {
930        let f = |x: f64| x * x - 2.0;
931        let r1 = secant(&f, 1.0, 2.0, 1e-12, 100).unwrap();
932        let r2 = secant(&f, 1.0, 2.0, 1e-12, 100).unwrap();
933        assert_eq!(r1.to_bits(), r2.to_bits());
934    }
935
936    // -----------------------------------------------------------------------
937    // Unconstrained Optimization
938    // -----------------------------------------------------------------------
939
940    #[test]
941    fn test_minimize_gd_quadratic() {
942        let res = minimize_gd(&quadratic, &quadratic_grad, &[5.0, -3.0, 2.0], 0.1, 1000, 1e-8);
943        assert!(res.converged);
944        for &xi in &res.x {
945            assert!(xi.abs() < 1e-3);
946        }
947    }
948
949    #[test]
950    fn test_minimize_gd_rosenbrock() {
951        // GD is slow on Rosenbrock, so we use many iterations and a loose tolerance.
952        let res = minimize_gd(&rosenbrock, &rosenbrock_grad, &[-1.0, 1.0], 0.001, 100_000, 1e-6);
953        // May or may not converge fully, but should get reasonably close.
954        assert!((res.x[0] - 1.0).abs() < 0.5);
955        assert!((res.x[1] - 1.0).abs() < 0.5);
956    }
957
958    #[test]
959    fn test_minimize_gd_determinism() {
960        let r1 = minimize_gd(&quadratic, &quadratic_grad, &[5.0, -3.0], 0.1, 500, 1e-8);
961        let r2 = minimize_gd(&quadratic, &quadratic_grad, &[5.0, -3.0], 0.1, 500, 1e-8);
962        assert_eq!(r1.x.len(), r2.x.len());
963        for (a, b) in r1.x.iter().zip(r2.x.iter()) {
964            assert_eq!(a.to_bits(), b.to_bits());
965        }
966        assert_eq!(r1.fun.to_bits(), r2.fun.to_bits());
967    }
968
969    #[test]
970    fn test_minimize_bfgs_rosenbrock() {
971        let res = minimize_bfgs(&rosenbrock, &rosenbrock_grad, &[-1.0, 1.0], 1e-8, 500);
972        assert!(res.converged, "BFGS did not converge on Rosenbrock");
973        assert!(
974            (res.x[0] - 1.0).abs() < 1e-4,
975            "x[0]={} not near 1.0",
976            res.x[0]
977        );
978        assert!(
979            (res.x[1] - 1.0).abs() < 1e-4,
980            "x[1]={} not near 1.0",
981            res.x[1]
982        );
983        assert!(res.fun < 1e-8, "f(x)={} not near 0", res.fun);
984    }
985
986    #[test]
987    fn test_minimize_bfgs_quadratic() {
988        let res = minimize_bfgs(&quadratic, &quadratic_grad, &[10.0, -7.0, 3.0], 1e-10, 200);
989        assert!(res.converged);
990        for &xi in &res.x {
991            assert!(xi.abs() < 1e-5);
992        }
993    }
994
995    #[test]
996    fn test_minimize_bfgs_determinism() {
997        let r1 = minimize_bfgs(&rosenbrock, &rosenbrock_grad, &[-1.0, 1.0], 1e-8, 500);
998        let r2 = minimize_bfgs(&rosenbrock, &rosenbrock_grad, &[-1.0, 1.0], 1e-8, 500);
999        for (a, b) in r1.x.iter().zip(r2.x.iter()) {
1000            assert_eq!(a.to_bits(), b.to_bits());
1001        }
1002        assert_eq!(r1.fun.to_bits(), r2.fun.to_bits());
1003        assert_eq!(r1.niter, r2.niter);
1004    }
1005
1006    #[test]
1007    fn test_minimize_lbfgs_rosenbrock() {
1008        let res = minimize_lbfgs(&rosenbrock, &rosenbrock_grad, &[-1.0, 1.0], 10, 1e-8, 500);
1009        assert!(res.converged, "L-BFGS did not converge on Rosenbrock");
1010        assert!(
1011            (res.x[0] - 1.0).abs() < 1e-4,
1012            "x[0]={} not near 1.0",
1013            res.x[0]
1014        );
1015        assert!(
1016            (res.x[1] - 1.0).abs() < 1e-4,
1017            "x[1]={} not near 1.0",
1018            res.x[1]
1019        );
1020    }
1021
1022    #[test]
1023    fn test_minimize_lbfgs_determinism() {
1024        let r1 = minimize_lbfgs(&rosenbrock, &rosenbrock_grad, &[-1.0, 1.0], 5, 1e-8, 300);
1025        let r2 = minimize_lbfgs(&rosenbrock, &rosenbrock_grad, &[-1.0, 1.0], 5, 1e-8, 300);
1026        for (a, b) in r1.x.iter().zip(r2.x.iter()) {
1027            assert_eq!(a.to_bits(), b.to_bits());
1028        }
1029        assert_eq!(r1.niter, r2.niter);
1030    }
1031
1032    #[test]
1033    fn test_minimize_nelder_mead_quadratic() {
1034        let res = minimize_nelder_mead(&quadratic, &[5.0, -3.0, 2.0], 1e-10, 5000);
1035        assert!(res.converged, "Nelder-Mead did not converge on quadratic");
1036        for &xi in &res.x {
1037            assert!(
1038                xi.abs() < 1e-3,
1039                "x_i={} not near 0",
1040                xi
1041            );
1042        }
1043        assert!(res.fun < 1e-6, "f(x)={} not near 0", res.fun);
1044    }
1045
1046    #[test]
1047    fn test_minimize_nelder_mead_2d() {
1048        // f(x,y) = (x-3)^2 + (y+1)^2, minimum at (3, -1)
1049        let f = |x: &[f64]| {
1050            let a = x[0] - 3.0;
1051            let b = x[1] + 1.0;
1052            binned_sum_f64(&[a * a, b * b])
1053        };
1054        let res = minimize_nelder_mead(&f, &[0.0, 0.0], 1e-10, 5000);
1055        assert!(res.converged);
1056        assert!((res.x[0] - 3.0).abs() < 1e-4);
1057        assert!((res.x[1] - (-1.0)).abs() < 1e-4);
1058    }
1059
1060    #[test]
1061    fn test_minimize_nelder_mead_determinism() {
1062        let r1 = minimize_nelder_mead(&quadratic, &[5.0, -3.0], 1e-10, 2000);
1063        let r2 = minimize_nelder_mead(&quadratic, &[5.0, -3.0], 1e-10, 2000);
1064        for (a, b) in r1.x.iter().zip(r2.x.iter()) {
1065            assert_eq!(a.to_bits(), b.to_bits());
1066        }
1067        assert_eq!(r1.fun.to_bits(), r2.fun.to_bits());
1068        assert_eq!(r1.niter, r2.niter);
1069    }
1070
1071    // -----------------------------------------------------------------------
1072    // Armijo line search (indirect testing via BFGS)
1073    // -----------------------------------------------------------------------
1074
1075    #[test]
1076    fn test_armijo_decreases_function() {
1077        // Verify that Armijo returns a step that decreases the function.
1078        let x = &[2.0, 3.0];
1079        let g = quadratic_grad(x);
1080        let d: Vec<f64> = g.iter().map(|&gi| -gi).collect();
1081        let f0 = quadratic(x);
1082        let alpha = armijo_line_search(&quadratic, x, &d, &g, 1.0, 1e-4, 0.5);
1083        let x_new: Vec<f64> = x.iter().zip(d.iter()).map(|(&xi, &di)| xi + alpha * di).collect();
1084        let f1 = quadratic(&x_new);
1085        assert!(f1 < f0, "Armijo did not decrease: f0={}, f1={}", f0, f1);
1086    }
1087}