use crate::solver::evidence::{
GaussianMixtureConfig, StackingConfig, StackingWeights, TopologyScoreScale,
UNION_STRUCTURE_LADDER, UnionStructure, UnionStructureFit, fit_gaussian_mixture,
fit_union_ladder, fit_union_structure, solve_stacking_weights, union_per_point_log_density,
};
use crate::solver::priority_selection::{PriorityCandidate, rank_priority_candidates};
use ndarray::{Array2, ArrayView2};
use serde_json::Value as JsonValue;
const TK_LOG_2PI: f64 = 1.8378770664093453_f64;
pub const MIXTURE_K_LADDER: &[usize] = &[1, 2, 3, 5, 7, 9];
pub const STACKING_CV_FOLDS: usize = 5;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AutoTopologyKind {
Euclidean,
Circle,
Sphere,
Torus,
Cylinder,
Mixture {
k: usize,
},
Union {
structure: UnionStructure,
},
}
impl AutoTopologyKind {
pub const fn as_str(self) -> &'static str {
match self {
AutoTopologyKind::Euclidean => "euclidean",
AutoTopologyKind::Circle => "circle",
AutoTopologyKind::Sphere => "sphere",
AutoTopologyKind::Torus => "torus",
AutoTopologyKind::Cylinder => "cylinder",
AutoTopologyKind::Mixture { .. } => "mixture",
AutoTopologyKind::Union { structure } => structure.as_str(),
}
}
pub fn display_name(self) -> String {
match self {
AutoTopologyKind::Mixture { k } => format!("mixture_k{k}"),
other => other.as_str().to_string(),
}
}
pub const fn is_discrete_mixture(self) -> bool {
matches!(self, AutoTopologyKind::Mixture { .. })
}
pub const fn is_structured_union(self) -> bool {
matches!(self, AutoTopologyKind::Union { .. })
}
pub const fn is_discrete_class(self) -> bool {
self.is_discrete_mixture() || self.is_structured_union()
}
pub fn parse(value: &str) -> Result<Self, String> {
let normalized = value.trim().to_ascii_lowercase().replace('-', "_");
if let Some(structure) = parse_union_name(&normalized) {
return Ok(AutoTopologyKind::Union { structure });
}
if let Some(rest) = normalized.strip_prefix("mixture") {
let digits: String = rest.chars().filter(|c| c.is_ascii_digit()).collect();
if digits.is_empty() {
return Ok(AutoTopologyKind::Mixture {
k: *MIXTURE_K_LADDER.last().unwrap_or(&7),
});
}
let k: usize = digits
.parse()
.map_err(|_| format!("mixture order must be a positive integer; got {value:?}"))?;
if k == 0 {
return Err("mixture order k must be >= 1".to_string());
}
return Ok(AutoTopologyKind::Mixture { k });
}
match normalized.as_str() {
"euclidean" | "flat" | "euclidean_patch" | "euclideanpatch" => {
Ok(AutoTopologyKind::Euclidean)
}
"circle" | "periodic" | "s1" => Ok(AutoTopologyKind::Circle),
"sphere" | "s2" => Ok(AutoTopologyKind::Sphere),
"torus" => Ok(AutoTopologyKind::Torus),
"cylinder" => Ok(AutoTopologyKind::Cylinder),
other => Err(format!(
"topology candidate must be euclidean, circle, sphere, torus, cylinder, mixture[_k{{n}}], or a union (union_circle+circle, union_circle+cluster, union_line+cluster); got {other:?}"
)),
}
}
pub fn all() -> Vec<Self> {
vec![
AutoTopologyKind::Euclidean,
AutoTopologyKind::Circle,
AutoTopologyKind::Sphere,
AutoTopologyKind::Torus,
AutoTopologyKind::Cylinder,
]
}
pub fn mixture_ladder() -> Vec<Self> {
MIXTURE_K_LADDER
.iter()
.map(|&k| AutoTopologyKind::Mixture { k })
.collect()
}
pub fn union_ladder() -> Vec<Self> {
UNION_STRUCTURE_LADDER
.iter()
.map(|&structure| AutoTopologyKind::Union { structure })
.collect()
}
}
pub fn parse_union_name(normalized: &str) -> Option<UnionStructure> {
let Some(rest) = normalized.strip_prefix("union") else {
return None;
};
let body: String = rest
.chars()
.map(|c| if c == '_' || c == '-' { '+' } else { c })
.collect();
let body = body.trim_matches('+');
match body {
"circle+circle" => Some(UnionStructure::CircleCircle),
"circle+cluster" | "circle+point+cluster" | "circle+pointcluster" => {
Some(UnionStructure::CirclePointCluster)
}
"line+cluster" | "line+point+cluster" | "line+pointcluster" => {
Some(UnionStructure::LineCluster)
}
_ => None,
}
}
#[derive(Debug, Clone)]
pub struct TopologyAutoSelector {
pub candidates: Vec<AutoTopologyKind>,
pub score_scale: TopologyScoreScale,
pub latent: Option<String>,
}
impl TopologyAutoSelector {
pub fn new(candidates: Option<Vec<AutoTopologyKind>>) -> Self {
Self {
candidates: candidates.unwrap_or_else(AutoTopologyKind::all),
score_scale: TopologyScoreScale::PerEffectiveDim,
latent: None,
}
}
pub fn from_json(value: &JsonValue) -> Result<Self, String> {
let obj = value
.as_object()
.ok_or_else(|| "topology_auto_selector must be an object".to_string())?;
let candidates = match obj.get("candidates").filter(|value| !value.is_null()) {
None => AutoTopologyKind::all(),
Some(raw) => {
let items = raw.as_array().ok_or_else(|| {
"topology_auto_selector.candidates must be a list".to_string()
})?;
if items.is_empty() {
return Err(
"topology_auto_selector.candidates must have at least one entry"
.to_string(),
);
}
let mut out = Vec::with_capacity(items.len());
for (idx, item) in items.iter().enumerate() {
let name = item.as_str().ok_or_else(|| {
format!("topology_auto_selector.candidates[{idx}] must be a string")
})?;
let kind = AutoTopologyKind::parse(name)?;
if out.contains(&kind) {
return Err(format!(
"topology_auto_selector duplicate candidate {:?}",
kind.as_str()
));
}
out.push(kind);
}
out
}
};
let score_scale = match obj
.get("score_scale")
.and_then(JsonValue::as_str)
.unwrap_or("per_effective_dim")
.trim()
.to_ascii_lowercase()
.replace('-', "_")
.as_str()
{
"per_observation" => TopologyScoreScale::PerObservation,
"per_effective_dim" => TopologyScoreScale::PerEffectiveDim,
other => {
return Err(format!(
"topology_auto_selector.score_scale must be per_effective_dim or per_observation; got {other:?}"
));
}
};
let latent = obj
.get("latent")
.filter(|value| !value.is_null())
.map(|value| {
value
.as_str()
.map(str::to_string)
.ok_or_else(|| "topology_auto_selector.latent must be a string".to_string())
})
.transpose()?;
Ok(Self {
candidates,
score_scale,
latent,
})
}
}
#[derive(Debug, Clone)]
pub struct TopologyAutoFitEvidence<FitHandle> {
pub topology_name: String,
pub raw_reml: f64,
pub null_dim: f64,
pub null_space_logdet: Option<f64>,
pub effective_dim: f64,
pub n_obs: usize,
pub fit_handle: FitHandle,
}
#[derive(Debug, Clone)]
pub struct TopologyAutoRankedFit<FitHandle> {
pub topology_name: String,
pub tk_score: f64,
pub raw_reml: f64,
pub effective_dim: f64,
pub n_obs: usize,
pub fit_handle: FitHandle,
}
#[derive(Debug, Clone)]
pub struct TopologyAutoSelectorResult<FitHandle> {
pub ranked: Vec<TopologyAutoRankedFit<FitHandle>>,
pub winner_index: usize,
}
impl<FitHandle> TopologyAutoSelectorResult<FitHandle> {
pub fn winner(&self) -> Option<&TopologyAutoRankedFit<FitHandle>> {
self.ranked.get(self.winner_index)
}
}
pub fn select_topology_with_fit<FitHandle, FitErr>(
selector: &TopologyAutoSelector,
mut fit_one: impl FnMut(AutoTopologyKind) -> Result<TopologyAutoFitEvidence<FitHandle>, FitErr>,
) -> Result<TopologyAutoSelectorResult<FitHandle>, String>
where
FitErr: ToString,
{
let mut ranked = Vec::with_capacity(selector.candidates.len());
let mut errors = Vec::new();
for candidate in &selector.candidates {
match fit_one(*candidate) {
Ok(evidence) => {
let tk_score = tk_normalized_score(
evidence.raw_reml,
evidence.null_dim,
evidence.null_space_logdet,
evidence.effective_dim,
evidence.n_obs,
selector.score_scale,
)?;
ranked.push(TopologyAutoRankedFit {
topology_name: evidence.topology_name,
tk_score,
raw_reml: evidence.raw_reml,
effective_dim: evidence.effective_dim,
n_obs: evidence.n_obs,
fit_handle: evidence.fit_handle,
});
}
Err(err) => errors.push(format!("{}: {}", candidate.as_str(), err.to_string())),
}
}
if ranked.is_empty() {
return Err(format!(
"TopologyAutoSelector found no fittable topology candidates{}",
if errors.is_empty() {
String::new()
} else {
format!(" ({})", errors.join("; "))
}
));
}
ranked = rank_priority_candidates(
ranked
.into_iter()
.enumerate()
.map(|(idx, row)| {
let score = row.tk_score;
PriorityCandidate::new(row, idx, score, 0)
})
.collect(),
)
.into_iter()
.map(|row| row.item)
.collect();
Ok(TopologyAutoSelectorResult {
ranked,
winner_index: 0,
})
}
pub fn tk_normalized_score(
raw_reml: f64,
null_dim: f64,
null_space_logdet: Option<f64>,
effective_dim: f64,
n_obs: usize,
score_scale: TopologyScoreScale,
) -> Result<f64, String> {
if !(raw_reml.is_finite() && null_dim.is_finite()) || null_dim < -1.0e-9 {
return Err("TopologyAutoSelector received non-finite TK evidence inputs".to_string());
}
let normalizer = if null_dim.max(0.0) == 0.0 {
0.0
} else {
let logdet = null_space_logdet.ok_or_else(|| {
"TopologyAutoSelector TK normalizer requires null-space Hessian logdet".to_string()
})?;
if !logdet.is_finite() {
return Err("TopologyAutoSelector null-space Hessian logdet is not finite".to_string());
}
-0.5 * null_dim.max(0.0) * TK_LOG_2PI + 0.5 * logdet
};
let tk = raw_reml + normalizer;
match score_scale {
TopologyScoreScale::PerObservation => {
if n_obs == 0 {
Err("TopologyAutoSelector requires n_obs > 0".to_string())
} else {
Ok(tk / n_obs as f64)
}
}
TopologyScoreScale::PerEffectiveDim => {
if !(effective_dim.is_finite() && effective_dim > 0.0) {
Err("TopologyAutoSelector requires finite positive effective_dim".to_string())
} else {
Ok(tk / effective_dim)
}
}
}
}
#[derive(Debug, Clone)]
pub struct MixtureRungFit {
pub k: usize,
pub fit: crate::solver::evidence::GaussianMixtureFit,
pub num_parameters: usize,
pub negative_log_evidence: f64,
}
#[derive(Debug, Clone)]
pub struct MixtureRungResult {
pub fits: Vec<MixtureRungFit>,
pub winner_index: usize,
}
impl MixtureRungResult {
pub fn winner(&self) -> &MixtureRungFit {
&self.fits[self.winner_index]
}
}
pub const MIXTURE_REFINEMENT_MAX_PROBES: usize = 16;
pub fn fit_mixture_rung(
data: ArrayView2<'_, f64>,
ladder: &[usize],
config: GaussianMixtureConfig,
) -> Result<MixtureRungResult, String> {
let n = data.nrows();
let mut fits: Vec<MixtureRungFit> = Vec::new();
let mut errors: Vec<String> = Vec::new();
let mut attempted: std::collections::BTreeSet<usize> = std::collections::BTreeSet::new();
let try_order = |k: usize,
fits: &mut Vec<MixtureRungFit>,
errors: &mut Vec<String>,
attempted: &mut std::collections::BTreeSet<usize>| {
if k == 0 || k > n || !attempted.insert(k) {
return;
}
match fit_gaussian_mixture(data, k, config) {
Ok(fit) => match fit.laplace_negative_log_evidence(data) {
Ok(nle) => {
let num_parameters = fit.num_free_parameters();
fits.push(MixtureRungFit {
k,
fit,
num_parameters,
negative_log_evidence: nle,
});
}
Err(e) => errors.push(format!("mixture k={k} evidence: {e}")),
},
Err(e) => errors.push(format!("mixture k={k} fit: {e}")),
}
};
for &k in ladder {
try_order(k, &mut fits, &mut errors, &mut attempted);
}
if fits.is_empty() {
return Err(format!(
"mixture rung produced no fittable orders{}",
if errors.is_empty() {
String::new()
} else {
format!(" ({})", errors.join("; "))
}
));
}
let mut probes = 0usize;
while probes < MIXTURE_REFINEMENT_MAX_PROBES {
let best_k = fits
.iter()
.min_by(|a, b| {
a.negative_log_evidence
.partial_cmp(&b.negative_log_evidence)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.k.cmp(&b.k))
})
.map(|f| f.k)
.unwrap_or(1);
let next = [best_k.saturating_sub(1), best_k + 1]
.into_iter()
.find(|&k| k >= 1 && k <= n && !attempted.contains(&k));
let Some(k) = next else {
break; };
try_order(k, &mut fits, &mut errors, &mut attempted);
probes += 1;
}
let ranked = rank_priority_candidates(
fits.into_iter()
.enumerate()
.map(|(idx, row)| {
let score = row.negative_log_evidence;
let tie = row.k; PriorityCandidate::new(row, idx, score, tie)
})
.collect(),
)
.into_iter()
.map(|row| row.item)
.collect::<Vec<_>>();
Ok(MixtureRungResult {
fits: ranked,
winner_index: 0,
})
}
#[derive(Debug, Clone)]
pub struct UnionRungFit {
pub structure: UnionStructure,
pub fit: UnionStructureFit,
pub total_parameters: usize,
pub negative_log_evidence: f64,
}
#[derive(Debug, Clone)]
pub struct UnionRungResult {
pub fits: Vec<UnionRungFit>,
pub winner_index: usize,
}
impl UnionRungResult {
pub fn winner(&self) -> &UnionRungFit {
&self.fits[self.winner_index]
}
}
pub fn fit_union_rung(
data: ArrayView2<'_, f64>,
config: GaussianMixtureConfig,
) -> Result<UnionRungResult, String> {
let ladder = fit_union_ladder(data, config)?;
let fits: Vec<UnionRungFit> = ladder
.into_iter()
.map(|fit| UnionRungFit {
structure: fit.structure,
total_parameters: fit.total_parameters,
negative_log_evidence: fit.negative_log_evidence,
fit,
})
.collect();
if fits.is_empty() {
return Err("union rung produced no fittable composites".to_string());
}
Ok(UnionRungResult {
fits,
winner_index: 0,
})
}
pub fn fit_union_candidate(
data: ArrayView2<'_, f64>,
structure: UnionStructure,
config: GaussianMixtureConfig,
) -> Result<UnionRungFit, String> {
let fit = fit_union_structure(data, structure, config)?;
Ok(UnionRungFit {
structure: fit.structure,
total_parameters: fit.total_parameters,
negative_log_evidence: fit.negative_log_evidence,
fit,
})
}
pub type HeldOutDensityProvider<'a> =
Box<dyn Fn(&[usize], &[usize]) -> Result<Vec<f64>, String> + 'a>;
pub fn mixture_density_provider<'a>(
data: ArrayView2<'a, f64>,
k: usize,
config: GaussianMixtureConfig,
) -> HeldOutDensityProvider<'a> {
let owned = data.to_owned();
Box::new(
move |train: &[usize], eval: &[usize]| -> Result<Vec<f64>, String> {
let train_mat = gather_rows(owned.view(), train);
let fit = fit_gaussian_mixture(train_mat.view(), k.min(train.len().max(1)), config)?;
let eval_mat = gather_rows(owned.view(), eval);
let dens = fit.per_point_log_density(eval_mat.view())?;
Ok(dens.to_vec())
},
)
}
pub fn union_density_provider<'a>(
data: ArrayView2<'a, f64>,
structure: UnionStructure,
config: GaussianMixtureConfig,
) -> HeldOutDensityProvider<'a> {
let owned = data.to_owned();
Box::new(
move |train: &[usize], eval: &[usize]| -> Result<Vec<f64>, String> {
let train_mat = gather_rows(owned.view(), train);
let eval_mat = gather_rows(owned.view(), eval);
let dens =
union_per_point_log_density(train_mat.view(), eval_mat.view(), structure, config)?;
Ok(dens.to_vec())
},
)
}
fn gather_rows(data: ArrayView2<'_, f64>, idx: &[usize]) -> Array2<f64> {
let d = data.ncols();
let mut out = Array2::<f64>::zeros((idx.len(), d));
for (r, &i) in idx.iter().enumerate() {
for c in 0..d {
out[[r, c]] = data[[i, c]];
}
}
out
}
pub fn deterministic_cv_folds(n: usize, folds: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
let folds = folds.clamp(2, n.max(2));
let mut out = Vec::with_capacity(folds);
for f in 0..folds {
let mut train = Vec::new();
let mut eval = Vec::new();
for i in 0..n {
if i % folds == f {
eval.push(i);
} else {
train.push(i);
}
}
if !eval.is_empty() && !train.is_empty() {
out.push((train, eval));
}
}
out
}
pub fn build_cv_log_density_table(
n: usize,
folds: usize,
providers: &[HeldOutDensityProvider<'_>],
) -> Result<Array2<f64>, String> {
if providers.is_empty() {
return Err("stacking table requires at least one candidate provider".to_string());
}
let partition = deterministic_cv_folds(n, folds);
if partition.is_empty() {
return Err("stacking CV partition is empty (n too small for folds)".to_string());
}
let mut table = Array2::<f64>::from_elem((n, providers.len()), f64::NEG_INFINITY);
for (train, eval) in &partition {
for (col, provider) in providers.iter().enumerate() {
let dens = provider(train, eval)?;
if dens.len() != eval.len() {
return Err(format!(
"provider {col} returned {} densities for {} eval rows",
dens.len(),
eval.len()
));
}
for (slot, &row) in eval.iter().enumerate() {
table[[row, col]] = dens[slot];
}
}
}
Ok(table)
}
#[derive(Debug, Clone)]
pub struct CrossClassRaceVerdict {
pub candidate_names: Vec<String>,
pub is_cross_class: bool,
pub negative_log_evidence: Vec<f64>,
pub stacking: Option<StackingWeights>,
pub winner_index: usize,
pub headline: Headline,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Headline {
Evidence,
Stacking,
}
pub struct CrossClassCandidate<'a> {
pub kind: AutoTopologyKind,
pub negative_log_evidence: f64,
pub density_provider: HeldOutDensityProvider<'a>,
}
pub fn adjudicate_cross_class_race(
n: usize,
candidates: Vec<CrossClassCandidate<'_>>,
folds: usize,
stacking_config: StackingConfig,
) -> Result<CrossClassRaceVerdict, String> {
if candidates.is_empty() {
return Err("cross-class race requires at least one candidate".to_string());
}
let names: Vec<String> = candidates.iter().map(|c| c.kind.display_name()).collect();
let evidence: Vec<f64> = candidates.iter().map(|c| c.negative_log_evidence).collect();
let has_discrete = candidates.iter().any(|c| c.kind.is_discrete_class());
let has_smooth = candidates.iter().any(|c| !c.kind.is_discrete_class());
let is_cross_class = has_discrete && has_smooth;
if !is_cross_class {
let mut winner_index = 0usize;
let mut best = f64::INFINITY;
for (idx, &nle) in evidence.iter().enumerate() {
if nle.is_finite() && nle < best {
best = nle;
winner_index = idx;
}
}
return Ok(CrossClassRaceVerdict {
candidate_names: names,
is_cross_class: false,
negative_log_evidence: evidence,
stacking: None,
winner_index,
headline: Headline::Evidence,
});
}
let providers: Vec<HeldOutDensityProvider<'_>> =
candidates.into_iter().map(|c| c.density_provider).collect();
let table = build_cv_log_density_table(n, folds, &providers)?;
let stacking = solve_stacking_weights(table.view(), stacking_config)?;
let mut winner_index = 0usize;
let mut best_w = f64::NEG_INFINITY;
for (idx, &w) in stacking.weights.iter().enumerate() {
if w > best_w {
best_w = w;
winner_index = idx;
}
}
Ok(CrossClassRaceVerdict {
candidate_names: names,
is_cross_class: true,
negative_log_evidence: evidence,
stacking: Some(stacking),
winner_index,
headline: Headline::Stacking,
})
}