rlx_driver/buffer.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Explicit host ↔ device buffer abstraction (plan #59).
17//!
18//! Borrowed from MAX's `max/python/max/driver/buffer.py`. The point
19//! is *explicitness*: every host↔device transfer is a method call,
20//! not implicit conversion. Misreading "device" data as host (or
21//! vice versa) becomes a compile error, not a silent perf bug.
22//!
23//! Today this wraps a host-side `Vec<u8>` plus a `Device` tag. The
24//! `to_host` / `to_device` calls are no-ops on Device::Cpu; on
25//! Device::Metal a future commit will route through the existing
26//! `rlx-metal::arena` for actual MTLBuffer transfers.
27
28use crate::Device;
29use rlx_ir::{DType, Shape};
30
31/// A buffer that knows where its bytes live.
32#[derive(Debug, Clone)]
33pub struct Buffer {
34 bytes: Vec<u8>,
35 shape: Shape,
36 device: Device,
37}
38
39impl Buffer {
40 /// Create a buffer holding `data`, tagged as residing on `device`.
41 /// On Cpu this is just storage; on Metal callers will use
42 /// `to_device` to move it into device memory once the runtime
43 /// integrates with the arena.
44 pub fn new_host(shape: Shape, data: Vec<u8>) -> Self {
45 Self {
46 bytes: data,
47 shape,
48 device: Device::Cpu,
49 }
50 }
51
52 /// Create a zero-initialized host buffer of `shape`.
53 pub fn zeros(shape: Shape) -> Self {
54 let n = shape.size_bytes().unwrap_or(0);
55 Self {
56 bytes: vec![0u8; n],
57 shape,
58 device: Device::Cpu,
59 }
60 }
61
62 pub fn shape(&self) -> &Shape {
63 &self.shape
64 }
65 pub fn device(&self) -> Device {
66 self.device
67 }
68 pub fn dtype(&self) -> DType {
69 self.shape.dtype()
70 }
71 pub fn num_elements(&self) -> usize {
72 self.shape.num_elements().unwrap_or(0)
73 }
74 pub fn byte_size(&self) -> usize {
75 self.bytes.len()
76 }
77
78 /// Read as `&[f32]`. Panics if dtype isn't F32 — explicit type
79 /// check (plan #59 spirit: don't let mismatches go silent).
80 pub fn as_f32(&self) -> &[f32] {
81 assert_eq!(self.dtype(), DType::F32, "as_f32 on non-F32 buffer");
82 let n = self.num_elements();
83 unsafe { std::slice::from_raw_parts(self.bytes.as_ptr() as *const f32, n) }
84 }
85
86 pub fn as_f32_mut(&mut self) -> &mut [f32] {
87 assert_eq!(self.dtype(), DType::F32, "as_f32_mut on non-F32 buffer");
88 let n = self.num_elements();
89 unsafe { std::slice::from_raw_parts_mut(self.bytes.as_mut_ptr() as *mut f32, n) }
90 }
91
92 /// "Move to device" — explicit transfer call. CPU is a no-op
93 /// (the bytes are already where they need to be). Metal routes
94 /// through the backend (TODO: wire to rlx-metal::arena once a
95 /// caller needs it).
96 pub fn to_device(self, device: Device) -> Self {
97 match (self.device, device) {
98 (a, b) if a == b => self,
99 (Device::Cpu, Device::Metal) => Self {
100 device: Device::Metal,
101 ..self
102 },
103 (Device::Metal, Device::Cpu) => Self {
104 device: Device::Cpu,
105 ..self
106 },
107 _ => self,
108 }
109 }
110
111 /// "Read back to host" — explicit transfer. CPU no-op; Metal
112 /// blocks on completion + memcpy back from MTLBuffer (TODO).
113 pub fn to_host(self) -> Self {
114 self.to_device(Device::Cpu)
115 }
116
117 /// Raw bytes — host-side access. Panics if the buffer is on a
118 /// non-host device (would silently read uninitialized host
119 /// memory otherwise).
120 pub fn host_bytes(&self) -> &[u8] {
121 assert_eq!(
122 self.device,
123 Device::Cpu,
124 "host_bytes on non-host buffer; call .to_host() first"
125 );
126 &self.bytes
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn zeros_initializes() {
136 let b = Buffer::zeros(Shape::new(&[2, 3], DType::F32));
137 assert_eq!(b.num_elements(), 6);
138 assert_eq!(b.byte_size(), 24);
139 for v in b.as_f32() {
140 assert_eq!(*v, 0.0);
141 }
142 }
143
144 #[test]
145 fn dtype_mismatch_panics() {
146 let b = Buffer::zeros(Shape::new(&[4], DType::I32));
147 let result = std::panic::catch_unwind(|| b.as_f32());
148 assert!(result.is_err());
149 }
150
151 #[test]
152 fn to_device_round_trip() {
153 let b = Buffer::zeros(Shape::new(&[4], DType::F32));
154 let on_metal = b.to_device(Device::Metal);
155 assert_eq!(on_metal.device(), Device::Metal);
156 let back = on_metal.to_host();
157 assert_eq!(back.device(), Device::Cpu);
158 }
159}