Skip to main content

rlx_flow/
weight.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Weight loading trait — implemented by model-builder `WeightLoader` adapters.
17
18use anyhow::Result;
19
20/// Abstract weight source for block emission. Keeps `rlx-flow` independent of
21/// safetensors / GGUF file formats.
22pub trait WeightSource {
23    fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)>;
24
25    /// Optional probe for arch-specific key layout detection.
26    fn has(&self, key: &str) -> bool {
27        let _ = key;
28        false
29    }
30}
31
32impl<T: WeightSource + ?Sized> WeightSource for &mut T {
33    fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
34        (*self).take(key, transpose)
35    }
36    // Forward `has` too; otherwise a `&mut dyn WeightSource` silently falls
37    // back to the trait default (`false`) and key-layout probing breaks.
38    fn has(&self, key: &str) -> bool {
39        (**self).has(key)
40    }
41}
42
43/// In-memory weight map for tests and tooling.
44#[derive(Debug, Default, Clone)]
45pub struct MapWeights {
46    pub tensors: std::collections::HashMap<String, (Vec<f32>, Vec<usize>)>,
47}
48
49impl MapWeights {
50    pub fn insert(&mut self, key: impl Into<String>, data: Vec<f32>, shape: Vec<usize>) {
51        self.tensors.insert(key.into(), (data, shape));
52    }
53}
54
55impl WeightSource for MapWeights {
56    fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
57        let (data, shape) = self
58            .tensors
59            .remove(key)
60            .ok_or_else(|| anyhow::anyhow!("missing weight: {key}"))?;
61        if !transpose {
62            return Ok((data, shape));
63        }
64        if shape.len() != 2 {
65            return Err(anyhow::anyhow!("transpose requires rank-2 weight: {key}"));
66        }
67        let rows = shape[0];
68        let cols = shape[1];
69        let mut out = vec![0f32; rows * cols];
70        for r in 0..rows {
71            for c in 0..cols {
72                out[c * rows + r] = data[r * cols + c];
73            }
74        }
75        Ok((out, vec![cols, rows]))
76    }
77}