1use ferrum_kernels::backend::Backend;
12use ferrum_quantization::{DenseLinear, Linear};
13use ferrum_types::{FerrumError, Result};
14use half::{bf16, f16};
15use safetensors::{Dtype, SafeTensors};
16use serde::{Deserialize, Serialize};
17use std::collections::{HashMap, HashSet};
18use std::path::{Path, PathBuf};
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
21pub struct LoraAdapterConfig {
22 pub r: usize,
23 pub lora_alpha: usize,
24 pub target_modules: Vec<String>,
25 pub base_model_name_or_path: String,
26}
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct LoraTensorPair {
30 pub target_module: String,
31 pub a_tensor: String,
32 pub b_tensor: String,
33 pub in_features: usize,
34 pub out_features: usize,
35 pub rank: usize,
36}
37
38#[derive(Debug, Clone)]
39pub struct StartupLoraAdapter {
40 pub name: String,
41 pub public_model_id: String,
42 pub path: PathBuf,
43 pub config: LoraAdapterConfig,
44 pub tensors: Vec<LoraTensorPair>,
45}
46
47#[derive(Debug, Clone)]
48pub struct StartupLoraSpec {
49 pub name: String,
50 pub path: PathBuf,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct ActiveLoraAdapter {
55 pub name: String,
56 pub path: PathBuf,
57}
58
59pub struct RuntimeLoraLinear<B: Backend> {
60 pub layer_index: Option<usize>,
61 pub target_module: String,
62 pub in_features: usize,
63 pub out_features: usize,
64 pub rank: usize,
65 pub scaling: f32,
66 a: DenseLinear<B>,
67 b: DenseLinear<B>,
68}
69
70pub struct RuntimeLoraAdapter<B: Backend> {
71 pub name: String,
72 pub path: PathBuf,
73 pub config: LoraAdapterConfig,
74 pub linears: Vec<RuntimeLoraLinear<B>>,
75}
76
77impl<B: Backend> RuntimeLoraAdapter<B> {
78 pub fn apply_projection(
79 &self,
80 ctx: &mut B::Context,
81 layer_index: usize,
82 target_module: &str,
83 input: &B::Buffer,
84 out: &mut B::Buffer,
85 m: usize,
86 ) -> Result<usize> {
87 let mut applied = 0usize;
88 for linear in self.linears.iter().filter(|linear| {
89 linear.target_module == target_module
90 && linear
91 .layer_index
92 .map(|idx| idx == layer_index)
93 .unwrap_or(true)
94 }) {
95 linear.apply(ctx, input, out, m)?;
96 applied += 1;
97 }
98 Ok(applied)
99 }
100
101 pub fn supports_projection(
102 &self,
103 layer_index: usize,
104 target_module: &str,
105 in_features: usize,
106 out_features: usize,
107 ) -> bool {
108 self.linears.iter().any(|linear| {
109 linear.target_module == target_module
110 && linear.in_features == in_features
111 && linear.out_features == out_features
112 && linear
113 .layer_index
114 .map(|idx| idx == layer_index)
115 .unwrap_or(true)
116 })
117 }
118}
119
120impl<B: Backend> RuntimeLoraLinear<B> {
121 fn apply(
122 &self,
123 ctx: &mut B::Context,
124 input: &B::Buffer,
125 out: &mut B::Buffer,
126 m: usize,
127 ) -> Result<()> {
128 let mut low_rank = B::alloc(m * self.rank);
129 let mut delta = B::alloc(m * self.out_features);
130 self.a.forward(ctx, input, &mut low_rank, m);
131 self.b.forward(ctx, &low_rank, &mut delta, m);
132 B::scaled_add_inplace(ctx, out, &delta, self.scaling, m * self.out_features);
133 Ok(())
134 }
135}
136
137const SUPPORTED_TARGET_MODULES: &[&str] = &[
138 "q_proj",
139 "k_proj",
140 "v_proj",
141 "o_proj",
142 "qkv_proj",
143 "gate_proj",
144 "up_proj",
145 "down_proj",
146 "gate_up_proj",
147 "linear",
148];
149
150pub fn default_lora_model_id(base_model_id: &str, name: &str) -> String {
151 format!("{base_model_id}:{name}")
152}
153
154pub fn render_lora_model_id(template: &str, base_model_id: &str, name: &str) -> String {
155 template
156 .replace("<base>", base_model_id)
157 .replace("<name>", name)
158}
159
160pub fn load_startup_lora_adapter(
161 name: impl Into<String>,
162 path: impl AsRef<Path>,
163 public_model_id: impl Into<String>,
164) -> Result<StartupLoraAdapter> {
165 let name = name.into();
166 validate_lora_name(&name)?;
167 let path = path.as_ref().to_path_buf();
168 if !path.is_dir() {
169 return Err(FerrumError::config(format!(
170 "LoRA adapter path does not exist or is not a directory: {}",
171 path.display()
172 )));
173 }
174
175 let config_path = path.join("adapter_config.json");
176 let weights_path = path.join("adapter_model.safetensors");
177 if !config_path.is_file() {
178 return Err(FerrumError::config(format!(
179 "LoRA adapter missing adapter_config.json: {}",
180 path.display()
181 )));
182 }
183 if !weights_path.is_file() {
184 return Err(FerrumError::config(format!(
185 "LoRA adapter missing adapter_model.safetensors: {}",
186 path.display()
187 )));
188 }
189
190 let config: LoraAdapterConfig = serde_json::from_slice(
191 &std::fs::read(&config_path)
192 .map_err(|e| FerrumError::config(format!("read adapter_config.json: {e}")))?,
193 )
194 .map_err(|e| FerrumError::config(format!("invalid adapter_config.json: {e}")))?;
195 validate_lora_config(&config)?;
196
197 let weights = std::fs::read(&weights_path)
198 .map_err(|e| FerrumError::config(format!("read adapter_model.safetensors: {e}")))?;
199 let tensors = SafeTensors::deserialize(&weights)
200 .map_err(|e| FerrumError::config(format!("invalid adapter_model.safetensors: {e}")))?;
201 let pairs = collect_lora_tensor_pairs(&config, &tensors)?;
202
203 Ok(StartupLoraAdapter {
204 name,
205 public_model_id: public_model_id.into(),
206 path,
207 config,
208 tensors: pairs,
209 })
210}
211
212pub fn load_startup_lora_adapters(
213 base_model_id: &str,
214 template: Option<&str>,
215 specs: &[StartupLoraSpec],
216) -> Result<Vec<StartupLoraAdapter>> {
217 let mut seen_names = HashSet::new();
218 let mut seen_ids = HashSet::new();
219 let template = template.unwrap_or("<base>:<name>");
220 let mut out = Vec::with_capacity(specs.len());
221 for spec in specs {
222 if !seen_names.insert(spec.name.clone()) {
223 return Err(FerrumError::config(format!(
224 "duplicate LoRA adapter name: {}",
225 spec.name
226 )));
227 }
228 let public_model_id = render_lora_model_id(template, base_model_id, &spec.name);
229 if !seen_ids.insert(public_model_id.clone()) {
230 return Err(FerrumError::config(format!(
231 "duplicate LoRA public model id: {public_model_id}"
232 )));
233 }
234 out.push(load_startup_lora_adapter(
235 spec.name.clone(),
236 &spec.path,
237 public_model_id,
238 )?);
239 }
240 Ok(out)
241}
242
243pub fn load_runtime_lora_adapter<B: Backend>(
244 adapter: &ActiveLoraAdapter,
245) -> Result<RuntimeLoraAdapter<B>> {
246 let startup = load_startup_lora_adapter(
247 adapter.name.clone(),
248 &adapter.path,
249 default_lora_model_id("base", &adapter.name),
250 )?;
251 let weights_path = adapter.path.join("adapter_model.safetensors");
252 let weights = std::fs::read(&weights_path)
253 .map_err(|e| FerrumError::config(format!("read adapter_model.safetensors: {e}")))?;
254 let tensors = SafeTensors::deserialize(&weights)
255 .map_err(|e| FerrumError::config(format!("invalid adapter_model.safetensors: {e}")))?;
256
257 let mut linears = Vec::with_capacity(startup.tensors.len());
258 for pair in &startup.tensors {
259 let a = tensors
260 .tensor(&pair.a_tensor)
261 .map_err(|e| FerrumError::config(format!("read LoRA tensor {}: {e}", pair.a_tensor)))?;
262 let b = tensors
263 .tensor(&pair.b_tensor)
264 .map_err(|e| FerrumError::config(format!("read LoRA tensor {}: {e}", pair.b_tensor)))?;
265 let a_data = tensor_data_to_f32(a.dtype(), a.data())?;
266 let b_data = tensor_data_to_f32(b.dtype(), b.data())?;
267 linears.push(RuntimeLoraLinear {
268 layer_index: parse_lora_layer_index(&pair.a_tensor),
269 target_module: pair.target_module.clone(),
270 in_features: pair.in_features,
271 out_features: pair.out_features,
272 rank: pair.rank,
273 scaling: startup.config.lora_alpha as f32 / startup.config.r as f32,
274 a: DenseLinear::<B>::from_rows(&a_data, pair.rank, pair.in_features),
275 b: DenseLinear::<B>::from_rows(&b_data, pair.out_features, pair.rank),
276 });
277 }
278
279 Ok(RuntimeLoraAdapter {
280 name: adapter.name.clone(),
281 path: adapter.path.clone(),
282 config: startup.config,
283 linears,
284 })
285}
286
287fn validate_lora_name(name: &str) -> Result<()> {
288 if name.is_empty()
289 || !name
290 .chars()
291 .all(|c| c.is_ascii_alphanumeric() || matches!(c, '_' | '-' | '.'))
292 {
293 return Err(FerrumError::config(format!(
294 "invalid LoRA adapter name {name:?}; use [A-Za-z0-9_.-]"
295 )));
296 }
297 Ok(())
298}
299
300fn validate_lora_config(config: &LoraAdapterConfig) -> Result<()> {
301 if config.r == 0 {
302 return Err(FerrumError::config("LoRA adapter rank r must be > 0"));
303 }
304 if config.target_modules.is_empty() {
305 return Err(FerrumError::config(
306 "LoRA adapter target_modules must not be empty",
307 ));
308 }
309 for target in &config.target_modules {
310 if !SUPPORTED_TARGET_MODULES.contains(&target.as_str()) {
311 return Err(FerrumError::config(format!(
312 "unsupported LoRA target module: {target}"
313 )));
314 }
315 }
316 Ok(())
317}
318
319fn collect_lora_tensor_pairs(
320 config: &LoraAdapterConfig,
321 tensors: &SafeTensors<'_>,
322) -> Result<Vec<LoraTensorPair>> {
323 let names: Vec<String> = tensors.names().iter().map(|s| (*s).to_string()).collect();
324 let name_set: HashSet<&str> = names.iter().map(String::as_str).collect();
325 let mut pairs = Vec::new();
326
327 for target in &config.target_modules {
328 for a_name in names
329 .iter()
330 .filter(|name| is_lora_a_for_target(name, target))
331 {
332 let b_name = a_name.replace(".lora_A.weight", ".lora_B.weight");
333 if !name_set.contains(b_name.as_str()) {
334 return Err(FerrumError::config(format!(
335 "LoRA tensor pair missing B tensor for {a_name}"
336 )));
337 }
338 let a = tensors
339 .tensor(a_name)
340 .map_err(|e| FerrumError::config(format!("read LoRA tensor {a_name}: {e}")))?;
341 let b = tensors
342 .tensor(&b_name)
343 .map_err(|e| FerrumError::config(format!("read LoRA tensor {b_name}: {e}")))?;
344 validate_lora_dtype(a_name, a.dtype())?;
345 validate_lora_dtype(&b_name, b.dtype())?;
346 let a_shape = a.shape();
347 let b_shape = b.shape();
348 if a_shape.len() != 2 || b_shape.len() != 2 {
349 return Err(FerrumError::config(format!(
350 "LoRA tensors must be 2-D: {a_name} shape={a_shape:?}, {b_name} shape={b_shape:?}"
351 )));
352 }
353 let rank = a_shape[0];
354 if rank != config.r || b_shape[1] != config.r {
355 return Err(FerrumError::config(format!(
356 "LoRA rank mismatch for {target}: config r={}, A shape={a_shape:?}, B shape={b_shape:?}",
357 config.r
358 )));
359 }
360 pairs.push(LoraTensorPair {
361 target_module: target.clone(),
362 a_tensor: a_name.clone(),
363 b_tensor: b_name,
364 in_features: a_shape[1],
365 out_features: b_shape[0],
366 rank,
367 });
368 }
369 }
370
371 if pairs.is_empty() {
372 return Err(FerrumError::config(
373 "LoRA adapter did not contain any supported target module tensor pairs",
374 ));
375 }
376
377 let mut unique = HashMap::<String, LoraTensorPair>::new();
378 for pair in pairs {
379 unique.entry(pair.a_tensor.clone()).or_insert(pair);
380 }
381 Ok(unique.into_values().collect())
382}
383
384fn is_lora_a_for_target(name: &str, target: &str) -> bool {
385 let suffix = format!(".{target}.lora_A.weight");
386 name.ends_with(&suffix) || name == format!("{target}.lora_A.weight")
387}
388
389fn validate_lora_dtype(name: &str, dtype: Dtype) -> Result<()> {
390 match dtype {
391 Dtype::F16 | Dtype::BF16 | Dtype::F32 => Ok(()),
392 other => Err(FerrumError::config(format!(
393 "unsupported LoRA tensor dtype for {name}: {other:?}"
394 ))),
395 }
396}
397
398fn tensor_data_to_f32(dtype: Dtype, data: &[u8]) -> Result<Vec<f32>> {
399 match dtype {
400 Dtype::F32 => {
401 if data.len() % 4 != 0 {
402 return Err(FerrumError::config(
403 "F32 LoRA tensor byte length is not divisible by 4",
404 ));
405 }
406 Ok(data
407 .chunks_exact(4)
408 .map(|bytes| f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
409 .collect())
410 }
411 Dtype::F16 => {
412 if data.len() % 2 != 0 {
413 return Err(FerrumError::config(
414 "F16 LoRA tensor byte length is not divisible by 2",
415 ));
416 }
417 Ok(data
418 .chunks_exact(2)
419 .map(|bytes| f16::from_le_bytes([bytes[0], bytes[1]]).to_f32())
420 .collect())
421 }
422 Dtype::BF16 => {
423 if data.len() % 2 != 0 {
424 return Err(FerrumError::config(
425 "BF16 LoRA tensor byte length is not divisible by 2",
426 ));
427 }
428 Ok(data
429 .chunks_exact(2)
430 .map(|bytes| bf16::from_le_bytes([bytes[0], bytes[1]]).to_f32())
431 .collect())
432 }
433 other => Err(FerrumError::config(format!(
434 "unsupported LoRA tensor dtype for runtime load: {other:?}"
435 ))),
436 }
437}
438
439fn parse_lora_layer_index(name: &str) -> Option<usize> {
440 let marker = ".layers.";
441 let start = name.find(marker)? + marker.len();
442 let digits: String = name[start..]
443 .chars()
444 .take_while(|ch| ch.is_ascii_digit())
445 .collect();
446 (!digits.is_empty()).then(|| digits.parse().ok()).flatten()
447}