use crate::context::DeviceMetrics;
use crate::device::MemoryPressure;
use crate::orchestrator::policy_engine::PolicyResult;
use crate::telemetry::{should_log, Severity};
use std::collections::{HashMap, VecDeque};
use std::fmt;
use std::time::{SystemTime, UNIX_EPOCH};
const CPU_HISTORY_MAX_STAGES: usize = 256;
#[derive(Debug, Clone, PartialEq)]
pub enum RouteTarget {
Local,
Cloud,
Fallback(String),
}
impl RouteTarget {
pub fn as_str(&self) -> &str {
match self {
RouteTarget::Local => "local",
RouteTarget::Cloud => "cloud",
RouteTarget::Fallback(_) => "fallback",
}
}
pub fn to_json_string(&self) -> String {
match self {
RouteTarget::Local => "local".to_string(),
RouteTarget::Cloud => "cloud".to_string(),
RouteTarget::Fallback(id) => format!("fallback:{}", id),
}
}
}
impl fmt::Display for RouteTarget {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RouteTarget::Local => write!(f, "local"),
RouteTarget::Cloud => write!(f, "cloud"),
RouteTarget::Fallback(id) => write!(f, "fallback:{}", id),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub struct LocalReliabilityHint {
pub recent_abort_rate: f32,
pub sample_size: u32,
}
impl LocalReliabilityHint {
pub const EMPTY: Self = Self {
recent_abort_rate: 0.0,
sample_size: 0,
};
}
#[derive(Debug, Clone)]
pub struct RoutingDecision {
pub stage: String,
pub target: RouteTarget,
pub reason: String,
pub timestamp_ms: u64,
pub local_reliability_hint: LocalReliabilityHint,
}
impl RoutingDecision {
pub fn to_json(&self) -> String {
#[derive(serde::Serialize)]
struct LocalReliabilityHintWire {
recent_abort_rate: f32,
sample_size: u32,
}
#[derive(serde::Serialize)]
struct RoutingDecisionWire<'a> {
stage: &'a str,
target: String,
reason: &'a str,
timestamp_ms: u64,
local_reliability_hint: LocalReliabilityHintWire,
}
serde_json::to_string(&RoutingDecisionWire {
stage: &self.stage,
target: self.target.to_json_string(),
reason: &self.reason,
timestamp_ms: self.timestamp_ms,
local_reliability_hint: LocalReliabilityHintWire {
recent_abort_rate: self.local_reliability_hint.recent_abort_rate,
sample_size: self.local_reliability_hint.sample_size,
},
})
.unwrap_or_else(|_| String::from("{}"))
}
}
#[derive(Debug, Clone)]
pub struct LocalAvailability {
pub local_model_exists: bool,
}
impl LocalAvailability {
pub fn new(exists: bool) -> Self {
Self {
local_model_exists: exists,
}
}
}
pub trait RoutingEngine {
fn decide(
&mut self,
stage: &str,
metrics: &DeviceMetrics,
policy: &PolicyResult,
availability: &LocalAvailability,
) -> RoutingDecision;
fn record_feedback(&mut self, decision: &RoutingDecision, latency_ms: u32);
}
pub struct DefaultRoutingEngine {
cpu_sustain_samples: usize,
cpu_sustain_threshold_pct: f32,
cpu_history: HashMap<String, VecDeque<bool>>,
}
impl DefaultRoutingEngine {
pub fn new() -> Self {
Self {
cpu_sustain_samples: 2,
cpu_sustain_threshold_pct: 95.0,
cpu_history: HashMap::new(),
}
}
pub fn with_cpu_sustain_samples(mut self, samples: usize) -> Self {
self.cpu_sustain_samples = samples.max(1);
self
}
fn current_timestamp_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
fn log_decision(&self, decision: &RoutingDecision) {
if should_log(Severity::Info) {
println!("{}", decision.to_json());
}
}
fn cpu_is_sustained(&mut self, stage: &str, cpu_pct: Option<f32>) -> bool {
let is_hot = cpu_pct
.map(|pct| pct >= self.cpu_sustain_threshold_pct)
.unwrap_or(false);
if !self.cpu_history.contains_key(stage) && self.cpu_history.len() >= CPU_HISTORY_MAX_STAGES
{
if let Some(victim) = self.cpu_history.keys().next().cloned() {
self.cpu_history.remove(&victim);
}
}
let history = self.cpu_history.entry(stage.to_string()).or_default();
history.push_back(is_hot);
while history.len() > self.cpu_sustain_samples {
history.pop_front();
}
history.len() == self.cpu_sustain_samples && history.iter().all(|hot| *hot)
}
fn decision(
stage: &str,
target: RouteTarget,
reason: impl Into<String>,
timestamp_ms: u64,
) -> RoutingDecision {
RoutingDecision {
stage: stage.to_string(),
target,
reason: reason.into(),
timestamp_ms,
local_reliability_hint: LocalReliabilityHint::EMPTY,
}
}
}
impl Default for DefaultRoutingEngine {
fn default() -> Self {
Self::new()
}
}
impl RoutingEngine for DefaultRoutingEngine {
fn decide(
&mut self,
stage: &str,
metrics: &DeviceMetrics,
policy: &PolicyResult,
availability: &LocalAvailability,
) -> RoutingDecision {
let timestamp_ms = Self::current_timestamp_ms();
if !policy.allowed {
let reason = format!(
"policy_deny: {}",
policy
.reason
.as_deref()
.unwrap_or("policy denied cloud execution")
);
let decision = Self::decision(stage, RouteTarget::Local, reason, timestamp_ms);
self.log_decision(&decision);
return decision;
}
if !availability.local_model_exists {
let decision = Self::decision(
stage,
RouteTarget::Cloud,
"model_unavailable: local model not found",
timestamp_ms,
);
self.log_decision(&decision);
return decision;
}
if metrics.capabilities.should_throttle() {
let reason = format!(
"stress_throttle: battery {}%, thermal {:?}",
metrics.capabilities.battery_level(),
metrics.capabilities.thermal_state()
);
let decision = Self::decision(stage, RouteTarget::Cloud, reason, timestamp_ms);
self.log_decision(&decision);
return decision;
}
if metrics.resource.memory_pressure == MemoryPressure::Critical {
let reason = "stress_memory: memory pressure critical".to_string();
let decision = Self::decision(stage, RouteTarget::Cloud, reason, timestamp_ms);
self.log_decision(&decision);
return decision;
}
if self.cpu_is_sustained(stage, metrics.resource.cpu_pct) {
let reason = format!(
"stress_cpu_sustained: CPU >= {:.0}% for {} samples",
self.cpu_sustain_threshold_pct, self.cpu_sustain_samples
);
let decision = Self::decision(stage, RouteTarget::Cloud, reason, timestamp_ms);
self.log_decision(&decision);
return decision;
}
let decision = Self::decision(stage, RouteTarget::Local, "default_local", timestamp_ms);
self.log_decision(&decision);
decision
}
fn record_feedback(&mut self, _decision: &RoutingDecision, _latency_ms: u32) {
}
}
#[cfg(test)]
mod tests {
use super::super::policy_engine::PolicyResult;
use super::*;
use crate::device::{HardwareCapabilities, MemoryPressure, ResourceSnapshot, ThermalState};
fn metrics_with_live_state(
battery: u8,
thermal_state: ThermalState,
memory_pressure: MemoryPressure,
cpu_pct: Option<f32>,
) -> DeviceMetrics {
let capabilities = HardwareCapabilities {
battery_level: battery,
thermal_state,
..Default::default()
};
let mut resource = ResourceSnapshot::unknown();
resource.memory_pressure = memory_pressure;
resource.cpu_pct = cpu_pct;
resource.thermal_state = thermal_state;
resource.battery_pct = Some(battery);
DeviceMetrics {
capabilities,
resource,
}
}
#[test]
fn test_policy_deny_routes_local() {
let mut engine = DefaultRoutingEngine::new();
let metrics = DeviceMetrics::default();
let policy = PolicyResult::deny("test policy denial".to_string());
let availability = LocalAvailability::new(true);
let decision = engine.decide("test_stage", &metrics, &policy, &availability);
assert_eq!(decision.target, RouteTarget::Local);
assert!(decision.reason.contains("policy_deny"));
assert_eq!(decision.stage, "test_stage");
}
#[test]
fn stress_throttle_routes_cloud() {
let mut engine = DefaultRoutingEngine::new();
let metrics =
metrics_with_live_state(10, ThermalState::Normal, MemoryPressure::Normal, None);
let policy = PolicyResult::allow(Some("policy passed".to_string()));
let availability = LocalAvailability::new(true);
let decision = engine.decide("test_stage", &metrics, &policy, &availability);
assert_eq!(decision.target, RouteTarget::Cloud);
assert!(decision.reason.contains("stress_throttle"));
}
#[test]
fn critical_memory_routes_cloud() {
let mut engine = DefaultRoutingEngine::new();
let metrics =
metrics_with_live_state(50, ThermalState::Normal, MemoryPressure::Critical, None);
let policy = PolicyResult::allow(Some("policy passed".to_string()));
let availability = LocalAvailability::new(true);
let decision = engine.decide("test_stage", &metrics, &policy, &availability);
assert_eq!(decision.target, RouteTarget::Cloud);
assert!(decision.reason.contains("stress_memory"));
}
#[test]
fn warm_but_not_stressed_conditions_do_not_fire_stress_branch() {
let mut engine = DefaultRoutingEngine::new();
let metrics = metrics_with_live_state(25, ThermalState::Normal, MemoryPressure::Warn, None);
let policy = PolicyResult::allow(Some("policy passed".to_string()));
let availability = LocalAvailability::new(true);
let decision = engine.decide("test_stage", &metrics, &policy, &availability);
assert!(!decision.reason.contains("stress_"));
}
#[test]
fn sustained_cpu_routes_cloud_on_second_hot_sample() {
let mut engine = DefaultRoutingEngine::new();
let metrics =
metrics_with_live_state(50, ThermalState::Normal, MemoryPressure::Normal, Some(96.0));
let policy = PolicyResult::allow(Some("policy passed".to_string()));
let availability = LocalAvailability::new(true);
let first = engine.decide("test_stage", &metrics, &policy, &availability);
let second = engine.decide("test_stage", &metrics, &policy, &availability);
assert!(!first.reason.contains("stress_cpu_sustained"));
assert_eq!(second.target, RouteTarget::Cloud);
assert!(second.reason.contains("stress_cpu_sustained"));
}
#[test]
fn model_unavailable_overrides_stress() {
let mut engine = DefaultRoutingEngine::new();
let metrics =
metrics_with_live_state(10, ThermalState::Hot, MemoryPressure::Critical, Some(99.0));
let policy = PolicyResult::allow(Some("policy passed".to_string()));
let availability = LocalAvailability::new(false);
let decision = engine.decide("test_stage", &metrics, &policy, &availability);
assert_eq!(decision.target, RouteTarget::Cloud);
assert!(decision.reason.contains("model_unavailable"));
}
#[test]
fn sustained_cpu_history_is_stage_scoped() {
let mut engine = DefaultRoutingEngine::new();
let metrics =
metrics_with_live_state(50, ThermalState::Normal, MemoryPressure::Normal, Some(96.0));
let policy = PolicyResult::allow(Some("policy passed".to_string()));
let availability = LocalAvailability::new(true);
let stage_a_first = engine.decide("stage_a", &metrics, &policy, &availability);
let stage_b_first = engine.decide("stage_b", &metrics, &policy, &availability);
let stage_a_second = engine.decide("stage_a", &metrics, &policy, &availability);
assert!(!stage_a_first.reason.contains("stress_cpu_sustained"));
assert!(!stage_b_first.reason.contains("stress_cpu_sustained"));
assert!(stage_a_second.reason.contains("stress_cpu_sustained"));
}
#[test]
fn sustained_cpu_history_map_stays_bounded() {
let mut engine = DefaultRoutingEngine::new();
let metrics =
metrics_with_live_state(50, ThermalState::Normal, MemoryPressure::Normal, Some(96.0));
let policy = PolicyResult::allow(Some("policy passed".to_string()));
let availability = LocalAvailability::new(true);
for idx in 0..(CPU_HISTORY_MAX_STAGES + 32) {
let _ = engine.decide(&format!("stage-{idx}"), &metrics, &policy, &availability);
}
assert!(
engine.cpu_history.len() <= CPU_HISTORY_MAX_STAGES,
"CPU history should stay bounded"
);
}
#[test]
fn sustained_cpu_resets_when_sample_drops_below_threshold() {
let mut engine = DefaultRoutingEngine::new();
let hot =
metrics_with_live_state(50, ThermalState::Normal, MemoryPressure::Normal, Some(96.0));
let cool =
metrics_with_live_state(50, ThermalState::Normal, MemoryPressure::Normal, Some(20.0));
let policy = PolicyResult::allow(Some("policy passed".to_string()));
let availability = LocalAvailability::new(true);
let _ = engine.decide("test_stage", &hot, &policy, &availability);
let _ = engine.decide("test_stage", &cool, &policy, &availability);
let decision = engine.decide("test_stage", &hot, &policy, &availability);
assert!(!decision.reason.contains("stress_cpu_sustained"));
}
#[test]
fn test_missing_model_routes_cloud() {
let mut engine = DefaultRoutingEngine::new();
let metrics = DeviceMetrics::default();
let policy = PolicyResult::allow(Some("policy passed".to_string()));
let availability = LocalAvailability::new(false);
let decision = engine.decide("test_stage", &metrics, &policy, &availability);
assert_eq!(decision.target, RouteTarget::Cloud);
assert!(decision.reason.contains("model_unavailable"));
}
#[test]
fn default_unstressed_conditions_route_local() {
let mut engine = DefaultRoutingEngine::new();
let metrics =
metrics_with_live_state(80, ThermalState::Normal, MemoryPressure::Normal, Some(20.0));
let policy = PolicyResult::allow(Some("policy passed".to_string()));
let availability = LocalAvailability::new(true);
let decision = engine.decide("test_stage", &metrics, &policy, &availability);
assert_eq!(decision.target, RouteTarget::Local);
assert!(decision.reason.contains("default_local"));
}
#[test]
fn test_routing_decision_json_format() {
let decision = RoutingDecision {
stage: "motivator".to_string(),
target: RouteTarget::Cloud,
reason: "low network latency (110ms)".to_string(),
timestamp_ms: 1730559412312,
local_reliability_hint: LocalReliabilityHint::EMPTY,
};
let json = decision.to_json();
assert!(json.contains("\"stage\":\"motivator\""));
assert!(json.contains("\"target\":\"cloud\""));
assert!(json.contains("\"reason\":\"low network latency (110ms)\""));
assert!(json.contains("\"timestamp_ms\":1730559412312"));
}
#[test]
fn routing_decision_json_includes_local_reliability_hint() {
let decision = RoutingDecision {
stage: "stage-1".to_string(),
target: RouteTarget::Cloud,
reason: "history_bias".to_string(),
timestamp_ms: 1730559412312,
local_reliability_hint: LocalReliabilityHint {
recent_abort_rate: 1.0,
sample_size: 3,
},
};
let parsed: serde_json::Value =
serde_json::from_str(&decision.to_json()).expect("to_json output must be valid JSON");
let hint = parsed
.get("local_reliability_hint")
.expect("hint must be present at the payload top level");
assert_eq!(hint["recent_abort_rate"].as_f64(), Some(1.0));
assert_eq!(hint["sample_size"].as_u64(), Some(3));
assert!(
hint["sample_size"].is_u64(),
"sample_size must serialize as integer (got {:?})",
hint["sample_size"]
);
}
#[test]
fn routing_decision_json_emits_empty_local_reliability_hint() {
let decision = RoutingDecision {
stage: "stage-1".to_string(),
target: RouteTarget::Local,
reason: "default_local".to_string(),
timestamp_ms: 1730559412312,
local_reliability_hint: LocalReliabilityHint::EMPTY,
};
let parsed: serde_json::Value =
serde_json::from_str(&decision.to_json()).expect("to_json output must be valid JSON");
let hint = parsed
.get("local_reliability_hint")
.expect("hint must be present even when EMPTY");
assert_eq!(hint["recent_abort_rate"].as_f64(), Some(0.0));
assert_eq!(hint["sample_size"].as_u64(), Some(0));
}
#[test]
fn routing_decision_json_escapes_special_characters_in_reason() {
let decision = RoutingDecision {
stage: "stage-1".to_string(),
target: RouteTarget::Cloud,
reason: r#"hysteresis: recent local abort for model 'weird-"model' (stress_memory)"#
.to_string(),
timestamp_ms: 1730559412312,
local_reliability_hint: LocalReliabilityHint::EMPTY,
};
let json = decision.to_json();
let parsed: serde_json::Value =
serde_json::from_str(&json).expect("to_json output must be valid JSON");
assert_eq!(parsed["stage"], "stage-1");
assert_eq!(
parsed["reason"],
r#"hysteresis: recent local abort for model 'weird-"model' (stress_memory)"#
);
}
#[test]
fn test_route_target_to_json_string() {
assert_eq!(RouteTarget::Local.to_json_string(), "local");
assert_eq!(RouteTarget::Cloud.to_json_string(), "cloud");
assert_eq!(
RouteTarget::Fallback("model_v2".to_string()).to_json_string(),
"fallback:model_v2"
);
}
#[test]
fn test_route_target_display() {
assert_eq!(format!("{}", RouteTarget::Local), "local");
assert_eq!(format!("{}", RouteTarget::Cloud), "cloud");
assert_eq!(
format!("{}", RouteTarget::Fallback("model_v2".to_string())),
"fallback:model_v2"
);
assert_eq!(RouteTarget::Local.to_string(), "local");
assert_eq!(RouteTarget::Cloud.to_string(), "cloud");
assert_eq!(
RouteTarget::Fallback("backup".to_string()).to_string(),
"fallback:backup"
);
}
#[test]
fn boundary_battery_throttle_threshold() {
let mut engine = DefaultRoutingEngine::new();
let policy = PolicyResult::allow(Some("policy passed".to_string()));
let availability = LocalAvailability::new(true);
let metrics =
metrics_with_live_state(19, ThermalState::Normal, MemoryPressure::Normal, None);
let decision = engine.decide("test_stage", &metrics, &policy, &availability);
assert_eq!(decision.target, RouteTarget::Cloud);
assert!(decision.reason.contains("stress_throttle"));
let metrics =
metrics_with_live_state(20, ThermalState::Normal, MemoryPressure::Normal, None);
let decision = engine.decide("test_stage", &metrics, &policy, &availability);
assert_eq!(decision.target, RouteTarget::Local);
}
}