use crate::device::{detect_capabilities, HardwareCapabilities, ResourceSnapshot};
pub use crate::ir::{Envelope, EnvelopeKind};
use crate::pipeline::{ExecutionTarget, IntegrationProvider, StageOptions};
pub const DEVICE_CLASS_SCHEMA_VERSION: u16 = 1;
#[derive(Debug, Clone)]
pub struct DeviceMetrics {
pub capabilities: HardwareCapabilities,
pub resource: ResourceSnapshot,
}
impl DeviceMetrics {
pub fn with_live_snapshot(&self, snapshot: ResourceSnapshot) -> Self {
let mut metrics = self.clone();
metrics.resource = snapshot;
if let Some(available_mb) = snapshot.available_mem_mb {
metrics.capabilities.memory_available_mb = available_mb as u64;
}
if let Some(total_mb) = snapshot.total_mem_mb {
metrics.capabilities.memory_total_mb = total_mb as u64;
}
if let Some(cpu_pct) = snapshot.cpu_pct {
metrics.capabilities.cpu_usage_percent = cpu_pct;
}
if let Some(battery_pct) = snapshot.battery_pct {
metrics.capabilities.battery_level = battery_pct;
}
metrics.capabilities.thermal_state = snapshot.thermal_state;
metrics
}
pub fn canonical_device_class(&self) -> String {
let platform = self.capabilities.platform.as_str();
let arch = normalized_arch(std::env::consts::ARCH);
match platform {
"android" => format!("android-{arch}-unknown"),
"ios" => format!("unknown-ios-{arch}"),
"macos" | "linux" | "windows" => {
let accelerator = if self.capabilities.has_npu {
self.capabilities.npu_type.as_str()
} else if self.capabilities.has_gpu {
self.capabilities.gpu_type.as_str()
} else {
"cpu"
};
format!("desktop-{platform}-{arch}-{accelerator}")
}
_ => format!("unknown-{platform}-{arch}"),
}
}
}
fn normalized_arch(arch: &str) -> String {
match arch {
"aarch64" => "arm64".to_string(),
other => other.to_ascii_lowercase().replace('_', "-"),
}
}
impl Default for DeviceMetrics {
fn default() -> Self {
Self {
capabilities: detect_capabilities(),
resource: ResourceSnapshot::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct StageDescriptor {
pub name: String,
pub bundle_path: Option<String>,
pub target: Option<ExecutionTarget>,
pub provider: Option<IntegrationProvider>,
pub model: Option<String>,
pub options: Option<StageOptions>,
}
impl StageDescriptor {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
bundle_path: None,
target: None,
provider: None,
model: None,
options: None,
}
}
pub fn with_bundle_path(mut self, path: impl Into<String>) -> Self {
self.bundle_path = Some(path.into());
self
}
pub fn with_target(mut self, target: ExecutionTarget) -> Self {
self.target = Some(target);
self
}
pub fn with_provider(mut self, provider: IntegrationProvider) -> Self {
self.provider = Some(provider);
self.target = Some(ExecutionTarget::Cloud);
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_options(mut self, options: StageOptions) -> Self {
self.options = Some(options);
self
}
pub fn is_locally_runnable(&self) -> bool {
let allows_local = self
.target
.as_ref()
.map(ExecutionTarget::allows_local)
.unwrap_or(true);
self.bundle_path.is_some() && allows_local
}
pub fn is_cloud(&self) -> bool {
matches!(self.target, Some(ExecutionTarget::Cloud)) || self.provider.is_some()
}
pub fn is_device(&self) -> bool {
matches!(self.target, Some(ExecutionTarget::Device) | None) && self.provider.is_none()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::{GpuType, MemoryPressure, Platform, ThermalState};
#[test]
fn default_device_metrics_carries_unknown_memory_pressure() {
let metrics = DeviceMetrics::default();
assert_eq!(metrics.resource.memory_pressure, MemoryPressure::Unknown);
}
#[test]
fn stage_local_runnability_requires_bundle_path() {
let stage = StageDescriptor::new("test-model");
assert!(!stage.is_locally_runnable());
}
#[test]
fn stage_local_runnability_respects_network_target() {
let stage = StageDescriptor::new("test-model")
.with_bundle_path("/tmp/test-model")
.with_target(ExecutionTarget::Cloud);
assert!(!stage.is_locally_runnable());
}
#[test]
fn stage_local_runnability_allows_auto_with_bundle_path() {
let stage = StageDescriptor::new("test-model")
.with_bundle_path("/tmp/test-model")
.with_target(ExecutionTarget::Auto);
assert!(stage.is_locally_runnable());
}
#[test]
fn with_live_snapshot_preserves_capability_memory_when_snapshot_missing() {
let metrics = DeviceMetrics::default();
let snapshot = ResourceSnapshot::unknown();
let merged = metrics.with_live_snapshot(snapshot);
assert!(
merged.capabilities.memory_total_mb > 0,
"memory_total_mb must not be zeroed by an unknown snapshot"
);
}
#[test]
fn with_live_snapshot_uses_snapshot_memory_when_present() {
let metrics = DeviceMetrics::default();
let mut snapshot = ResourceSnapshot::unknown();
snapshot.memory_pressure = MemoryPressure::Normal;
snapshot.available_mem_mb = Some(2048);
snapshot.total_mem_mb = Some(8192);
snapshot.cpu_pct = Some(42.5);
let merged = metrics.with_live_snapshot(snapshot);
assert_eq!(merged.capabilities.memory_available_mb, 2048);
assert_eq!(merged.capabilities.memory_total_mb, 8192);
assert_eq!(merged.capabilities.cpu_usage_percent, 42.5);
}
#[test]
fn with_live_snapshot_overlays_battery_and_thermal_when_snapshot_carries_them() {
let metrics = DeviceMetrics::default();
let mut snapshot = ResourceSnapshot::unknown();
snapshot.battery_pct = Some(42);
snapshot.thermal_state = ThermalState::Hot;
let merged = metrics.with_live_snapshot(snapshot);
assert_eq!(merged.capabilities.battery_level, 42);
assert_eq!(merged.capabilities.thermal_state, ThermalState::Hot);
}
#[test]
fn canonical_device_class_uses_stable_desktop_capability_bucket() {
let mut metrics = DeviceMetrics::default();
metrics.capabilities.platform = Platform::MacOS;
metrics.capabilities.has_npu = false;
metrics.capabilities.has_gpu = true;
metrics.capabilities.gpu_type = GpuType::Metal;
let class = metrics.canonical_device_class();
assert!(class.starts_with("desktop-macos-"));
assert!(class.ends_with("-metal"));
}
#[test]
fn canonical_device_class_uses_android_unknown_fallback() {
let mut metrics = DeviceMetrics::default();
metrics.capabilities.platform = Platform::Android;
let class = metrics.canonical_device_class();
assert!(class.starts_with("android-"));
assert!(class.ends_with("-unknown"));
}
}