1use ndarray::Array1;
2use std::collections::HashSet;
3
4pub use gam_problem::{SeedConfig, SeedRiskProfile};
5use gam_problem::{clamp_seed_rho_to_bounds, normalize_seed_bounds};
6
7fn add_seed_dedup(seeds: &mut Vec<Array1<f64>>, seen: &mut HashSet<Vec<u64>>, seed: Array1<f64>) {
8 let key: Vec<u64> = seed.iter().map(|&v| v.to_bits()).collect();
9 if seen.insert(key) {
10 seeds.push(seed);
11 }
12}
13
14fn safe_ln_pos(x: f64) -> Option<f64> {
15 if x.is_finite() && x > 0.0 {
16 Some(x.ln())
17 } else {
18 None
19 }
20}
21
22fn spde_rho_triplet_from_log_tau_log_kappa_nu(
23 log_tau: f64,
24 log_kappa: f64,
25 nu: f64,
26 bounds: (f64, f64),
27) -> Option<Array1<f64>> {
28 if !(nu.is_finite() && nu > 1.0) {
29 return None;
30 }
31 let logc0 = 0.0;
32 let logc1 = safe_ln_pos(nu)?;
33 let logc2 = safe_ln_pos(0.5 * nu * (nu - 1.0))?;
34 let rho0 = clamp_seed_rho_to_bounds(log_tau + logc0 + 2.0 * nu * log_kappa, bounds);
35 let rho1 = clamp_seed_rho_to_bounds(log_tau + logc1 + 2.0 * (nu - 1.0) * log_kappa, bounds);
36 let rho2 = clamp_seed_rho_to_bounds(log_tau + logc2 + 2.0 * (nu - 2.0) * log_kappa, bounds);
37 Some(Array1::from_vec(vec![rho0, rho1, rho2]))
38}
39
40fn add_spde_manifold_seeds(
41 seeds: &mut Vec<Array1<f64>>,
42 seen: &mut HashSet<Vec<u64>>,
43 bounds: (f64, f64),
44 heuristic_rhos: Option<&[f64]>,
45 primary: &Array1<f64>,
46) {
47 if primary.len() != 3 {
48 return;
49 }
50 let tau_anchors = [primary[2], 0.0, -2.0, 2.0];
52 let log_kappa_grid = [-2.0, -1.0, 0.0, 1.0, 2.0];
53 let nu_grid = [1.25, 1.5, 2.0, 2.5, 3.0, 4.0];
54 for &tau in &tau_anchors {
55 for &lk in &log_kappa_grid {
56 for &nu in &nu_grid {
57 if let Some(seed) = spde_rho_triplet_from_log_tau_log_kappa_nu(tau, lk, nu, bounds)
58 {
59 add_seed_dedup(seeds, seen, seed);
60 }
61 }
62 }
63 }
64
65 if let Some(vals) = heuristic_rhos
68 && vals.len() == 3
69 {
70 let l0 = vals[0].exp();
71 let l1 = vals[1].exp();
72 let l2 = vals[2].exp();
73 if l0.is_finite() && l1.is_finite() && l2.is_finite() && l0 > 1e-12 && l2 > 1e-12 {
74 let r = (l1 * l1) / (l0 * l2);
75 if r > 2.0 {
76 let nu = r / (r - 2.0);
77 let kappa2 = l1 / ((r - 2.0) * l2);
78 if nu.is_finite() && nu > 1.0 && kappa2.is_finite() && kappa2 > 0.0 {
79 let log_kappa = 0.5 * kappa2.ln();
80 let c2 = 0.5 * nu * (nu - 1.0);
81 if c2.is_finite() && c2 > 0.0 {
82 let log_tau = (l2 / (c2 * kappa2.powf(nu - 2.0))).max(1e-12).ln();
83 let local_nu = [nu, (nu - 0.3).max(1.05), nu + 0.3];
84 let local_tau = [log_tau, log_tau - 1.0, log_tau + 1.0];
85 let local_kappa = [log_kappa, log_kappa - 0.5, log_kappa + 0.5];
86 for &t in &local_tau {
87 for &lk in &local_kappa {
88 for &n in &local_nu {
89 if let Some(seed) =
90 spde_rho_triplet_from_log_tau_log_kappa_nu(t, lk, n, bounds)
91 {
92 add_seed_dedup(seeds, seen, seed);
93 }
94 }
95 }
96 }
97 }
98 }
99 }
100 }
101 }
102}
103
104fn add_first_order_fallback_seeds(
105 seeds: &mut Vec<Array1<f64>>,
106 seen: &mut HashSet<Vec<u64>>,
107 bounds: (f64, f64),
108 heuristic_rhos: Option<&[f64]>,
109) {
110 let rho2_floor = bounds.0;
113 let default_log_kappa = [-2.0, -1.0, 0.0, 1.0];
114 let default_log_tau = [0.0, -2.0, 2.0];
115 for &t in &default_log_tau {
116 for &lk in &default_log_kappa {
117 let rho0 = clamp_seed_rho_to_bounds(t + 2.0 * lk, bounds);
118 let rho1 = clamp_seed_rho_to_bounds(t, bounds);
119 add_seed_dedup(seeds, seen, Array1::from_vec(vec![rho0, rho1, rho2_floor]));
120 }
121 }
122 if let Some(vals) = heuristic_rhos
123 && vals.len() == 3
124 && vals[0].is_finite()
125 && vals[1].is_finite()
126 {
127 let l0 = vals[0].exp();
128 let l1 = vals[1].exp();
129 let kappa2 = l0 / l1;
130 if kappa2.is_finite() && kappa2 > 0.0 {
131 let lk = 0.5 * kappa2.ln();
132 let t = vals[1];
133 let rho0 = clamp_seed_rho_to_bounds(t + 2.0 * lk, bounds);
134 let rho1 = clamp_seed_rho_to_bounds(t, bounds);
135 add_seed_dedup(seeds, seen, Array1::from_vec(vec![rho0, rho1, rho2_floor]));
136 }
137 }
138}
139
140fn add_nu2_reverse_manifold_seeds(
141 seeds: &mut Vec<Array1<f64>>,
142 seen: &mut HashSet<Vec<u64>>,
143 bounds: (f64, f64),
144 primary: &Array1<f64>,
145) {
146 if primary.len() != 3 {
147 return;
148 }
149 let ln_two = 2.0_f64.ln();
150 let tau_anchors = [primary[2], 0.0, -2.0, 2.0];
151 let log_kappa_grid = [-2.0, -1.0, 0.0, 1.0, 2.0];
152 for &tau_rho in &tau_anchors {
153 for &log_kappa in &log_kappa_grid {
154 let rho2 = clamp_seed_rho_to_bounds(tau_rho, bounds);
157 let rho1 = clamp_seed_rho_to_bounds(tau_rho + ln_two + 2.0 * log_kappa, bounds);
158 let rho0 = clamp_seed_rho_to_bounds(tau_rho + 4.0 * log_kappa, bounds);
159 add_seed_dedup(seeds, seen, Array1::from_vec(vec![rho0, rho1, rho2]));
160 }
161 }
162}
163
164fn halton(mut index: usize, base: usize) -> f64 {
165 let mut f = 1.0_f64;
166 let mut r = 0.0_f64;
167 while index > 0 {
168 f /= base as f64;
169 r += f * (index % base) as f64;
170 index /= base;
171 }
172 r
173}
174
175fn first_primes(n: usize) -> Vec<usize> {
176 let mut primes = Vec::with_capacity(n);
177 let mut x = 2usize;
178 while primes.len() < n {
179 let mut is_prime = true;
180 let mut d = 2usize;
181 while d * d <= x {
182 if x.is_multiple_of(d) {
183 is_prime = false;
184 break;
185 }
186 d += 1;
187 }
188 if is_prime {
189 primes.push(x);
190 }
191 x += 1;
192 }
193 primes
194}
195
196pub fn generate_rho_candidates(
197 num_penalties: usize,
198 heuristic_rhos: Option<&[f64]>,
199 config: &SeedConfig,
200) -> Vec<Array1<f64>> {
201 let mut seeds = Vec::new();
202 let mut seen: HashSet<Vec<u64>> = HashSet::new();
203
204 let bounds = normalize_seed_bounds(config.bounds);
205 let max_seeds = config.max_seeds.max(1);
206 let risk_shift = config.risk_profile.anchor_rho_shift();
207
208 if num_penalties == 0 {
209 add_seed_dedup(&mut seeds, &mut seen, Array1::<f64>::zeros(0));
210 return seeds;
211 }
212
213 let num_aux = config.num_auxiliary_trailing.min(num_penalties);
216 let num_smoothing = num_penalties - num_aux;
217 let aux_initial: Vec<f64> = if num_aux > 0 {
218 heuristic_rhos
219 .filter(|h| h.len() == num_penalties)
220 .map(|h| {
221 h[num_smoothing..]
222 .iter()
223 .copied()
224 .map(|v| clamp_seed_rho_to_bounds(v, bounds))
225 .collect()
226 })
227 .unwrap_or_else(|| vec![0.0; num_aux])
228 } else {
229 Vec::new()
230 };
231 let heuristic_rhovec: Option<Array1<f64>> = heuristic_rhos.and_then(|vals| {
232 if vals.len() == num_penalties {
233 Some(Array1::from_iter(
234 vals[..num_smoothing]
235 .iter()
236 .copied()
237 .map(|v| clamp_seed_rho_to_bounds(v, bounds))
238 .chain(
239 vals[num_smoothing..]
240 .iter()
241 .copied()
242 .map(|v| clamp_seed_rho_to_bounds(v, bounds)),
243 ),
244 ))
245 } else {
246 None
247 }
248 });
249
250 let primary = heuristic_rhovec.clone().unwrap_or_else(|| {
251 Array1::<f64>::from_elem(num_penalties, clamp_seed_rho_to_bounds(risk_shift, bounds))
252 });
253 add_seed_dedup(&mut seeds, &mut seen, primary.clone());
254 add_seed_dedup(&mut seeds, &mut seen, Array1::zeros(num_penalties));
256 match config.risk_profile {
260 SeedRiskProfile::Gaussian | SeedRiskProfile::GaussianLocationScale => {}
261 SeedRiskProfile::GeneralizedLinear | SeedRiskProfile::Survival => {
262 add_seed_dedup(
263 &mut seeds,
264 &mut seen,
265 Array1::from_elem(num_penalties, bounds.1),
266 );
267 }
268 }
269 if num_smoothing == 3 {
275 let smoothing_primary =
276 Array1::from_vec(primary.iter().take(num_smoothing).copied().collect());
277 let smoothing_heuristic_lambdas = heuristic_rhos.and_then(|vals| {
278 if vals.len() >= num_smoothing {
279 Some(&vals[..num_smoothing])
280 } else {
281 None
282 }
283 });
284 let mut spde_prefix_seeds = Vec::new();
285 let mut spde_prefix_seen: HashSet<Vec<u64>> = HashSet::new();
286 add_seed_dedup(
288 &mut spde_prefix_seeds,
289 &mut spde_prefix_seen,
290 Array1::from_vec(vec![primary[0], primary[1], bounds.0]),
291 );
292 add_nu2_reverse_manifold_seeds(
294 &mut spde_prefix_seeds,
295 &mut spde_prefix_seen,
296 bounds,
297 &smoothing_primary,
298 );
299 add_first_order_fallback_seeds(
300 &mut spde_prefix_seeds,
301 &mut spde_prefix_seen,
302 bounds,
303 smoothing_heuristic_lambdas,
304 );
305 add_spde_manifold_seeds(
306 &mut spde_prefix_seeds,
307 &mut spde_prefix_seen,
308 bounds,
309 smoothing_heuristic_lambdas,
310 &smoothing_primary,
311 );
312 for prefix_seed in spde_prefix_seeds {
313 let mut seed = Array1::<f64>::zeros(num_penalties);
314 for i in 0..num_smoothing {
315 seed[i] = prefix_seed[i];
316 }
317 for (i, &v) in aux_initial.iter().enumerate() {
318 seed[num_smoothing + i] = v;
319 }
320 add_seed_dedup(&mut seeds, &mut seen, seed);
321 }
322 }
323
324 for ¢er in config.risk_profile.baseline_centers() {
326 add_seed_dedup(
327 &mut seeds,
328 &mut seen,
329 Array1::from_elem(num_penalties, clamp_seed_rho_to_bounds(center, bounds)),
330 );
331 }
332
333 let dims_to_touch = num_penalties.min(12);
334 let step_base = if num_penalties <= 4 {
335 2.0
336 } else if num_penalties <= 12 {
337 2.5
338 } else {
339 3.0
340 };
341 let high_dim_cluster_threshold = 10usize;
342
343 if num_penalties >= high_dim_cluster_threshold {
344 let mut sorted_idx: Vec<usize> = (0..num_penalties).collect();
347 sorted_idx.sort_by(|&i, &j| primary[i].total_cmp(&primary[j]));
348
349 let cluster_size = (num_penalties / 3).max(1);
350 let small_end = cluster_size.min(num_penalties);
351 let large_start = num_penalties.saturating_sub(cluster_size);
352 let small_cluster = &sorted_idx[..small_end];
353 let large_cluster = &sorted_idx[large_start..];
354
355 let small_scale = step_base;
356 let large_scale = step_base + 0.75;
357
358 let mut conflict_a = primary.clone();
359 for &i in large_cluster {
360 conflict_a[i] = clamp_seed_rho_to_bounds(primary[i] + large_scale, bounds);
361 }
362 for &i in small_cluster {
363 conflict_a[i] = clamp_seed_rho_to_bounds(primary[i] - small_scale, bounds);
364 }
365 add_seed_dedup(&mut seeds, &mut seen, conflict_a);
366
367 let mut conflict_b = primary.clone();
368 for &i in large_cluster {
369 conflict_b[i] = clamp_seed_rho_to_bounds(primary[i] - large_scale, bounds);
370 }
371 for &i in small_cluster {
372 conflict_b[i] = clamp_seed_rho_to_bounds(primary[i] + small_scale, bounds);
373 }
374 add_seed_dedup(&mut seeds, &mut seen, conflict_b);
375
376 let mut heavy_up = primary.clone();
377 for &i in large_cluster {
378 heavy_up[i] = clamp_seed_rho_to_bounds(primary[i] + large_scale, bounds);
379 }
380 add_seed_dedup(&mut seeds, &mut seen, heavy_up);
381
382 let mut light_down = primary.clone();
383 for &i in small_cluster {
384 light_down[i] = clamp_seed_rho_to_bounds(primary[i] - small_scale, bounds);
385 }
386 add_seed_dedup(&mut seeds, &mut seen, light_down);
387 } else {
388 for i in 0..dims_to_touch {
390 let scale = step_base + 0.25 * primary[i].abs().min(8.0);
391 for dir in [-1.0, 1.0] {
392 let mut s = primary.clone();
393 s[i] = clamp_seed_rho_to_bounds(primary[i] + dir * scale, bounds);
394 add_seed_dedup(&mut seeds, &mut seen, s);
395 }
396 }
397
398 let pair_dims = num_penalties.min(6);
399 for i in 0..pair_dims {
400 for j in (i + 1)..pair_dims {
401 let mut s1 = primary.clone();
402 s1[i] = clamp_seed_rho_to_bounds(primary[i] + step_base, bounds);
403 s1[j] = clamp_seed_rho_to_bounds(primary[j] - step_base, bounds);
404 add_seed_dedup(&mut seeds, &mut seen, s1);
405
406 let mut s2 = primary.clone();
407 s2[i] = clamp_seed_rho_to_bounds(primary[i] - step_base, bounds);
408 s2[j] = clamp_seed_rho_to_bounds(primary[j] + step_base, bounds);
409 add_seed_dedup(&mut seeds, &mut seen, s2);
410 }
411 }
412 }
413
414 for &shift in config.risk_profile.global_shifts() {
431 let swept = primary.mapv(|v| clamp_seed_rho_to_bounds(v + shift, bounds));
432 add_seed_dedup(&mut seeds, &mut seen, swept);
433 }
434
435 if let Some(probe_rho) = config.over_smoothing_probe_rho {
443 let mut probe = primary.clone();
444 for j in 0..num_smoothing {
445 probe[j] = clamp_seed_rho_to_bounds(probe_rho, bounds);
446 }
447 add_seed_dedup(&mut seeds, &mut seen, probe);
448 }
449
450 let exploratory = max_seeds.saturating_sub(seeds.len()).min(8);
453 if exploratory > 0 {
454 let primes = first_primes(num_penalties.max(1));
455 let amp = config.risk_profile.exploratory_amplitude();
456 for t in 0..exploratory {
457 let mut s = primary.clone();
458 for i in 0..num_penalties {
459 let u = halton(t + 1, primes[i]); let centered = 2.0 * u - 1.0; s[i] = clamp_seed_rho_to_bounds(primary[i] + amp * centered, bounds);
462 }
463 add_seed_dedup(&mut seeds, &mut seen, s);
464 }
465 }
466
467 if num_aux > 0 {
473 for seed in &mut seeds {
474 for (i, &v) in aux_initial.iter().enumerate() {
475 seed[num_smoothing + i] = v;
476 }
477 }
478 let mut deduped = Vec::new();
479 let mut seen2: HashSet<Vec<u64>> = HashSet::new();
480 for seed in seeds {
481 let key: Vec<u64> = seed.iter().map(|&v| v.to_bits()).collect();
482 if seen2.insert(key) {
483 deduped.push(seed);
484 }
485 }
486 seeds = deduped;
487 }
488
489 if seeds.len() > max_seeds {
490 seeds.truncate(max_seeds);
491 }
492
493 if seeds.is_empty() {
494 seeds.push(Array1::<f64>::zeros(num_penalties));
495 }
496
497 seeds
498}
499
500pub fn select_objective_seed_on_log_lambda_grid<F>(
508 rho_seed: &Array1<f64>,
509 bounds: (f64, f64),
510 n_smooths: usize,
511 nullspace_coords: &[usize],
512 mut eval_cost: F,
513) -> Array1<f64>
514where
515 F: FnMut(&Array1<f64>) -> Option<f64>,
516{
517 let k = rho_seed.len();
518 if k == 0 || n_smooths == 0 || n_smooths > k {
519 return rho_seed.clone();
520 }
521 let bnds = normalize_seed_bounds(bounds);
522 let clamp_vec = |v: &Array1<f64>| -> Array1<f64> {
523 let mut out = v.clone();
524 for i in 0..n_smooths {
525 out[i] = clamp_seed_rho_to_bounds(out[i], bnds);
526 }
527 out
528 };
529
530 let baseline_seed = clamp_vec(rho_seed);
531 let baseline_cost = eval_cost(&baseline_seed);
532 log::info!(
533 "[SEED-GRID] baseline rho=[{}] cost={}",
534 baseline_seed
535 .iter()
536 .map(|v| format!("{v:.2}"))
537 .collect::<Vec<_>>()
538 .join(","),
539 baseline_cost
540 .map(|c| format!("{c:.6e}"))
541 .unwrap_or_else(|| "non-finite".to_string()),
542 );
543
544 let shifts: [f64; 9] = [-12.0, -9.0, -6.0, -3.0, 0.0, 3.0, 6.0, 9.0, 12.0];
545 let mut best_seed = baseline_seed.clone();
546 let mut best_cost: Option<f64> = baseline_cost.filter(|c| c.is_finite());
547
548 for &delta in &shifts {
549 if delta == 0.0 && best_cost.is_some() {
550 continue;
551 }
552 let mut candidate = rho_seed.clone();
553 for i in 0..n_smooths {
554 candidate[i] = clamp_seed_rho_to_bounds(rho_seed[i] + delta, bnds);
555 }
556 let c_opt = eval_cost(&candidate);
557 log::info!(
558 "[SEED-GRID] shift={:+.1} rho=[{}] cost={}",
559 delta,
560 candidate
561 .iter()
562 .map(|v| format!("{v:.2}"))
563 .collect::<Vec<_>>()
564 .join(","),
565 c_opt
566 .map(|c| format!("{c:.6e}"))
567 .unwrap_or_else(|| "non-finite".to_string()),
568 );
569 if let Some(c) = c_opt
570 && c.is_finite()
571 && best_cost.map(|b| c < b).unwrap_or(true)
572 {
573 best_cost = Some(c);
574 best_seed = candidate;
575 }
576 }
577
578 if n_smooths <= 6 {
579 let saturation = clamp_seed_rho_to_bounds(bnds.1, bnds);
593 let keep_saturation = clamp_seed_rho_to_bounds(bnds.0, bnds);
612 for axis in 0..n_smooths {
613 let anchor = best_seed.clone();
614 let mut targets = vec![
615 clamp_seed_rho_to_bounds(anchor[axis] - 3.0, bnds),
616 clamp_seed_rho_to_bounds(anchor[axis] + 3.0, bnds),
617 ];
618 if (anchor[axis] - saturation).abs() > 1e-9 {
619 targets.push(saturation);
620 }
621 if nullspace_coords.contains(&axis) {
622 targets.push(clamp_seed_rho_to_bounds(anchor[axis] - 6.0, bnds));
626 if (anchor[axis] - keep_saturation).abs() > 1e-9 {
627 targets.push(keep_saturation);
628 }
629 }
630 for target in targets {
631 let mut candidate = anchor.clone();
632 candidate[axis] = target;
633 if let Some(c) = eval_cost(&candidate)
634 && c.is_finite()
635 && best_cost.map(|b| c < b).unwrap_or(true)
636 {
637 best_cost = Some(c);
638 best_seed = candidate;
639 }
640 }
641 }
642 for start in 0..n_smooths.saturating_sub(1) {
643 let anchor = best_seed.clone();
644 let mut candidate = anchor;
645 candidate[start] = saturation;
646 candidate[start + 1] = saturation;
647 if let Some(c) = eval_cost(&candidate)
648 && c.is_finite()
649 && best_cost.map(|b| c < b).unwrap_or(true)
650 {
651 best_cost = Some(c);
652 best_seed = candidate;
653 }
654 }
655 }
656
657 best_seed
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663
664 #[test]
665 fn uses_full_heuristicvector_as_primary_anchor() {
666 let cfg = SeedConfig {
667 risk_profile: SeedRiskProfile::Gaussian,
668 ..SeedConfig::default()
669 };
670 let heur = [-2.0, 0.0, 2.0];
671 let seeds = generate_rho_candidates(3, Some(&heur), &cfg);
672 assert!(!seeds.is_empty());
673 let first = &seeds[0];
674 assert_eq!(first.len(), 3);
675 assert!((first[0] - heur[0]).abs() < 1e-12);
676 assert!((first[1] - heur[1]).abs() < 1e-12);
677 assert!((first[2] - heur[2]).abs() < 1e-12);
678 }
679
680 #[test]
681 fn high_dim_uses_cluster_conflict_probeswithout_exploding() {
682 let cfg = SeedConfig {
683 max_seeds: 18,
684 risk_profile: SeedRiskProfile::GeneralizedLinear,
685 ..SeedConfig::default()
686 };
687 let heur = [-6.0, -5.0, -4.0, 0.0, 2.0, 4.0, -3.0, 0.0, 3.0, 5.0];
688 let seeds = generate_rho_candidates(10, Some(&heur), &cfg);
689 assert!(seeds.len() <= 18);
690 let primary = &seeds[0];
693 let has_conflict = seeds.iter().skip(1).any(|s| {
694 let mut any_up = false;
695 let mut any_down = false;
696 for i in 0..s.len() {
697 if s[i] > primary[i] {
698 any_up = true;
699 } else if s[i] < primary[i] {
700 any_down = true;
701 }
702 }
703 any_up && any_down
704 });
705 assert!(has_conflict);
706 }
707
708 #[test]
709 fn includes_neutralzero_seed() {
710 let cfg = SeedConfig::default();
711 let seeds = generate_rho_candidates(5, None, &cfg);
712 let haszero = seeds
713 .iter()
714 .any(|s| s.iter().all(|v| (*v - 0.0).abs() < 1e-12));
715 assert!(haszero);
716 }
717
718 #[test]
719 fn generalized_linear_seeds_include_early_stability_retreat_seed() {
720 let cfg = SeedConfig {
721 risk_profile: SeedRiskProfile::GeneralizedLinear,
722 ..SeedConfig::default()
723 };
724 let seeds = generate_rho_candidates(3, None, &cfg);
725 let retreat = Array1::from_elem(3, cfg.bounds.1);
726 let retreat_idx = seeds
727 .iter()
728 .position(|seed| seed == retreat)
729 .expect("generalized-linear seeds should include an upper-bound retreat seed");
730 assert!(
731 retreat_idx <= 2,
732 "retreat seed should be available before broader exploratory seeds: {retreat_idx}"
733 );
734 }
735
736 #[test]
737 fn objective_grid_can_seed_adjacent_pair_oversmoothing_corner() {
738 let base = Array1::zeros(4);
739 let selected =
740 select_objective_seed_on_log_lambda_grid(&base, (-12.0, 12.0), 4, &[], |rho| {
741 let supported_cost = 0.1 * (rho[0].powi(2) + rho[1].powi(2));
742 let unsupported_gap = (rho[2] - 12.0).powi(2) + (rho[3] - 12.0).powi(2);
743 Some(supported_cost + unsupported_gap)
744 });
745 assert_eq!(selected.to_vec(), vec![0.0, 0.0, 12.0, 12.0]);
746 }
747
748 #[test]
749 fn three_penalty_seeds_include_nu2_reverse_manifold_triplets() {
750 let cfg = SeedConfig::default();
751 let seeds = generate_rho_candidates(3, None, &cfg);
752 let ln4 = 4.0_f64.ln();
753 let has_nu2_manifold_seed = seeds
754 .iter()
755 .any(|s| s.len() == 3 && ((2.0 * s[1] - s[0] - s[2]) - ln4).abs() < 1e-8);
756 assert!(has_nu2_manifold_seed);
757 }
758
759 #[test]
760 fn three_penalty_seeds_include_general_spde_manifold_points() {
761 let cfg = SeedConfig::default();
762 let heur = [2.0, 10.0, 3.0];
763 let seeds = generate_rho_candidates(3, Some(&heur), &cfg);
764 let has_non_nu2 = seeds.iter().any(|s| {
765 s.len() == 3 && ((2.0 * s[1] - s[0] - s[2]) - 4.0_f64.ln()).abs() > 1e-3
768 });
769 assert!(has_non_nu2);
770 }
771
772 #[test]
773 fn three_penalty_seeds_include_first_order_fallbackwith_rho2_floor() {
774 let cfg = SeedConfig {
775 bounds: (-12.0, 12.0),
776 ..SeedConfig::default()
777 };
778 let seeds = generate_rho_candidates(3, None, &cfg);
779 let has_floor = seeds
780 .iter()
781 .any(|s| s.len() == 3 && (s[2] - (-12.0)).abs() < 1e-12);
782 assert!(has_floor);
783 }
784
785 #[test]
786 fn auxiliary_trailing_dims_pinned_to_initial_values() {
787 let cfg = SeedConfig {
791 num_auxiliary_trailing: 2,
792 risk_profile: SeedRiskProfile::GeneralizedLinear,
793 ..SeedConfig::default()
794 };
795 let heur = [0.0, 10.0_f64.ln(), 0.0, 0.0]; let seeds = generate_rho_candidates(4, Some(&heur), &cfg);
797 assert!(!seeds.is_empty());
798 for (idx, seed) in seeds.iter().enumerate() {
800 assert_eq!(seed.len(), 4);
801 assert!(
802 (seed[2] - 0.0).abs() < 1e-12 && (seed[3] - 0.0).abs() < 1e-12,
803 "seed {} has auxiliary dims [{}, {}], expected [0, 0]",
804 idx,
805 seed[2],
806 seed[3],
807 );
808 }
809 let has_nonzero_smoothing = seeds
811 .iter()
812 .any(|s| s[0].abs() > 1e-12 || s[1].abs() > 1e-12);
813 assert!(has_nonzero_smoothing);
814 }
815
816 #[test]
817 fn auxiliary_dims_dedup_collapses_identical_seeds() {
818 let cfg = SeedConfig {
821 num_auxiliary_trailing: 1,
822 max_seeds: 32,
823 risk_profile: SeedRiskProfile::GeneralizedLinear,
824 ..SeedConfig::default()
825 };
826 let seeds_with_aux = generate_rho_candidates(3, None, &cfg);
827 let cfg_no_aux = SeedConfig {
828 num_auxiliary_trailing: 0,
829 max_seeds: 32,
830 risk_profile: SeedRiskProfile::GeneralizedLinear,
831 ..SeedConfig::default()
832 };
833 let seeds_without_aux = generate_rho_candidates(3, None, &cfg_no_aux);
834 assert!(seeds_with_aux.len() <= seeds_without_aux.len());
836 }
837
838 #[test]
839 fn objective_grid_seed_selects_lowest_finite_cost_candidate() {
840 let base = Array1::from_vec(vec![0.0, 0.0]);
841 let selected =
842 select_objective_seed_on_log_lambda_grid(&base, (-12.0, 12.0), 2, &[], |rho| {
843 Some((rho[0] - 6.0).powi(2) + (rho[1] - 6.0).powi(2))
844 });
845
846 assert!((selected[0] - 6.0).abs() < 1e-12);
847 assert!((selected[1] - 6.0).abs() < 1e-12);
848 }
849
850 #[test]
851 fn objective_grid_seed_keeps_baseline_when_no_candidate_improves_cost() {
852 let base = Array1::from_vec(vec![1.0, -2.0]);
853 let selected =
854 select_objective_seed_on_log_lambda_grid(&base, (-12.0, 12.0), 2, &[], |rho| {
855 if (rho[0] - 1.0).abs() < 1e-12 && (rho[1] + 2.0).abs() < 1e-12 {
856 Some(0.0)
857 } else {
858 Some(1.0)
859 }
860 });
861
862 assert_eq!(selected, base);
863 }
864}