use crate::dtype::DType;
use crate::error::{RusTorchError, RusTorchResult};
use crate::gpu::device_cache::{CoreMLCache, DeviceCache};
use crate::gpu::smart_device_selector::{OperationProfile, OperationType, SmartDeviceSelector};
use crate::gpu::{DeviceCapability, DeviceType, GpuDevice, OpType};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct TensorInfo {
pub dtype: DType,
pub shape: Vec<usize>,
pub requires_custom_kernel: bool,
pub memory_size_bytes: usize,
}
#[derive(Debug, Clone, Copy)]
pub enum TransferMethod {
ZeroCopy, HostStaging, Standard, }
pub struct HybridExecutor {
primary_device: DeviceType,
fallback_devices: Vec<DeviceType>,
capability_cache: HashMap<DeviceType, DeviceCapability>,
operation_routing: HashMap<OpType, Vec<DeviceType>>,
small_tensor_threshold: usize, large_tensor_threshold: usize,
smart_selector: SmartDeviceSelector,
device_cache: DeviceCache,
}
impl HybridExecutor {
pub fn new() -> Self {
let device_cache = DeviceCache::new();
device_cache.warmup();
let available_devices = Self::detect_available_devices(&device_cache);
let mut executor = Self {
primary_device: DeviceType::Auto,
fallback_devices: Vec::new(),
capability_cache: HashMap::new(),
operation_routing: HashMap::new(),
small_tensor_threshold: 1_000_000, large_tensor_threshold: 100_000_000, smart_selector: SmartDeviceSelector::new(available_devices),
device_cache,
};
executor.initialize_device_capabilities();
executor.build_fallback_chain();
executor.setup_operation_routing();
executor
}
pub fn global() -> &'static Self {
use std::sync::OnceLock;
static EXECUTOR: OnceLock<HybridExecutor> = OnceLock::new();
EXECUTOR.get_or_init(|| Self::new())
}
pub fn select_device(&self, op_type: OpType, tensor_info: &TensorInfo) -> DeviceType {
if tensor_info.memory_size_bytes < self.small_tensor_threshold {
return DeviceType::Cpu;
}
if self.is_coreml_supported(&op_type, tensor_info) {
#[cfg(feature = "coreml")]
if self.is_device_available(DeviceType::CoreML(0)) {
return DeviceType::CoreML(0);
}
}
if let Some(gpu_device) = self.select_gpu_device(&op_type, tensor_info) {
return gpu_device;
}
DeviceType::Cpu
}
fn is_coreml_supported(&self, op_type: &OpType, tensor_info: &TensorInfo) -> bool {
#[cfg(feature = "coreml")]
{
if let Some(capability) = self.capability_cache.get(&DeviceType::CoreML(0)) {
if !capability.supports_operation(op_type) {
return false;
}
match tensor_info.dtype {
DType::Float16 | DType::Float32 => {
tensor_info.shape.len() <= 5 && !tensor_info.requires_custom_kernel
}
_ => false,
}
} else {
false
}
}
#[cfg(not(feature = "coreml"))]
{
false
}
}
fn select_gpu_device(&self, op_type: &OpType, tensor_info: &TensorInfo) -> Option<DeviceType> {
if let Some(devices) = self.operation_routing.get(op_type) {
for &device in devices {
if self.is_device_available(device)
&& self.is_operation_efficient(device, tensor_info)
{
return Some(device);
}
}
}
#[cfg(feature = "cuda")]
if self.is_device_available(DeviceType::Cuda(0)) {
return Some(DeviceType::Cuda(0));
}
#[cfg(feature = "metal")]
if self.is_device_available(DeviceType::Metal(0)) {
return Some(DeviceType::Metal(0));
}
#[cfg(feature = "opencl")]
if self.is_device_available(DeviceType::OpenCL(0)) {
return Some(DeviceType::OpenCL(0));
}
None
}
fn is_device_available(&self, device: DeviceType) -> bool {
match device {
DeviceType::Cpu => true,
#[cfg(feature = "cuda")]
DeviceType::Cuda(_) => crate::backends::DeviceManager::is_cuda_available(),
#[cfg(feature = "metal")]
DeviceType::Metal(_) => crate::backends::DeviceManager::is_metal_available(),
#[cfg(feature = "opencl")]
DeviceType::OpenCL(_) => crate::backends::DeviceManager::is_opencl_available(),
#[cfg(feature = "coreml")]
DeviceType::CoreML(_) => crate::backends::DeviceManager::is_coreml_available(),
_ => false,
}
}
fn is_operation_efficient(&self, device: DeviceType, tensor_info: &TensorInfo) -> bool {
if tensor_info.memory_size_bytes > self.large_tensor_threshold {
return matches!(
device,
DeviceType::Cuda(_) | DeviceType::Metal(_) | DeviceType::CoreML(_)
);
}
true
}
pub fn next_fallback_device(&self, failed_device: DeviceType) -> DeviceType {
if let Some(pos) = self
.fallback_devices
.iter()
.position(|&d| d == failed_device)
{
if pos + 1 < self.fallback_devices.len() {
return self.fallback_devices[pos + 1];
}
}
DeviceType::Cpu
}
fn initialize_device_capabilities(&mut self) {
let mut cpu_ops = HashSet::new();
cpu_ops.insert(OpType::LinearAlgebra);
cpu_ops.insert(OpType::Convolution);
cpu_ops.insert(OpType::Activation);
cpu_ops.insert(OpType::Reduction);
cpu_ops.insert(OpType::Normalization);
cpu_ops.insert(OpType::ComplexMath);
cpu_ops.insert(OpType::Distribution);
cpu_ops.insert(OpType::DistributedOps);
self.capability_cache.insert(
DeviceType::Cpu,
DeviceCapability {
device_type: DeviceType::Cpu,
supports_f16: false,
supports_f32: true,
supports_f64: true,
supports_complex: true,
supports_distributed: true,
max_memory_gb: 32.0,
supported_operations: cpu_ops,
},
);
#[cfg(feature = "coreml")]
{
self.capability_cache
.insert(DeviceType::CoreML(0), DeviceCapability::coreml_capability());
}
}
fn detect_available_devices(device_cache: &DeviceCache) -> Vec<DeviceType> {
let mut available = Vec::new();
available.push(DeviceType::Cpu);
if device_cache.is_device_available(&DeviceType::CoreML(0)) {
available.push(DeviceType::CoreML(0));
}
if device_cache.is_device_available(&DeviceType::Metal(0)) {
available.push(DeviceType::Metal(0));
}
if device_cache.is_device_available(&DeviceType::Cuda(0)) {
available.push(DeviceType::Cuda(0));
}
available
}
pub fn select_optimal_device(&self, tensor_info: &TensorInfo, op_type: OpType) -> DeviceType {
let operation_type = match op_type {
OpType::LinearAlgebra => OperationType::MatrixMultiplication,
OpType::Activation => OperationType::Activation,
OpType::Convolution => OperationType::Convolution,
OpType::Reduction | OpType::Normalization => OperationType::ElementWise,
OpType::ComplexMath => OperationType::ComplexNumber,
OpType::Distribution => OperationType::StatisticalDistribution,
OpType::CustomKernel => OperationType::CustomKernel,
OpType::DistributedOps => OperationType::DistributedOp,
};
let profile = OperationProfile::new(
operation_type,
&tensor_info.shape,
self.get_dtype_size(&tensor_info.dtype),
);
let selected = self.smart_selector.select_device(&profile);
if self.device_cache.is_device_available(&selected) {
selected
} else {
self.fallback_devices
.first()
.cloned()
.unwrap_or(DeviceType::Cpu)
}
}
pub fn get_operation_fallback_chain(
&self,
tensor_info: &TensorInfo,
op_type: OpType,
) -> Vec<DeviceType> {
let operation_type = match op_type {
OpType::LinearAlgebra => OperationType::MatrixMultiplication,
OpType::Activation => OperationType::Activation,
OpType::Convolution => OperationType::Convolution,
OpType::Reduction | OpType::Normalization => OperationType::ElementWise,
_ => OperationType::ElementWise,
};
let profile = OperationProfile::new(
operation_type,
&tensor_info.shape,
self.get_dtype_size(&tensor_info.dtype),
);
self.smart_selector.get_fallback_chain(&profile)
}
fn get_dtype_size(&self, dtype: &DType) -> usize {
match dtype {
DType::Float16 | DType::BFloat16 => 2,
DType::Float32 => 4,
DType::Float64 => 8,
DType::Int8 => 1,
DType::Int16 => 2,
DType::Int32 => 4,
DType::Int64 => 8,
DType::UInt8 => 1,
DType::UInt16 => 2,
DType::UInt32 => 4,
DType::UInt64 => 8,
DType::Bool => 1,
DType::Complex64 => 8, DType::Complex128 => 16, }
}
fn build_fallback_chain(&mut self) {
self.fallback_devices.clear();
#[cfg(feature = "coreml")]
if self.is_device_available(DeviceType::CoreML(0)) {
self.fallback_devices.push(DeviceType::CoreML(0));
}
if self.is_apple_silicon() {
#[cfg(feature = "metal")]
if self.is_device_available(DeviceType::Metal(0)) {
self.fallback_devices.push(DeviceType::Metal(0));
}
} else {
#[cfg(feature = "cuda")]
if self.is_device_available(DeviceType::Cuda(0)) {
self.fallback_devices.push(DeviceType::Cuda(0));
}
}
#[cfg(feature = "opencl")]
if self.is_device_available(DeviceType::OpenCL(0)) {
self.fallback_devices.push(DeviceType::OpenCL(0));
}
self.fallback_devices.push(DeviceType::Cpu);
#[cfg(debug_assertions)]
{
eprintln!("🔄 Fallback chain: {:?}", self.fallback_devices);
}
}
fn is_apple_silicon(&self) -> bool {
#[cfg(target_os = "macos")]
{
cfg!(target_arch = "aarch64")
}
#[cfg(not(target_os = "macos"))]
{
false
}
}
fn setup_operation_routing(&mut self) {
self.operation_routing.insert(
OpType::ComplexMath,
vec![
DeviceType::Cuda(0),
DeviceType::Metal(0),
DeviceType::OpenCL(0),
DeviceType::Cpu,
],
);
self.operation_routing.insert(
OpType::DistributedOps,
vec![
DeviceType::Cpu, ],
);
self.operation_routing.insert(
OpType::CustomKernel,
vec![
DeviceType::Cuda(0),
DeviceType::Metal(0),
DeviceType::OpenCL(0),
DeviceType::Cpu,
],
);
}
pub fn has_gpu_support(&self) -> bool {
#[cfg(any(feature = "cuda", feature = "metal", feature = "opencl"))]
{
self.is_device_available(DeviceType::Cuda(0))
|| self.is_device_available(DeviceType::Metal(0))
|| self.is_device_available(DeviceType::OpenCL(0))
}
#[cfg(not(any(feature = "cuda", feature = "metal", feature = "opencl")))]
{
false
}
}
}
impl Default for HybridExecutor {
fn default() -> Self {
Self::new()
}
}
impl HybridExecutor {
pub fn hybrid_operation<F, R>(
&self,
op_type: OpType,
tensor_info: TensorInfo,
operation: F,
) -> RusTorchResult<R>
where
F: Fn(DeviceType) -> RusTorchResult<R>,
{
let device = self.select_device(op_type, &tensor_info);
match operation(device) {
Ok(result) => Ok(result),
Err(err) => {
for fallback_device in self.get_fallback_chain(device) {
if let Ok(result) = operation(fallback_device) {
return Ok(result);
}
}
Err(err)
}
}
}
pub fn execute<F, R>(
&self,
op_type: OpType,
tensor_info: TensorInfo,
operation: F,
) -> RusTorchResult<R>
where
F: Fn(DeviceType) -> RusTorchResult<R>,
{
let device = self.select_device(op_type, &tensor_info);
match operation(device) {
Ok(result) => Ok(result),
Err(err) => {
for fallback_device in self.get_fallback_chain(device) {
if let Ok(result) = operation(fallback_device) {
return Ok(result);
}
}
Err(err)
}
}
}
fn get_fallback_chain(&self, device: DeviceType) -> Vec<DeviceType> {
match device {
DeviceType::CoreML(_) => {
vec![DeviceType::Metal(0), DeviceType::Cuda(0), DeviceType::Cpu]
}
DeviceType::Metal(_) => vec![DeviceType::Cuda(0), DeviceType::Cpu],
DeviceType::Cuda(_) => vec![DeviceType::Metal(0), DeviceType::Cpu],
DeviceType::OpenCL(_) => vec![DeviceType::Cpu],
_ => vec![DeviceType::Cpu],
}
}
}
pub trait HybridExecution<T> {
fn hybrid_operation<F, R>(&self, op_type: OpType, operation: F) -> RusTorchResult<R>
where
F: Fn(DeviceType) -> RusTorchResult<R>;
fn tensor_info(&self) -> TensorInfo;
}