use super::types::*;
use super::OrchestrationAuthority;
use crate::cache_provider::{CacheProvider, FilesystemCacheProvider};
use crate::device::ResourceSnapshotProvider;
use crate::ir::Envelope;
use crate::orchestrator::policy_engine::{DefaultPolicyEngine, PolicyEngine};
use crate::orchestrator::routing_engine::{
DefaultRoutingEngine, LocalAvailability, LocalReliabilityHint, RouteTarget, RoutingDecision,
RoutingEngine,
};
use crate::pipeline::ExecutionTarget;
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
const DEFAULT_HYSTERESIS_TTL: Duration = Duration::from_secs(30);
const RELIABILITY_WINDOW: usize = 32;
const DEFAULT_HISTORY_BIAS_K: usize = 3;
const MAX_HYSTERESIS_KEYS: usize = 256;
const MAX_RELIABILITY_KEYS: usize = 256;
pub struct LocalAuthority {
policy_engine: DefaultPolicyEngine,
routing_engine: Mutex<DefaultRoutingEngine>,
cache_provider: Arc<dyn CacheProvider>,
resource_provider: Option<Arc<dyn ResourceSnapshotProvider>>,
hysteresis: Mutex<HashMap<(String, AbortReason), Instant>>,
reliability: Mutex<HashMap<(String, SignalContext), VecDeque<OutcomeCategory>>>,
history_bias_k: usize,
}
impl LocalAuthority {
pub fn new() -> Self {
crate::device::capabilities::prewarm();
Self {
policy_engine: DefaultPolicyEngine::with_default_policy(),
routing_engine: Mutex::new(DefaultRoutingEngine::new()),
cache_provider: Arc::new(FilesystemCacheProvider::new()),
resource_provider: None,
hysteresis: Mutex::new(HashMap::new()),
reliability: Mutex::new(HashMap::new()),
history_bias_k: DEFAULT_HISTORY_BIAS_K,
}
}
pub fn with_cache_provider(cache_provider: Arc<dyn CacheProvider>) -> Self {
crate::device::capabilities::prewarm();
Self {
policy_engine: DefaultPolicyEngine::with_default_policy(),
routing_engine: Mutex::new(DefaultRoutingEngine::new()),
cache_provider,
resource_provider: None,
hysteresis: Mutex::new(HashMap::new()),
reliability: Mutex::new(HashMap::new()),
history_bias_k: DEFAULT_HISTORY_BIAS_K,
}
}
pub fn with_policy_engine(policy_engine: DefaultPolicyEngine) -> Self {
crate::device::capabilities::prewarm();
Self {
policy_engine,
routing_engine: Mutex::new(DefaultRoutingEngine::new()),
cache_provider: Arc::new(FilesystemCacheProvider::new()),
resource_provider: None,
hysteresis: Mutex::new(HashMap::new()),
reliability: Mutex::new(HashMap::new()),
history_bias_k: DEFAULT_HISTORY_BIAS_K,
}
}
pub fn with_policy_and_cache(
policy_engine: DefaultPolicyEngine,
cache_provider: Arc<dyn CacheProvider>,
) -> Self {
crate::device::capabilities::prewarm();
Self {
policy_engine,
routing_engine: Mutex::new(DefaultRoutingEngine::new()),
cache_provider,
resource_provider: None,
hysteresis: Mutex::new(HashMap::new()),
reliability: Mutex::new(HashMap::new()),
history_bias_k: DEFAULT_HISTORY_BIAS_K,
}
}
pub fn with_resource_provider(mut self, provider: Arc<dyn ResourceSnapshotProvider>) -> Self {
self.resource_provider = Some(provider);
self
}
pub fn with_history_bias_k(mut self, k: usize) -> Self {
self.history_bias_k = k.max(1);
self
}
pub fn record_abort_for_hysteresis(&self, model_id: &str, reason: AbortReason, ttl: Duration) {
let expires_at = Instant::now() + ttl;
if let Ok(mut hysteresis) = self.hysteresis.lock() {
Self::prune_hysteresis(&mut hysteresis);
hysteresis.insert((model_id.to_string(), reason), expires_at);
Self::prune_hysteresis(&mut hysteresis);
}
}
pub fn record_abort_for_hysteresis_default_ttl(&self, model_id: &str, reason: AbortReason) {
self.record_abort_for_hysteresis(model_id, reason, DEFAULT_HYSTERESIS_TTL);
}
fn check_model_exists(&self, model_id: &str) -> bool {
self.cache_provider.is_model_cached(model_id)
}
fn find_local_model(&self, model_id: &str) -> Option<String> {
self.cache_provider
.get_model_path(model_id)
.and_then(|p| p.to_str().map(|s| s.to_string()))
}
fn active_hysteresis_for(&self, model_id: &str) -> Option<AbortReason> {
let mut hysteresis = self.hysteresis.lock().ok()?;
Self::prune_hysteresis(&mut hysteresis);
hysteresis
.iter()
.filter(|((candidate_model_id, _), _)| candidate_model_id == model_id)
.max_by_key(|(_, expires_at)| **expires_at)
.map(|((_, reason), _)| *reason)
}
fn history_snapshot(&self, model_id: &str, signal: SignalContext) -> VecDeque<OutcomeCategory> {
self.reliability
.lock()
.ok()
.and_then(|history| history.get(&(model_id.to_string(), signal)).cloned())
.unwrap_or_default()
}
fn reliability_hint(&self, model_id: &str, signal: SignalContext) -> LocalReliabilityHint {
let history = self.history_snapshot(model_id, signal);
if history.is_empty() {
return LocalReliabilityHint::EMPTY;
}
let unreliable = history
.iter()
.filter(|category| category.is_local_unreliable())
.count();
LocalReliabilityHint {
recent_abort_rate: unreliable as f32 / history.len() as f32,
sample_size: history.len() as u32,
}
}
fn history_bias_should_skip_local(&self, model_id: &str, signal: SignalContext) -> bool {
let history = self.history_snapshot(model_id, signal);
if history.len() < self.history_bias_k {
return false;
}
history
.iter()
.rev()
.take(self.history_bias_k)
.all(OutcomeCategory::is_local_unreliable)
}
fn prune_hysteresis(hysteresis: &mut HashMap<(String, AbortReason), Instant>) {
let now = Instant::now();
hysteresis.retain(|_, expires_at| *expires_at > now);
while hysteresis.len() > MAX_HYSTERESIS_KEYS {
let Some((key, _)) = hysteresis
.iter()
.min_by_key(|(_, expires_at)| **expires_at)
.map(|(key, expires_at)| (key.clone(), *expires_at))
else {
break;
};
hysteresis.remove(&key);
}
}
fn prune_reliability(
reliability: &mut HashMap<(String, SignalContext), VecDeque<OutcomeCategory>>,
) {
while reliability.len() > MAX_RELIABILITY_KEYS {
let Some(victim) = reliability
.iter()
.min_by_key(|(_, history)| history.len())
.map(|(key, _)| key.clone())
else {
break;
};
reliability.remove(&victim);
}
}
}
impl Default for LocalAuthority {
fn default() -> Self {
Self::new()
}
}
impl OrchestrationAuthority for LocalAuthority {
fn apply_policy(&self, request: &PolicyRequest) -> AuthorityDecision<PolicyOutcome> {
let result =
self.policy_engine
.evaluate(&request.stage_id, &request.envelope, &request.metrics);
let outcome = if result.allowed {
if result.transforms_applied.is_empty() {
PolicyOutcome::Allow
} else {
PolicyOutcome::Transform {
transforms: result.transforms_applied.clone(),
}
}
} else {
PolicyOutcome::Deny {
reason: result
.reason
.clone()
.unwrap_or_else(|| "Policy denied".to_string()),
}
};
let reason = result
.reason
.unwrap_or_else(|| "Local policy evaluation".to_string());
AuthorityDecision {
result: outcome,
reason,
source: DecisionSource::Local,
confidence: 1.0, timestamp_ms: now_ms(),
}
}
fn resolve_target(&self, context: &StageContext) -> AuthorityDecision<ResolvedTarget> {
self.resolve_target_with_feedback(context).decision
}
fn resolve_target_with_feedback(&self, context: &StageContext) -> TargetResolution {
if let Some(resolution) = self.explicit_target_resolution(context) {
return resolution;
}
self.resolve_with_routing_engine(context)
}
fn select_model(&self, request: &ModelRequest) -> AuthorityDecision<ModelSelection> {
let local_path = self.find_local_model(&request.model_id);
let source = if let Some(path) = local_path {
ModelSource::Local { path }
} else {
ModelSource::Registry {
url: format!("https://api.xybrid.dev/v1/models/{}", request.model_id),
}
};
let is_local = source.is_local();
AuthorityDecision {
result: ModelSelection {
model_id: request.model_id.clone(),
variant: None,
source,
},
reason: if is_local {
format!("Model '{}' found locally", request.model_id)
} else {
format!(
"Model '{}' not found locally, will fetch from registry",
request.model_id
)
},
source: DecisionSource::Local,
confidence: 1.0,
timestamp_ms: now_ms(),
}
}
fn name(&self) -> &str {
"local"
}
fn record_outcome(&self, outcome: &ExecutionOutcome) {
if !matches!(outcome.target, ResolvedTarget::Device) {
return;
}
let category = outcome.effective_category();
let model_id = outcome.effective_model_id().to_string();
if let OutcomeCategory::AbortedForCloudFallback { reason } = &category {
self.record_abort_for_hysteresis_default_ttl(&model_id, *reason);
}
let Some(signal) = outcome.signal_context else {
return;
};
let key = (model_id, signal);
if let Ok(mut reliability) = self.reliability.lock() {
if !reliability.contains_key(&key) && reliability.len() >= MAX_RELIABILITY_KEYS {
Self::prune_reliability(&mut reliability);
if reliability.len() >= MAX_RELIABILITY_KEYS {
let victim = reliability
.iter()
.min_by_key(|(_, history)| history.len())
.map(|(victim_key, _)| victim_key.clone());
if let Some(victim) = victim {
reliability.remove(&victim);
}
}
}
let history = reliability.entry(key).or_default();
history.push_back(category);
while history.len() > RELIABILITY_WINDOW {
history.pop_front();
}
Self::prune_reliability(&mut reliability);
}
}
}
impl LocalAuthority {
fn routing_metrics(&self, context: &StageContext) -> crate::context::DeviceMetrics {
let snapshot = self
.resource_provider
.as_ref()
.map(|provider| provider.current_snapshot(Duration::from_millis(500)))
.unwrap_or_else(|| {
context
.resource_monitor
.current_snapshot(Duration::from_millis(500))
});
context.metrics.with_live_snapshot(snapshot)
}
fn target_from_route(target: RouteTarget) -> ResolvedTarget {
match target {
RouteTarget::Local => ResolvedTarget::Device,
RouteTarget::Cloud => ResolvedTarget::Cloud {
provider: "xybrid".to_string(),
},
RouteTarget::Fallback(id) => ResolvedTarget::Server { endpoint: id },
}
}
fn explicit_target_resolution(&self, context: &StageContext) -> Option<TargetResolution> {
let explicit = context.explicit_target.as_ref()?;
let target = match explicit {
ExecutionTarget::Device => ResolvedTarget::Device,
ExecutionTarget::Server => ResolvedTarget::Server {
endpoint: "https://api.xybrid.dev".to_string(),
},
ExecutionTarget::Cloud => ResolvedTarget::Cloud {
provider: "xybrid".to_string(),
},
ExecutionTarget::Auto => return None,
};
let metrics = self.routing_metrics(context);
Some(TargetResolution::new(
AuthorityDecision {
result: target,
reason: format!("Explicit target from pipeline YAML: {:?}", explicit),
source: DecisionSource::Local,
confidence: 1.0,
timestamp_ms: now_ms(),
},
context.model_id.clone(),
Some(SignalContext::from_metrics(&metrics)),
))
}
fn resolve_with_routing_engine(&self, context: &StageContext) -> TargetResolution {
let availability = context
.local_availability
.clone()
.unwrap_or_else(|| LocalAvailability::new(self.check_model_exists(&context.model_id)));
self.resolve_with_routing_engine_and_availability(context, availability)
}
fn resolve_with_routing_engine_and_availability(
&self,
context: &StageContext,
availability: LocalAvailability,
) -> TargetResolution {
let envelope = Envelope::new(context.input_kind.clone());
let live_metrics = self.routing_metrics(context);
let signal = SignalContext::from_metrics(&live_metrics);
let hint = self.reliability_hint(&context.model_id, signal);
let policy_result =
self.policy_engine
.evaluate(&context.stage_id, &envelope, &live_metrics);
if policy_result.allowed {
if let Some(reason) = self.active_hysteresis_for(&context.model_id) {
let decision = AuthorityDecision {
result: ResolvedTarget::Cloud {
provider: "xybrid".to_string(),
},
reason: format!(
"hysteresis: recent local abort for model '{}' ({})",
context.model_id, reason
),
source: DecisionSource::Local,
confidence: 0.9,
timestamp_ms: now_ms(),
};
return TargetResolution::new(decision, context.model_id.clone(), Some(signal))
.with_reliability_hint(hint);
}
if self.history_bias_should_skip_local(&context.model_id, signal) {
let decision = AuthorityDecision {
result: ResolvedTarget::Cloud {
provider: "xybrid".to_string(),
},
reason: format!(
"history_bias: recent local failure rate {:.0}% over {} samples",
hint.recent_abort_rate * 100.0,
hint.sample_size
),
source: DecisionSource::Local,
confidence: 0.85,
timestamp_ms: now_ms(),
};
return TargetResolution::new(decision, context.model_id.clone(), Some(signal))
.with_reliability_hint(hint);
}
}
let decision = {
let mut routing_engine = self.routing_engine.lock().unwrap();
routing_engine.decide(
&context.stage_id,
&live_metrics,
&policy_result,
&availability,
)
};
let target = Self::target_from_route(decision.target);
TargetResolution::new(
AuthorityDecision {
result: target,
reason: decision.reason,
source: DecisionSource::Local,
confidence: 0.8, timestamp_ms: decision.timestamp_ms,
},
context.model_id.clone(),
Some(signal),
)
.with_reliability_hint(hint)
}
pub fn resolve_routing_decision(&self, context: &StageContext) -> Option<RoutingDecision> {
let resolution = self.resolve_target_with_feedback(context);
let target = match resolution.decision.result {
ResolvedTarget::Device => RouteTarget::Local,
ResolvedTarget::Cloud { .. } => RouteTarget::Cloud,
ResolvedTarget::Server { endpoint } => RouteTarget::Fallback(endpoint),
};
Some(RoutingDecision {
stage: context.stage_id.clone(),
target,
reason: resolution.decision.reason,
timestamp_ms: resolution.decision.timestamp_ms,
local_reliability_hint: resolution
.local_reliability_hint
.unwrap_or(LocalReliabilityHint::EMPTY),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cache_provider::CacheProvider;
use crate::context::DeviceMetrics;
use crate::device::{MemoryPressure, ResourceMonitor, ResourceSnapshot, ThermalState};
use crate::ir::EnvelopeKind;
use std::path::PathBuf;
fn default_metrics() -> DeviceMetrics {
DeviceMetrics::default()
}
fn deny_all_text_policy() -> String {
r#"
version: "0.1.0"
deny_cloud_if:
- input.kind == "text"
signature: "test-deny-all"
"#
.to_string()
}
fn text_envelope(text: &str) -> Envelope {
Envelope::new(EnvelopeKind::Text(text.to_string()))
}
#[derive(Debug)]
struct FixedResourceProvider(ResourceSnapshot);
impl ResourceSnapshotProvider for FixedResourceProvider {
fn current_snapshot(&self, _max_age: Duration) -> ResourceSnapshot {
self.0
}
}
#[derive(Debug)]
struct CachedProvider;
impl CacheProvider for CachedProvider {
fn is_model_cached(&self, _model_id: &str) -> bool {
true
}
fn get_model_path(&self, model_id: &str) -> Option<PathBuf> {
Some(PathBuf::from(format!("/tmp/{model_id}")))
}
fn cache_dir(&self) -> PathBuf {
PathBuf::from("/tmp")
}
fn name(&self) -> &'static str {
"cached-test"
}
}
fn text_context() -> StageContext {
StageContext {
stage_id: "test-stage".to_string(),
model_id: "test-model".to_string(),
input_kind: EnvelopeKind::Text("test".to_string()),
metrics: default_metrics(),
resource_monitor: ResourceMonitor::global(),
explicit_target: None,
local_availability: None,
device_class: None,
device_class_schema_version: None,
}
}
fn signal() -> SignalContext {
SignalContext {
memory_pressure: MemoryPressure::Warn,
thermal_state: ThermalState::Normal,
cpu_bucket: Some(5),
}
}
#[test]
fn test_local_authority_default_allows() {
let authority = LocalAuthority::new();
let request = PolicyRequest {
stage_id: "test".to_string(),
envelope: text_envelope("hello"),
metrics: default_metrics(),
};
let decision = authority.apply_policy(&request);
assert!(decision.result.is_allowed());
assert_eq!(decision.source, DecisionSource::Local);
assert_eq!(decision.confidence, 1.0);
}
#[test]
fn test_local_authority_explicit_device_target() {
let authority = LocalAuthority::new();
let context = StageContext {
stage_id: "test".to_string(),
model_id: "test-model".to_string(),
input_kind: EnvelopeKind::Text("test".to_string()),
metrics: default_metrics(),
resource_monitor: ResourceMonitor::global(),
explicit_target: Some(ExecutionTarget::Device),
local_availability: None,
device_class: None,
device_class_schema_version: None,
};
let decision = authority.resolve_target(&context);
assert_eq!(decision.result, ResolvedTarget::Device);
assert!(decision.reason.contains("Explicit"));
}
#[test]
fn test_local_authority_explicit_cloud_target() {
let authority = LocalAuthority::new();
let context = StageContext {
stage_id: "test".to_string(),
model_id: "test-model".to_string(),
input_kind: EnvelopeKind::Text("test".to_string()),
metrics: default_metrics(),
resource_monitor: ResourceMonitor::global(),
explicit_target: Some(ExecutionTarget::Cloud),
local_availability: None,
device_class: None,
device_class_schema_version: None,
};
let decision = authority.resolve_target(&context);
assert!(matches!(decision.result, ResolvedTarget::Cloud { .. }));
}
#[test]
fn caller_local_availability_overrides_cache_provider() {
let authority = LocalAuthority::with_cache_provider(Arc::new(CachedProvider));
let mut context = text_context();
context.local_availability = Some(LocalAvailability::new(false));
let decision = authority.resolve_target(&context);
assert!(matches!(decision.result, ResolvedTarget::Cloud { .. }));
assert!(decision.reason.contains("model_unavailable"));
}
#[test]
fn test_local_authority_model_selection_not_found() {
let authority = LocalAuthority::new();
let request = ModelRequest {
model_id: "nonexistent-model-xyz".to_string(),
task: "test".to_string(),
constraints: ModelConstraints::default(),
};
let decision = authority.select_model(&request);
assert!(matches!(
decision.result.source,
ModelSource::Registry { .. }
));
assert!(decision.reason.contains("not found locally"));
}
#[test]
fn test_local_authority_name() {
let authority = LocalAuthority::new();
assert_eq!(authority.name(), "local");
}
#[test]
fn test_find_local_model_sdk_cache_structure() {
let authority = LocalAuthority::new();
let path = authority.find_local_model("kokoro-82m");
if let Some(p) = &path {
let p_lower = p.to_lowercase();
assert!(
p_lower.contains("kokoro"),
"Expected path to contain 'kokoro', got: {}",
p
);
}
}
#[test]
fn test_with_custom_cache_provider() {
use crate::cache_provider::NoopCacheProvider;
let provider = Arc::new(NoopCacheProvider);
let authority = LocalAuthority::with_cache_provider(provider);
let request = ModelRequest {
model_id: "any-model".to_string(),
task: "test".to_string(),
constraints: ModelConstraints::default(),
};
let decision = authority.select_model(&request);
assert!(matches!(
decision.result.source,
ModelSource::Registry { .. }
));
}
#[test]
fn fake_resource_provider_feeds_routing_metrics() {
let mut snapshot = ResourceSnapshot::unknown();
snapshot.memory_pressure = MemoryPressure::Critical;
snapshot.thermal_state = ThermalState::Normal;
snapshot.cpu_pct = Some(10.0);
let authority = LocalAuthority::with_cache_provider(Arc::new(CachedProvider))
.with_resource_provider(Arc::new(FixedResourceProvider(snapshot)));
let decision = authority.resolve_target(&text_context());
assert!(matches!(decision.result, ResolvedTarget::Cloud { .. }));
assert!(decision.reason.contains("stress_memory"));
}
#[test]
fn hysteresis_is_model_scoped_and_expires() {
let authority = LocalAuthority::with_cache_provider(Arc::new(CachedProvider));
authority.record_abort_for_hysteresis(
"test-model",
AbortReason::StressMemory,
Duration::from_millis(20),
);
let decision = authority.resolve_target(&text_context());
assert!(matches!(decision.result, ResolvedTarget::Cloud { .. }));
assert!(decision.reason.contains("hysteresis"));
let mut other = text_context();
other.model_id = "other-model".to_string();
let other_decision = authority.resolve_target(&other);
assert!(!other_decision.reason.contains("hysteresis"));
std::thread::sleep(Duration::from_millis(30));
let expired = authority.resolve_target(&text_context());
assert!(!expired.reason.contains("hysteresis"));
}
#[test]
fn policy_deny_overrides_hysteresis() {
let mut policy = DefaultPolicyEngine::new();
policy
.load_policies(deny_all_text_policy().into_bytes())
.expect("load deny-all policy");
let authority = LocalAuthority::with_policy_and_cache(policy, Arc::new(CachedProvider));
authority.record_abort_for_hysteresis_default_ttl("test-model", AbortReason::StressMemory);
let decision = authority.resolve_target(&text_context());
assert_eq!(decision.result, ResolvedTarget::Device);
assert!(decision.reason.contains("policy_deny"));
}
#[test]
fn device_abort_outcome_enters_hysteresis() {
let authority = LocalAuthority::with_cache_provider(Arc::new(CachedProvider));
authority.record_outcome(&ExecutionOutcome {
stage_id: "test-stage".to_string(),
target: ResolvedTarget::Device,
latency_ms: 12,
success: false,
error: None,
category: Some(OutcomeCategory::AbortedForCloudFallback {
reason: AbortReason::StressMemory,
}),
model_id: Some("test-model".to_string()),
signal_context: Some(signal()),
});
let decision = authority.resolve_target(&text_context());
assert!(matches!(decision.result, ResolvedTarget::Cloud { .. }));
assert!(decision.reason.contains("hysteresis"));
}
#[test]
fn cloud_failures_do_not_bias_local_reliability() {
let authority =
LocalAuthority::with_cache_provider(Arc::new(CachedProvider)).with_history_bias_k(3);
for idx in 0..3 {
authority.record_outcome(&ExecutionOutcome {
stage_id: "test-stage".to_string(),
target: ResolvedTarget::Cloud {
provider: "xybrid".to_string(),
},
latency_ms: 10,
success: false,
error: Some(format!("cloud-failure-{idx}")),
category: Some(OutcomeCategory::HardFail {
reason: "cloud_failed".to_string(),
}),
model_id: Some("test-model".to_string()),
signal_context: Some(signal()),
});
}
let mut snapshot = ResourceSnapshot::unknown();
snapshot.memory_pressure = MemoryPressure::Warn;
snapshot.thermal_state = ThermalState::Normal;
snapshot.cpu_pct = Some(55.0);
let authority = authority.with_resource_provider(Arc::new(FixedResourceProvider(snapshot)));
let decision = authority
.resolve_routing_decision(&text_context())
.expect("routing decision");
assert!(!decision.reason.contains("history_bias"));
assert_eq!(decision.local_reliability_hint.sample_size, 0);
}
#[test]
fn policy_deny_overrides_history_bias() {
let mut policy = DefaultPolicyEngine::new();
policy
.load_policies(deny_all_text_policy().into_bytes())
.expect("load deny-all policy");
let authority = LocalAuthority::with_policy_and_cache(policy, Arc::new(CachedProvider))
.with_history_bias_k(3);
for idx in 0..3 {
authority.record_outcome(&ExecutionOutcome {
stage_id: "test-stage".to_string(),
target: ResolvedTarget::Device,
latency_ms: 10,
success: false,
error: Some(format!("failure-{idx}")),
category: Some(OutcomeCategory::HardFail {
reason: "local_failed".to_string(),
}),
model_id: Some("test-model".to_string()),
signal_context: Some(signal()),
});
}
let decision = authority.resolve_target(&text_context());
assert_eq!(decision.result, ResolvedTarget::Device);
assert!(decision.reason.contains("policy_deny"));
}
#[test]
fn hysteresis_map_stays_bounded() {
let authority = LocalAuthority::with_cache_provider(Arc::new(CachedProvider));
for idx in 0..(MAX_HYSTERESIS_KEYS + 32) {
authority.record_abort_for_hysteresis_default_ttl(
&format!("model-{idx}"),
AbortReason::StressMemory,
);
}
assert!(
authority.hysteresis.lock().unwrap().len() <= MAX_HYSTERESIS_KEYS,
"hysteresis should stay bounded"
);
}
#[test]
fn reliability_map_stays_bounded() {
let authority = LocalAuthority::with_cache_provider(Arc::new(CachedProvider));
for idx in 0..(MAX_RELIABILITY_KEYS + 32) {
authority.record_outcome(&ExecutionOutcome {
stage_id: "test-stage".to_string(),
target: ResolvedTarget::Device,
latency_ms: 10,
success: false,
error: Some("local_failed".to_string()),
category: Some(OutcomeCategory::HardFail {
reason: "local_failed".to_string(),
}),
model_id: Some(format!("model-{idx}")),
signal_context: Some(signal()),
});
}
assert!(
authority.reliability.lock().unwrap().len() <= MAX_RELIABILITY_KEYS,
"reliability should stay bounded"
);
}
#[test]
fn reliability_window_evicts_oldest_after_32_entries() {
let authority = LocalAuthority::with_cache_provider(Arc::new(CachedProvider));
for idx in 0..33 {
authority.record_outcome(&ExecutionOutcome {
stage_id: "test-stage".to_string(),
target: ResolvedTarget::Device,
latency_ms: 10,
success: false,
error: Some(format!("failure-{idx}")),
category: Some(OutcomeCategory::HardFail {
reason: format!("failure-{idx}"),
}),
model_id: Some("test-model".to_string()),
signal_context: Some(signal()),
});
}
let history = authority.history_snapshot("test-model", signal());
assert_eq!(history.len(), RELIABILITY_WINDOW);
let actual_reasons: Vec<String> = history
.iter()
.map(|c| match c {
OutcomeCategory::HardFail { reason } => reason.clone(),
other => panic!("expected HardFail, got {other:?}"),
})
.collect();
let expected_reasons: Vec<String> = (1..=RELIABILITY_WINDOW)
.map(|i| format!("failure-{i}"))
.collect();
assert_eq!(
actual_reasons, expected_reasons,
"history must contain failure-1..failure-{RELIABILITY_WINDOW} in FIFO order (oldest first)"
);
}
fn record_hard_fail_and_snapshot(
authority: &LocalAuthority,
model: &str,
sig: SignalContext,
reason: &str,
) -> std::collections::VecDeque<OutcomeCategory> {
authority.record_outcome(&ExecutionOutcome {
stage_id: "test-stage".to_string(),
target: ResolvedTarget::Device,
latency_ms: 10,
success: false,
error: Some(reason.to_string()),
category: Some(OutcomeCategory::HardFail {
reason: reason.to_string(),
}),
model_id: Some(model.to_string()),
signal_context: Some(sig),
});
authority.history_snapshot(model, sig)
}
#[test]
fn reliability_history_is_scoped_by_memory_pressure() {
let authority = LocalAuthority::with_cache_provider(Arc::new(CachedProvider));
let mut warn_signal = signal();
warn_signal.memory_pressure = MemoryPressure::Warn;
let mut critical_signal = signal();
critical_signal.memory_pressure = MemoryPressure::Critical;
let warn_history =
record_hard_fail_and_snapshot(&authority, "test-model", warn_signal, "warn-failure");
let critical_history = record_hard_fail_and_snapshot(
&authority,
"test-model",
critical_signal,
"critical-failure",
);
assert_eq!(warn_history.len(), 1);
assert_eq!(critical_history.len(), 1);
assert_ne!(warn_history, critical_history);
}
#[test]
fn reliability_history_is_scoped_by_thermal_state() {
let authority = LocalAuthority::with_cache_provider(Arc::new(CachedProvider));
let mut normal_signal = signal();
normal_signal.thermal_state = ThermalState::Normal;
let mut hot_signal = signal();
hot_signal.thermal_state = ThermalState::Hot;
let normal_history =
record_hard_fail_and_snapshot(&authority, "test-model", normal_signal, "normal-fail");
let hot_history =
record_hard_fail_and_snapshot(&authority, "test-model", hot_signal, "hot-fail");
assert_eq!(normal_history.len(), 1);
assert_eq!(hot_history.len(), 1);
assert_ne!(normal_history, hot_history);
}
#[test]
fn reliability_history_is_scoped_by_cpu_bucket() {
let authority = LocalAuthority::with_cache_provider(Arc::new(CachedProvider));
let mut low_cpu_signal = signal();
low_cpu_signal.cpu_bucket = Some(2);
let mut high_cpu_signal = signal();
high_cpu_signal.cpu_bucket = Some(9);
let low_history =
record_hard_fail_and_snapshot(&authority, "test-model", low_cpu_signal, "low-cpu-fail");
let high_history = record_hard_fail_and_snapshot(
&authority,
"test-model",
high_cpu_signal,
"high-cpu-fail",
);
assert_eq!(low_history.len(), 1);
assert_eq!(high_history.len(), 1);
assert_ne!(low_history, high_history);
}
#[test]
fn reliability_history_is_scoped_by_model_id() {
let authority = LocalAuthority::with_cache_provider(Arc::new(CachedProvider));
let history_a = record_hard_fail_and_snapshot(&authority, "model-a", signal(), "a-fail");
let history_b = record_hard_fail_and_snapshot(&authority, "model-b", signal(), "b-fail");
assert_eq!(history_a.len(), 1);
assert_eq!(history_b.len(), 1);
assert_ne!(history_a, history_b);
}
#[test]
fn reliability_history_bias_routes_cloud_with_hint() {
let authority =
LocalAuthority::with_cache_provider(Arc::new(CachedProvider)).with_history_bias_k(3);
for idx in 0..3 {
authority.record_outcome(&ExecutionOutcome {
stage_id: "test-stage".to_string(),
target: ResolvedTarget::Device,
latency_ms: 10,
success: false,
error: Some(format!("failure-{idx}")),
category: Some(OutcomeCategory::HardFail {
reason: "local_failed".to_string(),
}),
model_id: Some("test-model".to_string()),
signal_context: Some(signal()),
});
}
let mut snapshot = ResourceSnapshot::unknown();
snapshot.memory_pressure = MemoryPressure::Warn;
snapshot.thermal_state = ThermalState::Normal;
snapshot.cpu_pct = Some(55.0);
let authority = authority.with_resource_provider(Arc::new(FixedResourceProvider(snapshot)));
let decision = authority
.resolve_routing_decision(&text_context())
.expect("routing decision");
assert_eq!(decision.target, RouteTarget::Cloud);
assert!(decision.reason.contains("history_bias"));
assert_eq!(decision.local_reliability_hint.sample_size, 3);
assert_eq!(decision.local_reliability_hint.recent_abort_rate, 1.0);
}
#[test]
fn success_reduces_history_bias() {
let authority =
LocalAuthority::with_cache_provider(Arc::new(CachedProvider)).with_history_bias_k(3);
for category in [
OutcomeCategory::HardFail {
reason: "a".to_string(),
},
OutcomeCategory::HardFail {
reason: "b".to_string(),
},
OutcomeCategory::Success,
] {
authority.record_outcome(&ExecutionOutcome {
stage_id: "test-stage".to_string(),
target: ResolvedTarget::Device,
latency_ms: 10,
success: matches!(category, OutcomeCategory::Success),
error: None,
category: Some(category),
model_id: Some("test-model".to_string()),
signal_context: Some(signal()),
});
}
let mut snapshot = ResourceSnapshot::unknown();
snapshot.memory_pressure = MemoryPressure::Warn;
snapshot.thermal_state = ThermalState::Normal;
snapshot.cpu_pct = Some(55.0);
let authority = authority.with_resource_provider(Arc::new(FixedResourceProvider(snapshot)));
let decision = authority.resolve_target(&text_context());
assert!(!decision.reason.contains("history_bias"));
}
#[test]
fn target_from_route_round_trips_fallback_without_prefix_doubling() {
let routed =
LocalAuthority::target_from_route(RouteTarget::Fallback("model_v2".to_string()));
let endpoint = match routed {
ResolvedTarget::Server { endpoint } => endpoint,
other => panic!("expected Server target, got {other:?}"),
};
assert_eq!(endpoint, "model_v2");
let reverse = match endpoint.as_str() {
"model_v2" => RouteTarget::Fallback(endpoint.clone()),
_ => unreachable!(),
};
assert_eq!(reverse.to_json_string(), "fallback:model_v2");
assert_eq!(reverse.to_string(), "fallback:model_v2");
}
#[test]
fn test_model_matching_logic() {
let test_cases = [
("kokoro-82m", "kokoro-82m-v1.0-onnx"), ("kokoro-82m", "kokoro82mv10onnx"), ("whisper-tiny", "whisper-tiny"), ];
for (query, dir_name) in test_cases {
let query_lower = query.to_lowercase();
let query_normalized = query_lower.replace("-", "").replace("_", "");
let dir_name_lower = dir_name.to_lowercase();
let dir_name_normalized = dir_name_lower.replace("-", "").replace("_", "");
let is_match = dir_name_lower.contains(&query_lower)
|| dir_name_normalized.contains(&query_normalized);
assert!(
is_match,
"Expected '{}' to match '{}' but it didn't",
query, dir_name
);
}
}
}