1use rlx_ir::Shape;
38use std::collections::HashMap;
39use std::sync::Arc;
40use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub struct WeightHandle(u64);
47
48impl WeightHandle {
49 pub fn id(self) -> u64 {
50 self.0
51 }
52}
53
54#[derive(Debug, Clone)]
57pub enum WeightKind {
58 Base,
60 LoraAdapter { adapter: String },
64 TiedAlias { target: WeightHandle },
68}
69
70#[derive(Debug)]
72pub struct WeightEntry {
73 pub name: String,
74 pub shape: Shape,
75 pub kind: WeightKind,
76 pub bytes: Arc<[u8]>,
80 pub refs: AtomicUsize,
85}
86
87pub struct WeightRegistry {
91 by_name: HashMap<String, WeightHandle>,
92 by_handle: HashMap<u64, Arc<WeightEntry>>,
93 next_id: AtomicU64,
94}
95
96impl WeightRegistry {
97 pub fn new() -> Self {
98 Self {
99 by_name: HashMap::new(),
100 by_handle: HashMap::new(),
101 next_id: AtomicU64::new(0),
102 }
103 }
104
105 fn alloc_id(&self) -> u64 {
106 self.next_id.fetch_add(1, Ordering::Relaxed)
107 }
108
109 pub fn register(
113 &mut self,
114 name: impl Into<String>,
115 shape: Shape,
116 bytes: Arc<[u8]>,
117 kind: WeightKind,
118 ) -> WeightHandle {
119 let name = name.into();
120 if let Some(&h) = self.by_name.get(&name) {
121 return h;
122 }
123 let id = self.alloc_id();
124 let h = WeightHandle(id);
125 let entry = Arc::new(WeightEntry {
126 name: name.clone(),
127 shape,
128 kind,
129 bytes,
130 refs: AtomicUsize::new(0),
131 });
132 self.by_name.insert(name, h);
133 self.by_handle.insert(id, entry);
134 h
135 }
136
137 pub fn lookup(&self, name: &str) -> Option<WeightHandle> {
139 self.by_name.get(name).copied()
140 }
141
142 pub fn get(&self, handle: WeightHandle) -> Option<&Arc<WeightEntry>> {
144 let entry = self.by_handle.get(&handle.0)?;
145 if let WeightKind::TiedAlias { target } = entry.kind {
146 return self.by_handle.get(&target.0);
147 }
148 Some(entry)
149 }
150
151 pub fn pin(&self, handle: WeightHandle) -> Option<usize> {
153 let entry = self.by_handle.get(&handle.0)?;
154 Some(entry.refs.fetch_add(1, Ordering::Relaxed) + 1)
155 }
156
157 pub fn release(&self, handle: WeightHandle) -> Option<usize> {
161 let entry = self.by_handle.get(&handle.0)?;
162 let prev = entry.refs.fetch_sub(1, Ordering::Relaxed);
163 debug_assert!(prev >= 1, "release on a zero-refcount entry");
164 Some(prev - 1)
165 }
166
167 pub fn unregister(&mut self, handle: WeightHandle) -> Option<String> {
170 let entry = self.by_handle.remove(&handle.0)?;
171 debug_assert_eq!(
172 entry.refs.load(Ordering::Relaxed),
173 0,
174 "unregister on a still-referenced entry: refs={}",
175 entry.refs.load(Ordering::Relaxed)
176 );
177 self.by_name.remove(&entry.name);
178 Some(entry.name.clone())
179 }
180
181 pub fn total_bytes(&self) -> usize {
185 self.by_handle
186 .values()
187 .filter(|e| !matches!(e.kind, WeightKind::TiedAlias { .. }))
188 .map(|e| e.bytes.len())
189 .sum()
190 }
191
192 pub fn lora_adapter_handles(&self, adapter: &str) -> Vec<WeightHandle> {
195 let mut v: Vec<WeightHandle> = self
196 .by_handle
197 .iter()
198 .filter_map(|(&id, e)| match &e.kind {
199 WeightKind::LoraAdapter { adapter: a } if a == adapter => Some(WeightHandle(id)),
200 _ => None,
201 })
202 .collect();
203 v.sort_by_key(|h| h.0);
204 v
205 }
206
207 pub fn lora_adapter_names(&self) -> Vec<String> {
209 let mut s: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
210 for e in self.by_handle.values() {
211 if let WeightKind::LoraAdapter { adapter } = &e.kind {
212 s.insert(adapter.clone());
213 }
214 }
215 s.into_iter().collect()
216 }
217
218 pub fn len(&self) -> usize {
219 self.by_handle.len()
220 }
221 pub fn is_empty(&self) -> bool {
222 self.by_handle.is_empty()
223 }
224}
225
226impl Default for WeightRegistry {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use rlx_ir::DType;
236
237 fn shape() -> Shape {
238 Shape::new(&[8, 8], DType::F32)
239 }
240 fn bytes(n: usize) -> Arc<[u8]> {
241 vec![0u8; n].into()
242 }
243
244 #[test]
245 fn register_and_lookup() {
246 let mut r = WeightRegistry::new();
247 let h = r.register("w", shape(), bytes(256), WeightKind::Base);
248 assert_eq!(r.lookup("w"), Some(h));
249 let entry = r.get(h).unwrap();
250 assert_eq!(entry.name, "w");
251 assert_eq!(entry.bytes.len(), 256);
252 }
253
254 #[test]
255 fn register_is_idempotent() {
256 let mut r = WeightRegistry::new();
257 let h1 = r.register("w", shape(), bytes(128), WeightKind::Base);
258 let h2 = r.register("w", shape(), bytes(999), WeightKind::Base);
259 assert_eq!(h1, h2);
261 assert_eq!(r.get(h1).unwrap().bytes.len(), 128);
262 }
263
264 #[test]
265 fn pin_release_balance() {
266 let mut r = WeightRegistry::new();
267 let h = r.register("w", shape(), bytes(64), WeightKind::Base);
268 assert_eq!(r.pin(h), Some(1));
269 assert_eq!(r.pin(h), Some(2));
270 assert_eq!(r.release(h), Some(1));
271 assert_eq!(r.release(h), Some(0));
272 assert_eq!(r.unregister(h), Some("w".to_string()));
274 assert!(r.lookup("w").is_none());
275 }
276
277 #[test]
278 fn tied_alias_resolves_to_target() {
279 let mut r = WeightRegistry::new();
280 let target = r.register("embed", shape(), bytes(128), WeightKind::Base);
281 let alias = r.register(
282 "lm_head",
283 shape(),
284 bytes(0), WeightKind::TiedAlias { target },
286 );
287 let resolved = r.get(alias).unwrap();
288 assert_eq!(resolved.name, "embed");
289 assert_eq!(resolved.bytes.len(), 128);
290 }
291
292 #[test]
293 fn total_bytes_skips_aliases() {
294 let mut r = WeightRegistry::new();
295 let _t = r.register("embed", shape(), bytes(100), WeightKind::Base);
296 let _a = r.register(
297 "lm_head",
298 shape(),
299 bytes(0),
300 WeightKind::TiedAlias {
301 target: r.lookup("embed").unwrap(),
302 },
303 );
304 let _b = r.register("ffn", shape(), bytes(200), WeightKind::Base);
305 assert_eq!(r.total_bytes(), 300, "alias must not double-count");
306 }
307
308 #[test]
309 fn lora_grouping() {
310 let mut r = WeightRegistry::new();
311 let _b = r.register("ffn", shape(), bytes(100), WeightKind::Base);
312 r.register(
313 "ffn.lora.a",
314 shape(),
315 bytes(8),
316 WeightKind::LoraAdapter {
317 adapter: "code".into(),
318 },
319 );
320 r.register(
321 "ffn.lora.b",
322 shape(),
323 bytes(8),
324 WeightKind::LoraAdapter {
325 adapter: "code".into(),
326 },
327 );
328 r.register(
329 "attn.lora.a",
330 shape(),
331 bytes(8),
332 WeightKind::LoraAdapter {
333 adapter: "math".into(),
334 },
335 );
336
337 let mut adapters = r.lora_adapter_names();
338 adapters.sort();
339 assert_eq!(adapters, vec!["code".to_string(), "math".to_string()]);
340
341 let code_handles = r.lora_adapter_handles("code");
342 assert_eq!(code_handles.len(), 2);
343 }
344}