Skip to main content

rlx_runtime/
weights.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Weight-loading abstraction.
17//!
18//! Native targets typically `mmap` a `.safetensors` file and read tensors
19//! via byte-offset slices into the mapping. WASM has no `mmap`; weights
20//! arrive as `Vec<u8>` from `fetch()` or `Response.arrayBuffer()`. Both
21//! paths produce the same shape: a name → byte-slice lookup.
22//!
23//! `WeightLoader` is the contract. Concrete implementations live in
24//! in-tree model loaders (mmap-based) and here (`BytesWeightLoader` — works on
25//! every target including WASM).
26
27/// A name-keyed view of weight tensor bytes.
28///
29/// Implementations promise that the returned slice stays valid for the
30/// lifetime of `&self`. On native, this is the mmap region; on WASM, it
31/// is the in-memory `Vec<u8>` owned by the loader.
32pub trait WeightLoader {
33    /// Return the raw bytes for the tensor named `name`, or `None` if not
34    /// present. Bytes are in the source file's storage order (typically
35    /// row-major, dtype-native).
36    fn tensor_bytes(&self, name: &str) -> Option<&[u8]>;
37
38    /// All tensor names (for iteration / discovery). Order is
39    /// implementation-defined but stable for a given loader instance.
40    fn names(&self) -> Vec<String>;
41}
42
43/// Owned in-memory weight loader. The simplest, most portable variant —
44/// works on every target including WASM.
45///
46/// Construct via `BytesWeightLoader::from_safetensors(bytes)` once
47/// model builders integrate. For now the bare struct lets external callers
48/// build their own name → bytes mapping.
49pub struct BytesWeightLoader {
50    /// `(name, start_offset, len)` triples into `data`.
51    entries: Vec<(String, usize, usize)>,
52    data: Vec<u8>,
53}
54
55impl BytesWeightLoader {
56    /// Build a loader from a list of `(name, bytes)` pairs. Each tensor
57    /// is appended into a single backing `Vec<u8>`; `tensor_bytes` returns
58    /// a borrow into that vec.
59    pub fn from_pairs(pairs: Vec<(String, Vec<u8>)>) -> Self {
60        let total: usize = pairs.iter().map(|(_, b)| b.len()).sum();
61        let mut data = Vec::with_capacity(total);
62        let mut entries = Vec::with_capacity(pairs.len());
63        for (name, bytes) in pairs {
64            let start = data.len();
65            let len = bytes.len();
66            data.extend_from_slice(&bytes);
67            entries.push((name, start, len));
68        }
69        Self { entries, data }
70    }
71}
72
73impl WeightLoader for BytesWeightLoader {
74    fn tensor_bytes(&self, name: &str) -> Option<&[u8]> {
75        self.entries
76            .iter()
77            .find(|(n, _, _)| n == name)
78            .map(|(_, off, len)| &self.data[*off..*off + *len])
79    }
80
81    fn names(&self) -> Vec<String> {
82        self.entries.iter().map(|(n, _, _)| n.clone()).collect()
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn round_trip() {
92        let loader = BytesWeightLoader::from_pairs(vec![
93            ("w".into(), vec![1, 2, 3, 4]),
94            ("b".into(), vec![5, 6]),
95        ]);
96        assert_eq!(loader.tensor_bytes("w"), Some(&[1u8, 2, 3, 4][..]));
97        assert_eq!(loader.tensor_bytes("b"), Some(&[5u8, 6][..]));
98        assert_eq!(loader.tensor_bytes("missing"), None);
99        assert_eq!(loader.names(), vec!["w".to_string(), "b".to_string()]);
100    }
101}