Skip to main content

rlx_driver/
buffer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Explicit host ↔ device buffer abstraction (plan #59).
17//!
18//! Borrowed from MAX's `max/python/max/driver/buffer.py`. The point
19//! is *explicitness*: every host↔device transfer is a method call,
20//! not implicit conversion. Misreading "device" data as host (or
21//! vice versa) becomes a compile error, not a silent perf bug.
22//!
23//! Today this wraps a host-side `Vec<u8>` plus a `Device` tag. The
24//! `to_host` / `to_device` calls are no-ops on Device::Cpu; on
25//! Device::Metal a future commit will route through the existing
26//! `rlx-metal::arena` for actual MTLBuffer transfers.
27
28use crate::Device;
29use rlx_ir::{DType, Shape};
30
31/// A buffer that knows where its bytes live.
32#[derive(Debug, Clone)]
33pub struct Buffer {
34    bytes: Vec<u8>,
35    shape: Shape,
36    device: Device,
37}
38
39impl Buffer {
40    /// Create a buffer holding `data`, tagged as residing on `device`.
41    /// On Cpu this is just storage; on Metal callers will use
42    /// `to_device` to move it into device memory once the runtime
43    /// integrates with the arena.
44    pub fn new_host(shape: Shape, data: Vec<u8>) -> Self {
45        Self {
46            bytes: data,
47            shape,
48            device: Device::Cpu,
49        }
50    }
51
52    /// Create a zero-initialized host buffer of `shape`.
53    pub fn zeros(shape: Shape) -> Self {
54        let n = shape.size_bytes().unwrap_or(0);
55        Self {
56            bytes: vec![0u8; n],
57            shape,
58            device: Device::Cpu,
59        }
60    }
61
62    pub fn shape(&self) -> &Shape {
63        &self.shape
64    }
65    pub fn device(&self) -> Device {
66        self.device
67    }
68    pub fn dtype(&self) -> DType {
69        self.shape.dtype()
70    }
71    pub fn num_elements(&self) -> usize {
72        self.shape.num_elements().unwrap_or(0)
73    }
74    pub fn byte_size(&self) -> usize {
75        self.bytes.len()
76    }
77
78    /// Read as `&[f32]`. Panics if dtype isn't F32 — explicit type
79    /// check (plan #59 spirit: don't let mismatches go silent).
80    pub fn as_f32(&self) -> &[f32] {
81        assert_eq!(self.dtype(), DType::F32, "as_f32 on non-F32 buffer");
82        let n = self.num_elements();
83        unsafe { std::slice::from_raw_parts(self.bytes.as_ptr() as *const f32, n) }
84    }
85
86    pub fn as_f32_mut(&mut self) -> &mut [f32] {
87        assert_eq!(self.dtype(), DType::F32, "as_f32_mut on non-F32 buffer");
88        let n = self.num_elements();
89        unsafe { std::slice::from_raw_parts_mut(self.bytes.as_mut_ptr() as *mut f32, n) }
90    }
91
92    /// "Move to device" — explicit transfer call. CPU is a no-op
93    /// (the bytes are already where they need to be). Metal routes
94    /// through the backend (TODO: wire to rlx-metal::arena once a
95    /// caller needs it).
96    pub fn to_device(self, device: Device) -> Self {
97        match (self.device, device) {
98            (a, b) if a == b => self,
99            (Device::Cpu, Device::Metal) => Self {
100                device: Device::Metal,
101                ..self
102            },
103            (Device::Metal, Device::Cpu) => Self {
104                device: Device::Cpu,
105                ..self
106            },
107            _ => self,
108        }
109    }
110
111    /// "Read back to host" — explicit transfer. CPU no-op; Metal
112    /// blocks on completion + memcpy back from MTLBuffer (TODO).
113    pub fn to_host(self) -> Self {
114        self.to_device(Device::Cpu)
115    }
116
117    /// Raw bytes — host-side access. Panics if the buffer is on a
118    /// non-host device (would silently read uninitialized host
119    /// memory otherwise).
120    pub fn host_bytes(&self) -> &[u8] {
121        assert_eq!(
122            self.device,
123            Device::Cpu,
124            "host_bytes on non-host buffer; call .to_host() first"
125        );
126        &self.bytes
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn zeros_initializes() {
136        let b = Buffer::zeros(Shape::new(&[2, 3], DType::F32));
137        assert_eq!(b.num_elements(), 6);
138        assert_eq!(b.byte_size(), 24);
139        for v in b.as_f32() {
140            assert_eq!(*v, 0.0);
141        }
142    }
143
144    #[test]
145    fn dtype_mismatch_panics() {
146        let b = Buffer::zeros(Shape::new(&[4], DType::I32));
147        let result = std::panic::catch_unwind(|| b.as_f32());
148        assert!(result.is_err());
149    }
150
151    #[test]
152    fn to_device_round_trip() {
153        let b = Buffer::zeros(Shape::new(&[4], DType::F32));
154        let on_metal = b.to_device(Device::Metal);
155        assert_eq!(on_metal.device(), Device::Metal);
156        let back = on_metal.to_host();
157        assert_eq!(back.device(), Device::Cpu);
158    }
159}