use super::provider::IntegrationProvider;
use super::stage::{FallbackConfig, StageConfig};
use super::target::ExecutionTarget;
use crate::context::DeviceMetrics;
use crate::device::capabilities::HardwareCapabilities;
use crate::device::MemoryPressure;
use crate::orchestrator::routing_engine::{
LocalAvailability, LocalReliabilityHint, RouteTarget, RoutingDecision,
};
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone)]
pub struct ResolutionContext {
pub metrics: DeviceMetrics,
pub local_available: bool,
pub server_available: bool,
pub integration_available: std::collections::HashMap<IntegrationProvider, bool>,
pub capabilities: HardwareCapabilities,
}
impl ResolutionContext {
pub fn new(metrics: DeviceMetrics) -> Self {
let capabilities = metrics.capabilities.clone();
Self {
metrics,
local_available: false,
server_available: true, integration_available: std::collections::HashMap::new(),
capabilities,
}
}
pub fn with_local_available(mut self, available: bool) -> Self {
self.local_available = available;
self
}
pub fn with_server_available(mut self, available: bool) -> Self {
self.server_available = available;
self
}
pub fn with_integration_available(
mut self,
provider: IntegrationProvider,
available: bool,
) -> Self {
self.integration_available.insert(provider, available);
self
}
pub fn is_cloud_available(&self, provider: &IntegrationProvider) -> bool {
*self.integration_available.get(provider).unwrap_or(&true)
}
}
#[derive(Debug, Clone)]
pub struct ResolvedTarget {
pub target: RouteTarget,
pub reason: String,
pub provider: Option<IntegrationProvider>,
pub model: String,
pub version: Option<String>,
}
impl ResolvedTarget {
pub fn local(model: &str, version: Option<&str>, reason: &str) -> Self {
Self {
target: RouteTarget::Local,
reason: reason.to_string(),
provider: None,
model: model.to_string(),
version: version.map(|v| v.to_string()),
}
}
pub fn server(model: &str, version: Option<&str>, reason: &str) -> Self {
Self {
target: RouteTarget::Cloud,
reason: format!("server: {}", reason),
provider: None,
model: model.to_string(),
version: version.map(|v| v.to_string()),
}
}
pub fn integration(provider: IntegrationProvider, model: &str, reason: &str) -> Self {
Self {
target: RouteTarget::Cloud,
reason: format!("integration/{}: {}", provider, reason),
provider: Some(provider),
model: model.to_string(),
version: None, }
}
pub fn fallback(model: &str, version: Option<&str>, reason: &str) -> Self {
Self {
target: RouteTarget::Fallback(model.to_string()),
reason: reason.to_string(),
provider: None,
model: model.to_string(),
version: version.map(|v| v.to_string()),
}
}
pub fn to_routing_decision(
&self,
stage: &str,
local_reliability_hint: LocalReliabilityHint,
) -> RoutingDecision {
RoutingDecision {
stage: stage.to_string(),
target: self.target.clone(),
reason: self.reason.clone(),
timestamp_ms: current_timestamp_ms(),
local_reliability_hint,
}
}
}
pub struct TargetResolver;
impl TargetResolver {
pub fn resolve(
stage: &StageConfig,
context: &ResolutionContext,
) -> Result<ResolvedTarget, ResolutionError> {
let primary = Self::resolve_target(
&stage.target,
&stage.model,
stage.version.as_deref(),
stage.provider.as_ref(),
stage.prefer.as_ref(),
context,
);
match primary {
Ok(resolved) => Ok(resolved),
Err(err) => {
for fallback in &stage.fallback {
let fallback_result = Self::resolve_fallback(fallback, &stage.model, context);
if let Ok(resolved) = fallback_result {
return Ok(resolved);
}
}
Err(err)
}
}
}
fn resolve_target(
target: &ExecutionTarget,
model: &str,
version: Option<&str>,
provider: Option<&IntegrationProvider>,
prefer: Option<&ExecutionTarget>,
context: &ResolutionContext,
) -> Result<ResolvedTarget, ResolutionError> {
match target {
ExecutionTarget::Device => Self::resolve_device(model, version, context),
ExecutionTarget::Server => Self::resolve_server(model, version, context),
ExecutionTarget::Cloud => {
let provider =
provider.ok_or_else(|| ResolutionError::MissingProvider(model.to_string()))?;
Self::resolve_integration(provider, model, context)
}
ExecutionTarget::Auto => Self::resolve_auto(model, version, provider, prefer, context),
}
}
fn resolve_device(
model: &str,
version: Option<&str>,
context: &ResolutionContext,
) -> Result<ResolvedTarget, ResolutionError> {
if !context.local_available {
return Err(ResolutionError::DeviceUnavailable(model.to_string()));
}
if context.capabilities.should_throttle() {
return Err(ResolutionError::DeviceThrottled(
"device is thermal throttled".to_string(),
));
}
Ok(ResolvedTarget::local(
model,
version,
"device target resolved",
))
}
fn resolve_server(
model: &str,
version: Option<&str>,
context: &ResolutionContext,
) -> Result<ResolvedTarget, ResolutionError> {
if !context.server_available {
return Err(ResolutionError::ServerUnavailable(model.to_string()));
}
Ok(ResolvedTarget::server(
model,
version,
"server target resolved",
))
}
fn resolve_integration(
provider: &IntegrationProvider,
model: &str,
context: &ResolutionContext,
) -> Result<ResolvedTarget, ResolutionError> {
if !context.is_cloud_available(provider) {
return Err(ResolutionError::IntegrationUnavailable(
*provider,
model.to_string(),
));
}
Ok(ResolvedTarget::integration(
*provider,
model,
"integration target resolved",
))
}
fn resolve_auto(
model: &str,
version: Option<&str>,
provider: Option<&IntegrationProvider>,
prefer: Option<&ExecutionTarget>,
context: &ResolutionContext,
) -> Result<ResolvedTarget, ResolutionError> {
if let Some(preferred) = prefer {
let result = Self::resolve_target(preferred, model, version, provider, None, context);
if result.is_ok() {
return result;
}
}
if context.local_available && Self::should_prefer_device(context) {
return Ok(ResolvedTarget::local(
model,
version,
"auto: device preferred (good conditions)",
));
}
if let Some(prov) = provider {
if context.is_cloud_available(prov) {
return Ok(ResolvedTarget::integration(
*prov,
model,
"auto: integration available",
));
}
}
if context.server_available {
return Ok(ResolvedTarget::server(
model,
version,
"auto: server available",
));
}
if context.local_available {
return Ok(ResolvedTarget::local(
model,
version,
"auto: fallback to device",
));
}
Err(ResolutionError::NoTargetAvailable(model.to_string()))
}
fn should_prefer_device(context: &ResolutionContext) -> bool {
!context.capabilities.should_throttle()
&& context.metrics.resource.memory_pressure != MemoryPressure::Critical
&& (context.capabilities.has_gpu
|| context.capabilities.has_metal
|| context.capabilities.has_nnapi)
}
fn resolve_fallback(
fallback: &FallbackConfig,
original_model: &str,
context: &ResolutionContext,
) -> Result<ResolvedTarget, ResolutionError> {
let model = fallback.model.as_deref().unwrap_or(original_model);
let version = fallback.version.as_deref();
let provider = fallback.provider.as_ref();
Self::resolve_target(&fallback.target, model, version, provider, None, context)
}
}
#[derive(Debug, Clone)]
pub enum ResolutionError {
DeviceUnavailable(String),
DeviceThrottled(String),
ServerUnavailable(String),
NetworkTooSlow(u32),
IntegrationUnavailable(IntegrationProvider, String),
MissingProvider(String),
NoTargetAvailable(String),
}
impl std::fmt::Display for ResolutionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ResolutionError::DeviceUnavailable(model) => {
write!(f, "device model '{}' not available", model)
}
ResolutionError::DeviceThrottled(reason) => {
write!(f, "device throttled: {}", reason)
}
ResolutionError::ServerUnavailable(model) => {
write!(f, "server model '{}' not available", model)
}
ResolutionError::NetworkTooSlow(rtt) => {
write!(f, "network too slow ({}ms RTT)", rtt)
}
ResolutionError::IntegrationUnavailable(provider, model) => {
write!(
f,
"integration provider '{}' unavailable for model '{}'",
provider, model
)
}
ResolutionError::MissingProvider(model) => {
write!(f, "integration target for '{}' requires a provider", model)
}
ResolutionError::NoTargetAvailable(model) => {
write!(f, "no execution target available for model '{}'", model)
}
}
}
}
impl std::error::Error for ResolutionError {}
fn current_timestamp_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
impl From<&LocalAvailability> for bool {
fn from(availability: &LocalAvailability) -> bool {
availability.local_model_exists
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_context() -> ResolutionContext {
ResolutionContext::new(DeviceMetrics::default()).with_local_available(true)
}
#[test]
fn test_resolve_device_target() {
let context = test_context();
let stage = StageConfig::new("asr", "wav2vec2-base-960h")
.with_target(ExecutionTarget::Device)
.with_version("1.0");
let result = TargetResolver::resolve(&stage, &context).unwrap();
assert_eq!(result.target, RouteTarget::Local);
assert_eq!(result.model, "wav2vec2-base-960h");
assert_eq!(result.version, Some("1.0".to_string()));
}
#[test]
fn test_resolve_device_unavailable() {
let context = test_context().with_local_available(false);
let stage =
StageConfig::new("asr", "wav2vec2-base-960h").with_target(ExecutionTarget::Device);
let result = TargetResolver::resolve(&stage, &context);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ResolutionError::DeviceUnavailable(_)
));
}
#[test]
fn test_resolve_server_target() {
let context = test_context();
let stage = StageConfig::new("asr", "whisper-large-v3")
.with_target(ExecutionTarget::Server)
.with_version("1.0");
let result = TargetResolver::resolve(&stage, &context).unwrap();
assert_eq!(result.target, RouteTarget::Cloud);
assert!(result.reason.contains("server"));
}
#[test]
fn test_resolve_integration_target() {
let context = test_context().with_integration_available(IntegrationProvider::OpenAI, true);
let stage =
StageConfig::new("llm", "gpt-4o-mini").with_provider(IntegrationProvider::OpenAI);
let result = TargetResolver::resolve(&stage, &context).unwrap();
assert_eq!(result.target, RouteTarget::Cloud);
assert_eq!(result.provider, Some(IntegrationProvider::OpenAI));
assert!(result.reason.contains("integration"));
}
#[test]
fn test_resolve_integration_missing_provider() {
let context = test_context();
let stage = StageConfig::new("llm", "gpt-4o-mini").with_target(ExecutionTarget::Cloud);
let result = TargetResolver::resolve(&stage, &context);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ResolutionError::MissingProvider(_)
));
}
#[test]
fn test_resolve_auto_prefers_device() {
let context = test_context();
let stage =
StageConfig::new("asr", "wav2vec2-base-960h").with_target(ExecutionTarget::Auto);
let result = TargetResolver::resolve(&stage, &context).unwrap();
assert!(result.target == RouteTarget::Local || result.target == RouteTarget::Cloud);
}
#[test]
fn test_resolve_auto_with_preference() {
let mut stage = StageConfig::new("tts", "piper-en-us").with_target(ExecutionTarget::Auto);
stage.prefer = Some(ExecutionTarget::Device);
let context = test_context();
let result = TargetResolver::resolve(&stage, &context).unwrap();
assert_eq!(result.target, RouteTarget::Local);
}
#[test]
fn test_resolve_with_fallback_chain() {
let context = test_context()
.with_local_available(false)
.with_integration_available(IntegrationProvider::OpenAI, true);
let stage = StageConfig::new("asr", "whisper-large-v3")
.with_target(ExecutionTarget::Device)
.with_fallback(
FallbackConfig::new(ExecutionTarget::Cloud)
.with_provider(IntegrationProvider::OpenAI)
.with_model("whisper-1"),
);
let result = TargetResolver::resolve(&stage, &context).unwrap();
assert_eq!(result.target, RouteTarget::Cloud);
assert_eq!(result.provider, Some(IntegrationProvider::OpenAI));
assert_eq!(result.model, "whisper-1");
}
#[test]
fn test_resolution_context_builder() {
let context = ResolutionContext::new(DeviceMetrics::default())
.with_local_available(true)
.with_server_available(true)
.with_integration_available(IntegrationProvider::OpenAI, true)
.with_integration_available(IntegrationProvider::Anthropic, false);
assert!(context.local_available);
assert!(context.server_available);
assert!(context.is_cloud_available(&IntegrationProvider::OpenAI));
assert!(!context.is_cloud_available(&IntegrationProvider::Anthropic));
assert!(context.is_cloud_available(&IntegrationProvider::Google)); }
#[test]
fn test_resolved_target_to_routing_decision() {
let resolved = ResolvedTarget::local("wav2vec2", Some("1.0"), "test reason");
let decision = resolved.to_routing_decision(
"asr",
crate::orchestrator::routing_engine::LocalReliabilityHint::EMPTY,
);
assert_eq!(decision.stage, "asr");
assert_eq!(decision.target, RouteTarget::Local);
assert_eq!(decision.reason, "test reason");
assert!(decision.timestamp_ms > 0);
}
#[test]
fn resolved_target_to_routing_decision_carries_local_reliability_hint() {
let resolved = ResolvedTarget::local("wav2vec2", Some("1.0"), "test reason");
let hint = crate::orchestrator::routing_engine::LocalReliabilityHint {
recent_abort_rate: 0.75,
sample_size: 4,
};
let decision = resolved.to_routing_decision("asr", hint);
assert_eq!(decision.local_reliability_hint, hint);
}
}