use crate::error::RusTorchResult;
use crate::hybrid_f32::tensor::core::F32Tensor;
#[cfg(target_os = "macos")]
pub mod coreml;
#[cfg(target_os = "macos")]
pub mod metal;
#[cfg(target_os = "macos")]
pub use coreml::F32CoreMLExecutor;
#[cfg(target_os = "macos")]
pub use metal::F32MetalExecutor;
pub trait F32GPUExecutor {
fn initialize(&mut self, device_id: usize) -> RusTorchResult<()>;
fn transfer_to_gpu(&self, tensor: &mut F32Tensor) -> RusTorchResult<()>;
fn transfer_to_cpu(&self, tensor: &mut F32Tensor) -> RusTorchResult<()>;
fn matmul_f32(&self, a: &F32Tensor, b: &F32Tensor) -> RusTorchResult<F32Tensor>;
fn conv2d_f32(
&self,
input: &F32Tensor,
kernel: &F32Tensor,
stride: (usize, usize),
padding: (usize, usize),
) -> RusTorchResult<F32Tensor>;
fn get_performance_info(&self) -> DevicePerformanceInfo;
fn parallel_reduction_f32(&self, tensor: &F32Tensor, operation: &str) -> RusTorchResult<f32>;
fn statistical_processing_f32(
&self,
tensor: &F32Tensor,
operation: &str,
) -> RusTorchResult<f32>;
}
#[derive(Debug, Clone)]
pub struct DevicePerformanceInfo {
pub device_name: String,
pub memory_bandwidth: f64, pub compute_units: usize,
pub supports_f32: bool,
pub supports_float16: bool,
pub estimated_tflops_f32: f64,
pub estimated_tflops_f16: f64,
}
impl DevicePerformanceInfo {
pub fn cpu_baseline() -> Self {
Self {
device_name: "CPU".to_string(),
memory_bandwidth: 50.0,
compute_units: 8,
supports_f32: true,
supports_float16: false,
estimated_tflops_f32: 0.5,
estimated_tflops_f16: 0.0,
}
}
pub fn metal_gpu_m1() -> Self {
Self {
device_name: "Apple M1 GPU".to_string(),
memory_bandwidth: 68.25,
compute_units: 8,
supports_f32: true,
supports_float16: true,
estimated_tflops_f32: 2.6,
estimated_tflops_f16: 5.2,
}
}
pub fn neural_engine_m1() -> Self {
Self {
device_name: "Apple M1 Neural Engine".to_string(),
memory_bandwidth: 68.25,
compute_units: 16,
supports_f32: true,
supports_float16: true,
estimated_tflops_f32: 7.0, estimated_tflops_f16: 15.8, }
}
}
#[derive(Debug)]
pub struct F32UnifiedGPUContext {
#[cfg(target_os = "macos")]
metal_executor: Option<F32MetalExecutor>,
#[cfg(target_os = "macos")]
coreml_executor: Option<F32CoreMLExecutor>,
current_device: GPUDevice,
}
#[derive(Debug, Clone)]
pub enum GPUDevice {
CPU,
#[cfg(target_os = "macos")]
Metal(usize),
#[cfg(target_os = "macos")]
CoreML(usize),
}
impl F32UnifiedGPUContext {
pub fn new() -> Self {
Self {
#[cfg(target_os = "macos")]
metal_executor: None,
#[cfg(target_os = "macos")]
coreml_executor: None,
current_device: GPUDevice::CPU,
}
}
pub fn detect_available_devices(&self) -> Vec<(GPUDevice, DevicePerformanceInfo)> {
let mut devices = vec![(GPUDevice::CPU, DevicePerformanceInfo::cpu_baseline())];
#[cfg(target_os = "macos")]
{
devices.push((GPUDevice::Metal(0), DevicePerformanceInfo::metal_gpu_m1()));
}
#[cfg(all(target_os = "macos", feature = "coreml"))]
{
devices.push((
GPUDevice::CoreML(0),
DevicePerformanceInfo::neural_engine_m1(),
));
}
devices
}
pub fn select_optimal_device(&self, operation: &str, tensor_size: usize) -> GPUDevice {
match (operation, tensor_size) {
("matmul", size) if size > 50000 => GPUDevice::Metal(0),
("conv2d", size) if size > 1000 => GPUDevice::CoreML(0),
("activation", _) => GPUDevice::CoreML(0),
_ => GPUDevice::CPU,
}
}
pub fn initialize_device(&mut self, device: GPUDevice) -> RusTorchResult<()> {
match device {
#[cfg(target_os = "macos")]
GPUDevice::Metal(device_id) => {
let mut executor = F32MetalExecutor::new();
executor.initialize(device_id)?;
self.metal_executor = Some(executor);
self.current_device = device;
}
#[cfg(target_os = "macos")]
GPUDevice::CoreML(device_id) => {
let mut executor = F32CoreMLExecutor::new();
executor.initialize(device_id)?;
self.coreml_executor = Some(executor);
self.current_device = device;
}
GPUDevice::CPU => {
self.current_device = device;
}
}
Ok(())
}
pub fn execute_matmul(&self, a: &F32Tensor, b: &F32Tensor) -> RusTorchResult<F32Tensor> {
crate::hybrid_f32_experimental!();
match &self.current_device {
#[cfg(target_os = "macos")]
GPUDevice::Metal(_) => {
if let Some(executor) = &self.metal_executor {
executor.matmul_f32(a, b)
} else {
Err(crate::error::RusTorchError::BackendUnavailable {
backend: "Metal".to_string(),
})
}
}
#[cfg(target_os = "macos")]
GPUDevice::CoreML(_) => {
if let Some(executor) = &self.coreml_executor {
executor.matmul_f32(a, b)
} else {
Err(crate::error::RusTorchError::BackendUnavailable {
backend: "CoreML".to_string(),
})
}
}
GPUDevice::CPU => {
a.matmul(b)
}
}
}
pub fn execute_parallel_reduction(
&self,
tensor: &F32Tensor,
operation: &str,
) -> RusTorchResult<f32> {
crate::hybrid_f32_experimental!();
match &self.current_device {
#[cfg(target_os = "macos")]
GPUDevice::Metal(_) => {
if let Some(executor) = &self.metal_executor {
executor.parallel_reduction_f32(tensor, operation)
} else {
Err(crate::error::RusTorchError::BackendUnavailable {
backend: "Metal".to_string(),
})
}
}
#[cfg(target_os = "macos")]
GPUDevice::CoreML(_) => {
if let Some(executor) = &self.coreml_executor {
executor.parallel_reduction_f32(tensor, operation)
} else {
Err(crate::error::RusTorchError::BackendUnavailable {
backend: "CoreML".to_string(),
})
}
}
GPUDevice::CPU => {
match operation {
"sum" => tensor.sum(),
"mean" => tensor.mean(),
"min" => tensor.min(),
"max" => tensor.max(),
_ => Err(crate::error::RusTorchError::tensor_op(&format!(
"Unsupported reduction operation: {}",
operation
))),
}
}
}
}
pub fn execute_statistical_processing(
&self,
tensor: &F32Tensor,
operation: &str,
) -> RusTorchResult<f32> {
crate::hybrid_f32_experimental!();
match &self.current_device {
#[cfg(target_os = "macos")]
GPUDevice::Metal(_) => {
if let Some(executor) = &self.metal_executor {
executor.statistical_processing_f32(tensor, operation)
} else {
Err(crate::error::RusTorchError::BackendUnavailable {
backend: "Metal".to_string(),
})
}
}
#[cfg(target_os = "macos")]
GPUDevice::CoreML(_) => {
if let Some(executor) = &self.coreml_executor {
executor.statistical_processing_f32(tensor, operation)
} else {
Err(crate::error::RusTorchError::BackendUnavailable {
backend: "CoreML".to_string(),
})
}
}
GPUDevice::CPU => {
match operation {
"std" => {
let mean_val = tensor.mean()?;
let variance = tensor
.data
.iter()
.map(|&x| (x - mean_val).powi(2))
.sum::<f32>()
/ (tensor.data.len() as f32);
Ok(variance.sqrt())
}
"variance" => {
let mean_val = tensor.mean()?;
let variance = tensor
.data
.iter()
.map(|&x| (x - mean_val).powi(2))
.sum::<f32>()
/ (tensor.data.len() as f32);
Ok(variance)
}
_ => Err(crate::error::RusTorchError::tensor_op(&format!(
"Unsupported statistics operation: {}",
operation
))),
}
}
}
}
}