Skip to main content

mlx_native/
device.rs

1//! [`MlxDevice`] — Metal device and command queue wrapper.
2//!
3//! This is the entry-point for all GPU work.  Create one with
4//! [`MlxDevice::new()`] and use it to allocate buffers and create
5//! command encoders.
6
7use metal::{Device, CommandQueue, MTLResourceOptions};
8
9use crate::buffer::MlxBuffer;
10use crate::dtypes::DType;
11use crate::encoder::CommandEncoder;
12use crate::error::{MlxError, Result};
13
14/// Wraps a Metal device and its command queue.
15///
16/// # Thread Safety
17///
18/// `MlxDevice` is `Send + Sync` — you can share it across threads. The
19/// underlying Metal device and command queue are thread-safe on Apple Silicon.
20pub struct MlxDevice {
21    device: Device,
22    queue: CommandQueue,
23}
24
25// metal::Device and metal::CommandQueue are both Send + Sync.
26crate::static_assertions_send_sync!(MlxDevice);
27
28impl MlxDevice {
29    /// Initialize the Metal GPU device and create a command queue.
30    ///
31    /// Returns `Err(MlxError::DeviceNotFound)` if no Metal device is available
32    /// (e.g. running on a non-Apple-Silicon machine or in a headless Linux VM).
33    pub fn new() -> Result<Self> {
34        let device = Device::system_default().ok_or(MlxError::DeviceNotFound)?;
35        let queue = device.new_command_queue();
36        Ok(Self { device, queue })
37    }
38
39    /// Create a [`CommandEncoder`] for batching GPU dispatches.
40    ///
41    /// The encoder wraps a fresh Metal command buffer from the device's command
42    /// queue.  Encode one or more kernel dispatches, then call
43    /// [`CommandEncoder::commit_and_wait`] to submit and block until completion.
44    pub fn command_encoder(&self) -> Result<CommandEncoder> {
45        CommandEncoder::new(&self.queue)
46    }
47
48    /// Allocate a new GPU buffer with `StorageModeShared`.
49    ///
50    /// # Arguments
51    ///
52    /// * `byte_len` — Size of the buffer in bytes.  Must be > 0.
53    /// * `dtype`    — Element data type for metadata tracking.
54    /// * `shape`    — Tensor dimensions for metadata tracking.
55    ///
56    /// # Errors
57    ///
58    /// Returns `MlxError::InvalidArgument` if `byte_len` is zero.
59    /// Returns `MlxError::BufferAllocationError` if Metal cannot allocate.
60    pub fn alloc_buffer(
61        &self,
62        byte_len: usize,
63        dtype: DType,
64        shape: Vec<usize>,
65    ) -> Result<MlxBuffer> {
66        if byte_len == 0 {
67            return Err(MlxError::InvalidArgument(
68                "Buffer byte length must be > 0".into(),
69            ));
70        }
71        let metal_buf = self.device.new_buffer(
72            byte_len as u64,
73            MTLResourceOptions::StorageModeShared,
74        );
75        // Metal returns a non-null buffer on success; a null pointer indicates
76        // failure (typically out-of-memory).
77        if metal_buf.contents().is_null() {
78            return Err(MlxError::BufferAllocationError { bytes: byte_len });
79        }
80        Ok(MlxBuffer::from_raw(metal_buf, dtype, shape))
81    }
82
83    /// Borrow the underlying `metal::Device` for direct Metal API calls
84    /// (e.g. kernel compilation in [`KernelRegistry`](crate::KernelRegistry)).
85    #[inline]
86    pub fn metal_device(&self) -> &metal::DeviceRef {
87        &self.device
88    }
89
90    /// Borrow the underlying `metal::CommandQueue`.
91    #[inline]
92    pub fn metal_queue(&self) -> &CommandQueue {
93        &self.queue
94    }
95
96    /// Human-readable name of the GPU (e.g. "Apple M2 Max").
97    pub fn name(&self) -> String {
98        self.device.name().to_string()
99    }
100}
101
102impl std::fmt::Debug for MlxDevice {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        f.debug_struct("MlxDevice")
105            .field("name", &self.device.name())
106            .finish()
107    }
108}