mlx_native/device.rs
1//! [`MlxDevice`] — Metal device and command queue wrapper.
2//!
3//! This is the entry-point for all GPU work. Create one with
4//! [`MlxDevice::new()`] and use it to allocate buffers and create
5//! command encoders.
6
7use metal::{CommandQueue, Device, MTLResourceOptions};
8
9use crate::buffer::MlxBuffer;
10use crate::dtypes::DType;
11use crate::encoder::CommandEncoder;
12use crate::encoder_session::EncoderSession;
13use crate::error::{MlxError, Result};
14use crate::residency::{macos_15_or_newer, residency_disabled_by_env, ResidencySet};
15
16/// Wraps a Metal device and its command queue.
17///
18/// # Thread Safety
19///
20/// `MlxDevice` is `Send + Sync` — you can share it across threads. The
21/// underlying Metal device and command queue are thread-safe on Apple Silicon.
22///
23/// `Clone` is a cheap Arc-bump on the underlying handles: `metal::Device`
24/// + `metal::CommandQueue` wrap NSObject Arc-pointers internally, and
25/// [`ResidencySet`] is `#[derive(Clone)]` over an `Arc<ResidencySetInner>`.
26/// Cloning yields a SECOND handle pointing at the SAME GPU device, command
27/// queue, and residency-set NSObject — multiple owners (e.g. an
28/// `AdamOptimizer` + a per-step `GpuTape`) can register allocations
29/// against the same residency set without double-create. ADR-020
30/// iter-13b dependency.
31#[derive(Clone)]
32pub struct MlxDevice {
33 device: Device,
34 queue: CommandQueue,
35 residency_set: Option<ResidencySet>,
36}
37
38// metal::Device and metal::CommandQueue are both Send + Sync.
39crate::static_assertions_send_sync!(MlxDevice);
40
41impl MlxDevice {
42 /// Initialize the Metal GPU device and create a command queue.
43 ///
44 /// Returns `Err(MlxError::DeviceNotFound)` if no Metal device is available
45 /// (e.g. running on a non-Apple-Silicon machine or in a headless Linux VM).
46 pub fn new() -> Result<Self> {
47 let device = Device::system_default().ok_or(MlxError::DeviceNotFound)?;
48 let queue = device.new_command_queue();
49 let log_init = std::env::var("MLX_NATIVE_LOG_INIT").as_deref() == Ok("1");
50
51 let residency_set = if residency_disabled_by_env() {
52 if log_init {
53 eprintln!("[mlx-native] residency sets = false (reason: HF2Q_NO_RESIDENCY=1)");
54 }
55 None
56 } else if !macos_15_or_newer() {
57 if log_init {
58 eprintln!("[mlx-native] residency sets = false (reason: macOS < 15.0)");
59 }
60 None
61 } else {
62 let set = ResidencySet::new(&device)?;
63 if set.is_noop() {
64 if log_init {
65 eprintln!("[mlx-native] residency sets = false (reason: macOS < 15.0)");
66 }
67 None
68 } else {
69 set.register_with_queue(&queue);
70 if log_init {
71 eprintln!("[mlx-native] residency sets = true");
72 }
73 Some(set)
74 }
75 };
76
77 Ok(Self {
78 device,
79 queue,
80 residency_set,
81 })
82 }
83
84 /// Create a [`CommandEncoder`] for batching GPU dispatches.
85 ///
86 /// The encoder wraps a fresh Metal command buffer from the device's command
87 /// queue. Encode one or more kernel dispatches, then call
88 /// [`CommandEncoder::commit_and_wait`] to submit and block until completion.
89 ///
90 /// ADR-015 iter8e (Phase 3b): the encoder is bound to the device's
91 /// residency set so every `commit*` boundary flushes deferred
92 /// add/remove staging (one `[set commit]` per CB submission instead
93 /// of per-allocation). When residency sets are disabled
94 /// (HF2Q_NO_RESIDENCY=1, macOS<15) the binding is `None` and the
95 /// flush is a no-op.
96 pub fn command_encoder(&self) -> Result<CommandEncoder> {
97 CommandEncoder::new_with_residency(&self.queue, self.residency_set.clone())
98 }
99
100 /// Create an [`EncoderSession`] (ADR-019 Phase 0b iter89e2-A — bare
101 /// struct) for one transformer stage's worth of GPU work.
102 ///
103 /// Gated on `HF2Q_ENCODER_SESSION=1` (default OFF). When the gate is
104 /// unset, returns `Ok(None)` so callers can fall back to
105 /// [`Self::command_encoder`] without an extra conditional. When set,
106 /// returns `Ok(Some(EncoderSession))` carrying a fresh
107 /// [`CommandEncoder`] — same construction path as `command_encoder()`,
108 /// just wrapped in the session shell.
109 ///
110 /// In iter89e2-A no production code path consumes this method; it
111 /// exists so the env-gate has a callable factory and the lifecycle
112 /// tests have a public entry point. Phase 1+ migrations
113 /// (`forward_gpu.rs`, `gpu_full_attn.rs`, `gpu_delta_net.rs`) opt in
114 /// per-call site.
115 ///
116 /// # Errors
117 ///
118 /// Surfaces any error from the underlying `EncoderSession::new`
119 /// — currently infallible past metal-rs's `new_command_buffer`,
120 /// preserved for future-proofing.
121 pub fn encoder_session(&self) -> Result<Option<EncoderSession>> {
122 if !EncoderSession::env_enabled() {
123 return Ok(None);
124 }
125 EncoderSession::new(&self.device, &self.queue, self.residency_set.clone()).map(Some)
126 }
127
128 /// Allocate a new GPU buffer with `StorageModeShared`.
129 ///
130 /// # Arguments
131 ///
132 /// * `byte_len` — Size of the buffer in bytes. Must be > 0.
133 /// * `dtype` — Element data type for metadata tracking.
134 /// * `shape` — Tensor dimensions for metadata tracking.
135 ///
136 /// # Errors
137 ///
138 /// Returns `MlxError::InvalidArgument` if `byte_len` is zero.
139 /// Returns `MlxError::BufferAllocationError` if Metal cannot allocate.
140 pub fn alloc_buffer(
141 &self,
142 byte_len: usize,
143 dtype: DType,
144 shape: Vec<usize>,
145 ) -> Result<MlxBuffer> {
146 if byte_len == 0 {
147 return Err(MlxError::InvalidArgument(
148 "Buffer byte length must be > 0".into(),
149 ));
150 }
151 let metal_buf = self
152 .device
153 .new_buffer(byte_len as u64, MTLResourceOptions::StorageModeShared);
154 // Metal returns a non-null buffer on success; a null pointer indicates
155 // failure (typically out-of-memory).
156 if metal_buf.contents().is_null() {
157 return Err(MlxError::BufferAllocationError { bytes: byte_len });
158 }
159 // ADR-015 iter61a (broken-window B-W-1 fix): explicitly zero every
160 // newly-allocated GPU buffer. `MTLResourceOptions::StorageModeShared`
161 // does NOT guarantee zeroed pages on Apple Silicon — Metal's allocator
162 // recycles pages from recently-freed allocations within the device's
163 // private heap before the OS sees the free, so a fresh buffer can
164 // contain residual bytes from prior allocations in the same process.
165 // In a cold process this surfaces as run-to-run non-determinism: the
166 // heap state at the moment Metal services `newBufferWithLength`
167 // differs across cold invocations, and any kernel that reads a buffer
168 // before fully populating it (e.g. DeltaNet's `ssm_conv` reads
169 // conv_state, MoE expert routing reads scratch, attn-output buffers
170 // before the final write barrier) propagates that garbage into
171 // logits → argmax → divergent generations across cold runs.
172 // The cost is one memset per allocation; on workloads dominated by
173 // weight-load (one-time) and kvcache (one-time), this is negligible.
174 // Per `feedback_no_broken_windows` + mantra "No fallback. No stub.
175 // Just pure excellence." — fix at the source.
176 //
177 // Safety: `metal_buf.contents()` is non-null (verified above), points
178 // to exactly `byte_len` bytes of `StorageModeShared` memory we just
179 // allocated and have exclusive access to (no other thread or GPU
180 // dispatch references it yet — we haven't returned the MlxBuffer
181 // wrapper yet, and the underlying CB queue is not in flight on this
182 // allocation). Writing zero bytes is well-defined for any DType.
183 unsafe {
184 std::ptr::write_bytes(metal_buf.contents() as *mut u8, 0, byte_len);
185 }
186 // ADR-015 iter8e (Phase 3b): auto-register the new allocation with the
187 // device's residency set so it gets the MTLResidencySet hint on the
188 // next dispatch. The `with_residency` path stages the addAllocation
189 // but DEFERS the `[set commit]` to the next CommandEncoder::commit*
190 // boundary via flush_pending — mirrors llama.cpp's batch-add /
191 // single-commit pattern in ggml-metal-device.m:1378-1382.
192 //
193 // No-op when residency_set is None (HF2Q_NO_RESIDENCY=1, macOS<15,
194 // or no Metal device).
195 match self.residency_set.as_ref() {
196 Some(set) => Ok(MlxBuffer::with_residency(
197 metal_buf,
198 dtype,
199 shape,
200 set.clone(),
201 )),
202 None => Ok(MlxBuffer::from_raw(metal_buf, dtype, shape)),
203 }
204 }
205
206 /// Borrow the underlying `metal::Device` for direct Metal API calls
207 /// (e.g. kernel compilation in [`KernelRegistry`](crate::KernelRegistry)).
208 #[inline]
209 pub fn metal_device(&self) -> &metal::DeviceRef {
210 &self.device
211 }
212
213 /// Borrow the underlying `metal::CommandQueue`.
214 #[inline]
215 pub fn metal_queue(&self) -> &CommandQueue {
216 &self.queue
217 }
218
219 /// Borrow the device-level residency set, if residency support is enabled.
220 #[inline]
221 pub(crate) fn residency_set(&self) -> Option<&ResidencySet> {
222 self.residency_set.as_ref()
223 }
224
225 /// Return whether this device has an active Metal residency set.
226 #[inline]
227 pub fn residency_sets_enabled(&self) -> bool {
228 self.residency_set.is_some()
229 }
230
231 /// Human-readable name of the GPU (e.g. "Apple M2 Max").
232 pub fn name(&self) -> String {
233 self.device.name().to_string()
234 }
235}
236
237impl std::fmt::Debug for MlxDevice {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 f.debug_struct("MlxDevice")
240 .field("name", &self.device.name())
241 .finish()
242 }
243}