rlx_flow/weight.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Weight loading trait — implemented by model-builder `WeightLoader` adapters.
17
18use anyhow::Result;
19
20/// Abstract weight source for block emission. Keeps `rlx-flow` independent of
21/// safetensors / GGUF file formats.
22pub trait WeightSource {
23 fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)>;
24
25 /// Optional probe for arch-specific key layout detection.
26 fn has(&self, key: &str) -> bool {
27 let _ = key;
28 false
29 }
30}
31
32impl<T: WeightSource + ?Sized> WeightSource for &mut T {
33 fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
34 (*self).take(key, transpose)
35 }
36 // Forward `has` too; otherwise a `&mut dyn WeightSource` silently falls
37 // back to the trait default (`false`) and key-layout probing breaks.
38 fn has(&self, key: &str) -> bool {
39 (**self).has(key)
40 }
41}
42
43/// In-memory weight map for tests and tooling.
44#[derive(Debug, Default, Clone)]
45pub struct MapWeights {
46 pub tensors: std::collections::HashMap<String, (Vec<f32>, Vec<usize>)>,
47}
48
49impl MapWeights {
50 pub fn insert(&mut self, key: impl Into<String>, data: Vec<f32>, shape: Vec<usize>) {
51 self.tensors.insert(key.into(), (data, shape));
52 }
53}
54
55impl WeightSource for MapWeights {
56 fn take(&mut self, key: &str, transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
57 let (data, shape) = self
58 .tensors
59 .remove(key)
60 .ok_or_else(|| anyhow::anyhow!("missing weight: {key}"))?;
61 if !transpose {
62 return Ok((data, shape));
63 }
64 if shape.len() != 2 {
65 return Err(anyhow::anyhow!("transpose requires rank-2 weight: {key}"));
66 }
67 let rows = shape[0];
68 let cols = shape[1];
69 let mut out = vec![0f32; rows * cols];
70 for r in 0..rows {
71 for c in 0..cols {
72 out[c * rows + r] = data[r * cols + c];
73 }
74 }
75 Ok((out, vec![cols, rows]))
76 }
77}