use serde::{Deserialize, Serialize};
#[cfg(feature = "coreml")]
use crate::backends::{AneCapabilities, ComputeUnits};
use crate::backends::{DType, DeviceType, Quantization};
use crate::kernels::AttentionConfig;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Platform {
MacOS,
Linux,
Windows,
Wasm,
IOS,
Android,
Unknown,
}
impl Default for Platform {
fn default() -> Self {
Self::detect()
}
}
impl Platform {
pub fn detect() -> Self {
#[cfg(target_os = "macos")]
{
Self::MacOS
}
#[cfg(target_os = "linux")]
{
#[cfg(target_os = "android")]
{
Self::Android
}
#[cfg(not(target_os = "android"))]
{
Self::Linux
}
}
#[cfg(target_os = "windows")]
{
Self::Windows
}
#[cfg(target_arch = "wasm32")]
{
Self::Wasm
}
#[cfg(target_os = "ios")]
{
Self::IOS
}
#[cfg(target_os = "android")]
{
Self::Android
}
#[cfg(not(any(
target_os = "macos",
target_os = "linux",
target_os = "windows",
target_arch = "wasm32",
target_os = "ios",
target_os = "android"
)))]
{
Self::Unknown
}
}
pub fn supports_gpu(&self) -> bool {
matches!(
self,
Self::MacOS | Self::Linux | Self::Windows | Self::IOS | Self::Wasm
)
}
pub fn default_gpu_backend(&self) -> Option<GpuBackend> {
match self {
Self::MacOS | Self::IOS => Some(GpuBackend::Metal),
Self::Linux | Self::Windows => Some(GpuBackend::Cuda),
Self::Wasm => Some(GpuBackend::WebGPU),
Self::Android | Self::Unknown => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Architecture {
Aarch64,
X86_64,
Wasm32,
Unknown,
}
impl Default for Architecture {
fn default() -> Self {
Self::detect()
}
}
impl Architecture {
pub fn detect() -> Self {
#[cfg(target_arch = "aarch64")]
{
Self::Aarch64
}
#[cfg(target_arch = "x86_64")]
{
Self::X86_64
}
#[cfg(target_arch = "wasm32")]
{
Self::Wasm32
}
#[cfg(not(any(
target_arch = "aarch64",
target_arch = "x86_64",
target_arch = "wasm32"
)))]
{
Self::Unknown
}
}
pub fn has_simd(&self) -> bool {
matches!(self, Self::Aarch64 | Self::X86_64)
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct CpuFeatures {
pub neon: bool,
pub avx2: bool,
pub avx512: bool,
pub sse42: bool,
pub sve: bool,
pub sve2: bool,
}
impl CpuFeatures {
pub fn detect() -> Self {
let mut features = Self::default();
#[cfg(target_arch = "aarch64")]
{
features.neon = true;
#[cfg(target_os = "linux")]
{
features.sve = false;
features.sve2 = false;
}
}
#[cfg(target_arch = "x86_64")]
{
#[cfg(target_feature = "sse4.2")]
{
features.sse42 = true;
}
#[cfg(target_feature = "avx2")]
{
features.avx2 = true;
}
#[cfg(target_feature = "avx512f")]
{
features.avx512 = true;
}
#[cfg(not(target_feature = "avx2"))]
{
features.avx2 = Self::detect_avx2_runtime();
}
#[cfg(not(target_feature = "sse4.2"))]
{
features.sse42 = Self::detect_sse42_runtime();
}
}
features
}
#[cfg(target_arch = "x86_64")]
fn detect_avx2_runtime() -> bool {
#[cfg(all(target_arch = "x86_64", not(target_feature = "avx2")))]
{
std::arch::is_x86_feature_detected!("avx2")
}
#[cfg(target_feature = "avx2")]
{
true
}
}
#[cfg(target_arch = "x86_64")]
fn detect_sse42_runtime() -> bool {
#[cfg(all(target_arch = "x86_64", not(target_feature = "sse4.2")))]
{
std::arch::is_x86_feature_detected!("sse4.2")
}
#[cfg(target_feature = "sse4.2")]
{
true
}
}
pub fn best_simd_width(&self) -> usize {
if self.avx512 {
512
} else if self.avx2 {
256
} else if self.neon || self.sse42 {
128
} else {
0
}
}
pub fn simd_float_lanes(&self) -> usize {
self.best_simd_width() / 32 }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum GpuBackend {
Metal,
Cuda,
WebGPU,
Vulkan,
OpenCL,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct GpuCapabilities {
pub backend: GpuBackend,
pub vram_mb: Option<usize>,
pub compute_units: Option<usize>,
pub name: Option<String>,
pub supports_fp16: bool,
pub supports_int8: bool,
pub has_tensor_cores: bool,
pub max_shared_memory: Option<usize>,
}
impl GpuCapabilities {
pub fn detect() -> Option<Self> {
#[cfg(all(target_os = "macos", feature = "metal-compute"))]
{
return Self::detect_metal();
}
#[cfg(all(target_os = "macos", not(feature = "metal-compute")))]
{
return Some(Self {
backend: GpuBackend::Metal,
vram_mb: None,
compute_units: None,
name: Some("Apple GPU (metal-compute feature not enabled)".to_string()),
supports_fp16: true,
supports_int8: true,
has_tensor_cores: false,
max_shared_memory: Some(32 * 1024), });
}
#[cfg(target_os = "ios")]
{
return Some(Self {
backend: GpuBackend::Metal,
vram_mb: None,
compute_units: None,
name: Some("Apple GPU (iOS)".to_string()),
supports_fp16: true,
supports_int8: true,
has_tensor_cores: false,
max_shared_memory: Some(32 * 1024),
});
}
#[cfg(any(target_os = "linux", target_os = "windows"))]
{
if let Some(cuda) = Self::detect_cuda() {
return Some(cuda);
}
}
#[cfg(target_arch = "wasm32")]
{
return Self::detect_webgpu();
}
#[allow(unreachable_code)]
None
}
#[cfg(all(target_os = "macos", feature = "metal-compute"))]
fn detect_metal() -> Option<Self> {
use crate::metal::{get_device_info, is_metal_available};
if !is_metal_available() {
return None;
}
match get_device_info() {
Some(info) => {
let is_apple_silicon = info.has_unified_memory;
Some(Self {
backend: GpuBackend::Metal,
vram_mb: Some(info.recommended_max_working_set_size / (1024 * 1024)),
compute_units: Some(info.max_threads_per_threadgroup),
name: Some(info.name),
supports_fp16: is_apple_silicon, supports_int8: true,
has_tensor_cores: is_apple_silicon, max_shared_memory: Some(32 * 1024), })
}
None => Some(Self {
backend: GpuBackend::Metal,
vram_mb: None,
compute_units: None,
name: Some("Apple GPU".to_string()),
supports_fp16: true,
supports_int8: true,
has_tensor_cores: false,
max_shared_memory: Some(32 * 1024),
}),
}
}
#[cfg(any(target_os = "linux", target_os = "windows"))]
fn detect_cuda() -> Option<Self> {
None
}
#[cfg(target_arch = "wasm32")]
fn detect_webgpu() -> Option<Self> {
Some(Self {
backend: GpuBackend::WebGPU,
vram_mb: None,
compute_units: None,
name: Some("WebGPU (browser)".to_string()),
supports_fp16: true,
supports_int8: false, has_tensor_cores: false,
max_shared_memory: Some(16 * 1024), })
}
pub fn can_fit_model(&self, model_size_gb: f32) -> bool {
if let Some(vram_mb) = self.vram_mb {
let vram_gb = vram_mb as f32 / 1024.0;
vram_gb >= model_size_gb * 1.2
} else {
true
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct CoreInfo {
pub physical_cores: usize,
pub logical_cores: usize,
pub performance_cores: Option<usize>,
pub efficiency_cores: Option<usize>,
}
impl Default for CoreInfo {
fn default() -> Self {
Self::detect()
}
}
impl CoreInfo {
pub fn detect() -> Self {
let logical_cores = Self::detect_logical_cores();
let physical_cores = Self::detect_physical_cores(logical_cores);
#[cfg(target_os = "macos")]
{
let (perf, eff) = Self::detect_apple_cores();
return Self {
physical_cores,
logical_cores,
performance_cores: perf,
efficiency_cores: eff,
};
}
#[cfg(not(target_os = "macos"))]
Self {
physical_cores,
logical_cores,
performance_cores: None,
efficiency_cores: None,
}
}
fn detect_logical_cores() -> usize {
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
}
fn detect_physical_cores(logical: usize) -> usize {
#[cfg(target_os = "macos")]
{
Self::sysctl_physical_cores().unwrap_or(logical)
}
#[cfg(target_os = "linux")]
{
Self::linux_physical_cores().unwrap_or(logical / 2).max(1)
}
#[cfg(target_os = "windows")]
{
(logical / 2).max(1)
}
#[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
{
logical
}
}
#[cfg(target_os = "macos")]
fn sysctl_physical_cores() -> Option<usize> {
use std::process::Command;
let output = Command::new("sysctl")
.args(["-n", "hw.physicalcpu"])
.output()
.ok()?;
String::from_utf8_lossy(&output.stdout).trim().parse().ok()
}
#[cfg(target_os = "linux")]
fn linux_physical_cores() -> Option<usize> {
use std::fs;
let cpuinfo = fs::read_to_string("/proc/cpuinfo").ok()?;
let mut cores = std::collections::HashSet::new();
let mut physical_id = None;
let mut core_id = None;
for line in cpuinfo.lines() {
if line.starts_with("physical id") {
physical_id = line
.split(':')
.nth(1)
.and_then(|s| s.trim().parse::<usize>().ok());
} else if line.starts_with("core id") {
core_id = line
.split(':')
.nth(1)
.and_then(|s| s.trim().parse::<usize>().ok());
}
if let (Some(pid), Some(cid)) = (physical_id, core_id) {
cores.insert((pid, cid));
physical_id = None;
core_id = None;
}
}
if cores.is_empty() {
Some(
cpuinfo
.lines()
.filter(|l| l.starts_with("processor"))
.count(),
)
} else {
Some(cores.len())
}
}
#[cfg(target_os = "macos")]
fn detect_apple_cores() -> (Option<usize>, Option<usize>) {
use std::process::Command;
let perf = Command::new("sysctl")
.args(["-n", "hw.perflevel0.physicalcpu"])
.output()
.ok()
.and_then(|o| String::from_utf8_lossy(&o.stdout).trim().parse().ok());
let eff = Command::new("sysctl")
.args(["-n", "hw.perflevel1.physicalcpu"])
.output()
.ok()
.and_then(|o| String::from_utf8_lossy(&o.stdout).trim().parse().ok());
(perf, eff)
}
pub fn recommended_threads(&self) -> usize {
if let Some(perf) = self.performance_cores {
perf
} else {
self.physical_cores
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AneInfo {
pub available: bool,
pub tops: f32,
pub max_model_size_mb: usize,
pub supported_ops: Vec<String>,
}
impl Default for AneInfo {
fn default() -> Self {
Self::detect()
}
}
impl AneInfo {
pub fn detect() -> Self {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
Self {
available: true,
tops: Self::detect_ane_tops(),
max_model_size_mb: 2048, supported_ops: vec![
"MatMul".to_string(),
"Conv2D".to_string(),
"GELU".to_string(),
"SiLU".to_string(),
"LayerNorm".to_string(),
"Softmax".to_string(),
"Add".to_string(),
"Mul".to_string(),
],
}
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
Self {
available: false,
tops: 0.0,
max_model_size_mb: 0,
supported_ops: vec![],
}
}
}
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn detect_ane_tops() -> f32 {
use std::process::Command;
if let Ok(output) = Command::new("sysctl")
.args(["-n", "machdep.cpu.brand_string"])
.output()
{
let brand = String::from_utf8_lossy(&output.stdout).to_lowercase();
if brand.contains("m4") {
if brand.contains("max") {
return 38.0; } else if brand.contains("pro") {
return 38.0; } else {
return 38.0; }
}
if brand.contains("m3") {
if brand.contains("max") {
return 18.0;
} else if brand.contains("pro") {
return 18.0;
} else {
return 18.0;
}
}
if brand.contains("m2") {
if brand.contains("ultra") {
return 31.6; } else if brand.contains("max") {
return 15.8;
} else if brand.contains("pro") {
return 15.8;
} else {
return 15.8;
}
}
if brand.contains("m1") {
if brand.contains("ultra") {
return 22.0; } else if brand.contains("max") {
return 11.0;
} else if brand.contains("pro") {
return 11.0;
} else {
return 11.0;
}
}
}
11.0
}
pub fn is_model_suitable(&self, model_size_mb: usize) -> bool {
self.available && model_size_mb <= self.max_model_size_mb
}
pub fn recommended_strategy(&self, model_size_mb: usize) -> AneStrategy {
if !self.available {
return AneStrategy::GpuOnly;
}
if model_size_mb <= 500 {
AneStrategy::AneOnly
} else if model_size_mb <= self.max_model_size_mb {
AneStrategy::Hybrid
} else {
AneStrategy::GpuOnly
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AneStrategy {
AneOnly,
Hybrid,
GpuOnly,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SystemCapabilities {
pub platform: Platform,
pub arch: Architecture,
pub cpu_features: CpuFeatures,
pub gpu: Option<GpuCapabilities>,
pub ane: AneInfo,
pub memory_mb: usize,
pub available_memory_mb: Option<usize>,
pub cores: CoreInfo,
}
impl Default for SystemCapabilities {
fn default() -> Self {
Self::detect()
}
}
impl SystemCapabilities {
pub fn detect() -> Self {
Self {
platform: Platform::detect(),
arch: Architecture::detect(),
cpu_features: CpuFeatures::detect(),
gpu: GpuCapabilities::detect(),
ane: AneInfo::detect(),
memory_mb: Self::detect_total_memory(),
available_memory_mb: Self::detect_available_memory(),
cores: CoreInfo::detect(),
}
}
fn detect_total_memory() -> usize {
#[cfg(target_os = "macos")]
{
Self::macos_total_memory().unwrap_or(8 * 1024) }
#[cfg(target_os = "linux")]
{
Self::linux_total_memory().unwrap_or(8 * 1024)
}
#[cfg(target_os = "windows")]
{
Self::windows_total_memory().unwrap_or(8 * 1024)
}
#[cfg(target_arch = "wasm32")]
{
4 * 1024
}
#[cfg(not(any(
target_os = "macos",
target_os = "linux",
target_os = "windows",
target_arch = "wasm32"
)))]
{
4 * 1024 }
}
fn detect_available_memory() -> Option<usize> {
#[cfg(target_os = "macos")]
{
None
}
#[cfg(target_os = "linux")]
{
Self::linux_available_memory()
}
#[cfg(not(any(target_os = "macos", target_os = "linux")))]
{
None
}
}
#[cfg(target_os = "macos")]
fn macos_total_memory() -> Option<usize> {
use std::process::Command;
let output = Command::new("sysctl")
.args(["-n", "hw.memsize"])
.output()
.ok()?;
let bytes: u64 = String::from_utf8_lossy(&output.stdout)
.trim()
.parse()
.ok()?;
Some((bytes / (1024 * 1024)) as usize)
}
#[cfg(target_os = "linux")]
fn linux_total_memory() -> Option<usize> {
use std::fs;
let meminfo = fs::read_to_string("/proc/meminfo").ok()?;
for line in meminfo.lines() {
if line.starts_with("MemTotal:") {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 2 {
let kb: usize = parts[1].parse().ok()?;
return Some(kb / 1024); }
}
}
None
}
#[cfg(target_os = "linux")]
fn linux_available_memory() -> Option<usize> {
use std::fs;
let meminfo = fs::read_to_string("/proc/meminfo").ok()?;
for line in meminfo.lines() {
if line.starts_with("MemAvailable:") {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 2 {
let kb: usize = parts[1].parse().ok()?;
return Some(kb / 1024);
}
}
}
None
}
#[cfg(target_os = "windows")]
fn windows_total_memory() -> Option<usize> {
None
}
pub fn optimal_config(&self) -> InferenceConfig {
let compute_backend = self.select_compute_backend();
let quantization = self.optimal_quantization(7.0); let batch_size = self.recommended_batch_size(2048); let thread_count = self.cores.recommended_threads();
let block_size = self.optimal_block_size();
InferenceConfig {
compute_backend,
quantization,
batch_size,
thread_count,
block_size,
use_flash_attention: true,
device_type: self.optimal_device_type(),
dtype: self.optimal_dtype(),
}
}
pub fn optimal_attention_config(&self) -> AttentionConfig {
let mut config = AttentionConfig {
num_heads: 32,
num_kv_heads: 8, head_dim: 128,
max_seq_len: self.optimal_max_seq_len(),
causal: true,
scale: 0.0, };
let available_mb = self.available_memory_mb.unwrap_or(self.memory_mb / 2);
if available_mb < 4096 {
config.max_seq_len = 2048;
} else if available_mb < 8192 {
config.max_seq_len = 4096;
} else {
config.max_seq_len = 8192;
}
config
}
pub fn optimal_quantization(&self, model_size_gb: f32) -> Quantization {
let available_mb = self.available_memory_mb.unwrap_or(self.memory_mb / 2);
let available_gb = available_mb as f32 / 1024.0;
if let Some(ref gpu) = self.gpu {
if let Some(vram_mb) = gpu.vram_mb {
let vram_gb = vram_mb as f32 / 1024.0;
if vram_gb >= model_size_gb * 1.5 {
return Quantization::F16;
} else if vram_gb >= model_size_gb * 0.75 {
return Quantization::Q8;
} else if vram_gb >= model_size_gb * 0.4 {
return Quantization::Q4K;
}
}
}
if available_gb >= model_size_gb * 4.0 {
Quantization::F16
} else if available_gb >= model_size_gb * 1.5 {
Quantization::Q8
} else if available_gb >= model_size_gb * 0.6 {
Quantization::Q4K
} else {
Quantization::Q4
}
}
pub fn recommended_batch_size(&self, seq_len: usize) -> usize {
let available_mb = self.available_memory_mb.unwrap_or(self.memory_mb / 2);
let kv_per_token_kb = 128.0 / 1024.0; let kv_per_batch_mb = (kv_per_token_kb * seq_len as f32) / 1024.0;
let available_for_batch_mb = available_mb as f32 * 0.5;
let max_batch = (available_for_batch_mb / kv_per_batch_mb).floor() as usize;
max_batch.clamp(1, 64)
}
fn select_compute_backend(&self) -> ComputeBackend {
self.select_compute_backend_for_model(7.0 * 1024.0) }
pub fn select_compute_backend_for_model(&self, model_size_mb: f32) -> ComputeBackend {
#[cfg(feature = "coreml")]
{
if self.ane.available {
let strategy = self.ane.recommended_strategy(model_size_mb as usize);
match strategy {
AneStrategy::AneOnly => {
return ComputeBackend::CoreML;
}
AneStrategy::Hybrid => {
if let Some(ref gpu) = self.gpu {
if matches!(gpu.backend, GpuBackend::Metal) {
return ComputeBackend::HybridAne;
}
}
return ComputeBackend::CoreML;
}
AneStrategy::GpuOnly => {
}
}
}
}
if let Some(ref gpu) = self.gpu {
match gpu.backend {
GpuBackend::Metal => return ComputeBackend::Metal,
GpuBackend::Cuda => return ComputeBackend::Cuda,
GpuBackend::WebGPU => return ComputeBackend::WebGPU,
_ => {}
}
}
if self.cpu_features.avx512 {
ComputeBackend::CpuAvx512
} else if self.cpu_features.avx2 {
ComputeBackend::CpuAvx2
} else if self.cpu_features.neon {
ComputeBackend::CpuNeon
} else {
ComputeBackend::CpuScalar
}
}
pub fn select_power_efficient_backend(&self) -> ComputeBackend {
#[cfg(feature = "coreml")]
{
if self.ane.available {
return ComputeBackend::CoreML;
}
}
self.select_compute_backend()
}
fn optimal_device_type(&self) -> DeviceType {
if let Some(ref gpu) = self.gpu {
match gpu.backend {
GpuBackend::Metal => DeviceType::Metal,
GpuBackend::Cuda => DeviceType::Cuda(0),
_ => DeviceType::Cpu,
}
} else {
DeviceType::Cpu
}
}
fn optimal_dtype(&self) -> DType {
if let Some(ref gpu) = self.gpu {
if gpu.supports_fp16 {
return DType::F16;
}
}
DType::F32
}
fn optimal_block_size(&self) -> usize {
if let Some(ref gpu) = self.gpu {
if let Some(shared_mem) = gpu.max_shared_memory {
let head_dim = 128; let max_block = shared_mem / (head_dim * 4 * 2 * 2);
return max_block.clamp(32, 128);
}
}
#[cfg(target_os = "macos")]
{
64 }
#[cfg(not(target_os = "macos"))]
{
32 }
}
fn optimal_max_seq_len(&self) -> usize {
let available_mb = self.available_memory_mb.unwrap_or(self.memory_mb / 2);
if available_mb >= 32 * 1024 {
32768
} else if available_mb >= 16 * 1024 {
16384
} else if available_mb >= 8 * 1024 {
8192
} else if available_mb >= 4 * 1024 {
4096
} else {
2048
}
}
pub fn can_run_model(&self, model_size_gb: f32) -> bool {
let available_mb = self.available_memory_mb.unwrap_or(self.memory_mb / 2);
let available_gb = available_mb as f32 / 1024.0;
let min_required_gb = model_size_gb * 0.4 + 2.0;
available_gb >= min_required_gb
}
pub fn summary(&self) -> String {
let mut parts = vec![];
parts.push(format!("{:?} ({:?})", self.platform, self.arch));
parts.push(format!(
"{} cores ({} physical)",
self.cores.logical_cores, self.cores.physical_cores
));
if let Some(perf) = self.cores.performance_cores {
parts.push(format!(
"{}P+{}E cores",
perf,
self.cores.efficiency_cores.unwrap_or(0)
));
}
parts.push(format!("{}GB RAM", self.memory_mb / 1024));
if let Some(ref gpu) = self.gpu {
let gpu_info = match gpu.vram_mb {
Some(vram) => format!("{:?} ({}GB VRAM)", gpu.backend, vram / 1024),
None => format!("{:?}", gpu.backend),
};
parts.push(gpu_info);
} else {
parts.push("No GPU".to_string());
}
if self.ane.available {
parts.push(format!("ANE ({:.0} TOPS)", self.ane.tops));
}
let simd = if self.cpu_features.avx512 {
"AVX-512"
} else if self.cpu_features.avx2 {
"AVX2"
} else if self.cpu_features.neon {
"NEON"
} else if self.cpu_features.sse42 {
"SSE4.2"
} else {
"Scalar"
};
parts.push(simd.to_string());
parts.join(", ")
}
pub fn ane_summary(&self) -> String {
if !self.ane.available {
return "ANE: Not available".to_string();
}
format!(
"ANE: {:.0} TOPS, max model {}MB, {} supported ops",
self.ane.tops,
self.ane.max_model_size_mb,
self.ane.supported_ops.len()
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ComputeBackend {
Metal,
CoreML,
HybridAne,
Cuda,
WebGPU,
CpuAvx512,
CpuAvx2,
CpuNeon,
CpuScalar,
}
impl ComputeBackend {
pub fn is_gpu(&self) -> bool {
matches!(
self,
Self::Metal | Self::CoreML | Self::HybridAne | Self::Cuda | Self::WebGPU
)
}
pub fn uses_ane(&self) -> bool {
matches!(self, Self::CoreML | Self::HybridAne)
}
pub fn relative_performance(&self) -> f32 {
match self {
Self::HybridAne => 12.0, Self::Metal => 10.0, Self::CoreML => 8.0, Self::Cuda => 15.0, Self::WebGPU => 5.0, Self::CpuAvx512 => 4.0, Self::CpuAvx2 => 2.5, Self::CpuNeon => 2.0, Self::CpuScalar => 1.0, }
}
pub fn power_efficiency(&self) -> f32 {
match self {
Self::CoreML => 4.0, Self::HybridAne => 3.0, Self::Metal => 2.0, Self::Cuda => 1.0, Self::WebGPU => 1.5, Self::CpuAvx512 => 1.2,
Self::CpuAvx2 => 1.3,
Self::CpuNeon => 1.5, Self::CpuScalar => 1.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceConfig {
pub compute_backend: ComputeBackend,
pub quantization: Quantization,
pub batch_size: usize,
pub thread_count: usize,
pub block_size: usize,
pub use_flash_attention: bool,
pub device_type: DeviceType,
pub dtype: DType,
}
impl Default for InferenceConfig {
fn default() -> Self {
Self::auto()
}
}
impl InferenceConfig {
pub fn auto() -> Self {
SystemCapabilities::detect().optimal_config()
}
pub fn low_memory() -> Self {
let mut config = Self::auto();
config.quantization = Quantization::Q4K;
config.batch_size = 1;
config.block_size = 32;
config
}
pub fn high_throughput() -> Self {
let caps = SystemCapabilities::detect();
let mut config = caps.optimal_config();
config.batch_size = (config.batch_size * 2).min(32);
config.block_size = 128;
config
}
pub fn low_latency() -> Self {
let mut config = Self::auto();
config.batch_size = 1;
config.block_size = 32;
let caps = SystemCapabilities::detect();
config.thread_count = caps.cores.logical_cores;
config
}
pub fn estimated_tokens_per_second(&self) -> f32 {
let base = match self.compute_backend {
ComputeBackend::HybridAne => 90.0, ComputeBackend::Metal => 80.0,
ComputeBackend::CoreML => 60.0, ComputeBackend::Cuda => 100.0,
ComputeBackend::WebGPU => 40.0,
ComputeBackend::CpuAvx512 => 30.0,
ComputeBackend::CpuAvx2 => 20.0,
ComputeBackend::CpuNeon => 20.0,
ComputeBackend::CpuScalar => 5.0,
};
let quant_factor = match self.quantization {
Quantization::Q4 | Quantization::Q4K => 2.0, Quantization::Q8 => 1.5,
Quantization::F16 | Quantization::Bf16 => 1.0,
Quantization::None => 0.5,
Quantization::Q2K => 2.5, };
let batch_factor = (self.batch_size as f32).sqrt();
base * quant_factor * batch_factor
}
pub fn power_efficient() -> Self {
let caps = SystemCapabilities::detect();
let mut config = caps.optimal_config();
config.compute_backend = caps.select_power_efficient_backend();
config.batch_size = 1;
config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_platform_detection() {
let platform = Platform::detect();
#[cfg(target_os = "macos")]
assert_eq!(platform, Platform::MacOS);
#[cfg(target_os = "linux")]
assert_eq!(platform, Platform::Linux);
#[cfg(target_os = "windows")]
assert_eq!(platform, Platform::Windows);
}
#[test]
fn test_architecture_detection() {
let arch = Architecture::detect();
#[cfg(target_arch = "aarch64")]
assert_eq!(arch, Architecture::Aarch64);
#[cfg(target_arch = "x86_64")]
assert_eq!(arch, Architecture::X86_64);
}
#[test]
fn test_cpu_features_detection() {
let features = CpuFeatures::detect();
#[cfg(target_arch = "aarch64")]
assert!(features.neon, "NEON should always be available on aarch64");
#[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))]
assert!(
features.best_simd_width() >= 128,
"Should have at least 128-bit SIMD"
);
}
#[test]
fn test_system_capabilities_detect() {
let caps = SystemCapabilities::detect();
assert!(caps.cores.physical_cores >= 1);
assert!(caps.cores.logical_cores >= 1);
assert!(caps.memory_mb > 0, "Memory should be detected");
#[cfg(target_os = "macos")]
assert_eq!(caps.platform, Platform::MacOS);
#[cfg(target_arch = "aarch64")]
assert_eq!(caps.arch, Architecture::Aarch64);
}
#[test]
fn test_optimal_config() {
let caps = SystemCapabilities::detect();
let config = caps.optimal_config();
assert!(config.batch_size >= 1);
assert!(config.thread_count >= 1);
assert!(config.block_size >= 16);
#[cfg(all(target_os = "macos", feature = "metal-compute"))]
{
if caps.gpu.is_some() {
assert_eq!(config.compute_backend, ComputeBackend::Metal);
}
}
#[cfg(target_arch = "aarch64")]
{
if !config.compute_backend.is_gpu() {
assert_eq!(config.compute_backend, ComputeBackend::CpuNeon);
}
}
}
#[test]
fn test_inference_config_auto() {
let config = InferenceConfig::auto();
assert!(config.batch_size >= 1);
assert!(config.thread_count >= 1);
assert!(config.use_flash_attention);
}
#[test]
fn test_inference_config_presets() {
let low_mem = InferenceConfig::low_memory();
let high_throughput = InferenceConfig::high_throughput();
let low_latency = InferenceConfig::low_latency();
assert!(matches!(
low_mem.quantization,
Quantization::Q4 | Quantization::Q4K | Quantization::Q2K
));
assert_eq!(low_mem.batch_size, 1);
assert_eq!(low_latency.batch_size, 1);
assert!(high_throughput.batch_size >= 2);
}
#[test]
fn test_optimal_quantization() {
let caps = SystemCapabilities::detect();
let quant_small = caps.optimal_quantization(1.0);
let quant_large = caps.optimal_quantization(70.0);
assert!(
quant_large.bytes_per_weight() <= quant_small.bytes_per_weight(),
"Larger models should use more aggressive quantization"
);
}
#[test]
fn test_recommended_batch_size() {
let caps = SystemCapabilities::detect();
let batch_short = caps.recommended_batch_size(512);
let batch_long = caps.recommended_batch_size(8192);
assert!(
batch_short >= batch_long,
"Shorter sequences should allow larger batches"
);
}
#[test]
fn test_can_run_model() {
let caps = SystemCapabilities::detect();
assert!(caps.can_run_model(0.1), "Should be able to run 100MB model");
assert!(
!caps.can_run_model(1000.0),
"Should not be able to run 1TB model"
);
}
#[test]
fn test_system_summary() {
let caps = SystemCapabilities::detect();
let summary = caps.summary();
assert!(!summary.is_empty());
assert!(summary.contains("cores") || summary.contains("RAM"));
}
#[test]
fn test_compute_backend_properties() {
assert!(ComputeBackend::Metal.is_gpu());
assert!(ComputeBackend::Cuda.is_gpu());
assert!(!ComputeBackend::CpuNeon.is_gpu());
assert!(!ComputeBackend::CpuScalar.is_gpu());
assert!(
ComputeBackend::Metal.relative_performance()
> ComputeBackend::CpuNeon.relative_performance()
);
}
#[test]
fn test_gpu_can_fit_model() {
let gpu = GpuCapabilities {
backend: GpuBackend::Metal,
vram_mb: Some(16 * 1024), compute_units: Some(128),
name: Some("Test GPU".to_string()),
supports_fp16: true,
supports_int8: true,
has_tensor_cores: true,
max_shared_memory: Some(32 * 1024),
};
assert!(gpu.can_fit_model(7.0));
assert!(!gpu.can_fit_model(70.0));
}
#[test]
fn test_core_info() {
let cores = CoreInfo::detect();
assert!(cores.physical_cores >= 1);
assert!(cores.logical_cores >= 1);
assert!(cores.logical_cores >= cores.physical_cores);
let recommended = cores.recommended_threads();
assert!(recommended >= 1);
assert!(recommended <= cores.logical_cores);
}
#[test]
fn test_estimated_tokens_per_second() {
let config = InferenceConfig::auto();
let tps = config.estimated_tokens_per_second();
assert!(tps > 0.0);
let low_latency = InferenceConfig::low_latency();
let tps_low_latency = low_latency.estimated_tokens_per_second();
assert!(tps_low_latency > 0.0);
}
#[test]
fn test_ane_info_detect() {
let ane = AneInfo::detect();
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
assert!(ane.available, "ANE should be available on Apple Silicon");
assert!(ane.tops > 0.0, "ANE TOPS should be positive");
assert!(
ane.max_model_size_mb > 0,
"ANE max model size should be positive"
);
assert!(
!ane.supported_ops.is_empty(),
"ANE should have supported ops"
);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
assert!(
!ane.available,
"ANE should not be available on non-Apple Silicon"
);
}
}
#[test]
fn test_ane_model_suitability() {
let ane = AneInfo {
available: true,
tops: 38.0,
max_model_size_mb: 2048,
supported_ops: vec!["MatMul".to_string()],
};
assert!(ane.is_model_suitable(500));
assert!(ane.is_model_suitable(2048));
assert!(!ane.is_model_suitable(4096));
assert!(!ane.is_model_suitable(8192));
}
#[test]
fn test_ane_strategy_recommendation() {
let ane = AneInfo {
available: true,
tops: 38.0,
max_model_size_mb: 2048,
supported_ops: vec!["MatMul".to_string()],
};
assert_eq!(ane.recommended_strategy(300), AneStrategy::AneOnly);
assert_eq!(ane.recommended_strategy(1000), AneStrategy::Hybrid);
assert_eq!(ane.recommended_strategy(4000), AneStrategy::GpuOnly);
}
#[test]
fn test_ane_strategy_unavailable() {
let ane = AneInfo {
available: false,
tops: 0.0,
max_model_size_mb: 0,
supported_ops: vec![],
};
assert_eq!(ane.recommended_strategy(100), AneStrategy::GpuOnly);
assert_eq!(ane.recommended_strategy(1000), AneStrategy::GpuOnly);
assert_eq!(ane.recommended_strategy(10000), AneStrategy::GpuOnly);
}
#[test]
fn test_compute_backend_ane_properties() {
assert!(ComputeBackend::CoreML.uses_ane());
assert!(ComputeBackend::HybridAne.uses_ane());
assert!(!ComputeBackend::Metal.uses_ane());
assert!(!ComputeBackend::Cuda.uses_ane());
assert!(!ComputeBackend::CpuNeon.uses_ane());
assert!(ComputeBackend::CoreML.is_gpu());
assert!(ComputeBackend::HybridAne.is_gpu());
}
#[test]
fn test_compute_backend_power_efficiency() {
assert!(
ComputeBackend::CoreML.power_efficiency() > ComputeBackend::Metal.power_efficiency(),
"CoreML should be more power efficient than Metal"
);
assert!(
ComputeBackend::HybridAne.power_efficiency() > ComputeBackend::Metal.power_efficiency(),
"HybridAne should be more power efficient than Metal"
);
}
#[test]
fn test_system_capabilities_includes_ane() {
let caps = SystemCapabilities::detect();
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
assert!(caps.ane.available);
let summary = caps.summary();
assert!(summary.contains("ANE"), "Summary should include ANE info");
}
}
#[test]
fn test_ane_summary() {
let caps = SystemCapabilities::detect();
let ane_summary = caps.ane_summary();
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
assert!(ane_summary.contains("TOPS"));
assert!(ane_summary.contains("supported ops"));
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
assert!(ane_summary.contains("Not available"));
}
}
#[test]
fn test_power_efficient_config() {
let config = InferenceConfig::power_efficient();
assert_eq!(config.batch_size, 1);
#[cfg(all(target_os = "macos", target_arch = "aarch64", feature = "coreml"))]
{
assert!(
config.compute_backend.uses_ane(),
"Power efficient config should use ANE on Apple Silicon"
);
}
}
#[test]
fn test_select_compute_backend_for_model_size() {
let caps = SystemCapabilities::detect();
let _small_backend = caps.select_compute_backend_for_model(500.0);
let _medium_backend = caps.select_compute_backend_for_model(2000.0);
let _large_backend = caps.select_compute_backend_for_model(10000.0);
}
}