Skip to main content

rlx_flow/
weight.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Weight loading trait — implemented by model-builder `WeightLoader` adapters.
5
6use anyhow::Result;
7
8/// Abstract weight source for block emission. Keeps `rlx-flow` independent of
9/// safetensors / GGUF file formats.
10pub trait WeightSource {
11    fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)>;
12
13    /// Optional probe for arch-specific key layout detection.
14    fn has(&self, key: &str) -> bool {
15        let _ = key;
16        false
17    }
18}
19
20impl<T: WeightSource + ?Sized> WeightSource for &mut T {
21    fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
22        (*self).take(key, transpose)
23    }
24}
25
26/// In-memory weight map for tests and tooling.
27#[derive(Debug, Default, Clone)]
28pub struct MapWeights {
29    pub tensors: std::collections::HashMap<String, (Vec<f32>, Vec<usize>)>,
30}
31
32impl MapWeights {
33    pub fn insert(&mut self, key: impl Into<String>, data: Vec<f32>, shape: Vec<usize>) {
34        self.tensors.insert(key.into(), (data, shape));
35    }
36}
37
38impl WeightSource for MapWeights {
39    fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
40        let (data, shape) = self
41            .tensors
42            .remove(key)
43            .ok_or_else(|| anyhow::anyhow!("missing weight: {key}"))?;
44        if !transpose {
45            return Ok((data, shape));
46        }
47        if shape.len() != 2 {
48            return Err(anyhow::anyhow!("transpose requires rank-2 weight: {key}"));
49        }
50        let rows = shape[0];
51        let cols = shape[1];
52        let mut out = vec![0f32; rows * cols];
53        for r in 0..rows {
54            for c in 0..cols {
55                out[c * rows + r] = data[r * cols + c];
56            }
57        }
58        Ok((out, vec![cols, rows]))
59    }
60}