Skip to main content

rlx_flux2/
lora.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! FLUX.2 LoRA adapter merge (PEFT / diffusers safetensors → base weight map).
17
18use anyhow::{Context, Result, bail, ensure};
19use rlx_core::weight_map::WeightMap;
20use std::collections::HashMap;
21use std::path::Path;
22
23/// LoRA side tensor in a matched pair.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25enum LoraSide {
26    A,
27    B,
28}
29
30/// Normalize HF / diffusers key prefixes so LoRA keys align with denoiser weights.
31fn normalize_lora_key(key: &str) -> String {
32    crate::adapt::normalize_flux2_key(key)
33}
34
35fn parse_lora_side(key: &str) -> Option<(&str, LoraSide)> {
36    for (suffix, side) in [
37        (".lora_A.weight", LoraSide::A),
38        (".lora_B.weight", LoraSide::B),
39        (".lora_down.weight", LoraSide::A),
40        (".lora_up.weight", LoraSide::B),
41        (".lora_a.weight", LoraSide::A),
42        (".lora_b.weight", LoraSide::B),
43    ] {
44        if let Some(base) = key.strip_suffix(suffix) {
45            return Some((base, side));
46        }
47    }
48    None
49}
50
51fn lora_delta(
52    a: &[f32],
53    a_shape: &[usize],
54    b: &[f32],
55    b_shape: &[usize],
56    scale: f32,
57) -> Result<Vec<f32>> {
58    ensure!(
59        a_shape.len() == 2 && b_shape.len() == 2,
60        "LoRA A/B must be rank-2"
61    );
62    let (rank_a, in_dim) = (a_shape[0], a_shape[1]);
63    let (out_dim, rank_b) = (b_shape[0], b_shape[1]);
64    ensure!(
65        rank_a == rank_b,
66        "LoRA rank mismatch: A {a_shape:?} vs B {b_shape:?}"
67    );
68    ensure!(
69        a.len() == rank_a * in_dim && b.len() == out_dim * rank_b,
70        "LoRA tensor size mismatch"
71    );
72    let mut delta = vec![0.0f32; out_dim * in_dim];
73    for o in 0..out_dim {
74        for i in 0..in_dim {
75            let mut acc = 0.0f32;
76            for r in 0..rank_a {
77                acc += b[o * rank_b + r] * a[r * in_dim + i];
78            }
79            delta[o * in_dim + i] = scale * acc;
80        }
81    }
82    Ok(delta)
83}
84
85/// Merge LoRA safetensors into `base` in-place. Returns the number of merged layers.
86pub fn apply_flux2_lora(base: &mut WeightMap, lora: &WeightMap, scale: f32) -> Result<usize> {
87    if scale == 0.0 {
88        return Ok(0);
89    }
90
91    #[allow(clippy::type_complexity)]
92    let mut pairs: HashMap<
93        String,
94        (
95            Option<(Vec<f32>, Vec<usize>)>,
96            Option<(Vec<f32>, Vec<usize>)>,
97        ),
98    > = HashMap::new();
99
100    for key in lora.keys() {
101        let norm = normalize_lora_key(key);
102        let Some((base_prefix, side)) = parse_lora_side(&norm) else {
103            continue;
104        };
105        let Some((data, shape)) = lora.get(key) else {
106            continue;
107        };
108        let entry = pairs.entry(base_prefix.to_string()).or_default();
109        match side {
110            LoraSide::A => entry.0 = Some((data.to_vec(), shape.to_vec())),
111            LoraSide::B => entry.1 = Some((data.to_vec(), shape.to_vec())),
112        }
113    }
114
115    let mut merged = 0usize;
116    for (prefix, (a, b)) in pairs {
117        let (a, b) = match (a, b) {
118            (Some(a), Some(b)) => (a, b),
119            _ => continue,
120        };
121        let weight_key = format!("{prefix}.weight");
122        if !base.has(&weight_key) {
123            continue;
124        }
125        let delta = lora_delta(&a.0, &a.1, &b.0, &b.1, scale)?;
126        base.merge_add_weight(&weight_key, &delta)?;
127        merged += 1;
128    }
129    Ok(merged)
130}
131
132/// Load LoRA from safetensors and merge into `base`.
133pub fn load_and_apply_flux2_lora(
134    base: &mut WeightMap,
135    lora_path: &Path,
136    scale: f32,
137) -> Result<usize> {
138    let path = lora_path
139        .to_str()
140        .with_context(|| format!("non-utf8 LoRA path {lora_path:?}"))?;
141    let lora = WeightMap::from_file(path)?;
142    apply_flux2_lora(base, &lora, scale)
143}
144
145/// Load LoRA from a directory of safetensors shards.
146pub fn load_and_apply_flux2_lora_dir(
147    base: &mut WeightMap,
148    lora_dir: &Path,
149    scale: f32,
150) -> Result<usize> {
151    let lora = WeightMap::from_safetensors_dir(lora_dir)?;
152    apply_flux2_lora(base, &lora, scale)
153}
154
155/// Parse `--lora-scale` style input; rejects NaN/inf.
156pub fn parse_lora_scale(s: &str) -> Result<f32> {
157    let v: f32 = s.parse().context("lora scale: f32")?;
158    if !v.is_finite() {
159        bail!("lora scale must be finite");
160    }
161    Ok(v)
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use std::collections::HashMap;
168
169    #[test]
170    fn lora_delta_rank1_matches_manual() {
171        // W is 2x2 zeros; delta = scale * B @ A with rank 1
172        let a = vec![1.0f32, 2.0]; // [1, 2]
173        let b = vec![3.0f32, 4.0]; // [2, 1]
174        let delta = lora_delta(&a, &[1, 2], &b, &[2, 1], 1.0).unwrap();
175        // B @ A = [[3],[4]] @ [[1,2]] = [[3,6],[4,8]]
176        assert_eq!(delta, vec![3.0, 6.0, 4.0, 8.0]);
177    }
178
179    #[test]
180    fn apply_lora_merges_into_base_weight() {
181        let mut base = WeightMap::from_tensors(HashMap::from([(
182            "proj.weight".to_string(),
183            (vec![10.0, 20.0], vec![2, 1]),
184        )]));
185        let lora = WeightMap::from_tensors(HashMap::from([
186            ("proj.lora_A.weight".to_string(), (vec![2.0], vec![1, 1])),
187            (
188                "proj.lora_B.weight".to_string(),
189                (vec![3.0, 4.0], vec![2, 1]),
190            ),
191        ]));
192        apply_flux2_lora(&mut base, &lora, 1.0).unwrap();
193        let (w, _) = base.get("proj.weight").unwrap();
194        assert!((w[0] - 16.0).abs() < 1e-5);
195        assert!((w[1] - 28.0).abs() < 1e-5);
196    }
197}