use core::fmt;
use crate::automl::ModelFactory;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[non_exhaustive]
pub enum Objective {
RegressionRmse,
RegressionMae,
RegressionR2,
DirectionalAccuracy,
ClassificationF1,
ClassificationKappa,
DistributionalCrps,
}
impl Objective {
pub fn is_minimization(self) -> bool {
matches!(
self,
Objective::RegressionRmse | Objective::RegressionMae | Objective::DistributionalCrps
)
}
pub fn as_str(self) -> &'static str {
match self {
Objective::RegressionRmse => "regression_rmse",
Objective::RegressionMae => "regression_mae",
Objective::RegressionR2 => "regression_r2",
Objective::DirectionalAccuracy => "directional_accuracy",
Objective::ClassificationF1 => "classification_f1",
Objective::ClassificationKappa => "classification_kappa",
Objective::DistributionalCrps => "distributional_crps",
}
}
}
impl fmt::Display for Objective {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ComplexityClass {
Tiny,
Small,
Medium,
Large,
}
impl Default for ComplexityClass {
fn default() -> Self {
ComplexityClass::Tiny
}
}
impl ComplexityClass {
pub fn from_hint(hint: usize) -> Self {
if hint <= 100 {
ComplexityClass::Tiny
} else if hint <= 1_000 {
ComplexityClass::Small
} else if hint <= 10_000 {
ComplexityClass::Medium
} else {
ComplexityClass::Large
}
}
}
pub trait MetaLearner: Send + Sync {
fn lipschitz_bound(&self) -> f64;
fn objectives(&self) -> &[Objective];
fn complexity_class(&self) -> ComplexityClass;
fn tunes_continuous_knobs(&self) -> bool;
fn tunes_structure(&self) -> bool;
fn is_no_op(&self) -> bool {
self.lipschitz_bound() == 1.0 && !self.tunes_continuous_knobs() && !self.tunes_structure()
}
fn rationale(&self) -> Option<&str> {
None
}
}
#[derive(Debug, Clone)]
pub struct NoOpMetaLearner {
reason: &'static str,
complexity: ComplexityClass,
}
impl NoOpMetaLearner {
pub fn with_reason(reason: &'static str) -> Self {
Self {
reason,
complexity: ComplexityClass::Small,
}
}
pub fn new(reason: &'static str, complexity: ComplexityClass) -> Self {
Self { reason, complexity }
}
pub fn reason(&self) -> &'static str {
self.reason
}
}
impl Default for NoOpMetaLearner {
fn default() -> Self {
Self::with_reason("default factory opt-out: no MetaLearner declared")
}
}
impl MetaLearner for NoOpMetaLearner {
fn lipschitz_bound(&self) -> f64 {
1.0
}
fn objectives(&self) -> &[Objective] {
&[]
}
fn complexity_class(&self) -> ComplexityClass {
self.complexity
}
fn tunes_continuous_knobs(&self) -> bool {
false
}
fn tunes_structure(&self) -> bool {
false
}
fn rationale(&self) -> Option<&str> {
Some(self.reason)
}
}
#[derive(Debug, Clone, Copy)]
pub struct SgbtMetaLearner {
complexity: ComplexityClass,
}
impl SgbtMetaLearner {
pub fn new(complexity: ComplexityClass) -> Self {
Self { complexity }
}
}
impl MetaLearner for SgbtMetaLearner {
fn lipschitz_bound(&self) -> f64 {
0.7
}
fn objectives(&self) -> &[Objective] {
const REGRESSION: &[Objective] = &[
Objective::RegressionRmse,
Objective::RegressionMae,
Objective::RegressionR2,
Objective::DirectionalAccuracy,
];
REGRESSION
}
fn complexity_class(&self) -> ComplexityClass {
self.complexity
}
fn tunes_continuous_knobs(&self) -> bool {
true
}
fn tunes_structure(&self) -> bool {
true
}
fn rationale(&self) -> Option<&str> {
Some(
"SGBT regression family: SPSA on (lr, lambda) with ρ=0.3 \
convex-blended step (L=0.7); structural changes at tree-replacement \
boundary",
)
}
}
#[derive(Debug, Clone, Copy)]
pub struct SgbtClassificationMetaLearner {
complexity: ComplexityClass,
}
impl SgbtClassificationMetaLearner {
pub fn new(complexity: ComplexityClass) -> Self {
Self { complexity }
}
}
impl MetaLearner for SgbtClassificationMetaLearner {
fn lipschitz_bound(&self) -> f64 {
0.7
}
fn objectives(&self) -> &[Objective] {
const CLASSIFICATION: &[Objective] =
&[Objective::ClassificationF1, Objective::ClassificationKappa];
CLASSIFICATION
}
fn complexity_class(&self) -> ComplexityClass {
self.complexity
}
fn tunes_continuous_knobs(&self) -> bool {
true
}
fn tunes_structure(&self) -> bool {
true
}
fn rationale(&self) -> Option<&str> {
Some(
"Classification SGBT family: SPSA on (lr, lambda) with ρ=0.3 \
convex-blended step (L=0.7); softmax committee, classification \
objective surface",
)
}
}
pub struct MetaSearch;
#[derive(Debug, Clone, Default)]
pub struct MetaScore {
pub values: std::collections::BTreeMap<Objective, f64>,
pub complexity: ComplexityClass,
}
impl MetaScore {
pub fn new(complexity: ComplexityClass) -> Self {
Self {
values: std::collections::BTreeMap::new(),
complexity,
}
}
pub fn record(&mut self, obj: Objective, value: f64) {
self.values.insert(obj, value);
}
pub fn get(&self, obj: Objective) -> Option<f64> {
self.values.get(&obj).copied()
}
}
impl MetaSearch {
pub fn pareto_dominates(a: &MetaScore, b: &MetaScore) -> bool {
let mut shared = 0usize;
let mut a_strictly_better_on_some = false;
let mut a_no_worse_on_all = true;
for (&obj, &av) in &a.values {
if let Some(&bv) = b.values.get(&obj) {
shared += 1;
let (a_eff, b_eff) = if obj.is_minimization() {
(-av, -bv)
} else {
(av, bv)
};
match a_eff.partial_cmp(&b_eff) {
Some(core::cmp::Ordering::Greater) => a_strictly_better_on_some = true,
Some(core::cmp::Ordering::Less) => {
a_no_worse_on_all = false;
}
Some(core::cmp::Ordering::Equal) => {}
None => {
a_no_worse_on_all = false;
}
}
}
}
match a.complexity.cmp(&b.complexity) {
core::cmp::Ordering::Less => a_strictly_better_on_some = true,
core::cmp::Ordering::Greater => {
a_no_worse_on_all = false;
}
core::cmp::Ordering::Equal => {}
}
shared > 0 && a_no_worse_on_all && a_strictly_better_on_some
}
pub fn pareto_front(scores: &[MetaScore]) -> Vec<usize> {
let mut front = Vec::with_capacity(scores.len());
for (i, si) in scores.iter().enumerate() {
let dominated = scores.iter().enumerate().any(|(j, sj)| {
if i == j {
return false;
}
MetaSearch::pareto_dominates(sj, si)
});
if !dominated {
front.push(i);
}
}
front
}
}
pub trait FactoryMetaLearner: ModelFactory {
fn meta_learner(&self) -> Box<dyn MetaLearner> {
Box::new(NoOpMetaLearner::new(
"default ModelFactory::meta_learner: family did not override; \
online adaptation not declared",
ComplexityClass::from_hint(self.complexity_hint()),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::automl::Factory;
#[test]
fn meta_learner_no_op_default_compiles_for_all_models() {
let _: Box<dyn MetaLearner> = Factory::sgbt(5).meta_learner();
let _: Box<dyn MetaLearner> = Factory::distributional(5).meta_learner();
let _: Box<dyn MetaLearner> = Factory::multiclass_sgbt(5, 3).meta_learner();
let _: Box<dyn MetaLearner> = Factory::esn().meta_learner();
let _: Box<dyn MetaLearner> = Factory::mamba(4).meta_learner();
let _: Box<dyn MetaLearner> = Factory::mamba3(4).meta_learner();
let _: Box<dyn MetaLearner> = Factory::mamba_bd(4).meta_learner();
let _: Box<dyn MetaLearner> = Factory::slstm(4).meta_learner();
let _: Box<dyn MetaLearner> = Factory::mgrade(4).meta_learner();
let _: Box<dyn MetaLearner> = Factory::attention(8).meta_learner();
let _: Box<dyn MetaLearner> = Factory::delta_product(8).meta_learner();
let _: Box<dyn MetaLearner> = Factory::rwkv7(8).meta_learner();
let _: Box<dyn MetaLearner> = Factory::spike_net().meta_learner();
let _: Box<dyn MetaLearner> = Factory::kan(4).meta_learner();
let _: Box<dyn MetaLearner> = Factory::ttt(4).meta_learner();
}
#[test]
fn no_op_default_factory_has_identity_lipschitz() {
let factory = Factory::esn();
let m = factory.meta_learner();
assert_eq!(
m.lipschitz_bound(),
1.0,
"default no-op factory must declare L=1.0 (non-expansive identity), got {}",
m.lipschitz_bound()
);
assert!(
m.is_no_op(),
"default no-op factory must report is_no_op() = true"
);
assert!(
m.objectives().is_empty(),
"no-op MetaLearner must declare no objective surface"
);
assert!(
!m.tunes_continuous_knobs(),
"no-op MetaLearner must NOT declare continuous-knob tuning"
);
assert!(
!m.tunes_structure(),
"no-op MetaLearner must NOT declare structural tuning"
);
assert!(
m.rationale().is_some(),
"no-op MetaLearner must surface a rationale for the opt-out"
);
}
#[test]
fn meta_learner_lipschitz_bound_derives_from_implementation() {
let m = SgbtMetaLearner::new(ComplexityClass::Medium);
let rho: f64 = 0.3;
let expected = rho.max(1.0_f64 - rho); assert!(
(m.lipschitz_bound() - expected).abs() < 1e-12,
"SgbtMetaLearner Lipschitz must derive from rho-blending: \
expected max(rho, 1-rho) = {}, got {}",
expected,
m.lipschitz_bound()
);
assert!(
m.lipschitz_bound() < 1.0,
"SgbtMetaLearner must be a strict contraction (L < 1), got {}",
m.lipschitz_bound()
);
let mc = SgbtClassificationMetaLearner::new(ComplexityClass::Medium);
assert!(
(mc.lipschitz_bound() - expected).abs() < 1e-12,
"SgbtClassificationMetaLearner must use the same SPSA Lipschitz bound"
);
}
#[test]
fn lipschitz_product_satisfies_banach_contraction_invariant() {
let m = SgbtMetaLearner::new(ComplexityClass::Medium);
let l_drift = 0.95_f64; let product = m.lipschitz_bound() * l_drift;
assert!(
product < 1.0,
"Banach contraction invariant requires ∏ L_i < 1; got {} = {} · {}",
product,
m.lipschitz_bound(),
l_drift
);
}
#[test]
fn meta_learner_across_family_comparison_is_pareto_not_scalar() {
let mut a = MetaScore::new(ComplexityClass::Small);
a.record(Objective::RegressionRmse, 0.10);
a.record(Objective::RegressionR2, 0.50);
let mut b = MetaScore::new(ComplexityClass::Small);
b.record(Objective::RegressionRmse, 0.15); b.record(Objective::RegressionR2, 0.80);
assert!(
!MetaSearch::pareto_dominates(&a, &b),
"a should NOT Pareto-dominate b (a wins RMSE, loses R²)"
);
assert!(
!MetaSearch::pareto_dominates(&b, &a),
"b should NOT Pareto-dominate a (b wins R², loses RMSE)"
);
let scores = [a, b];
let front = MetaSearch::pareto_front(&scores);
assert_eq!(
front.len(),
2,
"trade-off candidates must both appear on Pareto front, got {:?}",
front
);
}
#[test]
fn pareto_strict_dominance_excludes_dominated() {
let mut a = MetaScore::new(ComplexityClass::Small);
a.record(Objective::RegressionRmse, 0.10);
a.record(Objective::RegressionR2, 0.80);
let mut b = MetaScore::new(ComplexityClass::Small);
b.record(Objective::RegressionRmse, 0.15); b.record(Objective::RegressionR2, 0.50);
assert!(
MetaSearch::pareto_dominates(&a, &b),
"a beats b on every axis: a should strictly Pareto-dominate b"
);
assert!(
!MetaSearch::pareto_dominates(&b, &a),
"b cannot dominate a (worse on every axis)"
);
let scores = [a, b];
let front = MetaSearch::pareto_front(&scores);
assert_eq!(
front,
vec![0],
"Pareto front must contain only the dominating candidate, got {:?}",
front
);
}
#[test]
fn pareto_complexity_axis_breaks_equal_metric_ties() {
let mut tiny = MetaScore::new(ComplexityClass::Tiny);
tiny.record(Objective::RegressionRmse, 0.10);
let mut large = MetaScore::new(ComplexityClass::Large);
large.record(Objective::RegressionRmse, 0.10);
assert!(
MetaSearch::pareto_dominates(&tiny, &large),
"equal metric + lower complexity must strictly dominate"
);
assert!(
!MetaSearch::pareto_dominates(&large, &tiny),
"equal metric + higher complexity cannot dominate"
);
}
#[test]
fn pareto_handles_nan_as_incomparable() {
let mut a = MetaScore::new(ComplexityClass::Small);
a.record(Objective::RegressionRmse, f64::NAN);
a.record(Objective::RegressionR2, 0.80);
let mut b = MetaScore::new(ComplexityClass::Small);
b.record(Objective::RegressionRmse, 0.15);
b.record(Objective::RegressionR2, 0.50);
assert!(
!MetaSearch::pareto_dominates(&a, &b),
"NaN on a shared axis must prevent dominance"
);
}
#[test]
fn pareto_disjoint_objectives_are_incomparable() {
let mut regression = MetaScore::new(ComplexityClass::Small);
regression.record(Objective::RegressionRmse, 0.10);
let mut classification = MetaScore::new(ComplexityClass::Small);
classification.record(Objective::ClassificationF1, 0.80);
assert!(
!MetaSearch::pareto_dominates(®ression, &classification),
"disjoint objective surfaces must not dominate"
);
assert!(
!MetaSearch::pareto_dominates(&classification, ®ression),
"disjoint objective surfaces must not dominate (reverse direction)"
);
}
#[test]
fn meta_learner_compose_with_within_model_auto_tuner() {
let factory_meta = SgbtMetaLearner::new(ComplexityClass::Medium);
let within_model_spsa_lipschitz = 0.7_f64;
let combined = factory_meta.lipschitz_bound() * within_model_spsa_lipschitz;
assert!(
combined < 1.0,
"MetaLearner ∘ within-model SPSA must be a strict contraction; \
got {} = {} · {}",
combined,
factory_meta.lipschitz_bound(),
within_model_spsa_lipschitz
);
let margin = 1.0 - combined;
assert!(
margin > 0.0,
"Composition must leave a non-zero margin for additional adapters; \
got margin = {}",
margin
);
let no_op = NoOpMetaLearner::default();
let no_op_composed = no_op.lipschitz_bound() * within_model_spsa_lipschitz;
assert!(
(no_op_composed - within_model_spsa_lipschitz).abs() < 1e-12,
"no-op MetaLearner must contribute a factor of 1 (identity); \
got composed = {}, expected {}",
no_op_composed,
within_model_spsa_lipschitz
);
}
#[test]
fn complexity_class_buckets_match_documented_thresholds() {
assert_eq!(ComplexityClass::from_hint(0), ComplexityClass::Tiny);
assert_eq!(ComplexityClass::from_hint(100), ComplexityClass::Tiny);
assert_eq!(ComplexityClass::from_hint(101), ComplexityClass::Small);
assert_eq!(ComplexityClass::from_hint(1_000), ComplexityClass::Small);
assert_eq!(ComplexityClass::from_hint(1_001), ComplexityClass::Medium);
assert_eq!(ComplexityClass::from_hint(10_000), ComplexityClass::Medium);
assert_eq!(ComplexityClass::from_hint(10_001), ComplexityClass::Large);
assert_eq!(
ComplexityClass::from_hint(usize::MAX),
ComplexityClass::Large
);
}
#[test]
fn objective_sign_convention() {
assert!(Objective::RegressionRmse.is_minimization());
assert!(Objective::RegressionMae.is_minimization());
assert!(Objective::DistributionalCrps.is_minimization());
assert!(!Objective::RegressionR2.is_minimization());
assert!(!Objective::DirectionalAccuracy.is_minimization());
assert!(!Objective::ClassificationF1.is_minimization());
assert!(!Objective::ClassificationKappa.is_minimization());
}
#[test]
fn meta_learner_is_trait_object_safe() {
let _: Box<dyn MetaLearner> = Box::new(NoOpMetaLearner::default());
let _: Box<dyn MetaLearner> = Box::new(SgbtMetaLearner::new(ComplexityClass::Medium));
let _: Box<dyn MetaLearner> =
Box::new(SgbtClassificationMetaLearner::new(ComplexityClass::Medium));
}
}