use rlx_ir::Shape;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct WeightHandle(u64);
impl WeightHandle {
pub fn id(self) -> u64 {
self.0
}
}
#[derive(Debug, Clone)]
pub enum WeightKind {
Base,
LoraAdapter { adapter: String },
TiedAlias { target: WeightHandle },
}
#[derive(Debug)]
pub struct WeightEntry {
pub name: String,
pub shape: Shape,
pub kind: WeightKind,
pub bytes: Arc<[u8]>,
pub refs: AtomicUsize,
}
pub struct WeightRegistry {
by_name: HashMap<String, WeightHandle>,
by_handle: HashMap<u64, Arc<WeightEntry>>,
next_id: AtomicU64,
}
impl WeightRegistry {
pub fn new() -> Self {
Self {
by_name: HashMap::new(),
by_handle: HashMap::new(),
next_id: AtomicU64::new(0),
}
}
fn alloc_id(&self) -> u64 {
self.next_id.fetch_add(1, Ordering::Relaxed)
}
pub fn register(
&mut self,
name: impl Into<String>,
shape: Shape,
bytes: Arc<[u8]>,
kind: WeightKind,
) -> WeightHandle {
let name = name.into();
if let Some(&h) = self.by_name.get(&name) {
return h;
}
let id = self.alloc_id();
let h = WeightHandle(id);
let entry = Arc::new(WeightEntry {
name: name.clone(),
shape,
kind,
bytes,
refs: AtomicUsize::new(0),
});
self.by_name.insert(name, h);
self.by_handle.insert(id, entry);
h
}
pub fn lookup(&self, name: &str) -> Option<WeightHandle> {
self.by_name.get(name).copied()
}
pub fn get(&self, handle: WeightHandle) -> Option<&Arc<WeightEntry>> {
let entry = self.by_handle.get(&handle.0)?;
if let WeightKind::TiedAlias { target } = entry.kind {
return self.by_handle.get(&target.0);
}
Some(entry)
}
pub fn pin(&self, handle: WeightHandle) -> Option<usize> {
let entry = self.by_handle.get(&handle.0)?;
Some(entry.refs.fetch_add(1, Ordering::Relaxed) + 1)
}
pub fn release(&self, handle: WeightHandle) -> Option<usize> {
let entry = self.by_handle.get(&handle.0)?;
let prev = entry.refs.fetch_sub(1, Ordering::Relaxed);
debug_assert!(prev >= 1, "release on a zero-refcount entry");
Some(prev - 1)
}
pub fn unregister(&mut self, handle: WeightHandle) -> Option<String> {
let entry = self.by_handle.remove(&handle.0)?;
debug_assert_eq!(
entry.refs.load(Ordering::Relaxed),
0,
"unregister on a still-referenced entry: refs={}",
entry.refs.load(Ordering::Relaxed)
);
self.by_name.remove(&entry.name);
Some(entry.name.clone())
}
pub fn total_bytes(&self) -> usize {
self.by_handle
.values()
.filter(|e| !matches!(e.kind, WeightKind::TiedAlias { .. }))
.map(|e| e.bytes.len())
.sum()
}
pub fn lora_adapter_handles(&self, adapter: &str) -> Vec<WeightHandle> {
let mut v: Vec<WeightHandle> = self
.by_handle
.iter()
.filter_map(|(&id, e)| match &e.kind {
WeightKind::LoraAdapter { adapter: a } if a == adapter => Some(WeightHandle(id)),
_ => None,
})
.collect();
v.sort_by_key(|h| h.0);
v
}
pub fn lora_adapter_names(&self) -> Vec<String> {
let mut s: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
for e in self.by_handle.values() {
if let WeightKind::LoraAdapter { adapter } = &e.kind {
s.insert(adapter.clone());
}
}
s.into_iter().collect()
}
pub fn len(&self) -> usize {
self.by_handle.len()
}
pub fn is_empty(&self) -> bool {
self.by_handle.is_empty()
}
}
impl Default for WeightRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::DType;
fn shape() -> Shape {
Shape::new(&[8, 8], DType::F32)
}
fn bytes(n: usize) -> Arc<[u8]> {
vec![0u8; n].into()
}
#[test]
fn register_and_lookup() {
let mut r = WeightRegistry::new();
let h = r.register("w", shape(), bytes(256), WeightKind::Base);
assert_eq!(r.lookup("w"), Some(h));
let entry = r.get(h).unwrap();
assert_eq!(entry.name, "w");
assert_eq!(entry.bytes.len(), 256);
}
#[test]
fn register_is_idempotent() {
let mut r = WeightRegistry::new();
let h1 = r.register("w", shape(), bytes(128), WeightKind::Base);
let h2 = r.register("w", shape(), bytes(999), WeightKind::Base);
assert_eq!(h1, h2);
assert_eq!(r.get(h1).unwrap().bytes.len(), 128);
}
#[test]
fn pin_release_balance() {
let mut r = WeightRegistry::new();
let h = r.register("w", shape(), bytes(64), WeightKind::Base);
assert_eq!(r.pin(h), Some(1));
assert_eq!(r.pin(h), Some(2));
assert_eq!(r.release(h), Some(1));
assert_eq!(r.release(h), Some(0));
assert_eq!(r.unregister(h), Some("w".to_string()));
assert!(r.lookup("w").is_none());
}
#[test]
fn tied_alias_resolves_to_target() {
let mut r = WeightRegistry::new();
let target = r.register("embed", shape(), bytes(128), WeightKind::Base);
let alias = r.register(
"lm_head",
shape(),
bytes(0), WeightKind::TiedAlias { target },
);
let resolved = r.get(alias).unwrap();
assert_eq!(resolved.name, "embed");
assert_eq!(resolved.bytes.len(), 128);
}
#[test]
fn total_bytes_skips_aliases() {
let mut r = WeightRegistry::new();
let _t = r.register("embed", shape(), bytes(100), WeightKind::Base);
let _a = r.register(
"lm_head",
shape(),
bytes(0),
WeightKind::TiedAlias {
target: r.lookup("embed").unwrap(),
},
);
let _b = r.register("ffn", shape(), bytes(200), WeightKind::Base);
assert_eq!(r.total_bytes(), 300, "alias must not double-count");
}
#[test]
fn lora_grouping() {
let mut r = WeightRegistry::new();
let _b = r.register("ffn", shape(), bytes(100), WeightKind::Base);
r.register(
"ffn.lora.a",
shape(),
bytes(8),
WeightKind::LoraAdapter {
adapter: "code".into(),
},
);
r.register(
"ffn.lora.b",
shape(),
bytes(8),
WeightKind::LoraAdapter {
adapter: "code".into(),
},
);
r.register(
"attn.lora.a",
shape(),
bytes(8),
WeightKind::LoraAdapter {
adapter: "math".into(),
},
);
let mut adapters = r.lora_adapter_names();
adapters.sort();
assert_eq!(adapters, vec!["code".to_string(), "math".to_string()]);
let code_handles = r.lora_adapter_handles("code");
assert_eq!(code_handles.len(), 2);
}
}