rlx_runtime/weights.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Weight-loading abstraction.
17//!
18//! Native targets typically `mmap` a `.safetensors` file and read tensors
19//! via byte-offset slices into the mapping. WASM has no `mmap`; weights
20//! arrive as `Vec<u8>` from `fetch()` or `Response.arrayBuffer()`. Both
21//! paths produce the same shape: a name → byte-slice lookup.
22//!
23//! `WeightLoader` is the contract. Concrete implementations live in
24//! in-tree model loaders (mmap-based) and here (`BytesWeightLoader` — works on
25//! every target including WASM).
26
27/// A name-keyed view of weight tensor bytes.
28///
29/// Implementations promise that the returned slice stays valid for the
30/// lifetime of `&self`. On native, this is the mmap region; on WASM, it
31/// is the in-memory `Vec<u8>` owned by the loader.
32pub trait WeightLoader {
33 /// Return the raw bytes for the tensor named `name`, or `None` if not
34 /// present. Bytes are in the source file's storage order (typically
35 /// row-major, dtype-native).
36 fn tensor_bytes(&self, name: &str) -> Option<&[u8]>;
37
38 /// All tensor names (for iteration / discovery). Order is
39 /// implementation-defined but stable for a given loader instance.
40 fn names(&self) -> Vec<String>;
41}
42
43/// Owned in-memory weight loader. The simplest, most portable variant —
44/// works on every target including WASM.
45///
46/// Construct via `BytesWeightLoader::from_safetensors(bytes)` once
47/// model builders integrate. For now the bare struct lets external callers
48/// build their own name → bytes mapping.
49pub struct BytesWeightLoader {
50 /// `(name, start_offset, len)` triples into `data`.
51 entries: Vec<(String, usize, usize)>,
52 data: Vec<u8>,
53}
54
55impl BytesWeightLoader {
56 /// Build a loader from a list of `(name, bytes)` pairs. Each tensor
57 /// is appended into a single backing `Vec<u8>`; `tensor_bytes` returns
58 /// a borrow into that vec.
59 pub fn from_pairs(pairs: Vec<(String, Vec<u8>)>) -> Self {
60 let total: usize = pairs.iter().map(|(_, b)| b.len()).sum();
61 let mut data = Vec::with_capacity(total);
62 let mut entries = Vec::with_capacity(pairs.len());
63 for (name, bytes) in pairs {
64 let start = data.len();
65 let len = bytes.len();
66 data.extend_from_slice(&bytes);
67 entries.push((name, start, len));
68 }
69 Self { entries, data }
70 }
71}
72
73impl WeightLoader for BytesWeightLoader {
74 fn tensor_bytes(&self, name: &str) -> Option<&[u8]> {
75 self.entries
76 .iter()
77 .find(|(n, _, _)| n == name)
78 .map(|(_, off, len)| &self.data[*off..*off + *len])
79 }
80
81 fn names(&self) -> Vec<String> {
82 self.entries.iter().map(|(n, _, _)| n.clone()).collect()
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89
90 #[test]
91 fn round_trip() {
92 let loader = BytesWeightLoader::from_pairs(vec![
93 ("w".into(), vec![1, 2, 3, 4]),
94 ("b".into(), vec![5, 6]),
95 ]);
96 assert_eq!(loader.tensor_bytes("w"), Some(&[1u8, 2, 3, 4][..]));
97 assert_eq!(loader.tensor_bytes("b"), Some(&[5u8, 6][..]));
98 assert_eq!(loader.tensor_bytes("missing"), None);
99 assert_eq!(loader.names(), vec!["w".to_string(), "b".to_string()]);
100 }
101}