mlx_native/device.rs
1//! [`MlxDevice`] — Metal device and command queue wrapper.
2//!
3//! This is the entry-point for all GPU work. Create one with
4//! [`MlxDevice::new()`] and use it to allocate buffers and create
5//! command encoders.
6
7use metal::{CommandQueue, Device, MTLResourceOptions};
8
9use crate::buffer::MlxBuffer;
10use crate::dtypes::DType;
11use crate::encoder::CommandEncoder;
12use crate::error::{MlxError, Result};
13use crate::residency::{macos_15_or_newer, residency_disabled_by_env, ResidencySet};
14
15/// Wraps a Metal device and its command queue.
16///
17/// # Thread Safety
18///
19/// `MlxDevice` is `Send + Sync` — you can share it across threads. The
20/// underlying Metal device and command queue are thread-safe on Apple Silicon.
21pub struct MlxDevice {
22 device: Device,
23 queue: CommandQueue,
24 residency_set: Option<ResidencySet>,
25}
26
27// metal::Device and metal::CommandQueue are both Send + Sync.
28crate::static_assertions_send_sync!(MlxDevice);
29
30impl MlxDevice {
31 /// Initialize the Metal GPU device and create a command queue.
32 ///
33 /// Returns `Err(MlxError::DeviceNotFound)` if no Metal device is available
34 /// (e.g. running on a non-Apple-Silicon machine or in a headless Linux VM).
35 pub fn new() -> Result<Self> {
36 let device = Device::system_default().ok_or(MlxError::DeviceNotFound)?;
37 let queue = device.new_command_queue();
38 let log_init = std::env::var("MLX_NATIVE_LOG_INIT").as_deref() == Ok("1");
39
40 let residency_set = if residency_disabled_by_env() {
41 if log_init {
42 eprintln!("[mlx-native] residency sets = false (reason: HF2Q_NO_RESIDENCY=1)");
43 }
44 None
45 } else if !macos_15_or_newer() {
46 if log_init {
47 eprintln!("[mlx-native] residency sets = false (reason: macOS < 15.0)");
48 }
49 None
50 } else {
51 let set = ResidencySet::new(&device)?;
52 if set.is_noop() {
53 if log_init {
54 eprintln!("[mlx-native] residency sets = false (reason: macOS < 15.0)");
55 }
56 None
57 } else {
58 set.register_with_queue(&queue);
59 if log_init {
60 eprintln!("[mlx-native] residency sets = true");
61 }
62 Some(set)
63 }
64 };
65
66 Ok(Self {
67 device,
68 queue,
69 residency_set,
70 })
71 }
72
73 /// Create a [`CommandEncoder`] for batching GPU dispatches.
74 ///
75 /// The encoder wraps a fresh Metal command buffer from the device's command
76 /// queue. Encode one or more kernel dispatches, then call
77 /// [`CommandEncoder::commit_and_wait`] to submit and block until completion.
78 ///
79 /// ADR-015 iter8e (Phase 3b): the encoder is bound to the device's
80 /// residency set so every `commit*` boundary flushes deferred
81 /// add/remove staging (one `[set commit]` per CB submission instead
82 /// of per-allocation). When residency sets are disabled
83 /// (HF2Q_NO_RESIDENCY=1, macOS<15) the binding is `None` and the
84 /// flush is a no-op.
85 pub fn command_encoder(&self) -> Result<CommandEncoder> {
86 CommandEncoder::new_with_residency(&self.queue, self.residency_set.clone())
87 }
88
89 /// Allocate a new GPU buffer with `StorageModeShared`.
90 ///
91 /// # Arguments
92 ///
93 /// * `byte_len` — Size of the buffer in bytes. Must be > 0.
94 /// * `dtype` — Element data type for metadata tracking.
95 /// * `shape` — Tensor dimensions for metadata tracking.
96 ///
97 /// # Errors
98 ///
99 /// Returns `MlxError::InvalidArgument` if `byte_len` is zero.
100 /// Returns `MlxError::BufferAllocationError` if Metal cannot allocate.
101 pub fn alloc_buffer(
102 &self,
103 byte_len: usize,
104 dtype: DType,
105 shape: Vec<usize>,
106 ) -> Result<MlxBuffer> {
107 if byte_len == 0 {
108 return Err(MlxError::InvalidArgument(
109 "Buffer byte length must be > 0".into(),
110 ));
111 }
112 let metal_buf = self
113 .device
114 .new_buffer(byte_len as u64, MTLResourceOptions::StorageModeShared);
115 // Metal returns a non-null buffer on success; a null pointer indicates
116 // failure (typically out-of-memory).
117 if metal_buf.contents().is_null() {
118 return Err(MlxError::BufferAllocationError { bytes: byte_len });
119 }
120 // ADR-015 iter8e (Phase 3b): auto-register the new allocation with the
121 // device's residency set so it gets the MTLResidencySet hint on the
122 // next dispatch. The `with_residency` path stages the addAllocation
123 // but DEFERS the `[set commit]` to the next CommandEncoder::commit*
124 // boundary via flush_pending — mirrors llama.cpp's batch-add /
125 // single-commit pattern in ggml-metal-device.m:1378-1382.
126 //
127 // No-op when residency_set is None (HF2Q_NO_RESIDENCY=1, macOS<15,
128 // or no Metal device).
129 match self.residency_set.as_ref() {
130 Some(set) => Ok(MlxBuffer::with_residency(
131 metal_buf,
132 dtype,
133 shape,
134 set.clone(),
135 )),
136 None => Ok(MlxBuffer::from_raw(metal_buf, dtype, shape)),
137 }
138 }
139
140 /// Borrow the underlying `metal::Device` for direct Metal API calls
141 /// (e.g. kernel compilation in [`KernelRegistry`](crate::KernelRegistry)).
142 #[inline]
143 pub fn metal_device(&self) -> &metal::DeviceRef {
144 &self.device
145 }
146
147 /// Borrow the underlying `metal::CommandQueue`.
148 #[inline]
149 pub fn metal_queue(&self) -> &CommandQueue {
150 &self.queue
151 }
152
153 /// Borrow the device-level residency set, if residency support is enabled.
154 #[inline]
155 pub(crate) fn residency_set(&self) -> Option<&ResidencySet> {
156 self.residency_set.as_ref()
157 }
158
159 /// Return whether this device has an active Metal residency set.
160 #[inline]
161 pub fn residency_sets_enabled(&self) -> bool {
162 self.residency_set.is_some()
163 }
164
165 /// Human-readable name of the GPU (e.g. "Apple M2 Max").
166 pub fn name(&self) -> String {
167 self.device.name().to_string()
168 }
169}
170
171impl std::fmt::Debug for MlxDevice {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 f.debug_struct("MlxDevice")
174 .field("name", &self.device.name())
175 .finish()
176 }
177}