Skip to main content

rlx_runtime/
weight_registry.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Named weights registry (plan #24).
17//!
18//! Borrowed from MAX's `weights_registry/weights_registry.mojo`.
19//! Promotes the passive `weights.rs` loader contract to an active
20//! per-process registry: named handles + reference counts + LoRA
21//! adapter accounting.
22//!
23//! Why a registry beyond the existing `WeightLoader`?
24//!   - **LoRA hot-swap.** Multiple adapters share the same base
25//!     weights. The registry owns the bytes; adapters are
26//!     handle-pointers with their own metadata.
27//!   - **Tied embeddings.** GPT-2 / LLaMA / Gemma tie input
28//!     embedding to output projection. With a registry both
29//!     positions resolve to the same `WeightEntry`.
30//!   - **Weight streaming.** Future "load layer N when running
31//!     layer N" patterns need to refcount which weights are
32//!     in-memory.
33//!   - **Memory accounting.** [`#35`] memory estimation queries
34//!     `total_bytes()` to gate model loads against unified-memory
35//!     budget.
36
37use rlx_ir::Shape;
38use std::collections::HashMap;
39use std::sync::Arc;
40use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
41
42/// Stable handle into a [`WeightRegistry`]. Cheap to copy; passing
43/// it around is the recommended way to refer to a weight without
44/// keeping the registry borrowed.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub struct WeightHandle(u64);
47
48impl WeightHandle {
49    pub fn id(self) -> u64 {
50        self.0
51    }
52}
53
54/// What role a weight plays. Drives downstream scheduling
55/// (LoRA-aware request grouping) and accounting.
56#[derive(Debug, Clone)]
57pub enum WeightKind {
58    /// A base model weight — independent storage.
59    Base,
60    /// A LoRA adapter's slice (down-proj A or up-proj B).
61    /// Multiple adapters can attach to the same base; the
62    /// scheduler groups requests by `adapter` name.
63    LoraAdapter { adapter: String },
64    /// A view that resolves to another weight's storage. Used for
65    /// tied embeddings — `embed_tokens.weight` and
66    /// `lm_head.weight` are the same buffer, two names.
67    TiedAlias { target: WeightHandle },
68}
69
70/// One entry in the registry.
71#[derive(Debug)]
72pub struct WeightEntry {
73    pub name: String,
74    pub shape: Shape,
75    pub kind: WeightKind,
76    /// `Arc` so multiple consumers (different graphs / different
77    /// LoRA combinations / hot-reload) can hold the same bytes
78    /// without copying.
79    pub bytes: Arc<[u8]>,
80    /// Ref count. Goes up on `pin`, down on `release`. Hitting
81    /// zero on `release` keeps the entry in the registry; explicit
82    /// `unregister` is required to drop. This separation matters
83    /// for "weight streaming" use cases that re-pin frequently.
84    pub refs: AtomicUsize,
85}
86
87/// The registry itself. One per process is the typical setup; a
88/// `Session` can borrow a registry to resolve weight names during
89/// graph compile.
90pub struct WeightRegistry {
91    by_name: HashMap<String, WeightHandle>,
92    by_handle: HashMap<u64, Arc<WeightEntry>>,
93    next_id: AtomicU64,
94}
95
96impl WeightRegistry {
97    pub fn new() -> Self {
98        Self {
99            by_name: HashMap::new(),
100            by_handle: HashMap::new(),
101            next_id: AtomicU64::new(0),
102        }
103    }
104
105    fn alloc_id(&self) -> u64 {
106        self.next_id.fetch_add(1, Ordering::Relaxed)
107    }
108
109    /// Register a fresh weight under `name`. If `name` already
110    /// exists, returns the existing handle (idempotent — useful for
111    /// weight loaders that may see the same tensor twice).
112    pub fn register(
113        &mut self,
114        name: impl Into<String>,
115        shape: Shape,
116        bytes: Arc<[u8]>,
117        kind: WeightKind,
118    ) -> WeightHandle {
119        let name = name.into();
120        if let Some(&h) = self.by_name.get(&name) {
121            return h;
122        }
123        let id = self.alloc_id();
124        let h = WeightHandle(id);
125        let entry = Arc::new(WeightEntry {
126            name: name.clone(),
127            shape,
128            kind,
129            bytes,
130            refs: AtomicUsize::new(0),
131        });
132        self.by_name.insert(name, h);
133        self.by_handle.insert(id, entry);
134        h
135    }
136
137    /// Resolve a name → handle.
138    pub fn lookup(&self, name: &str) -> Option<WeightHandle> {
139        self.by_name.get(name).copied()
140    }
141
142    /// Read the entry for `handle`. Resolves a TiedAlias one step.
143    pub fn get(&self, handle: WeightHandle) -> Option<&Arc<WeightEntry>> {
144        let entry = self.by_handle.get(&handle.0)?;
145        if let WeightKind::TiedAlias { target } = entry.kind {
146            return self.by_handle.get(&target.0);
147        }
148        Some(entry)
149    }
150
151    /// Increment the entry's refcount. Returns the new count.
152    pub fn pin(&self, handle: WeightHandle) -> Option<usize> {
153        let entry = self.by_handle.get(&handle.0)?;
154        Some(entry.refs.fetch_add(1, Ordering::Relaxed) + 1)
155    }
156
157    /// Decrement the entry's refcount. Returns the new count.
158    /// Hitting zero does NOT drop the entry — call `unregister` to
159    /// drop. Returns `None` if the handle is unknown.
160    pub fn release(&self, handle: WeightHandle) -> Option<usize> {
161        let entry = self.by_handle.get(&handle.0)?;
162        let prev = entry.refs.fetch_sub(1, Ordering::Relaxed);
163        debug_assert!(prev >= 1, "release on a zero-refcount entry");
164        Some(prev - 1)
165    }
166
167    /// Drop an entry. Caller must have already `release`d to zero
168    /// (debug-asserted). Returns the entry's name on success.
169    pub fn unregister(&mut self, handle: WeightHandle) -> Option<String> {
170        let entry = self.by_handle.remove(&handle.0)?;
171        debug_assert_eq!(
172            entry.refs.load(Ordering::Relaxed),
173            0,
174            "unregister on a still-referenced entry: refs={}",
175            entry.refs.load(Ordering::Relaxed)
176        );
177        self.by_name.remove(&entry.name);
178        Some(entry.name.clone())
179    }
180
181    /// Total registered weight bytes — sums each base entry's
182    /// `bytes.len()` plus each adapter; tied-alias entries don't
183    /// double-count (they share storage).
184    pub fn total_bytes(&self) -> usize {
185        self.by_handle
186            .values()
187            .filter(|e| !matches!(e.kind, WeightKind::TiedAlias { .. }))
188            .map(|e| e.bytes.len())
189            .sum()
190    }
191
192    /// All handles whose kind is `LoraAdapter { adapter: <name> }`.
193    /// Used by LoRA-aware scheduling (#33) to group requests.
194    pub fn lora_adapter_handles(&self, adapter: &str) -> Vec<WeightHandle> {
195        let mut v: Vec<WeightHandle> = self
196            .by_handle
197            .iter()
198            .filter_map(|(&id, e)| match &e.kind {
199                WeightKind::LoraAdapter { adapter: a } if a == adapter => Some(WeightHandle(id)),
200                _ => None,
201            })
202            .collect();
203        v.sort_by_key(|h| h.0);
204        v
205    }
206
207    /// All distinct LoRA adapter names currently registered.
208    pub fn lora_adapter_names(&self) -> Vec<String> {
209        let mut s: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
210        for e in self.by_handle.values() {
211            if let WeightKind::LoraAdapter { adapter } = &e.kind {
212                s.insert(adapter.clone());
213            }
214        }
215        s.into_iter().collect()
216    }
217
218    pub fn len(&self) -> usize {
219        self.by_handle.len()
220    }
221    pub fn is_empty(&self) -> bool {
222        self.by_handle.is_empty()
223    }
224}
225
226impl Default for WeightRegistry {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use rlx_ir::DType;
236
237    fn shape() -> Shape {
238        Shape::new(&[8, 8], DType::F32)
239    }
240    fn bytes(n: usize) -> Arc<[u8]> {
241        vec![0u8; n].into()
242    }
243
244    #[test]
245    fn register_and_lookup() {
246        let mut r = WeightRegistry::new();
247        let h = r.register("w", shape(), bytes(256), WeightKind::Base);
248        assert_eq!(r.lookup("w"), Some(h));
249        let entry = r.get(h).unwrap();
250        assert_eq!(entry.name, "w");
251        assert_eq!(entry.bytes.len(), 256);
252    }
253
254    #[test]
255    fn register_is_idempotent() {
256        let mut r = WeightRegistry::new();
257        let h1 = r.register("w", shape(), bytes(128), WeightKind::Base);
258        let h2 = r.register("w", shape(), bytes(999), WeightKind::Base);
259        // Same handle; the second register call doesn't overwrite.
260        assert_eq!(h1, h2);
261        assert_eq!(r.get(h1).unwrap().bytes.len(), 128);
262    }
263
264    #[test]
265    fn pin_release_balance() {
266        let mut r = WeightRegistry::new();
267        let h = r.register("w", shape(), bytes(64), WeightKind::Base);
268        assert_eq!(r.pin(h), Some(1));
269        assert_eq!(r.pin(h), Some(2));
270        assert_eq!(r.release(h), Some(1));
271        assert_eq!(r.release(h), Some(0));
272        // Unregister now ok.
273        assert_eq!(r.unregister(h), Some("w".to_string()));
274        assert!(r.lookup("w").is_none());
275    }
276
277    #[test]
278    fn tied_alias_resolves_to_target() {
279        let mut r = WeightRegistry::new();
280        let target = r.register("embed", shape(), bytes(128), WeightKind::Base);
281        let alias = r.register(
282            "lm_head",
283            shape(),
284            bytes(0), // alias has no bytes of its own
285            WeightKind::TiedAlias { target },
286        );
287        let resolved = r.get(alias).unwrap();
288        assert_eq!(resolved.name, "embed");
289        assert_eq!(resolved.bytes.len(), 128);
290    }
291
292    #[test]
293    fn total_bytes_skips_aliases() {
294        let mut r = WeightRegistry::new();
295        let _t = r.register("embed", shape(), bytes(100), WeightKind::Base);
296        let _a = r.register(
297            "lm_head",
298            shape(),
299            bytes(0),
300            WeightKind::TiedAlias {
301                target: r.lookup("embed").unwrap(),
302            },
303        );
304        let _b = r.register("ffn", shape(), bytes(200), WeightKind::Base);
305        assert_eq!(r.total_bytes(), 300, "alias must not double-count");
306    }
307
308    #[test]
309    fn lora_grouping() {
310        let mut r = WeightRegistry::new();
311        let _b = r.register("ffn", shape(), bytes(100), WeightKind::Base);
312        r.register(
313            "ffn.lora.a",
314            shape(),
315            bytes(8),
316            WeightKind::LoraAdapter {
317                adapter: "code".into(),
318            },
319        );
320        r.register(
321            "ffn.lora.b",
322            shape(),
323            bytes(8),
324            WeightKind::LoraAdapter {
325                adapter: "code".into(),
326            },
327        );
328        r.register(
329            "attn.lora.a",
330            shape(),
331            bytes(8),
332            WeightKind::LoraAdapter {
333                adapter: "math".into(),
334            },
335        );
336
337        let mut adapters = r.lora_adapter_names();
338        adapters.sort();
339        assert_eq!(adapters, vec!["code".to_string(), "math".to_string()]);
340
341        let code_handles = r.lora_adapter_handles("code");
342        assert_eq!(code_handles.len(), 2);
343    }
344}