Skip to main content

gam_solve/
seeding.rs

1use ndarray::Array1;
2use std::collections::HashSet;
3
4pub use gam_problem::{SeedConfig, SeedRiskProfile};
5use gam_problem::{clamp_seed_rho_to_bounds, normalize_seed_bounds};
6
7fn add_seed_dedup(seeds: &mut Vec<Array1<f64>>, seen: &mut HashSet<Vec<u64>>, seed: Array1<f64>) {
8    let key: Vec<u64> = seed.iter().map(|&v| v.to_bits()).collect();
9    if seen.insert(key) {
10        seeds.push(seed);
11    }
12}
13
14fn safe_ln_pos(x: f64) -> Option<f64> {
15    if x.is_finite() && x > 0.0 {
16        Some(x.ln())
17    } else {
18        None
19    }
20}
21
22fn spde_rho_triplet_from_log_tau_log_kappa_nu(
23    log_tau: f64,
24    log_kappa: f64,
25    nu: f64,
26    bounds: (f64, f64),
27) -> Option<Array1<f64>> {
28    if !(nu.is_finite() && nu > 1.0) {
29        return None;
30    }
31    let logc0 = 0.0;
32    let logc1 = safe_ln_pos(nu)?;
33    let logc2 = safe_ln_pos(0.5 * nu * (nu - 1.0))?;
34    let rho0 = clamp_seed_rho_to_bounds(log_tau + logc0 + 2.0 * nu * log_kappa, bounds);
35    let rho1 = clamp_seed_rho_to_bounds(log_tau + logc1 + 2.0 * (nu - 1.0) * log_kappa, bounds);
36    let rho2 = clamp_seed_rho_to_bounds(log_tau + logc2 + 2.0 * (nu - 2.0) * log_kappa, bounds);
37    Some(Array1::from_vec(vec![rho0, rho1, rho2]))
38}
39
40fn add_spde_manifold_seeds(
41    seeds: &mut Vec<Array1<f64>>,
42    seen: &mut HashSet<Vec<u64>>,
43    bounds: (f64, f64),
44    heuristic_rhos: Option<&[f64]>,
45    primary: &Array1<f64>,
46) {
47    if primary.len() != 3 {
48        return;
49    }
50    // Broad default manifold grid in (log_tau, log_kappa, nu).
51    let tau_anchors = [primary[2], 0.0, -2.0, 2.0];
52    let log_kappa_grid = [-2.0, -1.0, 0.0, 1.0, 2.0];
53    let nu_grid = [1.25, 1.5, 2.0, 2.5, 3.0, 4.0];
54    for &tau in &tau_anchors {
55        for &lk in &log_kappa_grid {
56            for &nu in &nu_grid {
57                if let Some(seed) = spde_rho_triplet_from_log_tau_log_kappa_nu(tau, lk, nu, bounds)
58                {
59                    add_seed_dedup(seeds, seen, seed);
60                }
61            }
62        }
63    }
64
65    // Data-informed anchor: convert the rho seed to lambdas, then invert to
66    // (nu, kappa^2, tau) when feasible.
67    if let Some(vals) = heuristic_rhos
68        && vals.len() == 3
69    {
70        let l0 = vals[0].exp();
71        let l1 = vals[1].exp();
72        let l2 = vals[2].exp();
73        if l0.is_finite() && l1.is_finite() && l2.is_finite() && l0 > 1e-12 && l2 > 1e-12 {
74            let r = (l1 * l1) / (l0 * l2);
75            if r > 2.0 {
76                let nu = r / (r - 2.0);
77                let kappa2 = l1 / ((r - 2.0) * l2);
78                if nu.is_finite() && nu > 1.0 && kappa2.is_finite() && kappa2 > 0.0 {
79                    let log_kappa = 0.5 * kappa2.ln();
80                    let c2 = 0.5 * nu * (nu - 1.0);
81                    if c2.is_finite() && c2 > 0.0 {
82                        let log_tau = (l2 / (c2 * kappa2.powf(nu - 2.0))).max(1e-12).ln();
83                        let local_nu = [nu, (nu - 0.3).max(1.05), nu + 0.3];
84                        let local_tau = [log_tau, log_tau - 1.0, log_tau + 1.0];
85                        let local_kappa = [log_kappa, log_kappa - 0.5, log_kappa + 0.5];
86                        for &t in &local_tau {
87                            for &lk in &local_kappa {
88                                for &n in &local_nu {
89                                    if let Some(seed) =
90                                        spde_rho_triplet_from_log_tau_log_kappa_nu(t, lk, n, bounds)
91                                    {
92                                        add_seed_dedup(seeds, seen, seed);
93                                    }
94                                }
95                            }
96                        }
97                    }
98                }
99            }
100        }
101    }
102}
103
104fn add_first_order_fallback_seeds(
105    seeds: &mut Vec<Array1<f64>>,
106    seen: &mut HashSet<Vec<u64>>,
107    bounds: (f64, f64),
108    heuristic_rhos: Option<&[f64]>,
109) {
110    // Degenerate λ2 -> 0 fallback (first-order mass+tension):
111    // λ0 = τ κ^2, λ1 = τ, λ2 ≈ 0.
112    let rho2_floor = bounds.0;
113    let default_log_kappa = [-2.0, -1.0, 0.0, 1.0];
114    let default_log_tau = [0.0, -2.0, 2.0];
115    for &t in &default_log_tau {
116        for &lk in &default_log_kappa {
117            let rho0 = clamp_seed_rho_to_bounds(t + 2.0 * lk, bounds);
118            let rho1 = clamp_seed_rho_to_bounds(t, bounds);
119            add_seed_dedup(seeds, seen, Array1::from_vec(vec![rho0, rho1, rho2_floor]));
120        }
121    }
122    if let Some(vals) = heuristic_rhos
123        && vals.len() == 3
124        && vals[0].is_finite()
125        && vals[1].is_finite()
126    {
127        let l0 = vals[0].exp();
128        let l1 = vals[1].exp();
129        let kappa2 = l0 / l1;
130        if kappa2.is_finite() && kappa2 > 0.0 {
131            let lk = 0.5 * kappa2.ln();
132            let t = vals[1];
133            let rho0 = clamp_seed_rho_to_bounds(t + 2.0 * lk, bounds);
134            let rho1 = clamp_seed_rho_to_bounds(t, bounds);
135            add_seed_dedup(seeds, seen, Array1::from_vec(vec![rho0, rho1, rho2_floor]));
136        }
137    }
138}
139
140fn add_nu2_reverse_manifold_seeds(
141    seeds: &mut Vec<Array1<f64>>,
142    seen: &mut HashSet<Vec<u64>>,
143    bounds: (f64, f64),
144    primary: &Array1<f64>,
145) {
146    if primary.len() != 3 {
147        return;
148    }
149    let ln_two = 2.0_f64.ln();
150    let tau_anchors = [primary[2], 0.0, -2.0, 2.0];
151    let log_kappa_grid = [-2.0, -1.0, 0.0, 1.0, 2.0];
152    for &tau_rho in &tau_anchors {
153        for &log_kappa in &log_kappa_grid {
154            // Continuous-order reverse map at nu=2:
155            // lambda0 = tau * kappa^4, lambda1 = tau * 2*kappa^2, lambda2 = tau.
156            let rho2 = clamp_seed_rho_to_bounds(tau_rho, bounds);
157            let rho1 = clamp_seed_rho_to_bounds(tau_rho + ln_two + 2.0 * log_kappa, bounds);
158            let rho0 = clamp_seed_rho_to_bounds(tau_rho + 4.0 * log_kappa, bounds);
159            add_seed_dedup(seeds, seen, Array1::from_vec(vec![rho0, rho1, rho2]));
160        }
161    }
162}
163
164fn halton(mut index: usize, base: usize) -> f64 {
165    let mut f = 1.0_f64;
166    let mut r = 0.0_f64;
167    while index > 0 {
168        f /= base as f64;
169        r += f * (index % base) as f64;
170        index /= base;
171    }
172    r
173}
174
175fn first_primes(n: usize) -> Vec<usize> {
176    let mut primes = Vec::with_capacity(n);
177    let mut x = 2usize;
178    while primes.len() < n {
179        let mut is_prime = true;
180        let mut d = 2usize;
181        while d * d <= x {
182            if x.is_multiple_of(d) {
183                is_prime = false;
184                break;
185            }
186            d += 1;
187        }
188        if is_prime {
189            primes.push(x);
190        }
191        x += 1;
192    }
193    primes
194}
195
196pub fn generate_rho_candidates(
197    num_penalties: usize,
198    heuristic_rhos: Option<&[f64]>,
199    config: &SeedConfig,
200) -> Vec<Array1<f64>> {
201    let mut seeds = Vec::new();
202    let mut seen: HashSet<Vec<u64>> = HashSet::new();
203
204    let bounds = normalize_seed_bounds(config.bounds);
205    let max_seeds = config.max_seeds.max(1);
206    let risk_shift = config.risk_profile.anchor_rho_shift();
207
208    if num_penalties == 0 {
209        add_seed_dedup(&mut seeds, &mut seen, Array1::<f64>::zeros(0));
210        return seeds;
211    }
212
213    // Prefer a full heuristic vector (length == k) as the primary anchor.
214    // Values are already in the outer optimizer's rho/theta parameter space.
215    let num_aux = config.num_auxiliary_trailing.min(num_penalties);
216    let num_smoothing = num_penalties - num_aux;
217    let aux_initial: Vec<f64> = if num_aux > 0 {
218        heuristic_rhos
219            .filter(|h| h.len() == num_penalties)
220            .map(|h| {
221                h[num_smoothing..]
222                    .iter()
223                    .copied()
224                    .map(|v| clamp_seed_rho_to_bounds(v, bounds))
225                    .collect()
226            })
227            .unwrap_or_else(|| vec![0.0; num_aux])
228    } else {
229        Vec::new()
230    };
231    let heuristic_rhovec: Option<Array1<f64>> = heuristic_rhos.and_then(|vals| {
232        if vals.len() == num_penalties {
233            Some(Array1::from_iter(
234                vals[..num_smoothing]
235                    .iter()
236                    .copied()
237                    .map(|v| clamp_seed_rho_to_bounds(v, bounds))
238                    .chain(
239                        vals[num_smoothing..]
240                            .iter()
241                            .copied()
242                            .map(|v| clamp_seed_rho_to_bounds(v, bounds)),
243                    ),
244            ))
245        } else {
246            None
247        }
248    });
249
250    let primary = heuristic_rhovec.clone().unwrap_or_else(|| {
251        Array1::<f64>::from_elem(num_penalties, clamp_seed_rho_to_bounds(risk_shift, bounds))
252    });
253    add_seed_dedup(&mut seeds, &mut seen, primary.clone());
254    // Always include neutral baseline independently of heuristic anchor.
255    add_seed_dedup(&mut seeds, &mut seen, Array1::zeros(num_penalties));
256    // Generalized and survival models can hit PIRLS separation at moderate
257    // smoothing levels. Put an aggressively over-smoothed isotropic seed near
258    // the front so startup validation can still find a stable basin.
259    match config.risk_profile {
260        SeedRiskProfile::Gaussian | SeedRiskProfile::GaussianLocationScale => {}
261        SeedRiskProfile::GeneralizedLinear | SeedRiskProfile::Survival => {
262            add_seed_dedup(
263                &mut seeds,
264                &mut seen,
265                Array1::from_elem(num_penalties, bounds.1),
266            );
267        }
268    }
269    // For exactly three smoothing penalties (mass/tension/stiffness), inject
270    // physically coherent manifold seeds in rho-space:
271    // - general SPDE manifold over (log_tau, log_kappa, nu),
272    // - nu=2 reverse-map seeds,
273    // - first-order fallback seeds (lambda2 near lower bound).
274    if num_smoothing == 3 {
275        let smoothing_primary =
276            Array1::from_vec(primary.iter().take(num_smoothing).copied().collect());
277        let smoothing_heuristic_lambdas = heuristic_rhos.and_then(|vals| {
278            if vals.len() >= num_smoothing {
279                Some(&vals[..num_smoothing])
280            } else {
281                None
282            }
283        });
284        let mut spde_prefix_seeds = Vec::new();
285        let mut spde_prefix_seen: HashSet<Vec<u64>> = HashSet::new();
286        // Guarantee a first-order fallback anchor regardless of later truncation.
287        add_seed_dedup(
288            &mut spde_prefix_seeds,
289            &mut spde_prefix_seen,
290            Array1::from_vec(vec![primary[0], primary[1], bounds.0]),
291        );
292        // Ensure a nu=2-consistent seed is always present before broader grids.
293        add_nu2_reverse_manifold_seeds(
294            &mut spde_prefix_seeds,
295            &mut spde_prefix_seen,
296            bounds,
297            &smoothing_primary,
298        );
299        add_first_order_fallback_seeds(
300            &mut spde_prefix_seeds,
301            &mut spde_prefix_seen,
302            bounds,
303            smoothing_heuristic_lambdas,
304        );
305        add_spde_manifold_seeds(
306            &mut spde_prefix_seeds,
307            &mut spde_prefix_seen,
308            bounds,
309            smoothing_heuristic_lambdas,
310            &smoothing_primary,
311        );
312        for prefix_seed in spde_prefix_seeds {
313            let mut seed = Array1::<f64>::zeros(num_penalties);
314            for i in 0..num_smoothing {
315                seed[i] = prefix_seed[i];
316            }
317            for (i, &v) in aux_initial.iter().enumerate() {
318                seed[num_smoothing + i] = v;
319            }
320            add_seed_dedup(&mut seeds, &mut seen, seed);
321        }
322    }
323
324    // Broad symmetric baselines around the center to guarantee global coverage.
325    for &center in config.risk_profile.baseline_centers() {
326        add_seed_dedup(
327            &mut seeds,
328            &mut seen,
329            Array1::from_elem(num_penalties, clamp_seed_rho_to_bounds(center, bounds)),
330        );
331    }
332
333    let dims_to_touch = num_penalties.min(12);
334    let step_base = if num_penalties <= 4 {
335        2.0
336    } else if num_penalties <= 12 {
337        2.5
338    } else {
339        3.0
340    };
341    let high_dim_cluster_threshold = 10usize;
342
343    if num_penalties >= high_dim_cluster_threshold {
344        // High-dimensional path: probe relative scaling conflicts by clustering
345        // penalties into low/high heuristic-magnitude groups.
346        let mut sorted_idx: Vec<usize> = (0..num_penalties).collect();
347        sorted_idx.sort_by(|&i, &j| primary[i].total_cmp(&primary[j]));
348
349        let cluster_size = (num_penalties / 3).max(1);
350        let small_end = cluster_size.min(num_penalties);
351        let large_start = num_penalties.saturating_sub(cluster_size);
352        let small_cluster = &sorted_idx[..small_end];
353        let large_cluster = &sorted_idx[large_start..];
354
355        let small_scale = step_base;
356        let large_scale = step_base + 0.75;
357
358        let mut conflict_a = primary.clone();
359        for &i in large_cluster {
360            conflict_a[i] = clamp_seed_rho_to_bounds(primary[i] + large_scale, bounds);
361        }
362        for &i in small_cluster {
363            conflict_a[i] = clamp_seed_rho_to_bounds(primary[i] - small_scale, bounds);
364        }
365        add_seed_dedup(&mut seeds, &mut seen, conflict_a);
366
367        let mut conflict_b = primary.clone();
368        for &i in large_cluster {
369            conflict_b[i] = clamp_seed_rho_to_bounds(primary[i] - large_scale, bounds);
370        }
371        for &i in small_cluster {
372            conflict_b[i] = clamp_seed_rho_to_bounds(primary[i] + small_scale, bounds);
373        }
374        add_seed_dedup(&mut seeds, &mut seen, conflict_b);
375
376        let mut heavy_up = primary.clone();
377        for &i in large_cluster {
378            heavy_up[i] = clamp_seed_rho_to_bounds(primary[i] + large_scale, bounds);
379        }
380        add_seed_dedup(&mut seeds, &mut seen, heavy_up);
381
382        let mut light_down = primary.clone();
383        for &i in small_cluster {
384            light_down[i] = clamp_seed_rho_to_bounds(primary[i] - small_scale, bounds);
385        }
386        add_seed_dedup(&mut seeds, &mut seen, light_down);
387    } else {
388        // Low-dimensional path: coordinate and sparse pair probes are still cheap.
389        for i in 0..dims_to_touch {
390            let scale = step_base + 0.25 * primary[i].abs().min(8.0);
391            for dir in [-1.0, 1.0] {
392                let mut s = primary.clone();
393                s[i] = clamp_seed_rho_to_bounds(primary[i] + dir * scale, bounds);
394                add_seed_dedup(&mut seeds, &mut seen, s);
395            }
396        }
397
398        let pair_dims = num_penalties.min(6);
399        for i in 0..pair_dims {
400            for j in (i + 1)..pair_dims {
401                let mut s1 = primary.clone();
402                s1[i] = clamp_seed_rho_to_bounds(primary[i] + step_base, bounds);
403                s1[j] = clamp_seed_rho_to_bounds(primary[j] - step_base, bounds);
404                add_seed_dedup(&mut seeds, &mut seen, s1);
405
406                let mut s2 = primary.clone();
407                s2[i] = clamp_seed_rho_to_bounds(primary[i] - step_base, bounds);
408                s2[j] = clamp_seed_rho_to_bounds(primary[j] + step_base, bounds);
409                add_seed_dedup(&mut seeds, &mut seen, s2);
410            }
411        }
412    }
413
414    // Global shrink/expand sweeps from the anchor to probe over/under-smoothing regimes.
415    // The flexible (negative-shift) side MUST be probed as densely as the
416    // over-smoothing side: the seed-screening proxy is a capped-inner-iteration
417    // fit, and an over-smoothed seed converges trivially under that cap (its
418    // coefficients collapse into the penalty null space, the LAML is locally
419    // flat), so screening systematically ranks over-smoothed seeds first
420    // (documented in `rank_seeds_with_screening`). For a GeneralizedLinear /
421    // Survival model whose true optimum is flexible (e.g. a smooth Poisson
422    // tensor surface that genuinely needs ~10 effective df), a seed grid that
423    // only sweeps the over-smoothing side leaves the flexible basin unprobed,
424    // so none of the few full-budget solves ever lands in it and the fit
425    // over-smooths (#1082/#1373). Symmetric negative shifts give the flexible
426    // basin a candidate; the keep-best multi-start then retains it only if it
427    // actually scores better, so this can never worsen a fit — it only lets the
428    // optimizer SEE the lower-λ basin. Over-smoothed seeds remain present (and
429    // earlier in the list) so PIRLS-separation startup stability is unchanged.
430    for &shift in config.risk_profile.global_shifts() {
431        let swept = primary.mapv(|v| clamp_seed_rho_to_bounds(v + shift, bounds));
432        add_seed_dedup(&mut seeds, &mut seen, swept);
433    }
434
435    // #1464 over-smoothing probe: an ABSOLUTE high-λ start on every smoothing
436    // dimension (auxiliary dims left at the anchor's values; they are re-pinned
437    // below). The global shift sweeps above reach only ≈ +4 from the anchor, so
438    // a collapsing-kernel smooth whose true REML optimum is a large λ would never
439    // be seeded into its over-smoothing basin. This puts a candidate IN it; the
440    // keep-best multistart adopts it only when it scores strictly better, so it
441    // can never worsen a fit. `None` (the default) skips this entirely.
442    if let Some(probe_rho) = config.over_smoothing_probe_rho {
443        let mut probe = primary.clone();
444        for j in 0..num_smoothing {
445            probe[j] = clamp_seed_rho_to_bounds(probe_rho, bounds);
446        }
447        add_seed_dedup(&mut seeds, &mut seen, probe);
448    }
449
450    // Low-discrepancy exploratory seeds around the anchor for basin discovery.
451    // These are still deterministic and do not encode any solver-side bias.
452    let exploratory = max_seeds.saturating_sub(seeds.len()).min(8);
453    if exploratory > 0 {
454        let primes = first_primes(num_penalties.max(1));
455        let amp = config.risk_profile.exploratory_amplitude();
456        for t in 0..exploratory {
457            let mut s = primary.clone();
458            for i in 0..num_penalties {
459                let u = halton(t + 1, primes[i]); // (0,1)
460                let centered = 2.0 * u - 1.0; // (-1,1)
461                s[i] = clamp_seed_rho_to_bounds(primary[i] + amp * centered, bounds);
462            }
463            add_seed_dedup(&mut seeds, &mut seen, s);
464        }
465    }
466
467    // Pin auxiliary trailing dimensions to their initial values in every seed.
468    // Auxiliary params (e.g. SAS epsilon, log_delta) live in a different
469    // parameter space than log-smoothing rho and must not be swept by the
470    // smoothing seeding grid.  After pinning we re-dedup because seeds that
471    // differed only in the (now-overwritten) auxiliary dimensions collapse.
472    if num_aux > 0 {
473        for seed in &mut seeds {
474            for (i, &v) in aux_initial.iter().enumerate() {
475                seed[num_smoothing + i] = v;
476            }
477        }
478        let mut deduped = Vec::new();
479        let mut seen2: HashSet<Vec<u64>> = HashSet::new();
480        for seed in seeds {
481            let key: Vec<u64> = seed.iter().map(|&v| v.to_bits()).collect();
482            if seen2.insert(key) {
483                deduped.push(seed);
484            }
485        }
486        seeds = deduped;
487    }
488
489    if seeds.len() > max_seeds {
490        seeds.truncate(max_seeds);
491    }
492
493    if seeds.is_empty() {
494        seeds.push(Array1::<f64>::zeros(num_penalties));
495    }
496
497    seeds
498}
499
500/// Choose an initial log-smoothing vector by evaluating the same objective the
501/// outer optimizer will minimize on a small deterministic grid around the
502/// analytic/heuristic seed.
503///
504/// This is initialization, not a fallback: no candidate is accepted unless it
505/// has a lower finite objective value under `eval_cost`, and the returned seed
506/// is still optimized by the normal outer solver.
507pub fn select_objective_seed_on_log_lambda_grid<F>(
508    rho_seed: &Array1<f64>,
509    bounds: (f64, f64),
510    n_smooths: usize,
511    nullspace_coords: &[usize],
512    mut eval_cost: F,
513) -> Array1<f64>
514where
515    F: FnMut(&Array1<f64>) -> Option<f64>,
516{
517    let k = rho_seed.len();
518    if k == 0 || n_smooths == 0 || n_smooths > k {
519        return rho_seed.clone();
520    }
521    let bnds = normalize_seed_bounds(bounds);
522    let clamp_vec = |v: &Array1<f64>| -> Array1<f64> {
523        let mut out = v.clone();
524        for i in 0..n_smooths {
525            out[i] = clamp_seed_rho_to_bounds(out[i], bnds);
526        }
527        out
528    };
529
530    let baseline_seed = clamp_vec(rho_seed);
531    let baseline_cost = eval_cost(&baseline_seed);
532    log::info!(
533        "[SEED-GRID] baseline rho=[{}] cost={}",
534        baseline_seed
535            .iter()
536            .map(|v| format!("{v:.2}"))
537            .collect::<Vec<_>>()
538            .join(","),
539        baseline_cost
540            .map(|c| format!("{c:.6e}"))
541            .unwrap_or_else(|| "non-finite".to_string()),
542    );
543
544    let shifts: [f64; 9] = [-12.0, -9.0, -6.0, -3.0, 0.0, 3.0, 6.0, 9.0, 12.0];
545    let mut best_seed = baseline_seed.clone();
546    let mut best_cost: Option<f64> = baseline_cost.filter(|c| c.is_finite());
547
548    for &delta in &shifts {
549        if delta == 0.0 && best_cost.is_some() {
550            continue;
551        }
552        let mut candidate = rho_seed.clone();
553        for i in 0..n_smooths {
554            candidate[i] = clamp_seed_rho_to_bounds(rho_seed[i] + delta, bnds);
555        }
556        let c_opt = eval_cost(&candidate);
557        log::info!(
558            "[SEED-GRID] shift={:+.1} rho=[{}] cost={}",
559            delta,
560            candidate
561                .iter()
562                .map(|v| format!("{v:.2}"))
563                .collect::<Vec<_>>()
564                .join(","),
565            c_opt
566                .map(|c| format!("{c:.6e}"))
567                .unwrap_or_else(|| "non-finite".to_string()),
568        );
569        if let Some(c) = c_opt
570            && c.is_finite()
571            && best_cost.map(|b| c < b).unwrap_or(true)
572        {
573            best_cost = Some(c);
574            best_seed = candidate;
575        }
576    }
577
578    if n_smooths <= 6 {
579        // Per-axis refinement around the best isotropic point. The ±3 steps
580        // resolve a mild per-coordinate imbalance; the explicit saturation
581        // target (`bnds.1`, the over-smoothing upper bound) reaches asymmetric
582        // corners where selected penalty blocks are fully active while the
583        // others stay at the refined anchor. Those corners are load-bearing for
584        // double-penalty (Marra-Wood null-space shrinkage) smooths (#1266): an
585        // unsupported term must be allowed to send BOTH its wiggliness and
586        // null-space coordinates high, while a supported sibling term remains
587        // free. The isotropic grid only moves all coordinates together, so it
588        // cannot express "shrink s(z), keep s(x)". Probing these corners is
589        // criterion-ranked — a candidate is adopted only when it strictly lowers
590        // the true REML/LAML cost — so a genuinely better interior optimum or
591        // supported smooth simply wins the comparison.
592        let saturation = clamp_seed_rho_to_bounds(bnds.1, bnds);
593        // Lower-saturation ("keep") corner, the symmetric dual of `saturation`.
594        // The per-axis sweep above probes the over-smoothing/shrink-out corner
595        // (`bnds.1`) so an unsupported double-penalty null-space coordinate can
596        // rail its λ_null up and select the term out (#1266). The MISSING corner
597        // is the opposite one: a SUPPORTED null space (a genuine linear/constant
598        // trend the data buy) has its global REML optimum at a LOW λ_null "keep"
599        // basin, separated from the high-λ_null annihilation shelf by a flat
600        // valley. Without a keep-direction probe the grid can seed only the shelf
601        // corner, leaving the outer optimizer to cross that flat valley to reach
602        // the keep basin — a crossing whose success rode on sub-ULP gradient signs
603        // that a covariate reflection x→−x flips, so the mirror fit stalled on the
604        // shelf and annihilated the supported trend (#1548). Probing the keep
605        // corner for EXACTLY the null-space coordinates (where un-shrinking is
606        // safe — the wiggliness penalty stays active, so there is no λ→0
607        // inner-cap overfit artifact) lets the grid seed the well-conditioned keep
608        // basin directly. It is criterion-ranked like every other probe: an
609        // unsupported term's keep corner is never cheaper than its shrink-out
610        // corner, so #1266 is untouched.
611        let keep_saturation = clamp_seed_rho_to_bounds(bnds.0, bnds);
612        for axis in 0..n_smooths {
613            let anchor = best_seed.clone();
614            let mut targets = vec![
615                clamp_seed_rho_to_bounds(anchor[axis] - 3.0, bnds),
616                clamp_seed_rho_to_bounds(anchor[axis] + 3.0, bnds),
617            ];
618            if (anchor[axis] - saturation).abs() > 1e-9 {
619                targets.push(saturation);
620            }
621            if nullspace_coords.contains(&axis) {
622                // Step toward the keep basin (a moderate un-shrink) and the full
623                // keep saturation, so the probe reaches the basin wherever it sits
624                // between the anchor and λ_null → 0.
625                targets.push(clamp_seed_rho_to_bounds(anchor[axis] - 6.0, bnds));
626                if (anchor[axis] - keep_saturation).abs() > 1e-9 {
627                    targets.push(keep_saturation);
628                }
629            }
630            for target in targets {
631                let mut candidate = anchor.clone();
632                candidate[axis] = target;
633                if let Some(c) = eval_cost(&candidate)
634                    && c.is_finite()
635                    && best_cost.map(|b| c < b).unwrap_or(true)
636                {
637                    best_cost = Some(c);
638                    best_seed = candidate;
639                }
640            }
641        }
642        for start in 0..n_smooths.saturating_sub(1) {
643            let anchor = best_seed.clone();
644            let mut candidate = anchor;
645            candidate[start] = saturation;
646            candidate[start + 1] = saturation;
647            if let Some(c) = eval_cost(&candidate)
648                && c.is_finite()
649                && best_cost.map(|b| c < b).unwrap_or(true)
650            {
651                best_cost = Some(c);
652                best_seed = candidate;
653            }
654        }
655    }
656
657    best_seed
658}
659
660#[cfg(test)]
661mod tests {
662    use super::*;
663
664    #[test]
665    fn uses_full_heuristicvector_as_primary_anchor() {
666        let cfg = SeedConfig {
667            risk_profile: SeedRiskProfile::Gaussian,
668            ..SeedConfig::default()
669        };
670        let heur = [-2.0, 0.0, 2.0];
671        let seeds = generate_rho_candidates(3, Some(&heur), &cfg);
672        assert!(!seeds.is_empty());
673        let first = &seeds[0];
674        assert_eq!(first.len(), 3);
675        assert!((first[0] - heur[0]).abs() < 1e-12);
676        assert!((first[1] - heur[1]).abs() < 1e-12);
677        assert!((first[2] - heur[2]).abs() < 1e-12);
678    }
679
680    #[test]
681    fn high_dim_uses_cluster_conflict_probeswithout_exploding() {
682        let cfg = SeedConfig {
683            max_seeds: 18,
684            risk_profile: SeedRiskProfile::GeneralizedLinear,
685            ..SeedConfig::default()
686        };
687        let heur = [-6.0, -5.0, -4.0, 0.0, 2.0, 4.0, -3.0, 0.0, 3.0, 5.0];
688        let seeds = generate_rho_candidates(10, Some(&heur), &cfg);
689        assert!(seeds.len() <= 18);
690        // Presence of at least one asymmetric cluster-conflict seed:
691        // some coordinates increased while others decreased vs primary.
692        let primary = &seeds[0];
693        let has_conflict = seeds.iter().skip(1).any(|s| {
694            let mut any_up = false;
695            let mut any_down = false;
696            for i in 0..s.len() {
697                if s[i] > primary[i] {
698                    any_up = true;
699                } else if s[i] < primary[i] {
700                    any_down = true;
701                }
702            }
703            any_up && any_down
704        });
705        assert!(has_conflict);
706    }
707
708    #[test]
709    fn includes_neutralzero_seed() {
710        let cfg = SeedConfig::default();
711        let seeds = generate_rho_candidates(5, None, &cfg);
712        let haszero = seeds
713            .iter()
714            .any(|s| s.iter().all(|v| (*v - 0.0).abs() < 1e-12));
715        assert!(haszero);
716    }
717
718    #[test]
719    fn generalized_linear_seeds_include_early_stability_retreat_seed() {
720        let cfg = SeedConfig {
721            risk_profile: SeedRiskProfile::GeneralizedLinear,
722            ..SeedConfig::default()
723        };
724        let seeds = generate_rho_candidates(3, None, &cfg);
725        let retreat = Array1::from_elem(3, cfg.bounds.1);
726        let retreat_idx = seeds
727            .iter()
728            .position(|seed| seed == retreat)
729            .expect("generalized-linear seeds should include an upper-bound retreat seed");
730        assert!(
731            retreat_idx <= 2,
732            "retreat seed should be available before broader exploratory seeds: {retreat_idx}"
733        );
734    }
735
736    #[test]
737    fn objective_grid_can_seed_adjacent_pair_oversmoothing_corner() {
738        let base = Array1::zeros(4);
739        let selected =
740            select_objective_seed_on_log_lambda_grid(&base, (-12.0, 12.0), 4, &[], |rho| {
741                let supported_cost = 0.1 * (rho[0].powi(2) + rho[1].powi(2));
742                let unsupported_gap = (rho[2] - 12.0).powi(2) + (rho[3] - 12.0).powi(2);
743                Some(supported_cost + unsupported_gap)
744            });
745        assert_eq!(selected.to_vec(), vec![0.0, 0.0, 12.0, 12.0]);
746    }
747
748    #[test]
749    fn three_penalty_seeds_include_nu2_reverse_manifold_triplets() {
750        let cfg = SeedConfig::default();
751        let seeds = generate_rho_candidates(3, None, &cfg);
752        let ln4 = 4.0_f64.ln();
753        let has_nu2_manifold_seed = seeds
754            .iter()
755            .any(|s| s.len() == 3 && ((2.0 * s[1] - s[0] - s[2]) - ln4).abs() < 1e-8);
756        assert!(has_nu2_manifold_seed);
757    }
758
759    #[test]
760    fn three_penalty_seeds_include_general_spde_manifold_points() {
761        let cfg = SeedConfig::default();
762        let heur = [2.0, 10.0, 3.0];
763        let seeds = generate_rho_candidates(3, Some(&heur), &cfg);
764        let has_non_nu2 = seeds.iter().any(|s| {
765            // For nu=2, 2*rho1-rho0-rho2 = ln(4).
766            // General nu manifold should include points away from ln(4).
767            s.len() == 3 && ((2.0 * s[1] - s[0] - s[2]) - 4.0_f64.ln()).abs() > 1e-3
768        });
769        assert!(has_non_nu2);
770    }
771
772    #[test]
773    fn three_penalty_seeds_include_first_order_fallbackwith_rho2_floor() {
774        let cfg = SeedConfig {
775            bounds: (-12.0, 12.0),
776            ..SeedConfig::default()
777        };
778        let seeds = generate_rho_candidates(3, None, &cfg);
779        let has_floor = seeds
780            .iter()
781            .any(|s| s.len() == 3 && (s[2] - (-12.0)).abs() < 1e-12);
782        assert!(has_floor);
783    }
784
785    #[test]
786    fn auxiliary_trailing_dims_pinned_to_initial_values() {
787        // Simulate SAS optimization: 2 smoothing dims + 2 auxiliary dims
788        // (epsilon=0, log_delta=0).  The heuristic vector is in rho/theta
789        // space for both smoothing and auxiliary dimensions.
790        let cfg = SeedConfig {
791            num_auxiliary_trailing: 2,
792            risk_profile: SeedRiskProfile::GeneralizedLinear,
793            ..SeedConfig::default()
794        };
795        let heur = [0.0, 10.0_f64.ln(), 0.0, 0.0]; // rhos + SAS initials
796        let seeds = generate_rho_candidates(4, Some(&heur), &cfg);
797        assert!(!seeds.is_empty());
798        // EVERY seed must have the auxiliary dims pinned to 0.0.
799        for (idx, seed) in seeds.iter().enumerate() {
800            assert_eq!(seed.len(), 4);
801            assert!(
802                (seed[2] - 0.0).abs() < 1e-12 && (seed[3] - 0.0).abs() < 1e-12,
803                "seed {} has auxiliary dims [{}, {}], expected [0, 0]",
804                idx,
805                seed[2],
806                seed[3],
807            );
808        }
809        // The smoothing dims should NOT all be zero (some seeds should vary them).
810        let has_nonzero_smoothing = seeds
811            .iter()
812            .any(|s| s[0].abs() > 1e-12 || s[1].abs() > 1e-12);
813        assert!(has_nonzero_smoothing);
814    }
815
816    #[test]
817    fn auxiliary_dims_dedup_collapses_identical_seeds() {
818        // With auxiliary pinning, seeds that differed only in aux dims
819        // should collapse to a single seed.
820        let cfg = SeedConfig {
821            num_auxiliary_trailing: 1,
822            max_seeds: 32,
823            risk_profile: SeedRiskProfile::GeneralizedLinear,
824            ..SeedConfig::default()
825        };
826        let seeds_with_aux = generate_rho_candidates(3, None, &cfg);
827        let cfg_no_aux = SeedConfig {
828            num_auxiliary_trailing: 0,
829            max_seeds: 32,
830            risk_profile: SeedRiskProfile::GeneralizedLinear,
831            ..SeedConfig::default()
832        };
833        let seeds_without_aux = generate_rho_candidates(3, None, &cfg_no_aux);
834        // Aux pinning causes many seeds to collapse, so fewer unique seeds.
835        assert!(seeds_with_aux.len() <= seeds_without_aux.len());
836    }
837
838    #[test]
839    fn objective_grid_seed_selects_lowest_finite_cost_candidate() {
840        let base = Array1::from_vec(vec![0.0, 0.0]);
841        let selected =
842            select_objective_seed_on_log_lambda_grid(&base, (-12.0, 12.0), 2, &[], |rho| {
843                Some((rho[0] - 6.0).powi(2) + (rho[1] - 6.0).powi(2))
844            });
845
846        assert!((selected[0] - 6.0).abs() < 1e-12);
847        assert!((selected[1] - 6.0).abs() < 1e-12);
848    }
849
850    #[test]
851    fn objective_grid_seed_keeps_baseline_when_no_candidate_improves_cost() {
852        let base = Array1::from_vec(vec![1.0, -2.0]);
853        let selected =
854            select_objective_seed_on_log_lambda_grid(&base, (-12.0, 12.0), 2, &[], |rho| {
855                if (rho[0] - 1.0).abs() < 1e-12 && (rho[1] + 2.0).abs() < 1e-12 {
856                    Some(0.0)
857                } else {
858                    Some(1.0)
859                }
860            });
861
862        assert_eq!(selected, base);
863    }
864}