use metal::{CommandBuffer, CommandQueue, Device, DeviceRef, MTLResourceOptions};
use std::sync::{Arc, Mutex};
#[cfg(feature = "log")]
use log;
use crate::metal::error::{MetalError, Result};
#[derive(Clone)]
pub struct MetalDevice {
device: Device,
command_queue: Arc<Mutex<CommandQueue>>,
info: DeviceInfo,
}
#[derive(Debug, Clone)]
pub struct DeviceInfo {
pub name: String,
pub is_low_power: bool,
pub is_removable: bool,
pub has_unified_memory: bool,
pub max_threadgroup_memory: u64,
pub max_threads_per_threadgroup: usize,
}
impl MetalDevice {
pub fn new() -> Result<Self> {
Self::with_device(
Device::system_default()
.ok_or_else(|| MetalError::InvalidArgument("No Metal device found".to_string()))?,
)
}
pub fn with_device(device: Device) -> Result<Self> {
let command_queue = device.new_command_queue();
let info = DeviceInfo {
name: device.name().to_string(),
is_low_power: device.is_low_power(),
is_removable: device.is_removable(),
has_unified_memory: device.has_unified_memory(),
max_threadgroup_memory: device.max_threadgroup_memory_length(),
max_threads_per_threadgroup: device.max_threads_per_threadgroup().width as usize,
};
#[cfg(feature = "log")]
{
log::info!("Initialized Metal device: {}", info.name);
log::info!("Unified memory: {}", info.has_unified_memory);
log::info!(
"Max threadgroup memory: {} bytes",
info.max_threadgroup_memory
);
}
Ok(Self {
device,
command_queue: Arc::new(Mutex::new(command_queue)),
info,
})
}
pub fn device(&self) -> &DeviceRef {
&self.device
}
pub fn device_ref(&self) -> &Device {
&self.device
}
pub fn info(&self) -> &DeviceInfo {
&self.info
}
pub fn name(&self) -> &str {
&self.info.name
}
pub fn max_buffer_length(&self) -> usize {
self.device.max_buffer_length() as usize
}
pub fn max_threads_per_threadgroup(&self) -> (usize, usize, usize) {
let size = self.device.max_threads_per_threadgroup();
(
size.width as usize,
size.height as usize,
size.depth as usize,
)
}
pub fn new_command_buffer(&self) -> Result<CommandBuffer> {
let queue = self.command_queue.lock().map_err(|e| {
MetalError::InvalidArgument(format!("Failed to lock command queue: {}", e))
})?;
let buffer = queue.new_command_buffer();
Ok(buffer.to_owned())
}
pub fn synchronize(&self) -> Result<()> {
let buffer = self.new_command_buffer()?;
buffer.commit();
buffer.wait_until_completed();
Ok(())
}
pub fn supports_feature(&self, feature: MetalFeature) -> bool {
match feature {
MetalFeature::Float16 => true, MetalFeature::RayTracing => self.device.supports_raytracing(),
MetalFeature::MemorylessRenderTargets => true,
MetalFeature::TileShading => true,
}
}
pub fn resource_options(&self) -> MTLResourceOptions {
if self.info.has_unified_memory {
MTLResourceOptions::StorageModeShared
} else {
MTLResourceOptions::StorageModeManaged
}
}
}
pub enum MetalFeature {
Float16,
RayTracing,
MemorylessRenderTargets,
TileShading,
}
impl std::fmt::Debug for MetalDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MetalDevice")
.field("name", &self.info.name)
.field("unified_memory", &self.info.has_unified_memory)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_creation() {
if Device::system_default().is_some() {
let device = MetalDevice::new();
assert!(device.is_ok());
let device = device.expect("operation should succeed");
assert!(!device.info().name.is_empty());
}
}
}