use crate::error::{HiveGpuError, Result};
use crate::traits::{GpuBackend, GpuContext};
use crate::types::{GpuCapabilities, GpuDeviceInfo, GpuMemoryStats};
use objc2::rc::Retained;
use objc2::runtime::ProtocolObject;
use objc2_metal::{
MTLCommandQueue, MTLCompileOptions, MTLComputePipelineState, MTLCreateSystemDefaultDevice,
MTLDevice, MTLGPUFamily, MTLLibrary, MTLSize,
};
use std::process::Command;
use std::sync::Arc;
use tracing::{debug, warn};
#[cfg(all(target_os = "macos", feature = "metal-native"))]
#[derive(Debug, Clone)]
pub struct MetalNativeContext {
device: Retained<ProtocolObject<dyn MTLDevice>>,
command_queue: Retained<ProtocolObject<dyn MTLCommandQueue>>,
library: Retained<ProtocolObject<dyn MTLLibrary>>,
}
#[cfg(all(target_os = "macos", feature = "metal-native"))]
impl MetalNativeContext {
pub fn new() -> Result<Self> {
let device =
MTLCreateSystemDefaultDevice().ok_or_else(|| HiveGpuError::NoDeviceAvailable)?;
let command_queue = device
.newCommandQueue()
.ok_or_else(|| HiveGpuError::Other("Failed to create command queue".to_string()))?;
let library = Self::load_metal_library(&device)?;
debug!("✅ Metal native context created: {}", device.name());
Ok(Self {
device,
command_queue,
library,
})
}
pub fn device(&self) -> &ProtocolObject<dyn MTLDevice> {
&self.device
}
pub fn command_queue(&self) -> &ProtocolObject<dyn MTLCommandQueue> {
&self.command_queue
}
pub fn device_name(&self) -> String {
self.device.name().to_string()
}
pub fn supports_mps(&self) -> bool {
unsafe { self.device.supportsFamily(MTLGPUFamily::Apple7) }
}
pub fn max_threadgroup_size(&self) -> MTLSize {
self.device.maxThreadsPerThreadgroup()
}
pub fn max_buffer_size(&self) -> u64 {
1024 * 1024 * 1024
}
pub fn library(&self) -> &ProtocolObject<dyn MTLLibrary> {
&self.library
}
pub fn compute_pipeline(
&self,
function_name: &str,
) -> Result<Retained<ProtocolObject<dyn MTLComputePipelineState>>> {
let ns_name = objc2_foundation::NSString::from_str(function_name);
let function = self.library.newFunctionWithName(&ns_name).ok_or_else(|| {
HiveGpuError::ShaderCompilationFailed(format!(
"Metal function '{function_name}' not found in compiled library",
))
})?;
let pipeline = self
.device
.newComputePipelineStateWithFunction_error(&function)
.map_err(|e| {
HiveGpuError::ShaderCompilationFailed(format!(
"Failed to build pipeline for '{function_name}': {e:?}",
))
})?;
Ok(pipeline)
}
fn load_metal_library(
device: &ProtocolObject<dyn MTLDevice>,
) -> Result<Retained<ProtocolObject<dyn MTLLibrary>>> {
let shader_source = include_str!("../shaders/metal_hnsw.metal");
let ns_source = objc2_foundation::NSString::from_str(shader_source);
let options = MTLCompileOptions::new();
let library = unsafe {
device
.newLibraryWithSource_options_error(&ns_source, Some(&options))
.map_err(|e| {
HiveGpuError::ShaderCompilationFailed(format!(
"Failed to compile Metal shaders: {:?}",
e
))
})?
};
debug!(
"✅ Metal library loaded with {} functions",
library.functionNames().count()
);
Ok(library)
}
fn get_macos_version() -> Result<String> {
let output = Command::new("sw_vers")
.arg("-productVersion")
.output()
.map_err(|e| {
HiveGpuError::Other(format!("Failed to execute sw_vers command: {}", e))
})?;
if output.status.success() {
let version = String::from_utf8_lossy(&output.stdout).trim().to_string();
Ok(format!("macOS {}", version))
} else {
warn!("Failed to get macOS version, using fallback");
Ok("macOS Unknown".to_string())
}
}
}
impl GpuBackend for MetalNativeContext {
fn device_info(&self) -> GpuDeviceInfo {
let recommended_max = self.device.recommendedMaxWorkingSetSize();
let current_allocated = self.device.currentAllocatedSize() as u64;
let available = recommended_max.saturating_sub(current_allocated);
let driver_version =
Self::get_macos_version().unwrap_or_else(|_| "macOS Unknown".to_string());
let max_threads = self.device.maxThreadsPerThreadgroup();
let max_threads_per_block =
(max_threads.width * max_threads.height * max_threads.depth) as u32;
let max_shared_memory = 32 * 1024;
GpuDeviceInfo {
name: self.device_name(),
backend: "Metal".to_string(),
total_vram_bytes: recommended_max,
available_vram_bytes: available,
used_vram_bytes: current_allocated,
driver_version,
compute_capability: None, max_threads_per_block,
max_shared_memory_per_block: max_shared_memory,
device_id: 0, pci_bus_id: None, }
}
fn supports_operations(&self) -> GpuCapabilities {
GpuCapabilities {
supports_hnsw: true,
supports_batch: true,
max_dimension: 512, max_batch_size: 10000,
}
}
fn memory_stats(&self) -> GpuMemoryStats {
GpuMemoryStats {
total_allocated: 0,
available: self.max_buffer_size() as usize,
utilization: 0.0,
buffer_count: 0,
}
}
}
impl GpuContext for MetalNativeContext {
fn create_storage(
&self,
dimension: usize,
metric: crate::types::GpuDistanceMetric,
) -> Result<Box<dyn crate::traits::GpuVectorStorage>> {
use crate::metal::vector_storage::MetalNativeVectorStorage;
let storage = MetalNativeVectorStorage::new(Arc::new(self.clone()), dimension, metric)?;
Ok(Box::new(storage))
}
fn create_storage_with_config(
&self,
dimension: usize,
metric: crate::types::GpuDistanceMetric,
config: crate::types::HnswConfig,
) -> Result<Box<dyn crate::traits::GpuVectorStorage>> {
Err(HiveGpuError::Other("Not implemented yet".to_string()))
}
fn memory_stats(&self) -> GpuMemoryStats {
GpuBackend::memory_stats(self)
}
fn device_info(&self) -> Result<GpuDeviceInfo> {
Ok(GpuBackend::device_info(self))
}
}