Skip to main content

mlx_native/
device.rs

1//! [`MlxDevice`] — Metal device and command queue wrapper.
2//!
3//! This is the entry-point for all GPU work.  Create one with
4//! [`MlxDevice::new()`] and use it to allocate buffers and create
5//! command encoders.
6
7use metal::{CommandQueue, Device, MTLResourceOptions};
8
9use crate::buffer::MlxBuffer;
10use crate::dtypes::DType;
11use crate::encoder::CommandEncoder;
12use crate::encoder_session::EncoderSession;
13use crate::error::{MlxError, Result};
14use crate::residency::{macos_15_or_newer, residency_disabled_by_env, ResidencySet};
15
16/// Wraps a Metal device and its command queue.
17///
18/// # Thread Safety
19///
20/// `MlxDevice` is `Send + Sync` — you can share it across threads. The
21/// underlying Metal device and command queue are thread-safe on Apple Silicon.
22pub struct MlxDevice {
23    device: Device,
24    queue: CommandQueue,
25    residency_set: Option<ResidencySet>,
26}
27
28// metal::Device and metal::CommandQueue are both Send + Sync.
29crate::static_assertions_send_sync!(MlxDevice);
30
31impl MlxDevice {
32    /// Initialize the Metal GPU device and create a command queue.
33    ///
34    /// Returns `Err(MlxError::DeviceNotFound)` if no Metal device is available
35    /// (e.g. running on a non-Apple-Silicon machine or in a headless Linux VM).
36    pub fn new() -> Result<Self> {
37        let device = Device::system_default().ok_or(MlxError::DeviceNotFound)?;
38        let queue = device.new_command_queue();
39        let log_init = std::env::var("MLX_NATIVE_LOG_INIT").as_deref() == Ok("1");
40
41        let residency_set = if residency_disabled_by_env() {
42            if log_init {
43                eprintln!("[mlx-native] residency sets = false (reason: HF2Q_NO_RESIDENCY=1)");
44            }
45            None
46        } else if !macos_15_or_newer() {
47            if log_init {
48                eprintln!("[mlx-native] residency sets = false (reason: macOS < 15.0)");
49            }
50            None
51        } else {
52            let set = ResidencySet::new(&device)?;
53            if set.is_noop() {
54                if log_init {
55                    eprintln!("[mlx-native] residency sets = false (reason: macOS < 15.0)");
56                }
57                None
58            } else {
59                set.register_with_queue(&queue);
60                if log_init {
61                    eprintln!("[mlx-native] residency sets = true");
62                }
63                Some(set)
64            }
65        };
66
67        Ok(Self {
68            device,
69            queue,
70            residency_set,
71        })
72    }
73
74    /// Create a [`CommandEncoder`] for batching GPU dispatches.
75    ///
76    /// The encoder wraps a fresh Metal command buffer from the device's command
77    /// queue.  Encode one or more kernel dispatches, then call
78    /// [`CommandEncoder::commit_and_wait`] to submit and block until completion.
79    ///
80    /// ADR-015 iter8e (Phase 3b): the encoder is bound to the device's
81    /// residency set so every `commit*` boundary flushes deferred
82    /// add/remove staging (one `[set commit]` per CB submission instead
83    /// of per-allocation). When residency sets are disabled
84    /// (HF2Q_NO_RESIDENCY=1, macOS<15) the binding is `None` and the
85    /// flush is a no-op.
86    pub fn command_encoder(&self) -> Result<CommandEncoder> {
87        CommandEncoder::new_with_residency(&self.queue, self.residency_set.clone())
88    }
89
90    /// Create an [`EncoderSession`] (ADR-019 Phase 0b iter89e2-A — bare
91    /// struct) for one transformer stage's worth of GPU work.
92    ///
93    /// Gated on `HF2Q_ENCODER_SESSION=1` (default OFF). When the gate is
94    /// unset, returns `Ok(None)` so callers can fall back to
95    /// [`Self::command_encoder`] without an extra conditional. When set,
96    /// returns `Ok(Some(EncoderSession))` carrying a fresh
97    /// [`CommandEncoder`] — same construction path as `command_encoder()`,
98    /// just wrapped in the session shell.
99    ///
100    /// In iter89e2-A no production code path consumes this method; it
101    /// exists so the env-gate has a callable factory and the lifecycle
102    /// tests have a public entry point. Phase 1+ migrations
103    /// (`forward_gpu.rs`, `gpu_full_attn.rs`, `gpu_delta_net.rs`) opt in
104    /// per-call site.
105    ///
106    /// # Errors
107    ///
108    /// Surfaces any error from the underlying `EncoderSession::new`
109    /// — currently infallible past metal-rs's `new_command_buffer`,
110    /// preserved for future-proofing.
111    pub fn encoder_session(&self) -> Result<Option<EncoderSession>> {
112        if !EncoderSession::env_enabled() {
113            return Ok(None);
114        }
115        EncoderSession::new(&self.device, &self.queue, self.residency_set.clone()).map(Some)
116    }
117
118    /// Allocate a new GPU buffer with `StorageModeShared`.
119    ///
120    /// # Arguments
121    ///
122    /// * `byte_len` — Size of the buffer in bytes.  Must be > 0.
123    /// * `dtype`    — Element data type for metadata tracking.
124    /// * `shape`    — Tensor dimensions for metadata tracking.
125    ///
126    /// # Errors
127    ///
128    /// Returns `MlxError::InvalidArgument` if `byte_len` is zero.
129    /// Returns `MlxError::BufferAllocationError` if Metal cannot allocate.
130    pub fn alloc_buffer(
131        &self,
132        byte_len: usize,
133        dtype: DType,
134        shape: Vec<usize>,
135    ) -> Result<MlxBuffer> {
136        if byte_len == 0 {
137            return Err(MlxError::InvalidArgument(
138                "Buffer byte length must be > 0".into(),
139            ));
140        }
141        let metal_buf = self
142            .device
143            .new_buffer(byte_len as u64, MTLResourceOptions::StorageModeShared);
144        // Metal returns a non-null buffer on success; a null pointer indicates
145        // failure (typically out-of-memory).
146        if metal_buf.contents().is_null() {
147            return Err(MlxError::BufferAllocationError { bytes: byte_len });
148        }
149        // ADR-015 iter61a (broken-window B-W-1 fix): explicitly zero every
150        // newly-allocated GPU buffer. `MTLResourceOptions::StorageModeShared`
151        // does NOT guarantee zeroed pages on Apple Silicon — Metal's allocator
152        // recycles pages from recently-freed allocations within the device's
153        // private heap before the OS sees the free, so a fresh buffer can
154        // contain residual bytes from prior allocations in the same process.
155        // In a cold process this surfaces as run-to-run non-determinism: the
156        // heap state at the moment Metal services `newBufferWithLength`
157        // differs across cold invocations, and any kernel that reads a buffer
158        // before fully populating it (e.g. DeltaNet's `ssm_conv` reads
159        // conv_state, MoE expert routing reads scratch, attn-output buffers
160        // before the final write barrier) propagates that garbage into
161        // logits → argmax → divergent generations across cold runs.
162        // The cost is one memset per allocation; on workloads dominated by
163        // weight-load (one-time) and kvcache (one-time), this is negligible.
164        // Per `feedback_no_broken_windows` + mantra "No fallback. No stub.
165        // Just pure excellence." — fix at the source.
166        //
167        // Safety: `metal_buf.contents()` is non-null (verified above), points
168        // to exactly `byte_len` bytes of `StorageModeShared` memory we just
169        // allocated and have exclusive access to (no other thread or GPU
170        // dispatch references it yet — we haven't returned the MlxBuffer
171        // wrapper yet, and the underlying CB queue is not in flight on this
172        // allocation). Writing zero bytes is well-defined for any DType.
173        unsafe {
174            std::ptr::write_bytes(metal_buf.contents() as *mut u8, 0, byte_len);
175        }
176        // ADR-015 iter8e (Phase 3b): auto-register the new allocation with the
177        // device's residency set so it gets the MTLResidencySet hint on the
178        // next dispatch. The `with_residency` path stages the addAllocation
179        // but DEFERS the `[set commit]` to the next CommandEncoder::commit*
180        // boundary via flush_pending — mirrors llama.cpp's batch-add /
181        // single-commit pattern in ggml-metal-device.m:1378-1382.
182        //
183        // No-op when residency_set is None (HF2Q_NO_RESIDENCY=1, macOS<15,
184        // or no Metal device).
185        match self.residency_set.as_ref() {
186            Some(set) => Ok(MlxBuffer::with_residency(
187                metal_buf,
188                dtype,
189                shape,
190                set.clone(),
191            )),
192            None => Ok(MlxBuffer::from_raw(metal_buf, dtype, shape)),
193        }
194    }
195
196    /// Borrow the underlying `metal::Device` for direct Metal API calls
197    /// (e.g. kernel compilation in [`KernelRegistry`](crate::KernelRegistry)).
198    #[inline]
199    pub fn metal_device(&self) -> &metal::DeviceRef {
200        &self.device
201    }
202
203    /// Borrow the underlying `metal::CommandQueue`.
204    #[inline]
205    pub fn metal_queue(&self) -> &CommandQueue {
206        &self.queue
207    }
208
209    /// Borrow the device-level residency set, if residency support is enabled.
210    #[inline]
211    pub(crate) fn residency_set(&self) -> Option<&ResidencySet> {
212        self.residency_set.as_ref()
213    }
214
215    /// Return whether this device has an active Metal residency set.
216    #[inline]
217    pub fn residency_sets_enabled(&self) -> bool {
218        self.residency_set.is_some()
219    }
220
221    /// Human-readable name of the GPU (e.g. "Apple M2 Max").
222    pub fn name(&self) -> String {
223        self.device.name().to_string()
224    }
225}
226
227impl std::fmt::Debug for MlxDevice {
228    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229        f.debug_struct("MlxDevice")
230            .field("name", &self.device.name())
231            .finish()
232    }
233}