use metal::{CommandQueue, Device, MTLResourceOptions};
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::residency::{macos_15_or_newer, residency_disabled_by_env, ResidencySet};
pub struct MlxDevice {
device: Device,
queue: CommandQueue,
residency_set: Option<ResidencySet>,
}
crate::static_assertions_send_sync!(MlxDevice);
impl MlxDevice {
pub fn new() -> Result<Self> {
let device = Device::system_default().ok_or(MlxError::DeviceNotFound)?;
let queue = device.new_command_queue();
let log_init = std::env::var("MLX_NATIVE_LOG_INIT").as_deref() == Ok("1");
let residency_set = if residency_disabled_by_env() {
if log_init {
eprintln!("[mlx-native] residency sets = false (reason: HF2Q_NO_RESIDENCY=1)");
}
None
} else if !macos_15_or_newer() {
if log_init {
eprintln!("[mlx-native] residency sets = false (reason: macOS < 15.0)");
}
None
} else {
let set = ResidencySet::new(&device)?;
if set.is_noop() {
if log_init {
eprintln!("[mlx-native] residency sets = false (reason: macOS < 15.0)");
}
None
} else {
set.register_with_queue(&queue);
if log_init {
eprintln!("[mlx-native] residency sets = true");
}
Some(set)
}
};
Ok(Self {
device,
queue,
residency_set,
})
}
pub fn command_encoder(&self) -> Result<CommandEncoder> {
CommandEncoder::new_with_residency(&self.queue, self.residency_set.clone())
}
pub fn alloc_buffer(
&self,
byte_len: usize,
dtype: DType,
shape: Vec<usize>,
) -> Result<MlxBuffer> {
if byte_len == 0 {
return Err(MlxError::InvalidArgument(
"Buffer byte length must be > 0".into(),
));
}
let metal_buf = self
.device
.new_buffer(byte_len as u64, MTLResourceOptions::StorageModeShared);
if metal_buf.contents().is_null() {
return Err(MlxError::BufferAllocationError { bytes: byte_len });
}
match self.residency_set.as_ref() {
Some(set) => Ok(MlxBuffer::with_residency(
metal_buf,
dtype,
shape,
set.clone(),
)),
None => Ok(MlxBuffer::from_raw(metal_buf, dtype, shape)),
}
}
#[inline]
pub fn metal_device(&self) -> &metal::DeviceRef {
&self.device
}
#[inline]
pub fn metal_queue(&self) -> &CommandQueue {
&self.queue
}
#[inline]
pub(crate) fn residency_set(&self) -> Option<&ResidencySet> {
self.residency_set.as_ref()
}
#[inline]
pub fn residency_sets_enabled(&self) -> bool {
self.residency_set.is_some()
}
pub fn name(&self) -> String {
self.device.name().to_string()
}
}
impl std::fmt::Debug for MlxDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MlxDevice")
.field("name", &self.device.name())
.finish()
}
}