Skip to main content

rlx_flow/
weight.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Weight loading trait — implemented by model-builder `WeightLoader` adapters.
17
18use anyhow::Result;
19
20/// Abstract weight source for block emission. Keeps `rlx-flow` independent of
21/// safetensors / GGUF file formats.
22pub trait WeightSource {
23    fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)>;
24
25    /// Optional probe for arch-specific key layout detection.
26    fn has(&self, key: &str) -> bool {
27        let _ = key;
28        false
29    }
30}
31
32impl<T: WeightSource + ?Sized> WeightSource for &mut T {
33    fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
34        (*self).take(key, transpose)
35    }
36}
37
38/// In-memory weight map for tests and tooling.
39#[derive(Debug, Default, Clone)]
40pub struct MapWeights {
41    pub tensors: std::collections::HashMap<String, (Vec<f32>, Vec<usize>)>,
42}
43
44impl MapWeights {
45    pub fn insert(&mut self, key: impl Into<String>, data: Vec<f32>, shape: Vec<usize>) {
46        self.tensors.insert(key.into(), (data, shape));
47    }
48}
49
50impl WeightSource for MapWeights {
51    fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
52        let (data, shape) = self
53            .tensors
54            .remove(key)
55            .ok_or_else(|| anyhow::anyhow!("missing weight: {key}"))?;
56        if !transpose {
57            return Ok((data, shape));
58        }
59        if shape.len() != 2 {
60            return Err(anyhow::anyhow!("transpose requires rank-2 weight: {key}"));
61        }
62        let rows = shape[0];
63        let cols = shape[1];
64        let mut out = vec![0f32; rows * cols];
65        for r in 0..rows {
66            for c in 0..cols {
67                out[c * rows + r] = data[r * cols + c];
68            }
69        }
70        Ok((out, vec![cols, rows]))
71    }
72}