use metal::{Device, CommandQueue, MTLResourceOptions};
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
pub struct MlxDevice {
device: Device,
queue: CommandQueue,
}
crate::static_assertions_send_sync!(MlxDevice);
impl MlxDevice {
pub fn new() -> Result<Self> {
let device = Device::system_default().ok_or(MlxError::DeviceNotFound)?;
let queue = device.new_command_queue();
Ok(Self { device, queue })
}
pub fn command_encoder(&self) -> Result<CommandEncoder> {
CommandEncoder::new(&self.queue)
}
pub fn alloc_buffer(
&self,
byte_len: usize,
dtype: DType,
shape: Vec<usize>,
) -> Result<MlxBuffer> {
if byte_len == 0 {
return Err(MlxError::InvalidArgument(
"Buffer byte length must be > 0".into(),
));
}
let metal_buf = self.device.new_buffer(
byte_len as u64,
MTLResourceOptions::StorageModeShared,
);
if metal_buf.contents().is_null() {
return Err(MlxError::BufferAllocationError { bytes: byte_len });
}
Ok(MlxBuffer::from_raw(metal_buf, dtype, shape))
}
#[inline]
pub fn metal_device(&self) -> &metal::DeviceRef {
&self.device
}
#[inline]
pub fn metal_queue(&self) -> &CommandQueue {
&self.queue
}
pub fn name(&self) -> String {
self.device.name().to_string()
}
}
impl std::fmt::Debug for MlxDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MlxDevice")
.field("name", &self.device.name())
.finish()
}
}