Skip to main content

rlx_driver/
handle.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Persistent buffer handles — state that survives across forward passes.
17//!
18//! For inference of stateful models (KV-cache, beam search) and for
19//! training (gradient accumulators, optimizer state), the runtime needs
20//! buffers that persist beyond a single `compiled.run()`. The arena
21//! is rebuilt every compile, so it can't carry state.
22//!
23//! `BufferHandle` is an opaque, stable identifier the user creates once
24//! and binds at compile time. The backend allocates a separate "handles"
25//! region (independent of the arena) and routes reads/writes there.
26//!
27//! Workflow:
28//!
29//! ```rust,ignore
30//! let kv_cache = BufferHandle::new("kv", &[batch, max_seq, num_heads, head_dim]);
31//! let session = Session::new(Device::Metal);
32//! let mut compiled = session.compile_with(graph, &CompileOptions::new()
33//!     .bind_handle(&kv_cache));
34//!
35//! for token in tokens {
36//!     compiled.bind_handle("kv", &kv_cache_data); // initial value
37//!     let logits = compiled.run(&[("token", &[token])]);
38//!     kv_cache_data = compiled.read_handle("kv").unwrap();
39//! }
40//! ```
41
42use rlx_ir::{DType, Shape};
43
44/// External, persistent buffer reference. Created once, bound at compile,
45/// carried across many `compiled.run()` invocations.
46#[derive(Debug, Clone)]
47pub struct BufferHandle {
48    pub name: String,
49    pub shape: Shape,
50}
51
52impl BufferHandle {
53    pub fn new(name: impl Into<String>, dims: &[usize], dtype: DType) -> Self {
54        Self {
55            name: name.into(),
56            shape: Shape::new(dims, dtype),
57        }
58    }
59
60    /// Total byte size of this handle.
61    pub fn byte_size(&self) -> usize {
62        self.shape.size_bytes().unwrap_or(0)
63    }
64
65    /// Number of elements (for slicing as &[f32] etc.).
66    pub fn num_elements(&self) -> usize {
67        self.shape.num_elements().unwrap_or(0)
68    }
69}