use crate::automl::auto_builder::ConfigDiagnostics;
use thiserror::Error;
use tracing::warn;
const REARM_DELAY: u64 = 5;
const CRITICAL_TIMEOUT_SAMPLES: u64 = 1000;
const THETA_RING: usize = 3;
#[derive(Debug, Clone)]
pub struct AdaptContext {
pub theta: [f64; 2],
pub diagnostics: ConfigDiagnostics,
pub prediction: f64,
pub target: f64,
pub sample_idx: u64,
}
#[derive(Debug, Clone, Default)]
pub struct ThetaDelta {
pub delta: [f64; 2],
}
pub trait MetaAdapter: Send {
fn lipschitz_bound(&self) -> f64;
fn update_period(&self) -> u64;
fn apply(&mut self, ctx: &AdaptContext) -> ThetaDelta;
fn progress(&self) -> f64;
fn flush_state(&mut self) {}
fn is_in_warmup(&self) -> bool {
false
}
fn advance_warmup(&mut self, _n_elapsed: u64) {}
}
#[derive(Debug, Clone, Default)]
pub struct NoOpAdapter;
impl MetaAdapter for NoOpAdapter {
fn lipschitz_bound(&self) -> f64 {
1.0
}
fn update_period(&self) -> u64 {
u64::MAX
}
fn apply(&mut self, _ctx: &AdaptContext) -> ThetaDelta {
ThetaDelta::default()
}
fn progress(&self) -> f64 {
0.0
}
}
#[derive(Debug, Clone)]
pub struct DriftRateAdapter {
delta_max: f64,
ewma_error: f64,
alpha: f64,
warmup_remaining: u64,
needs_reprime: bool,
}
impl DriftRateAdapter {
pub fn new(delta_max: f64) -> Self {
Self {
delta_max: delta_max.abs().clamp(0.0, 0.2),
ewma_error: f64::NAN,
alpha: 0.1,
warmup_remaining: 0,
needs_reprime: false,
}
}
}
impl MetaAdapter for DriftRateAdapter {
fn lipschitz_bound(&self) -> f64 {
1.0
}
fn update_period(&self) -> u64 {
1 }
fn apply(&mut self, ctx: &AdaptContext) -> ThetaDelta {
if self.needs_reprime || self.ewma_error.is_nan() {
self.ewma_error = (ctx.prediction - ctx.target).abs();
self.needs_reprime = false;
return ThetaDelta::default();
}
let err = (ctx.prediction - ctx.target).abs();
let prev_ewma = self.ewma_error;
self.ewma_error = self.alpha * err + (1.0 - self.alpha) * self.ewma_error;
let direction = if self.ewma_error > prev_ewma * 1.02 {
1.0 } else if self.ewma_error < prev_ewma * 0.98 {
-1.0 } else {
0.0
};
let raw_delta = direction * self.delta_max;
let new_theta_0 = (ctx.theta[0] + raw_delta).clamp(0.0, 1.0);
let actual_delta = new_theta_0 - ctx.theta[0];
ThetaDelta {
delta: [actual_delta, 0.0],
}
}
fn progress(&self) -> f64 {
-self.ewma_error }
fn flush_state(&mut self) {
self.ewma_error = f64::NAN;
self.needs_reprime = true;
self.warmup_remaining = REARM_DELAY;
}
fn is_in_warmup(&self) -> bool {
self.warmup_remaining > 0
}
fn advance_warmup(&mut self, n: u64) {
self.warmup_remaining = self.warmup_remaining.saturating_sub(n);
}
}
#[derive(Debug, Clone)]
pub struct PlasticityAdapter {
delta_max: f64,
utility_ewma: f64,
alpha: f64,
warmup_remaining: u64,
needs_reprime: bool,
}
impl PlasticityAdapter {
pub fn new(delta_max: f64) -> Self {
Self {
delta_max: delta_max.abs().clamp(0.0, 0.1),
utility_ewma: 0.5, alpha: 0.05,
warmup_remaining: 0,
needs_reprime: false,
}
}
}
impl MetaAdapter for PlasticityAdapter {
fn lipschitz_bound(&self) -> f64 {
1.0
}
fn update_period(&self) -> u64 {
10 }
fn apply(&mut self, ctx: &AdaptContext) -> ThetaDelta {
let utility = ctx.diagnostics.uncertainty.clamp(0.0, 1.0);
self.utility_ewma = self.alpha * utility + (1.0 - self.alpha) * self.utility_ewma;
self.needs_reprime = false;
let direction = if utility > 0.6 { 1.0 } else { -0.5 };
let raw_delta = direction * self.delta_max;
let new_theta_1 = (ctx.theta[1] + raw_delta).clamp(0.0, 1.0);
let actual_delta = new_theta_1 - ctx.theta[1];
ThetaDelta {
delta: [0.0, actual_delta],
}
}
fn progress(&self) -> f64 {
-self.utility_ewma }
fn flush_state(&mut self) {
self.utility_ewma = 0.5; self.needs_reprime = true;
self.warmup_remaining = REARM_DELAY;
}
fn is_in_warmup(&self) -> bool {
self.warmup_remaining > 0
}
fn advance_warmup(&mut self, n: u64) {
self.warmup_remaining = self.warmup_remaining.saturating_sub(n);
}
}
#[derive(Debug, Error)]
pub enum BusError {
#[error("adapter declared Lipschitz bound {declared:.4} > 1.0 — expansive adapters break compose-safety")]
LipschitzExpansive {
declared: f64,
},
#[error("adding adapter (L={new:.4}) pushes product to {product:.4} ≥ 1.0 — contraction guarantee lost")]
ProductReachesOne {
new: f64,
product: f64,
},
#[error("timescale violation: adapter period {new_period} is faster than existing period {slow_period} but has larger effective step ({new_step:.4e} > {slow_step:.4e})")]
TimescaleViolation {
new_period: u64,
slow_period: u64,
new_step: f64,
slow_step: f64,
},
}
#[derive(Debug, Default)]
struct CriticalSection {
depth: u32,
entered_at: Option<u64>,
}
pub struct AdaptationBus {
adapters: Vec<Box<dyn MetaAdapter>>,
pending_queue: Vec<Box<dyn MetaAdapter>>,
theta: [f64; 2],
theta_ring: [[f64; 2]; THETA_RING],
ring_head: usize,
l_product: f64,
oscillation_count: u32,
oscillation_margin: f64,
oscillation_threshold: u32,
critical: CriticalSection,
warmup_remaining: u64,
}
impl AdaptationBus {
pub fn new(theta_init: [f64; 2]) -> Self {
let theta = [theta_init[0].clamp(0.0, 1.0), theta_init[1].clamp(0.0, 1.0)];
Self {
adapters: Vec::new(),
pending_queue: Vec::new(),
theta,
theta_ring: [[theta[0], theta[1]]; THETA_RING],
ring_head: 0,
l_product: 1.0, oscillation_count: 0,
oscillation_margin: 0.05,
oscillation_threshold: 5,
critical: CriticalSection::default(),
warmup_remaining: 0,
}
}
pub fn register(&mut self, adapter: Box<dyn MetaAdapter>) -> Result<(), BusError> {
let l = adapter.lipschitz_bound();
if l > 1.0 + f64::EPSILON {
return Err(BusError::LipschitzExpansive { declared: l });
}
let new_product = self.l_product * l;
if new_product > 1.0 + f64::EPSILON {
return Err(BusError::ProductReachesOne {
new: l,
product: new_product,
});
}
self.validate_timescale(&*adapter)?;
if self.critical.depth > 0 {
self.pending_queue.push(adapter);
return Ok(());
}
self.l_product = new_product;
self.adapters.push(adapter);
Ok(())
}
pub fn step(&mut self, ctx: &AdaptContext) -> [f64; 2] {
self.check_critical_timeout(ctx.sample_idx);
if self.critical.depth > 0 {
return self.theta;
}
if self.warmup_remaining > 0 {
self.warmup_remaining = self.warmup_remaining.saturating_sub(1);
for a in &mut self.adapters {
a.advance_warmup(1);
}
}
for adapter in &mut self.adapters {
let period = adapter.update_period();
if period == u64::MAX || ctx.sample_idx % period != 0 {
continue;
}
if self.warmup_remaining > 0 || adapter.is_in_warmup() {
adapter.apply(ctx);
continue;
}
let delta = adapter.apply(ctx);
let l_i = adapter.lipschitz_bound();
let max_norm = (1.0 - l_i) * norm2(&self.theta);
let applied = clamp_delta(delta.delta, max_norm);
for (i, applied_i) in applied.iter().enumerate() {
self.theta[i] = (self.theta[i] + applied_i).clamp(0.0, 1.0);
}
}
self.detect_oscillation();
self.theta
}
pub fn enter_critical(&mut self, sample_idx: u64) -> u32 {
let prev = self.critical.depth;
self.critical.depth += 1;
if prev == 0 {
self.critical.entered_at = Some(sample_idx);
self.flush_adapter_state();
}
self.critical.depth
}
pub fn exit_critical(&mut self, sample_idx: u64) -> u32 {
if self.critical.depth == 0 {
warn!(
"AdaptationBus: exit_critical called without matching enter — likely a caller bug"
);
return 0;
}
self.critical.depth -= 1;
if self.critical.depth == 0 {
self.critical.entered_at = None;
self.arm_warmup(sample_idx);
self.drain_pending_queue();
}
self.critical.depth
}
pub fn theta(&self) -> [f64; 2] {
self.theta
}
pub fn lipschitz_product(&self) -> f64 {
self.l_product
}
pub fn in_critical_section(&self) -> bool {
self.critical.depth > 0
}
pub fn critical_depth(&self) -> u32 {
self.critical.depth
}
pub fn in_warmup(&self) -> bool {
self.warmup_remaining > 0
}
fn flush_adapter_state(&mut self) {
for adapter in &mut self.adapters {
adapter.flush_state();
}
}
fn arm_warmup(&mut self, _sample_idx: u64) {
self.warmup_remaining = REARM_DELAY;
}
fn drain_pending_queue(&mut self) {
let pending = std::mem::take(&mut self.pending_queue);
for adapter in pending {
if let Err(e) = self.register(adapter) {
warn!("AdaptationBus: deferred adapter registration failed: {e}");
}
}
}
fn check_critical_timeout(&mut self, sample_idx: u64) {
if let Some(entered_at) = self.critical.entered_at {
if sample_idx.saturating_sub(entered_at) > CRITICAL_TIMEOUT_SAMPLES {
warn!(
"AdaptationBus: critical section timeout ({CRITICAL_TIMEOUT_SAMPLES} samples) \
— force-exiting. Plasticity source likely did not call exit_critical. \
entered_at={entered_at}, current={sample_idx}"
);
self.critical.depth = 1;
self.exit_critical(sample_idx);
}
}
}
fn detect_oscillation(&mut self) {
let prev_idx = self.ring_head; let diff = norm2_diff(&self.theta, &self.theta_ring[prev_idx]);
self.ring_head = (self.ring_head + 1) % THETA_RING;
self.theta_ring[self.ring_head] = self.theta;
if diff > self.oscillation_margin {
self.oscillation_count += 1;
} else {
self.oscillation_count = 0;
}
if self.oscillation_count >= self.oscillation_threshold {
warn!(
"AdaptationBus: oscillation detected ({} consecutive steps with Δθ={:.4e} > margin={:.4e}). \
Declared ∏L_i={:.4}. This may indicate an adapter with under-declared Lipschitz bound. \
Halving step sizes as defence-in-depth.",
self.oscillation_count, diff, self.oscillation_margin, self.l_product
);
self.backoff_fastest_adapter();
self.oscillation_count = 0;
}
}
fn backoff_fastest_adapter(&mut self) {
self.oscillation_margin *= 0.5;
self.oscillation_margin = self.oscillation_margin.max(1e-6);
}
fn validate_timescale(&self, new: &dyn MetaAdapter) -> Result<(), BusError> {
let new_period = new.update_period();
let new_step = 1.0 - new.lipschitz_bound();
for existing in &self.adapters {
let slow_period = existing.update_period();
let slow_step = 1.0 - existing.lipschitz_bound();
if new_period < slow_period && slow_step > 0.0 {
let required_max_step = slow_step * (new_period as f64 / slow_period as f64);
if new_step > required_max_step + f64::EPSILON {
return Err(BusError::TimescaleViolation {
new_period,
slow_period,
new_step,
slow_step,
});
}
}
}
Ok(())
}
}
pub struct CriticalGuard<'a> {
bus: &'a mut AdaptationBus,
sample_idx: u64,
}
impl<'a> CriticalGuard<'a> {
pub fn new(bus: &'a mut AdaptationBus, sample_idx: u64) -> Self {
bus.enter_critical(sample_idx);
Self { bus, sample_idx }
}
}
impl<'a> Drop for CriticalGuard<'a> {
fn drop(&mut self) {
self.bus.exit_critical(self.sample_idx);
}
}
pub(crate) fn norm2(v: &[f64; 2]) -> f64 {
(v[0] * v[0] + v[1] * v[1]).sqrt()
}
pub(crate) fn norm2_diff(a: &[f64; 2], b: &[f64; 2]) -> f64 {
let d0 = a[0] - b[0];
let d1 = a[1] - b[1];
(d0 * d0 + d1 * d1).sqrt()
}
fn clamp_delta(delta: [f64; 2], max_norm: f64) -> [f64; 2] {
if max_norm <= 0.0 {
return delta;
}
let n = norm2(&delta);
if n <= max_norm {
delta
} else {
let scale = max_norm / n;
[delta[0] * scale, delta[1] * scale]
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ctx(theta: [f64; 2], sample_idx: u64) -> AdaptContext {
AdaptContext {
theta,
diagnostics: ConfigDiagnostics::default(),
prediction: 0.0,
target: 0.0,
sample_idx,
}
}
struct ContractionAdapter {
rho: f64, }
impl ContractionAdapter {
fn new(rho: f64) -> Self {
Self { rho }
}
}
impl MetaAdapter for ContractionAdapter {
fn lipschitz_bound(&self) -> f64 {
self.rho.max(1.0 - self.rho)
}
fn update_period(&self) -> u64 {
30
}
fn apply(&mut self, ctx: &AdaptContext) -> ThetaDelta {
let target = [0.5_f64, 0.5_f64];
let delta = [
self.rho * (target[0] - ctx.theta[0]),
self.rho * (target[1] - ctx.theta[1]),
];
ThetaDelta { delta }
}
fn progress(&self) -> f64 {
0.0
}
}
struct ExpansiveAdapter;
impl MetaAdapter for ExpansiveAdapter {
fn lipschitz_bound(&self) -> f64 {
1.5 }
fn update_period(&self) -> u64 {
10
}
fn apply(&mut self, _ctx: &AdaptContext) -> ThetaDelta {
ThetaDelta { delta: [0.1, 0.1] }
}
fn progress(&self) -> f64 {
0.0
}
}
#[test]
fn bus_rejects_adapter_with_lipschitz_above_one() {
let mut bus = AdaptationBus::new([0.5, 0.5]);
let result = bus.register(Box::new(ExpansiveAdapter));
assert!(
result.is_err(),
"bus must reject adapters with declared L > 1.0"
);
match result.unwrap_err() {
BusError::LipschitzExpansive { declared } => {
assert!((declared - 1.5).abs() < 1e-10, "declared={declared}");
}
e => panic!("expected LipschitzExpansive, got {e:?}"),
}
}
#[test]
fn bus_rejects_registration_when_product_exceeds_one() {
struct EdgeAdapter;
impl MetaAdapter for EdgeAdapter {
fn lipschitz_bound(&self) -> f64 {
1.0 }
fn update_period(&self) -> u64 {
100
}
fn apply(&mut self, _ctx: &AdaptContext) -> ThetaDelta {
ThetaDelta::default()
}
fn progress(&self) -> f64 {
0.0
}
}
let mut bus = AdaptationBus::new([0.5, 0.5]);
for _ in 0..5 {
bus.register(Box::new(NoOpAdapter)).unwrap();
}
bus.register(Box::new(ContractionAdapter::new(0.3)))
.unwrap();
bus.register(Box::new(EdgeAdapter)).unwrap();
assert!(
bus.lipschitz_product() <= 1.0 + f64::EPSILON,
"product should not exceed 1.0"
);
}
#[test]
fn bus_accepts_noop_and_contraction() {
let mut bus = AdaptationBus::new([0.5, 0.5]);
bus.register(Box::new(NoOpAdapter)).unwrap();
bus.register(Box::new(ContractionAdapter::new(0.3)))
.unwrap();
let product = bus.lipschitz_product();
assert!(
(product - 0.7).abs() < 1e-10,
"product should be 1.0 * 0.7 = 0.7, got {product}"
);
}
#[test]
fn critical_section_pauses_continuous_adapters() {
let mut bus = AdaptationBus::new([0.5, 0.5]);
bus.register(Box::new(ContractionAdapter::new(0.3)))
.unwrap();
let theta_before = bus.theta();
let depth = bus.enter_critical(0);
assert_eq!(depth, 1, "depth should be 1 after first enter");
assert!(bus.in_critical_section());
for i in 0..100u64 {
let c = ctx([0.5, 0.5], i * 30); let theta = bus.step(&c);
assert_eq!(
theta, theta_before,
"theta must not change during critical section, step {i}"
);
}
}
#[test]
fn critical_section_flushes_adapter_state_on_exit() {
let mut bus = AdaptationBus::new([0.5, 0.5]);
bus.register(Box::new(DriftRateAdapter::new(0.01))).unwrap();
for i in 0..20u64 {
let mut c = ctx([0.5, 0.5], i);
c.prediction = i as f64 * 0.1;
c.target = i as f64 * 0.05;
bus.step(&c);
}
bus.enter_critical(20);
bus.exit_critical(21);
assert!(
bus.in_warmup(),
"bus should be in warmup after exit_critical"
);
}
#[test]
fn nested_critical_section_correct_depth() {
let mut bus = AdaptationBus::new([0.5, 0.5]);
let d1 = bus.enter_critical(0);
assert_eq!(d1, 1);
let d2 = bus.enter_critical(1);
assert_eq!(d2, 2, "nested enter should increment depth");
let d3 = bus.exit_critical(2);
assert_eq!(d3, 1, "first exit should decrement to 1");
assert!(bus.in_critical_section(), "still nested after first exit");
let d4 = bus.exit_critical(3);
assert_eq!(d4, 0, "final exit should reach depth 0");
assert!(
!bus.in_critical_section(),
"should not be in critical after final exit"
);
}
#[test]
fn exit_without_enter_is_idempotent_noop() {
let mut bus = AdaptationBus::new([0.5, 0.5]);
let depth = bus.exit_critical(0);
assert_eq!(depth, 0, "exit on depth=0 should return 0");
assert!(!bus.in_critical_section());
}
#[test]
fn adapter_queued_during_critical_registered_on_exit() {
let mut bus = AdaptationBus::new([0.5, 0.5]);
bus.enter_critical(0);
let adapter_count_before = bus.adapters.len();
bus.register(Box::new(NoOpAdapter)).unwrap();
assert_eq!(
bus.adapters.len(),
adapter_count_before,
"adapter should not appear in live list during critical"
);
assert_eq!(
bus.pending_queue.len(),
1,
"adapter should be in pending queue"
);
bus.exit_critical(1);
assert_eq!(
bus.pending_queue.len(),
0,
"pending queue should be drained on exit"
);
assert_eq!(
bus.adapters.len(),
adapter_count_before + 1,
"adapter should be in live list after exit"
);
}
#[test]
fn oscillation_detector_fires_on_repeated_large_swings() {
let mut bus = AdaptationBus::new([0.5, 0.5]);
bus.oscillation_margin = 0.001; bus.oscillation_threshold = 3;
let margin_before = bus.oscillation_margin;
for i in 0..20u64 {
bus.theta = if i % 2 == 0 { [0.3, 0.5] } else { [0.7, 0.5] };
bus.detect_oscillation();
}
assert!(
bus.oscillation_margin < margin_before,
"oscillation backoff should have tightened margin"
);
let _ = margin_before; }
#[test]
fn oscillation_detector_resets_after_stable_theta() {
let mut bus = AdaptationBus::new([0.5, 0.5]);
bus.oscillation_margin = 0.001;
bus.oscillation_threshold = 10;
for i in 0..5u64 {
bus.theta = if i % 2 == 0 { [0.3, 0.5] } else { [0.7, 0.5] };
bus.detect_oscillation();
}
assert!(
bus.oscillation_count > 0,
"oscillation count should be > 0 after swings"
);
for _ in 0..5u64 {
bus.theta = [0.5, 0.5];
bus.detect_oscillation();
}
assert_eq!(
bus.oscillation_count, 0,
"oscillation count should reset after stable theta"
);
}
#[test]
fn critical_guard_exits_on_drop() {
let mut bus = AdaptationBus::new([0.5, 0.5]);
assert!(!bus.in_critical_section());
{
let _guard = CriticalGuard::new(&mut bus, 0);
} assert!(
!bus.in_critical_section(),
"bus should not be in critical after guard drop"
);
}
#[test]
fn norm2_of_unit_diagonal() {
let v = [1.0_f64 / 2.0_f64.sqrt(), 1.0_f64 / 2.0_f64.sqrt()];
let n = norm2(&v);
assert!(
(n - 1.0).abs() < 1e-12,
"norm of unit vector should be 1.0, got {n}"
);
}
#[test]
fn clamp_delta_within_budget_unchanged() {
let delta = [0.01, 0.01];
let max_norm = 1.0;
let result = clamp_delta(delta, max_norm);
assert!((result[0] - 0.01).abs() < 1e-12);
assert!((result[1] - 0.01).abs() < 1e-12);
}
#[test]
fn clamp_delta_over_budget_scaled_down() {
let delta = [0.1, 0.1];
let max_norm = 0.05;
let result = clamp_delta(delta, max_norm);
let result_norm = norm2(&result);
assert!(
result_norm <= max_norm + 1e-12,
"clamped delta norm {result_norm} should be <= max_norm {max_norm}"
);
}
#[test]
fn theta_stays_in_unit_hypercube_under_contraction() {
let mut bus = AdaptationBus::new([0.1, 0.9]);
bus.register(Box::new(ContractionAdapter::new(0.3)))
.unwrap();
for i in 0u64..1000 {
let c = ctx([bus.theta()[0], bus.theta()[1]], i * 30);
let theta = bus.step(&c);
assert!(
theta[0] >= 0.0 && theta[0] <= 1.0,
"theta[0]={} out of `[0,1]` at step {i}",
theta[0]
);
assert!(
theta[1] >= 0.0 && theta[1] <= 1.0,
"theta[1]={} out of `[0,1]` at step {i}",
theta[1]
);
}
}
}