mlx_native/device.rs
1//! [`MlxDevice`] — Metal device and command queue wrapper.
2//!
3//! This is the entry-point for all GPU work. Create one with
4//! [`MlxDevice::new()`] and use it to allocate buffers and create
5//! command encoders.
6
7use metal::{Device, CommandQueue, MTLResourceOptions};
8
9use crate::buffer::MlxBuffer;
10use crate::dtypes::DType;
11use crate::encoder::CommandEncoder;
12use crate::error::{MlxError, Result};
13
14/// Wraps a Metal device and its command queue.
15///
16/// # Thread Safety
17///
18/// `MlxDevice` is `Send + Sync` — you can share it across threads. The
19/// underlying Metal device and command queue are thread-safe on Apple Silicon.
20pub struct MlxDevice {
21 device: Device,
22 queue: CommandQueue,
23}
24
25// metal::Device and metal::CommandQueue are both Send + Sync.
26crate::static_assertions_send_sync!(MlxDevice);
27
28impl MlxDevice {
29 /// Initialize the Metal GPU device and create a command queue.
30 ///
31 /// Returns `Err(MlxError::DeviceNotFound)` if no Metal device is available
32 /// (e.g. running on a non-Apple-Silicon machine or in a headless Linux VM).
33 pub fn new() -> Result<Self> {
34 let device = Device::system_default().ok_or(MlxError::DeviceNotFound)?;
35 let queue = device.new_command_queue();
36 Ok(Self { device, queue })
37 }
38
39 /// Create a [`CommandEncoder`] for batching GPU dispatches.
40 ///
41 /// The encoder wraps a fresh Metal command buffer from the device's command
42 /// queue. Encode one or more kernel dispatches, then call
43 /// [`CommandEncoder::commit_and_wait`] to submit and block until completion.
44 pub fn command_encoder(&self) -> Result<CommandEncoder> {
45 CommandEncoder::new(&self.queue)
46 }
47
48 /// Allocate a new GPU buffer with `StorageModeShared`.
49 ///
50 /// # Arguments
51 ///
52 /// * `byte_len` — Size of the buffer in bytes. Must be > 0.
53 /// * `dtype` — Element data type for metadata tracking.
54 /// * `shape` — Tensor dimensions for metadata tracking.
55 ///
56 /// # Errors
57 ///
58 /// Returns `MlxError::InvalidArgument` if `byte_len` is zero.
59 /// Returns `MlxError::BufferAllocationError` if Metal cannot allocate.
60 pub fn alloc_buffer(
61 &self,
62 byte_len: usize,
63 dtype: DType,
64 shape: Vec<usize>,
65 ) -> Result<MlxBuffer> {
66 if byte_len == 0 {
67 return Err(MlxError::InvalidArgument(
68 "Buffer byte length must be > 0".into(),
69 ));
70 }
71 let metal_buf = self.device.new_buffer(
72 byte_len as u64,
73 MTLResourceOptions::StorageModeShared,
74 );
75 // Metal returns a non-null buffer on success; a null pointer indicates
76 // failure (typically out-of-memory).
77 if metal_buf.contents().is_null() {
78 return Err(MlxError::BufferAllocationError { bytes: byte_len });
79 }
80 Ok(MlxBuffer::from_raw(metal_buf, dtype, shape))
81 }
82
83 /// Borrow the underlying `metal::Device` for direct Metal API calls
84 /// (e.g. kernel compilation in [`KernelRegistry`](crate::KernelRegistry)).
85 #[inline]
86 pub fn metal_device(&self) -> &metal::DeviceRef {
87 &self.device
88 }
89
90 /// Borrow the underlying `metal::CommandQueue`.
91 #[inline]
92 pub fn metal_queue(&self) -> &CommandQueue {
93 &self.queue
94 }
95
96 /// Human-readable name of the GPU (e.g. "Apple M2 Max").
97 pub fn name(&self) -> String {
98 self.device.name().to_string()
99 }
100}
101
102impl std::fmt::Debug for MlxDevice {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("MlxDevice")
105 .field("name", &self.device.name())
106 .finish()
107 }
108}