1use anyhow::{Context, Result, bail, ensure};
19use rlx_core::weight_map::WeightMap;
20use std::collections::HashMap;
21use std::path::Path;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25enum LoraSide {
26 A,
27 B,
28}
29
30fn normalize_lora_key(key: &str) -> String {
32 crate::adapt::normalize_flux2_key(key)
33}
34
35fn parse_lora_side(key: &str) -> Option<(&str, LoraSide)> {
36 for (suffix, side) in [
37 (".lora_A.weight", LoraSide::A),
38 (".lora_B.weight", LoraSide::B),
39 (".lora_down.weight", LoraSide::A),
40 (".lora_up.weight", LoraSide::B),
41 (".lora_a.weight", LoraSide::A),
42 (".lora_b.weight", LoraSide::B),
43 ] {
44 if let Some(base) = key.strip_suffix(suffix) {
45 return Some((base, side));
46 }
47 }
48 None
49}
50
51fn lora_delta(
52 a: &[f32],
53 a_shape: &[usize],
54 b: &[f32],
55 b_shape: &[usize],
56 scale: f32,
57) -> Result<Vec<f32>> {
58 ensure!(
59 a_shape.len() == 2 && b_shape.len() == 2,
60 "LoRA A/B must be rank-2"
61 );
62 let (rank_a, in_dim) = (a_shape[0], a_shape[1]);
63 let (out_dim, rank_b) = (b_shape[0], b_shape[1]);
64 ensure!(
65 rank_a == rank_b,
66 "LoRA rank mismatch: A {a_shape:?} vs B {b_shape:?}"
67 );
68 ensure!(
69 a.len() == rank_a * in_dim && b.len() == out_dim * rank_b,
70 "LoRA tensor size mismatch"
71 );
72 let mut delta = vec![0.0f32; out_dim * in_dim];
73 for o in 0..out_dim {
74 for i in 0..in_dim {
75 let mut acc = 0.0f32;
76 for r in 0..rank_a {
77 acc += b[o * rank_b + r] * a[r * in_dim + i];
78 }
79 delta[o * in_dim + i] = scale * acc;
80 }
81 }
82 Ok(delta)
83}
84
85pub fn apply_flux2_lora(base: &mut WeightMap, lora: &WeightMap, scale: f32) -> Result<usize> {
87 if scale == 0.0 {
88 return Ok(0);
89 }
90
91 #[allow(clippy::type_complexity)]
92 let mut pairs: HashMap<
93 String,
94 (
95 Option<(Vec<f32>, Vec<usize>)>,
96 Option<(Vec<f32>, Vec<usize>)>,
97 ),
98 > = HashMap::new();
99
100 for key in lora.keys() {
101 let norm = normalize_lora_key(key);
102 let Some((base_prefix, side)) = parse_lora_side(&norm) else {
103 continue;
104 };
105 let Some((data, shape)) = lora.get(key) else {
106 continue;
107 };
108 let entry = pairs.entry(base_prefix.to_string()).or_default();
109 match side {
110 LoraSide::A => entry.0 = Some((data.to_vec(), shape.to_vec())),
111 LoraSide::B => entry.1 = Some((data.to_vec(), shape.to_vec())),
112 }
113 }
114
115 let mut merged = 0usize;
116 for (prefix, (a, b)) in pairs {
117 let (a, b) = match (a, b) {
118 (Some(a), Some(b)) => (a, b),
119 _ => continue,
120 };
121 let weight_key = format!("{prefix}.weight");
122 if !base.has(&weight_key) {
123 continue;
124 }
125 let delta = lora_delta(&a.0, &a.1, &b.0, &b.1, scale)?;
126 base.merge_add_weight(&weight_key, &delta)?;
127 merged += 1;
128 }
129 Ok(merged)
130}
131
132pub fn load_and_apply_flux2_lora(
134 base: &mut WeightMap,
135 lora_path: &Path,
136 scale: f32,
137) -> Result<usize> {
138 let path = lora_path
139 .to_str()
140 .with_context(|| format!("non-utf8 LoRA path {lora_path:?}"))?;
141 let lora = WeightMap::from_file(path)?;
142 apply_flux2_lora(base, &lora, scale)
143}
144
145pub fn load_and_apply_flux2_lora_dir(
147 base: &mut WeightMap,
148 lora_dir: &Path,
149 scale: f32,
150) -> Result<usize> {
151 let lora = WeightMap::from_safetensors_dir(lora_dir)?;
152 apply_flux2_lora(base, &lora, scale)
153}
154
155pub fn parse_lora_scale(s: &str) -> Result<f32> {
157 let v: f32 = s.parse().context("lora scale: f32")?;
158 if !v.is_finite() {
159 bail!("lora scale must be finite");
160 }
161 Ok(v)
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use std::collections::HashMap;
168
169 #[test]
170 fn lora_delta_rank1_matches_manual() {
171 let a = vec![1.0f32, 2.0]; let b = vec![3.0f32, 4.0]; let delta = lora_delta(&a, &[1, 2], &b, &[2, 1], 1.0).unwrap();
175 assert_eq!(delta, vec![3.0, 6.0, 4.0, 8.0]);
177 }
178
179 #[test]
180 fn apply_lora_merges_into_base_weight() {
181 let mut base = WeightMap::from_tensors(HashMap::from([(
182 "proj.weight".to_string(),
183 (vec![10.0, 20.0], vec![2, 1]),
184 )]));
185 let lora = WeightMap::from_tensors(HashMap::from([
186 ("proj.lora_A.weight".to_string(), (vec![2.0], vec![1, 1])),
187 (
188 "proj.lora_B.weight".to_string(),
189 (vec![3.0, 4.0], vec![2, 1]),
190 ),
191 ]));
192 apply_flux2_lora(&mut base, &lora, 1.0).unwrap();
193 let (w, _) = base.get("proj.weight").unwrap();
194 assert!((w[0] - 16.0).abs() < 1e-5);
195 assert!((w[1] - 28.0).abs() < 1e-5);
196 }
197}