use crate::inference::row_measure::CoresetCertificate;
use crate::solver::evidence::{
GaussianMixtureConfig, StackingConfig, StackingWeights, TopologyScoreScale,
UNION_STRUCTURE_LADDER, UnionStructure, UnionStructureFit, fit_gaussian_mixture,
fit_union_ladder, fit_union_structure, solve_stacking_weights, union_per_point_log_density,
};
use crate::solver::priority_selection::{PriorityCandidate, rank_priority_candidates};
use ndarray::{Array2, ArrayView2};
use serde_json::Value as JsonValue;
use std::sync::Mutex;
use std::time::{Duration, Instant};
const TK_LOG_2PI: f64 = 1.8378770664093453_f64;
pub const MIXTURE_K_LADDER: &[usize] = &[1, 2, 3, 5, 7, 9];
pub const STACKING_CV_FOLDS: usize = 5;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AutoTopologyKind {
Euclidean,
Circle,
Sphere,
Torus,
Cylinder,
Mixture {
k: usize,
},
Union {
structure: UnionStructure,
},
}
impl AutoTopologyKind {
pub const fn as_str(self) -> &'static str {
match self {
AutoTopologyKind::Euclidean => "euclidean",
AutoTopologyKind::Circle => "circle",
AutoTopologyKind::Sphere => "sphere",
AutoTopologyKind::Torus => "torus",
AutoTopologyKind::Cylinder => "cylinder",
AutoTopologyKind::Mixture { .. } => "mixture",
AutoTopologyKind::Union { structure } => structure.as_str(),
}
}
pub fn display_name(self) -> String {
match self {
AutoTopologyKind::Mixture { k } => format!("mixture_k{k}"),
other => other.as_str().to_string(),
}
}
pub const fn is_discrete_mixture(self) -> bool {
matches!(self, AutoTopologyKind::Mixture { .. })
}
pub const fn is_structured_union(self) -> bool {
matches!(self, AutoTopologyKind::Union { .. })
}
pub const fn is_discrete_class(self) -> bool {
self.is_discrete_mixture() || self.is_structured_union()
}
pub fn parse(value: &str) -> Result<Self, String> {
let normalized = value.trim().to_ascii_lowercase().replace('-', "_");
if let Some(structure) = parse_union_name(&normalized) {
return Ok(AutoTopologyKind::Union { structure });
}
if let Some(rest) = normalized.strip_prefix("mixture") {
let digits: String = rest.chars().filter(|c| c.is_ascii_digit()).collect();
if digits.is_empty() {
return Ok(AutoTopologyKind::Mixture {
k: *MIXTURE_K_LADDER.last().unwrap_or(&7),
});
}
let k: usize = digits
.parse()
.map_err(|_| format!("mixture order must be a positive integer; got {value:?}"))?;
if k == 0 {
return Err("mixture order k must be >= 1".to_string());
}
return Ok(AutoTopologyKind::Mixture { k });
}
match normalized.as_str() {
"euclidean" | "flat" | "euclidean_patch" | "euclideanpatch" => {
Ok(AutoTopologyKind::Euclidean)
}
"circle" | "periodic" | "s1" => Ok(AutoTopologyKind::Circle),
"sphere" | "s2" => Ok(AutoTopologyKind::Sphere),
"torus" => Ok(AutoTopologyKind::Torus),
"cylinder" => Ok(AutoTopologyKind::Cylinder),
other => Err(format!(
"topology candidate must be euclidean, circle, sphere, torus, cylinder, mixture[_k{{n}}], or a union (union_circle+circle, union_circle+cluster, union_line+cluster); got {other:?}"
)),
}
}
pub fn all() -> Vec<Self> {
vec![
AutoTopologyKind::Euclidean,
AutoTopologyKind::Circle,
AutoTopologyKind::Sphere,
AutoTopologyKind::Torus,
AutoTopologyKind::Cylinder,
]
}
pub fn mixture_ladder() -> Vec<Self> {
MIXTURE_K_LADDER
.iter()
.map(|&k| AutoTopologyKind::Mixture { k })
.collect()
}
pub fn union_ladder() -> Vec<Self> {
UNION_STRUCTURE_LADDER
.iter()
.map(|&structure| AutoTopologyKind::Union { structure })
.collect()
}
}
pub fn parse_union_name(normalized: &str) -> Option<UnionStructure> {
let Some(rest) = normalized.strip_prefix("union") else {
return None;
};
let body: String = rest
.chars()
.map(|c| if c == '_' || c == '-' { '+' } else { c })
.collect();
let body = body.trim_matches('+');
match body {
"circle+circle" => Some(UnionStructure::CircleCircle),
"circle+cluster" | "circle+point+cluster" | "circle+pointcluster" => {
Some(UnionStructure::CirclePointCluster)
}
"line+cluster" | "line+point+cluster" | "line+pointcluster" => {
Some(UnionStructure::LineCluster)
}
_ => None,
}
}
#[derive(Debug, Clone)]
pub struct TopologyAutoSelector {
pub candidates: Vec<AutoTopologyKind>,
pub score_scale: TopologyScoreScale,
pub latent: Option<String>,
}
impl TopologyAutoSelector {
pub fn new(candidates: Option<Vec<AutoTopologyKind>>) -> Self {
Self {
candidates: candidates.unwrap_or_else(AutoTopologyKind::all),
score_scale: TopologyScoreScale::PerEffectiveDim,
latent: None,
}
}
pub fn from_json(value: &JsonValue) -> Result<Self, String> {
let obj = value
.as_object()
.ok_or_else(|| "topology_auto_selector must be an object".to_string())?;
let candidates = match obj.get("candidates").filter(|value| !value.is_null()) {
None => AutoTopologyKind::all(),
Some(raw) => {
let items = raw.as_array().ok_or_else(|| {
"topology_auto_selector.candidates must be a list".to_string()
})?;
if items.is_empty() {
return Err(
"topology_auto_selector.candidates must have at least one entry"
.to_string(),
);
}
let mut out = Vec::with_capacity(items.len());
for (idx, item) in items.iter().enumerate() {
let name = item.as_str().ok_or_else(|| {
format!("topology_auto_selector.candidates[{idx}] must be a string")
})?;
let kind = AutoTopologyKind::parse(name)?;
if out.contains(&kind) {
return Err(format!(
"topology_auto_selector duplicate candidate {:?}",
kind.as_str()
));
}
out.push(kind);
}
out
}
};
let score_scale = match obj
.get("score_scale")
.and_then(JsonValue::as_str)
.unwrap_or("per_effective_dim")
.trim()
.to_ascii_lowercase()
.replace('-', "_")
.as_str()
{
"per_observation" => TopologyScoreScale::PerObservation,
"per_effective_dim" => TopologyScoreScale::PerEffectiveDim,
other => {
return Err(format!(
"topology_auto_selector.score_scale must be per_effective_dim or per_observation; got {other:?}"
));
}
};
let latent = obj
.get("latent")
.filter(|value| !value.is_null())
.map(|value| {
value
.as_str()
.map(str::to_string)
.ok_or_else(|| "topology_auto_selector.latent must be a string".to_string())
})
.transpose()?;
Ok(Self {
candidates,
score_scale,
latent,
})
}
}
#[derive(Debug, Clone)]
pub struct TopologyAutoFitEvidence<FitHandle> {
pub topology_name: String,
pub raw_reml: f64,
pub null_dim: f64,
pub null_space_logdet: Option<f64>,
pub effective_dim: f64,
pub n_obs: usize,
pub fit_handle: FitHandle,
}
#[derive(Debug, Clone)]
pub struct TopologyAutoRankedFit<FitHandle> {
pub topology_name: String,
pub tk_score: f64,
pub raw_reml: f64,
pub effective_dim: f64,
pub n_obs: usize,
pub fit_handle: FitHandle,
}
#[derive(Debug, Clone)]
pub struct TopologyAutoSelectorResult<FitHandle> {
pub ranked: Vec<TopologyAutoRankedFit<FitHandle>>,
pub winner_index: usize,
}
impl<FitHandle> TopologyAutoSelectorResult<FitHandle> {
pub fn winner(&self) -> Option<&TopologyAutoRankedFit<FitHandle>> {
self.ranked.get(self.winner_index)
}
}
#[derive(Debug, Clone)]
pub struct TopologyRaceParallelCandidate<FitResult> {
pub candidate_index: usize,
pub per_fit_threads: usize,
pub wall_time: Duration,
pub result: FitResult,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct TopologyRaceThreadPlan {
coordinator_threads: usize,
per_fit_threads: usize,
concurrent_fits: usize,
}
impl TopologyRaceThreadPlan {
fn for_budget(candidate_count: usize, max_total_threads: usize) -> Self {
let max_total_threads = max_total_threads.max(1);
if candidate_count <= 1 {
return Self {
coordinator_threads: 0,
per_fit_threads: max_total_threads,
concurrent_fits: candidate_count,
};
}
let concurrent_fits = if max_total_threads >= 4 {
candidate_count.min(max_total_threads / 2).max(1)
} else {
1
};
let coordinator_threads = concurrent_fits;
let remaining = max_total_threads.saturating_sub(coordinator_threads);
let per_fit_threads = if remaining == 0 {
1
} else {
(remaining / concurrent_fits).max(1)
};
Self {
coordinator_threads,
per_fit_threads,
concurrent_fits,
}
}
}
pub fn run_topology_race_parallel<Candidate, FitResult, FitOne>(
candidates: Vec<Candidate>,
fit_one: FitOne,
) -> Result<Vec<TopologyRaceParallelCandidate<FitResult>>, String>
where
Candidate: Send,
FitResult: Send,
FitOne: Fn(Candidate) -> FitResult + Sync,
{
let max_total_threads = std::thread::available_parallelism()
.map(std::num::NonZeroUsize::get)
.unwrap_or(1);
run_topology_race_parallel_with_budget(candidates, fit_one, max_total_threads)
}
fn run_topology_race_parallel_with_budget<Candidate, FitResult, FitOne>(
candidates: Vec<Candidate>,
fit_one: FitOne,
max_total_threads: usize,
) -> Result<Vec<TopologyRaceParallelCandidate<FitResult>>, String>
where
Candidate: Send,
FitResult: Send,
FitOne: Fn(Candidate) -> FitResult + Sync,
{
let candidate_count = candidates.len();
if candidate_count == 0 {
return Ok(Vec::new());
}
let plan = TopologyRaceThreadPlan::for_budget(candidate_count, max_total_threads);
let mut candidates: Vec<Option<Candidate>> = candidates.into_iter().map(Some).collect();
let slots: Vec<Mutex<Option<TopologyRaceParallelCandidate<FitResult>>>> =
(0..candidate_count).map(|_| Mutex::new(None)).collect();
let pool_error: Mutex<Option<String>> = Mutex::new(None);
if plan.concurrent_fits <= 1 {
for idx in 0..candidate_count {
let candidate = candidates[idx]
.take()
.expect("topology race candidate must be present");
run_one_topology_race_candidate(
idx,
candidate,
&fit_one,
plan.per_fit_threads,
&slots[idx],
&pool_error,
);
if let Some(err) = pool_error.lock().expect("pool_error mutex poisoned").take() {
return Err(err);
}
}
} else {
let coordinator_pool = rayon::ThreadPoolBuilder::new()
.num_threads(plan.coordinator_threads)
.thread_name(|idx| format!("topology-race-coordinator-{idx}"))
.build()
.map_err(|err| format!("topology race coordinator Rayon pool: {err}"))?;
let mut batch_start = 0usize;
while batch_start < candidate_count {
let batch_end = (batch_start + plan.concurrent_fits).min(candidate_count);
coordinator_pool.scope(|scope| {
for idx in batch_start..batch_end {
let candidate = candidates[idx]
.take()
.expect("topology race candidate must be present");
let slot = &slots[idx];
let pool_error = &pool_error;
let fit_one = &fit_one;
scope.spawn(move |_| {
run_one_topology_race_candidate(
idx,
candidate,
fit_one,
plan.per_fit_threads,
slot,
pool_error,
);
});
}
});
if let Some(err) = pool_error.lock().expect("pool_error mutex poisoned").take() {
return Err(err);
}
batch_start = batch_end;
}
}
let mut out = Vec::with_capacity(candidate_count);
for (idx, slot) in slots.into_iter().enumerate() {
let row = slot
.into_inner()
.expect("topology race result mutex poisoned")
.ok_or_else(|| format!("topology race candidate {idx} did not produce a result"))?;
out.push(row);
}
Ok(out)
}
fn run_one_topology_race_candidate<Candidate, FitResult, FitOne>(
candidate_index: usize,
candidate: Candidate,
fit_one: &FitOne,
per_fit_threads: usize,
slot: &Mutex<Option<TopologyRaceParallelCandidate<FitResult>>>,
pool_error: &Mutex<Option<String>>,
) where
Candidate: Send,
FitResult: Send,
FitOne: Fn(Candidate) -> FitResult + Sync,
{
let pool = match rayon::ThreadPoolBuilder::new()
.num_threads(per_fit_threads)
.thread_name(move |idx| format!("topology-race-fit-{candidate_index}-{idx}"))
.build()
{
Ok(pool) => pool,
Err(err) => {
*pool_error.lock().expect("pool_error mutex poisoned") =
Some(format!("topology race candidate Rayon pool: {err}"));
return;
}
};
let started = Instant::now();
let result = pool.install(|| fit_one(candidate));
let wall_time = started.elapsed();
*slot.lock().expect("topology race result mutex poisoned") =
Some(TopologyRaceParallelCandidate {
candidate_index,
per_fit_threads,
wall_time,
result,
});
}
pub fn select_topology_with_fit<FitHandle, FitErr>(
selector: &TopologyAutoSelector,
mut fit_one: impl FnMut(AutoTopologyKind) -> Result<TopologyAutoFitEvidence<FitHandle>, FitErr>,
) -> Result<TopologyAutoSelectorResult<FitHandle>, String>
where
FitErr: ToString,
{
let mut ranked = Vec::with_capacity(selector.candidates.len());
let mut errors = Vec::new();
for candidate in &selector.candidates {
match fit_one(*candidate) {
Ok(evidence) => {
let tk_score = tk_normalized_score(
evidence.raw_reml,
evidence.null_dim,
evidence.null_space_logdet,
evidence.effective_dim,
evidence.n_obs,
selector.score_scale,
)?;
ranked.push(TopologyAutoRankedFit {
topology_name: evidence.topology_name,
tk_score,
raw_reml: evidence.raw_reml,
effective_dim: evidence.effective_dim,
n_obs: evidence.n_obs,
fit_handle: evidence.fit_handle,
});
}
Err(err) => errors.push(format!("{}: {}", candidate.as_str(), err.to_string())),
}
}
if ranked.is_empty() {
return Err(format!(
"TopologyAutoSelector found no fittable topology candidates{}",
if errors.is_empty() {
String::new()
} else {
format!(" ({})", errors.join("; "))
}
));
}
ranked = rank_priority_candidates(
ranked
.into_iter()
.enumerate()
.map(|(idx, row)| {
let score = row.tk_score;
PriorityCandidate::new(row, idx, score, 0)
})
.collect(),
)
.into_iter()
.map(|row| row.item)
.collect();
Ok(TopologyAutoSelectorResult {
ranked,
winner_index: 0,
})
}
pub fn select_topology_with_fit_parallel<FitHandle, FitErr>(
selector: &TopologyAutoSelector,
fit_one: impl Fn(AutoTopologyKind) -> Result<TopologyAutoFitEvidence<FitHandle>, FitErr> + Sync,
) -> Result<TopologyAutoSelectorResult<FitHandle>, String>
where
FitHandle: Send,
FitErr: ToString + Send,
{
let candidates: Vec<AutoTopologyKind> = selector.candidates.clone();
let race = run_topology_race_parallel(candidates, |candidate| {
(candidate, fit_one(candidate))
})?;
let mut ranked = Vec::with_capacity(race.len());
let mut errors = Vec::new();
for entry in race {
let (candidate, fit_result) = entry.result;
match fit_result {
Ok(evidence) => {
let tk_score = tk_normalized_score(
evidence.raw_reml,
evidence.null_dim,
evidence.null_space_logdet,
evidence.effective_dim,
evidence.n_obs,
selector.score_scale,
)?;
ranked.push(TopologyAutoRankedFit {
topology_name: evidence.topology_name,
tk_score,
raw_reml: evidence.raw_reml,
effective_dim: evidence.effective_dim,
n_obs: evidence.n_obs,
fit_handle: evidence.fit_handle,
});
}
Err(err) => errors.push(format!("{}: {}", candidate.as_str(), err.to_string())),
}
}
if ranked.is_empty() {
return Err(format!(
"TopologyAutoSelector found no fittable topology candidates{}",
if errors.is_empty() {
String::new()
} else {
format!(" ({})", errors.join("; "))
}
));
}
ranked = rank_priority_candidates(
ranked
.into_iter()
.enumerate()
.map(|(idx, row)| {
let score = row.tk_score;
PriorityCandidate::new(row, idx, score, 0)
})
.collect(),
)
.into_iter()
.map(|row| row.item)
.collect();
Ok(TopologyAutoSelectorResult {
ranked,
winner_index: 0,
})
}
pub fn tk_normalized_score(
raw_reml: f64,
null_dim: f64,
null_space_logdet: Option<f64>,
effective_dim: f64,
n_obs: usize,
score_scale: TopologyScoreScale,
) -> Result<f64, String> {
if !(raw_reml.is_finite() && null_dim.is_finite()) || null_dim < -1.0e-9 {
return Err("TopologyAutoSelector received non-finite TK evidence inputs".to_string());
}
let normalizer = if null_dim.max(0.0) == 0.0 {
0.0
} else {
let logdet = null_space_logdet.ok_or_else(|| {
"TopologyAutoSelector TK normalizer requires null-space Hessian logdet".to_string()
})?;
if !logdet.is_finite() {
return Err("TopologyAutoSelector null-space Hessian logdet is not finite".to_string());
}
-0.5 * null_dim.max(0.0) * TK_LOG_2PI + 0.5 * logdet
};
let tk = raw_reml + normalizer;
match score_scale {
TopologyScoreScale::PerObservation => {
if n_obs == 0 {
Err("TopologyAutoSelector requires n_obs > 0".to_string())
} else {
Ok(tk / n_obs as f64)
}
}
TopologyScoreScale::PerEffectiveDim => {
if !(effective_dim.is_finite() && effective_dim > 0.0) {
Err("TopologyAutoSelector requires finite positive effective_dim".to_string())
} else {
Ok(tk / effective_dim)
}
}
}
}
#[derive(Debug, Clone)]
pub struct MixtureRungFit {
pub k: usize,
pub fit: crate::solver::evidence::GaussianMixtureFit,
pub num_parameters: usize,
pub negative_log_evidence: f64,
}
#[derive(Debug, Clone)]
pub struct MixtureRungResult {
pub fits: Vec<MixtureRungFit>,
pub winner_index: usize,
}
impl MixtureRungResult {
pub fn winner(&self) -> &MixtureRungFit {
&self.fits[self.winner_index]
}
}
pub const MIXTURE_REFINEMENT_MAX_PROBES: usize = 16;
pub fn fit_mixture_rung(
data: ArrayView2<'_, f64>,
ladder: &[usize],
config: GaussianMixtureConfig,
) -> Result<MixtureRungResult, String> {
let n = data.nrows();
let mut fits: Vec<MixtureRungFit> = Vec::new();
let mut errors: Vec<String> = Vec::new();
let mut attempted: std::collections::BTreeSet<usize> = std::collections::BTreeSet::new();
let try_order = |k: usize,
fits: &mut Vec<MixtureRungFit>,
errors: &mut Vec<String>,
attempted: &mut std::collections::BTreeSet<usize>| {
if k == 0 || k > n || !attempted.insert(k) {
return;
}
match fit_gaussian_mixture(data, k, config) {
Ok(fit) => match fit.laplace_negative_log_evidence(data) {
Ok(nle) => {
let num_parameters = fit.num_free_parameters();
fits.push(MixtureRungFit {
k,
fit,
num_parameters,
negative_log_evidence: nle,
});
}
Err(e) => errors.push(format!("mixture k={k} evidence: {e}")),
},
Err(e) => errors.push(format!("mixture k={k} fit: {e}")),
}
};
for &k in ladder {
try_order(k, &mut fits, &mut errors, &mut attempted);
}
if fits.is_empty() {
return Err(format!(
"mixture rung produced no fittable orders{}",
if errors.is_empty() {
String::new()
} else {
format!(" ({})", errors.join("; "))
}
));
}
let mut probes = 0usize;
while probes < MIXTURE_REFINEMENT_MAX_PROBES {
let best_k = fits
.iter()
.min_by(|a, b| {
a.negative_log_evidence
.partial_cmp(&b.negative_log_evidence)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.k.cmp(&b.k))
})
.map(|f| f.k)
.unwrap_or(1);
let next = [best_k.saturating_sub(1), best_k + 1]
.into_iter()
.find(|&k| k >= 1 && k <= n && !attempted.contains(&k));
let Some(k) = next else {
break; };
try_order(k, &mut fits, &mut errors, &mut attempted);
probes += 1;
}
let ranked = rank_priority_candidates(
fits.into_iter()
.enumerate()
.map(|(idx, row)| {
let score = row.negative_log_evidence;
let tie = row.k; PriorityCandidate::new(row, idx, score, tie)
})
.collect(),
)
.into_iter()
.map(|row| row.item)
.collect::<Vec<_>>();
Ok(MixtureRungResult {
fits: ranked,
winner_index: 0,
})
}
#[derive(Debug, Clone)]
pub struct UnionRungFit {
pub structure: UnionStructure,
pub fit: UnionStructureFit,
pub total_parameters: usize,
pub negative_log_evidence: f64,
}
#[derive(Debug, Clone)]
pub struct UnionRungResult {
pub fits: Vec<UnionRungFit>,
pub winner_index: usize,
}
impl UnionRungResult {
pub fn winner(&self) -> &UnionRungFit {
&self.fits[self.winner_index]
}
}
pub fn fit_union_rung(
data: ArrayView2<'_, f64>,
config: GaussianMixtureConfig,
) -> Result<UnionRungResult, String> {
let ladder = fit_union_ladder(data, config)?;
let fits: Vec<UnionRungFit> = ladder
.into_iter()
.map(|fit| UnionRungFit {
structure: fit.structure,
total_parameters: fit.total_parameters,
negative_log_evidence: fit.negative_log_evidence,
fit,
})
.collect();
if fits.is_empty() {
return Err("union rung produced no fittable composites".to_string());
}
Ok(UnionRungResult {
fits,
winner_index: 0,
})
}
pub fn fit_union_candidate(
data: ArrayView2<'_, f64>,
structure: UnionStructure,
config: GaussianMixtureConfig,
) -> Result<UnionRungFit, String> {
let fit = fit_union_structure(data, structure, config)?;
Ok(UnionRungFit {
structure: fit.structure,
total_parameters: fit.total_parameters,
negative_log_evidence: fit.negative_log_evidence,
fit,
})
}
pub type HeldOutDensityProvider<'a> =
Box<dyn Fn(&[usize], &[usize]) -> Result<Vec<f64>, String> + 'a>;
pub fn mixture_density_provider<'a>(
data: ArrayView2<'a, f64>,
k: usize,
config: GaussianMixtureConfig,
) -> HeldOutDensityProvider<'a> {
let owned = data.to_owned();
Box::new(
move |train: &[usize], eval: &[usize]| -> Result<Vec<f64>, String> {
let train_mat = gather_rows(owned.view(), train);
let fit = fit_gaussian_mixture(train_mat.view(), k.min(train.len().max(1)), config)?;
let eval_mat = gather_rows(owned.view(), eval);
let dens = fit.per_point_log_density(eval_mat.view())?;
Ok(dens.to_vec())
},
)
}
pub fn union_density_provider<'a>(
data: ArrayView2<'a, f64>,
structure: UnionStructure,
config: GaussianMixtureConfig,
) -> HeldOutDensityProvider<'a> {
let owned = data.to_owned();
Box::new(
move |train: &[usize], eval: &[usize]| -> Result<Vec<f64>, String> {
let train_mat = gather_rows(owned.view(), train);
let eval_mat = gather_rows(owned.view(), eval);
let dens =
union_per_point_log_density(train_mat.view(), eval_mat.view(), structure, config)?;
Ok(dens.to_vec())
},
)
}
fn gather_rows(data: ArrayView2<'_, f64>, idx: &[usize]) -> Array2<f64> {
let d = data.ncols();
let mut out = Array2::<f64>::zeros((idx.len(), d));
for (r, &i) in idx.iter().enumerate() {
for c in 0..d {
out[[r, c]] = data[[i, c]];
}
}
out
}
pub fn deterministic_cv_folds(n: usize, folds: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
let folds = folds.clamp(2, n.max(2));
let mut out = Vec::with_capacity(folds);
for f in 0..folds {
let mut train = Vec::new();
let mut eval = Vec::new();
for i in 0..n {
if i % folds == f {
eval.push(i);
} else {
train.push(i);
}
}
if !eval.is_empty() && !train.is_empty() {
out.push((train, eval));
}
}
out
}
pub fn build_cv_log_density_table(
n: usize,
folds: usize,
providers: &[HeldOutDensityProvider<'_>],
) -> Result<Array2<f64>, String> {
if providers.is_empty() {
return Err("stacking table requires at least one candidate provider".to_string());
}
let partition = deterministic_cv_folds(n, folds);
if partition.is_empty() {
return Err("stacking CV partition is empty (n too small for folds)".to_string());
}
let mut table = Array2::<f64>::from_elem((n, providers.len()), f64::NEG_INFINITY);
for (train, eval) in &partition {
for (col, provider) in providers.iter().enumerate() {
let dens = provider(train, eval)?;
if dens.len() != eval.len() {
return Err(format!(
"provider {col} returned {} densities for {} eval rows",
dens.len(),
eval.len()
));
}
for (slot, &row) in eval.iter().enumerate() {
table[[row, col]] = dens[slot];
}
}
}
Ok(table)
}
#[derive(Debug, Clone)]
pub struct CrossClassRaceVerdict {
pub candidate_names: Vec<String>,
pub is_cross_class: bool,
pub negative_log_evidence: Vec<f64>,
pub stacking: Option<StackingWeights>,
pub winner_index: usize,
pub headline: Headline,
pub insufficient_margin: Option<InsufficientRaceMargin>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Headline {
Evidence,
Stacking,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum EvidenceCertification {
Exact,
Enclosure { gap: f64 },
Coreset { certificate: CoresetCertificate },
}
impl EvidenceCertification {
pub fn required_margin(&self) -> f64 {
match self {
EvidenceCertification::Exact => 0.0,
EvidenceCertification::Enclosure { gap } => *gap,
EvidenceCertification::Coreset { certificate } => certificate.race_transfer_margin(),
}
}
pub fn race_verdict(&self, race_lead: f64) -> crate::inference::certificates::Verdict {
use crate::inference::certificates::Verdict;
if !(race_lead.is_finite() && race_lead > 0.0) {
return Verdict::Insufficient;
}
match self {
EvidenceCertification::Exact => Verdict::Certified,
EvidenceCertification::Enclosure { gap } => {
let enclosure = crate::solver::logdet_bounds::LogdetEnclosure {
block_diag_logdet: 0.0,
lower: 0.0,
upper: *gap,
rho: 0.0,
p2: 0.0,
p3: None,
};
crate::inference::certificate_impls::enclosure_margin_verdict(&enclosure, race_lead)
}
EvidenceCertification::Coreset { certificate } => {
crate::inference::certificate_impls::coreset_race_verdict(
certificate.certify_margin(race_lead),
)
}
}
}
}
pub struct CrossClassCandidate<'a> {
pub kind: AutoTopologyKind,
pub negative_log_evidence: f64,
pub certification: EvidenceCertification,
pub density_provider: HeldOutDensityProvider<'a>,
}
impl<'a> CrossClassCandidate<'a> {
pub fn exact(
kind: AutoTopologyKind,
negative_log_evidence: f64,
density_provider: HeldOutDensityProvider<'a>,
) -> Self {
Self {
kind,
negative_log_evidence,
certification: EvidenceCertification::Exact,
density_provider,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct InsufficientRaceMargin {
pub provisional_winner: usize,
pub contender: usize,
pub lead: f64,
pub required_margin: f64,
}
pub fn adjudicate_cross_class_race(
n: usize,
candidates: Vec<CrossClassCandidate<'_>>,
folds: usize,
stacking_config: StackingConfig,
) -> Result<CrossClassRaceVerdict, String> {
if candidates.is_empty() {
return Err("cross-class race requires at least one candidate".to_string());
}
let names: Vec<String> = candidates.iter().map(|c| c.kind.display_name()).collect();
let evidence: Vec<f64> = candidates.iter().map(|c| c.negative_log_evidence).collect();
let has_discrete = candidates.iter().any(|c| c.kind.is_discrete_class());
let has_smooth = candidates.iter().any(|c| !c.kind.is_discrete_class());
let is_cross_class = has_discrete && has_smooth;
if !is_cross_class {
let certifications: Vec<EvidenceCertification> =
candidates.iter().map(|c| c.certification).collect();
let mut winner_index = 0usize;
let mut best = f64::INFINITY;
for (idx, &nle) in evidence.iter().enumerate() {
if nle.is_finite() && nle < best {
best = nle;
winner_index = idx;
}
}
let mut insufficient_margin: Option<InsufficientRaceMargin> = None;
for (idx, &nle) in evidence.iter().enumerate() {
if idx == winner_index || !nle.is_finite() {
continue;
}
let lead = nle - best;
let required = certifications[winner_index]
.required_margin()
.max(certifications[idx].required_margin());
if required > 0.0 && lead <= required {
let tighter = insufficient_margin.map(|m| lead < m.lead).unwrap_or(true);
if tighter {
insufficient_margin = Some(InsufficientRaceMargin {
provisional_winner: winner_index,
contender: idx,
lead,
required_margin: required,
});
}
}
}
return Ok(CrossClassRaceVerdict {
candidate_names: names,
is_cross_class: false,
negative_log_evidence: evidence,
stacking: None,
winner_index,
headline: Headline::Evidence,
insufficient_margin,
});
}
let providers: Vec<HeldOutDensityProvider<'_>> =
candidates.into_iter().map(|c| c.density_provider).collect();
let table = build_cv_log_density_table(n, folds, &providers)?;
let stacking = solve_stacking_weights(table.view(), stacking_config)?;
let mut winner_index = 0usize;
let mut best_w = f64::NEG_INFINITY;
for (idx, &w) in stacking.weights.iter().enumerate() {
if w > best_w {
best_w = w;
winner_index = idx;
}
}
Ok(CrossClassRaceVerdict {
candidate_names: names,
is_cross_class: true,
negative_log_evidence: evidence,
stacking: Some(stacking),
winner_index,
headline: Headline::Stacking,
insufficient_margin: None,
})
}
pub const CLOSURE_GAMMA_GRID: &[f64] = &[
0.0, 0.02, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.98, 1.0,
];
#[derive(Debug, Clone)]
pub struct ClosureProfilePoint<FitHandle> {
pub gamma: f64,
pub tk_score: f64,
pub fit_handle: FitHandle,
}
#[derive(Debug, Clone)]
pub struct ClosureSelection<FitHandle> {
pub ci: crate::geometry::ClosureProfileCi,
pub representative: ClosureProfilePoint<FitHandle>,
pub route_to_mixture_rung: bool,
}
pub fn profile_closure_within_smooth_class<FitHandle, FitAtGamma>(
fit_at_gamma: FitAtGamma,
level: f64,
) -> Result<ClosureSelection<FitHandle>, String>
where
FitHandle: Send,
FitAtGamma: Fn(f64) -> Result<(f64, FitHandle), String> + Sync,
{
let gammas: Vec<f64> = CLOSURE_GAMMA_GRID.to_vec();
let results = run_topology_race_parallel(gammas.clone(), |gamma| {
fit_at_gamma(gamma).map(|(score, handle)| (gamma, score, handle))
})?;
let mut points: Vec<ClosureProfilePoint<FitHandle>> = Vec::with_capacity(results.len());
for entry in results {
let (gamma, tk_score, fit_handle) = entry.result?;
if !tk_score.is_finite() {
return Err(format!(
"closure profile produced non-finite score at γ={gamma}"
));
}
points.push(ClosureProfilePoint {
gamma,
tk_score,
fit_handle,
});
}
if points.len() < 2 {
return Err("closure profile needs at least two evaluable γ grid points".into());
}
points.sort_by(|a, b| a.gamma.partial_cmp(&b.gamma).expect("finite γ"));
let grid: Vec<(f64, f64)> = points.iter().map(|p| (p.gamma, p.tk_score)).collect();
let ci = crate::geometry::profile_ci_from_grid(&grid, level)?;
let hat_index = points
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.tk_score.partial_cmp(&b.tk_score).expect("finite score"))
.map(|(i, _)| i)
.expect("non-empty points");
let representative = points.swap_remove(hat_index);
Ok(ClosureSelection {
ci,
representative,
route_to_mixture_rung: ci.singular_boundary,
})
}
#[cfg(test)]
mod tests {
use super::*;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
#[derive(Clone)]
struct SyntheticRaceCandidate {
seed: u64,
len: usize,
}
fn synthetic_fit(candidate: SyntheticRaceCandidate) -> Vec<u64> {
(0..candidate.len)
.into_par_iter()
.map(|i| {
let x = candidate.seed ^ ((i as u64 + 1) * 0x9e37_79b9_7f4a_7c15);
x.rotate_left((i % 31) as u32)
.wrapping_mul(0xbf58_476d_1ce4_e5b9)
})
.collect()
}
#[test]
fn topology_race_parallel_matches_sequential_synthetic_candidates() {
let candidates = vec![
SyntheticRaceCandidate { seed: 11, len: 64 },
SyntheticRaceCandidate { seed: 29, len: 64 },
SyntheticRaceCandidate { seed: 47, len: 64 },
];
let sequential = candidates
.iter()
.cloned()
.map(synthetic_fit)
.collect::<Vec<_>>();
let parallel =
run_topology_race_parallel_with_budget(candidates, synthetic_fit, 8).unwrap();
assert_eq!(parallel.len(), 3);
assert_eq!(
parallel
.iter()
.map(|row| row.candidate_index)
.collect::<Vec<_>>(),
vec![0, 1, 2]
);
assert!(parallel.iter().all(|row| row.per_fit_threads == 1));
let wall_times = parallel.iter().map(|row| row.wall_time).collect::<Vec<_>>();
assert_eq!(wall_times.len(), 3);
assert_eq!(
parallel
.into_iter()
.map(|row| row.result)
.collect::<Vec<_>>(),
sequential
);
}
fn trivial_provider<'a>() -> HeldOutDensityProvider<'a> {
Box::new(|_train: &[usize], eval: &[usize]| Ok(vec![0.0; eval.len()]))
}
#[test]
fn same_class_race_respects_enclosure_decision_margin() {
let near = vec![
CrossClassCandidate {
kind: AutoTopologyKind::Circle,
negative_log_evidence: 100.0,
certification: EvidenceCertification::Enclosure { gap: 1.0 },
density_provider: trivial_provider(),
},
CrossClassCandidate {
kind: AutoTopologyKind::Euclidean,
negative_log_evidence: 100.5,
certification: EvidenceCertification::Enclosure { gap: 1.0 },
density_provider: trivial_provider(),
},
];
let verdict =
adjudicate_cross_class_race(8, near, STACKING_CV_FOLDS, StackingConfig::default())
.expect("same-class race");
assert!(!verdict.is_cross_class);
assert_eq!(verdict.winner_index, 0);
let escalation = verdict
.insufficient_margin
.expect("lead inside the enclosure gap must be flagged provisional");
assert_eq!(escalation.provisional_winner, 0);
assert_eq!(escalation.contender, 1);
assert!((escalation.lead - 0.5).abs() < 1e-12);
assert!((escalation.required_margin - 1.0).abs() < 1e-12);
let far = vec![
CrossClassCandidate {
kind: AutoTopologyKind::Circle,
negative_log_evidence: 100.0,
certification: EvidenceCertification::Enclosure { gap: 1.0 },
density_provider: trivial_provider(),
},
CrossClassCandidate {
kind: AutoTopologyKind::Euclidean,
negative_log_evidence: 105.0,
certification: EvidenceCertification::Enclosure { gap: 1.0 },
density_provider: trivial_provider(),
},
];
let verdict_far =
adjudicate_cross_class_race(8, far, STACKING_CV_FOLDS, StackingConfig::default())
.expect("same-class race");
assert_eq!(verdict_far.winner_index, 0);
assert!(
verdict_far.insufficient_margin.is_none(),
"a lead clearing the enclosure gap must transfer the verdict"
);
}
#[test]
fn same_class_race_respects_coreset_transfer_margin() {
let cert = CoresetCertificate::new(0.05, 0.1, 32, 1000).expect("certificate");
let required = cert.race_transfer_margin();
let lead = 0.5 * required;
let candidates = vec![
CrossClassCandidate {
kind: AutoTopologyKind::Circle,
negative_log_evidence: 10.0,
certification: EvidenceCertification::Coreset { certificate: cert },
density_provider: trivial_provider(),
},
CrossClassCandidate {
kind: AutoTopologyKind::Euclidean,
negative_log_evidence: 10.0 + lead,
certification: EvidenceCertification::Coreset { certificate: cert },
density_provider: trivial_provider(),
},
];
let verdict = adjudicate_cross_class_race(
8,
candidates,
STACKING_CV_FOLDS,
StackingConfig::default(),
)
.expect("same-class race");
let escalation = verdict
.insufficient_margin
.expect("lead inside the coreset transfer margin must be flagged");
assert!((escalation.required_margin - required).abs() < 1e-9);
}
#[test]
fn race_verdict_maps_onto_unified_ladder() {
use crate::inference::certificates::Verdict;
assert_eq!(
EvidenceCertification::Exact.race_verdict(1e-6),
Verdict::Certified
);
assert_eq!(
EvidenceCertification::Exact.race_verdict(0.0),
Verdict::Insufficient
);
let enc = EvidenceCertification::Enclosure { gap: 0.2 };
assert_eq!(enc.race_verdict(0.5), Verdict::Certified);
assert_eq!(enc.race_verdict(0.1), Verdict::Insufficient);
let cert = CoresetCertificate::new(0.05, 0.1, 32, 1000).expect("certificate");
let required = cert.race_transfer_margin();
let coreset = EvidenceCertification::Coreset { certificate: cert };
assert_eq!(coreset.race_verdict(0.5 * required), Verdict::Insufficient);
assert_eq!(
coreset.race_verdict(2.0 * required + 1.0),
Verdict::Certified
);
}
#[test]
fn closure_profiler_recovers_interior_minimum_and_ci() {
let selection = profile_closure_within_smooth_class(
|gamma| Ok::<_, String>((100.0 + 80.0 * (gamma - 0.7).powi(2), gamma)),
0.95,
)
.expect("closure profile");
assert!(
(selection.ci.gamma_hat - 0.7).abs() < 0.06,
"γ̂ {}",
selection.ci.gamma_hat
);
assert!(!selection.ci.ci_includes_circle);
assert!(!selection.ci.ci_includes_interval);
assert!(!selection.route_to_mixture_rung);
assert!((selection.representative.gamma - selection.ci.gamma_hat).abs() < 1e-12);
}
#[test]
fn closure_profiler_routes_collapse_to_mixture_rung() {
let selection = profile_closure_within_smooth_class(
|gamma| Ok::<_, String>((10.0 + 25.0 * gamma, gamma)),
0.95,
)
.expect("closure profile");
assert!(selection.ci.gamma_hat.abs() < 1e-9);
assert!(selection.route_to_mixture_rung);
assert!(selection.ci.ci_includes_interval);
}
#[test]
fn topology_race_thread_plan_bounds_nested_rayon_threads() {
let plan = TopologyRaceThreadPlan::for_budget(3, 8);
assert_eq!(plan.concurrent_fits, 3);
assert!(
plan.coordinator_threads + plan.concurrent_fits * plan.per_fit_threads <= 8,
"plan must bound coordinator plus per-fit Rayon workers"
);
let small = TopologyRaceThreadPlan::for_budget(3, 2);
assert_eq!(small.concurrent_fits, 1);
assert!(small.coordinator_threads + small.per_fit_threads <= 2);
}
}