Skip to main content

mlx_native/
device.rs

1//! [`MlxDevice`] — Metal device and command queue wrapper.
2//!
3//! This is the entry-point for all GPU work.  Create one with
4//! [`MlxDevice::new()`] and use it to allocate buffers and create
5//! command encoders.
6
7use metal::{CommandQueue, Device, MTLResourceOptions};
8
9use crate::buffer::MlxBuffer;
10use crate::dtypes::DType;
11use crate::encoder::CommandEncoder;
12use crate::error::{MlxError, Result};
13use crate::residency::{macos_15_or_newer, residency_disabled_by_env, ResidencySet};
14
15/// Wraps a Metal device and its command queue.
16///
17/// # Thread Safety
18///
19/// `MlxDevice` is `Send + Sync` — you can share it across threads. The
20/// underlying Metal device and command queue are thread-safe on Apple Silicon.
21pub struct MlxDevice {
22    device: Device,
23    queue: CommandQueue,
24    residency_set: Option<ResidencySet>,
25}
26
27// metal::Device and metal::CommandQueue are both Send + Sync.
28crate::static_assertions_send_sync!(MlxDevice);
29
30impl MlxDevice {
31    /// Initialize the Metal GPU device and create a command queue.
32    ///
33    /// Returns `Err(MlxError::DeviceNotFound)` if no Metal device is available
34    /// (e.g. running on a non-Apple-Silicon machine or in a headless Linux VM).
35    pub fn new() -> Result<Self> {
36        let device = Device::system_default().ok_or(MlxError::DeviceNotFound)?;
37        let queue = device.new_command_queue();
38        let log_init = std::env::var("MLX_NATIVE_LOG_INIT").as_deref() == Ok("1");
39
40        let residency_set = if residency_disabled_by_env() {
41            if log_init {
42                eprintln!("[mlx-native] residency sets = false (reason: HF2Q_NO_RESIDENCY=1)");
43            }
44            None
45        } else if !macos_15_or_newer() {
46            if log_init {
47                eprintln!("[mlx-native] residency sets = false (reason: macOS < 15.0)");
48            }
49            None
50        } else {
51            let set = ResidencySet::new(&device)?;
52            if set.is_noop() {
53                if log_init {
54                    eprintln!("[mlx-native] residency sets = false (reason: macOS < 15.0)");
55                }
56                None
57            } else {
58                set.register_with_queue(&queue);
59                if log_init {
60                    eprintln!("[mlx-native] residency sets = true");
61                }
62                Some(set)
63            }
64        };
65
66        Ok(Self {
67            device,
68            queue,
69            residency_set,
70        })
71    }
72
73    /// Create a [`CommandEncoder`] for batching GPU dispatches.
74    ///
75    /// The encoder wraps a fresh Metal command buffer from the device's command
76    /// queue.  Encode one or more kernel dispatches, then call
77    /// [`CommandEncoder::commit_and_wait`] to submit and block until completion.
78    ///
79    /// ADR-015 iter8e (Phase 3b): the encoder is bound to the device's
80    /// residency set so every `commit*` boundary flushes deferred
81    /// add/remove staging (one `[set commit]` per CB submission instead
82    /// of per-allocation). When residency sets are disabled
83    /// (HF2Q_NO_RESIDENCY=1, macOS<15) the binding is `None` and the
84    /// flush is a no-op.
85    pub fn command_encoder(&self) -> Result<CommandEncoder> {
86        CommandEncoder::new_with_residency(&self.queue, self.residency_set.clone())
87    }
88
89    /// Allocate a new GPU buffer with `StorageModeShared`.
90    ///
91    /// # Arguments
92    ///
93    /// * `byte_len` — Size of the buffer in bytes.  Must be > 0.
94    /// * `dtype`    — Element data type for metadata tracking.
95    /// * `shape`    — Tensor dimensions for metadata tracking.
96    ///
97    /// # Errors
98    ///
99    /// Returns `MlxError::InvalidArgument` if `byte_len` is zero.
100    /// Returns `MlxError::BufferAllocationError` if Metal cannot allocate.
101    pub fn alloc_buffer(
102        &self,
103        byte_len: usize,
104        dtype: DType,
105        shape: Vec<usize>,
106    ) -> Result<MlxBuffer> {
107        if byte_len == 0 {
108            return Err(MlxError::InvalidArgument(
109                "Buffer byte length must be > 0".into(),
110            ));
111        }
112        let metal_buf = self
113            .device
114            .new_buffer(byte_len as u64, MTLResourceOptions::StorageModeShared);
115        // Metal returns a non-null buffer on success; a null pointer indicates
116        // failure (typically out-of-memory).
117        if metal_buf.contents().is_null() {
118            return Err(MlxError::BufferAllocationError { bytes: byte_len });
119        }
120        // ADR-015 iter61a (broken-window B-W-1 fix): explicitly zero every
121        // newly-allocated GPU buffer. `MTLResourceOptions::StorageModeShared`
122        // does NOT guarantee zeroed pages on Apple Silicon — Metal's allocator
123        // recycles pages from recently-freed allocations within the device's
124        // private heap before the OS sees the free, so a fresh buffer can
125        // contain residual bytes from prior allocations in the same process.
126        // In a cold process this surfaces as run-to-run non-determinism: the
127        // heap state at the moment Metal services `newBufferWithLength`
128        // differs across cold invocations, and any kernel that reads a buffer
129        // before fully populating it (e.g. DeltaNet's `ssm_conv` reads
130        // conv_state, MoE expert routing reads scratch, attn-output buffers
131        // before the final write barrier) propagates that garbage into
132        // logits → argmax → divergent generations across cold runs.
133        // The cost is one memset per allocation; on workloads dominated by
134        // weight-load (one-time) and kvcache (one-time), this is negligible.
135        // Per `feedback_no_broken_windows` + mantra "No fallback. No stub.
136        // Just pure excellence." — fix at the source.
137        //
138        // Safety: `metal_buf.contents()` is non-null (verified above), points
139        // to exactly `byte_len` bytes of `StorageModeShared` memory we just
140        // allocated and have exclusive access to (no other thread or GPU
141        // dispatch references it yet — we haven't returned the MlxBuffer
142        // wrapper yet, and the underlying CB queue is not in flight on this
143        // allocation). Writing zero bytes is well-defined for any DType.
144        unsafe {
145            std::ptr::write_bytes(metal_buf.contents() as *mut u8, 0, byte_len);
146        }
147        // ADR-015 iter8e (Phase 3b): auto-register the new allocation with the
148        // device's residency set so it gets the MTLResidencySet hint on the
149        // next dispatch. The `with_residency` path stages the addAllocation
150        // but DEFERS the `[set commit]` to the next CommandEncoder::commit*
151        // boundary via flush_pending — mirrors llama.cpp's batch-add /
152        // single-commit pattern in ggml-metal-device.m:1378-1382.
153        //
154        // No-op when residency_set is None (HF2Q_NO_RESIDENCY=1, macOS<15,
155        // or no Metal device).
156        match self.residency_set.as_ref() {
157            Some(set) => Ok(MlxBuffer::with_residency(
158                metal_buf,
159                dtype,
160                shape,
161                set.clone(),
162            )),
163            None => Ok(MlxBuffer::from_raw(metal_buf, dtype, shape)),
164        }
165    }
166
167    /// Borrow the underlying `metal::Device` for direct Metal API calls
168    /// (e.g. kernel compilation in [`KernelRegistry`](crate::KernelRegistry)).
169    #[inline]
170    pub fn metal_device(&self) -> &metal::DeviceRef {
171        &self.device
172    }
173
174    /// Borrow the underlying `metal::CommandQueue`.
175    #[inline]
176    pub fn metal_queue(&self) -> &CommandQueue {
177        &self.queue
178    }
179
180    /// Borrow the device-level residency set, if residency support is enabled.
181    #[inline]
182    pub(crate) fn residency_set(&self) -> Option<&ResidencySet> {
183        self.residency_set.as_ref()
184    }
185
186    /// Return whether this device has an active Metal residency set.
187    #[inline]
188    pub fn residency_sets_enabled(&self) -> bool {
189        self.residency_set.is_some()
190    }
191
192    /// Human-readable name of the GPU (e.g. "Apple M2 Max").
193    pub fn name(&self) -> String {
194        self.device.name().to_string()
195    }
196}
197
198impl std::fmt::Debug for MlxDevice {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        f.debug_struct("MlxDevice")
201            .field("name", &self.device.name())
202            .finish()
203    }
204}