use crate::Result;
pub mod benchmarks;
pub mod device;
pub mod mps_integration;
pub mod operations;
pub mod types;
pub mod shaders {
pub const METAL_KERNELS_SOURCE: &str = include_str!("shaders/metal_kernels.metal");
}
pub use device::{DeviceCapabilities, MetalDevice};
pub use types::{
ActivationType, BenchmarkResult, ConvConfig, DispatchConfig, ElementwiseOp, LayerConfig,
LayerType, MemoryAccessPattern, MetalKernelConfig, ReductionOp,
};
#[cfg(all(target_os = "macos", feature = "metal"))]
pub use benchmarks::{ConvConfig as BenchmarkConvConfig, MetalBenchmark};
#[cfg(all(target_os = "macos", feature = "metal"))]
pub use mps_integration::{LayerConfig as MPSLayerConfig, LayerType as MPSLayerType, MPSNeuralOps};
#[cfg(not(all(target_os = "macos", feature = "metal")))]
pub mod metal_stub {
use crate::{Result, TensorError};
pub fn metal_not_available() -> Result<()> {
Err(TensorError::device_error_simple(
"Metal kernels are only available on macOS with the 'metal' feature enabled"
.to_string(),
))
}
pub struct MetalDevice;
impl MetalDevice {
pub fn new() -> Result<Self> {
metal_not_available()?;
Ok(MetalDevice)
}
}
pub struct MetalBenchmark;
impl MetalBenchmark {
pub fn new() -> Result<Self> {
metal_not_available()?;
Ok(MetalBenchmark)
}
}
pub struct MPSNeuralOps;
impl MPSNeuralOps {
pub fn new() -> Result<Self> {
metal_not_available()?;
Ok(MPSNeuralOps)
}
}
}
#[cfg(not(all(target_os = "macos", feature = "metal")))]
pub use metal_stub::*;
#[cfg(all(target_os = "macos", feature = "metal"))]
#[derive(Debug)]
pub struct MetalKernels {
device: MetalDevice,
benchmark: Option<MetalBenchmark>,
mps_ops: Option<MPSNeuralOps>,
}
#[cfg(all(target_os = "macos", feature = "metal"))]
impl MetalKernels {
pub fn new() -> Result<Self> {
Ok(MetalKernels {
device: MetalDevice::new()?,
benchmark: None,
mps_ops: None,
})
}
pub fn device_mut(&mut self) -> &mut MetalDevice {
&mut self.device
}
pub fn device(&self) -> &MetalDevice {
&self.device
}
pub fn with_benchmarking(mut self) -> Result<Self> {
self.benchmark = Some(MetalBenchmark::new()?);
Ok(self)
}
pub fn with_mps_ops(mut self) -> Result<Self> {
self.mps_ops = Some(MPSNeuralOps::new()?);
Ok(self)
}
pub fn benchmark_mut(&mut self) -> Option<&mut MetalBenchmark> {
self.benchmark.as_mut()
}
pub fn benchmark(&self) -> Option<&MetalBenchmark> {
self.benchmark.as_ref()
}
pub fn mps_ops_mut(&mut self) -> Option<&mut MPSNeuralOps> {
self.mps_ops.as_mut()
}
pub fn mps_ops(&self) -> Option<&MPSNeuralOps> {
self.mps_ops.as_ref()
}
pub fn get_capabilities(&self) -> DeviceCapabilities {
self.device.get_device_capabilities()
}
pub fn run_benchmarks(&mut self) -> Result<Vec<BenchmarkResult>> {
match self.benchmark.as_mut() {
Some(benchmark) => benchmark.run_comprehensive_benchmarks(),
None => Err(crate::TensorError::invalid_operation_simple(
"Benchmarking not initialized. Call with_benchmarking() first".to_string(),
)),
}
}
pub fn generate_performance_report(&self) -> Result<String> {
match self.benchmark.as_ref() {
Some(benchmark) => Ok(benchmark.generate_report()),
None => Err(crate::TensorError::invalid_operation_simple(
"Benchmarking not initialized. Call with_benchmarking() first".to_string(),
)),
}
}
}
#[cfg(not(all(target_os = "macos", feature = "metal")))]
#[derive(Debug)]
pub struct MetalKernels;
#[cfg(not(all(target_os = "macos", feature = "metal")))]
impl MetalKernels {
pub fn new() -> Result<Self> {
metal_stub::metal_not_available()?;
Ok(MetalKernels)
}
}
pub fn create_metal_kernels() -> Result<MetalKernels> {
MetalKernels::new()
}
pub fn create_metal_kernels_full() -> Result<MetalKernels> {
#[cfg(all(target_os = "macos", feature = "metal"))]
{
MetalKernels::new()?.with_benchmarking()?.with_mps_ops()
}
#[cfg(not(all(target_os = "macos", feature = "metal")))]
{
MetalKernels::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(all(target_os = "macos", feature = "metal"))]
fn test_metal_kernels_creation() {
let result = MetalKernels::new();
assert!(result.is_ok() || result.unwrap_err().to_string().contains("No Metal device"));
}
#[test]
#[cfg(all(target_os = "macos", feature = "metal"))]
fn test_metal_kernels_with_features() {
if let Ok(mut kernels) = MetalKernels::new() {
let result = kernels.with_benchmarking().and_then(|k| k.with_mps_ops());
assert!(result.is_ok() || result.unwrap_err().to_string().contains("No Metal device"));
}
}
#[test]
#[cfg(all(target_os = "macos", feature = "metal"))]
fn test_convenience_functions() {
let basic = create_metal_kernels();
let full = create_metal_kernels_full();
if basic.is_ok() {
assert!(full.is_ok());
} else {
assert!(full.is_err());
}
}
#[test]
#[cfg(not(all(target_os = "macos", feature = "metal")))]
fn test_metal_not_available() {
let result = MetalKernels::new();
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Metal kernels are only available on macOS"));
}
#[test]
fn test_shader_source_inclusion() {
assert!(!shaders::METAL_KERNELS_SOURCE.is_empty());
assert!(shaders::METAL_KERNELS_SOURCE.contains("elementwise_add"));
assert!(shaders::METAL_KERNELS_SOURCE.contains("matrix_multiply_naive"));
}
#[test]
#[cfg(all(target_os = "macos", feature = "metal"))]
fn test_device_capabilities() {
if let Ok(kernels) = MetalKernels::new() {
let capabilities = kernels.get_capabilities();
assert!(capabilities.max_threads_per_threadgroup > 0);
assert!(capabilities.compute_units > 0);
assert!(capabilities.memory_bandwidth_gbps > 0.0);
}
}
}