use crate::{DType, Result};
#[cfg(feature = "ug")]
use candle_metal_kernels::metal::ComputePipeline;
use candle_metal_kernels::{
metal::{
BlitCommandEncoder, Buffer, BufferMap, Commands, ComputeCommandEncoder, Device,
MTLResourceOptions,
},
Kernels,
};
use objc2_foundation::NSURL;
use objc2_metal::{MTLCaptureDescriptor, MTLCaptureDestination, MTLCaptureManager};
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock};
use super::MetalError;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DeviceId(usize);
impl DeviceId {
pub(crate) fn new() -> Self {
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
#[derive(Clone)]
pub struct MetalDevice {
pub(crate) id: DeviceId,
pub(crate) device: Device,
pub(crate) commands: Arc<RwLock<Commands>>,
pub(crate) buffers: Arc<RwLock<BufferMap>>,
pub(crate) private_buffers: Arc<RwLock<BufferMap>>,
pub(crate) kernels: Arc<Kernels>,
pub(crate) seed: Arc<Mutex<Buffer>>,
pub(crate) seed_value: Arc<RwLock<u64>>,
}
pub const RESOURCE_OPTIONS: MTLResourceOptions =
objc2_metal::MTLResourceOptions(MTLResourceOptions::StorageModeShared.bits());
#[cfg(target_os = "ios")]
pub const PRIVATE_RESOURCE_OPTIONS: MTLResourceOptions = MTLResourceOptions::StorageModeShared;
#[cfg(not(target_os = "ios"))]
pub const PRIVATE_RESOURCE_OPTIONS: MTLResourceOptions = MTLResourceOptions::StorageModePrivate;
impl std::fmt::Debug for MetalDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalDevice({:?})", self.id)
}
}
impl std::ops::Deref for MetalDevice {
type Target = Device;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl MetalDevice {
#[cfg(all(feature = "ug", not(target_arch = "wasm32"), not(target_os = "ios")))]
pub fn compile(
&self,
func_name: &'static str,
kernel: candle_ug::lang::ssa::Kernel,
) -> Result<ComputePipeline> {
let mut buf = vec![];
candle_ug::metal::code_gen::gen(&mut buf, func_name, &kernel)?;
let metal_code = String::from_utf8(buf)?;
let lib = self
.device
.new_library_with_source(&metal_code, None)
.map_err(MetalError::from)?;
let func = lib
.get_function(func_name, None)
.map_err(MetalError::from)?;
let pl = self
.device
.new_compute_pipeline_state_with_function(&func)
.map_err(MetalError::from)?;
Ok(pl)
}
pub fn id(&self) -> DeviceId {
self.id
}
pub fn metal_device(&self) -> &Device {
&self.device
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
}
pub fn command_encoder(&self) -> Result<ComputeCommandEncoder> {
let commands = self.commands.write().map_err(MetalError::from)?;
let (flush, command_encoder) = commands.command_encoder().map_err(MetalError::from)?;
if flush {
self.drop_unused_buffers()?
}
Ok(command_encoder)
}
pub fn blit_command_encoder(&self) -> Result<BlitCommandEncoder> {
let commands = self.commands.write().map_err(MetalError::from)?;
let (flush, command_encoder) = commands.blit_command_encoder().map_err(MetalError::from)?;
if flush {
self.drop_unused_buffers()?
}
Ok(command_encoder)
}
pub fn wait_until_completed(&self) -> Result<()> {
let commands = self.commands.write().map_err(MetalError::from)?;
commands.wait_until_completed().map_err(MetalError::from)?;
Ok(())
}
pub fn kernels(&self) -> &Kernels {
&self.kernels
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn new_buffer(
&self,
element_count: usize,
dtype: DType,
_name: &str,
) -> Result<Arc<Buffer>> {
let size = element_count * dtype.size_in_bytes();
let mut buffers = self.private_buffers.write().map_err(MetalError::from)?;
if let Some(b) = find_available_buffer(size, &buffers) {
return Ok(b.clone());
}
let size = buf_size(size);
let subbuffers = buffers.entry(size).or_insert(vec![]);
let new_buffer = self
.device
.new_buffer(size, PRIVATE_RESOURCE_OPTIONS)
.map_err(MetalError::from)?;
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
pub fn new_private_buffer(
&self,
element_count: usize,
dtype: DType,
_name: &str,
) -> Result<Arc<Buffer>> {
let size = element_count * dtype.size_in_bytes();
let buffer = self
.device
.new_buffer(size, PRIVATE_RESOURCE_OPTIONS)
.map_err(MetalError::from)?;
Ok(Arc::new(buffer))
}
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
let size = core::mem::size_of_val(data);
let new_buffer = self
.device
.new_buffer_with_data(data.as_ptr().cast(), size, RESOURCE_OPTIONS)
.map_err(MetalError::from)?;
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
let subbuffers = buffers.entry(size).or_insert(vec![]);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
let buffer = self.allocate_buffer(size_in_bytes)?;
let blit = self.blit_command_encoder()?;
blit.set_label("zeros");
blit.fill_buffer(&buffer, (0, buffer.length()), 0);
blit.end_encoding();
Ok(buffer)
}
pub fn allocate_buffer(&self, size: usize) -> Result<Arc<Buffer>> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
if let Some(b) = find_available_buffer(size, &buffers) {
return Ok(b.clone());
}
let size = buf_size(size);
let subbuffers = buffers.entry(size).or_insert(vec![]);
let new_buffer = self
.device
.new_buffer(size, RESOURCE_OPTIONS)
.map_err(MetalError::from)?;
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let capture = unsafe { MTLCaptureManager::sharedCaptureManager() };
let descriptor = MTLCaptureDescriptor::new();
descriptor.setDestination(MTLCaptureDestination::GPUTraceDocument);
descriptor.set_capture_device(self.device().as_ref());
if path.as_ref().is_absolute() {
let url = NSURL::from_file_path(path);
descriptor.setOutputURL(url.as_deref());
} else {
let path = std::env::current_dir()?.join(path);
let url = NSURL::from_file_path(path);
descriptor.setOutputURL(url.as_deref());
}
capture
.startCaptureWithDescriptor_error(&descriptor)
.map_err(|e| MetalError::from(e.to_string()))?;
Ok(())
}
}
fn buf_size(size: usize) -> usize {
size.next_power_of_two()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_buf_size_exact_powers_of_two() {
assert_eq!(buf_size(1), 1);
assert_eq!(buf_size(2), 2);
assert_eq!(buf_size(4), 4);
assert_eq!(buf_size(8), 8);
assert_eq!(buf_size(16), 16);
assert_eq!(buf_size(1024), 1024);
}
#[test]
fn test_buf_size_rounds_up() {
assert_eq!(buf_size(3), 4);
assert_eq!(buf_size(5), 8);
assert_eq!(buf_size(6), 8);
assert_eq!(buf_size(7), 8);
assert_eq!(buf_size(9), 16);
assert_eq!(buf_size(1000), 1024);
assert_eq!(buf_size(1025), 2048);
}
#[test]
fn test_buf_size_bf16_f16_scalar() {
assert_eq!(buf_size(2), 2);
}
}
fn find_available_buffer(size: usize, buffers: &BufferMap) -> Option<Arc<Buffer>> {
let mut best_buffer: Option<&Arc<Buffer>> = None;
let mut best_buffer_size = usize::MAX;
for (buffer_size, subbuffers) in buffers.iter() {
if buffer_size >= &size && buffer_size < &best_buffer_size {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
best_buffer.cloned()
}