Skip to main content

ferrum_models/
lora.rs

1//! Startup LoRA adapter loader and validator.
2//!
3//! G4 supports PEFT-style startup adapters with:
4//! - adapter_config.json
5//! - adapter_model.safetensors
6//!
7//! The production server registry uses the validated metadata here. Actual
8//! hot-loading, GGUF LoRA, and multi-adapter composition are intentionally out
9//! of scope.
10
11use ferrum_kernels::backend::Backend;
12use ferrum_quantization::{DenseLinear, Linear};
13use ferrum_types::{FerrumError, Result};
14use half::{bf16, f16};
15use safetensors::{Dtype, SafeTensors};
16use serde::{Deserialize, Serialize};
17use std::collections::{HashMap, HashSet};
18use std::path::{Path, PathBuf};
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
21pub struct LoraAdapterConfig {
22    pub r: usize,
23    pub lora_alpha: usize,
24    pub target_modules: Vec<String>,
25    pub base_model_name_or_path: String,
26}
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct LoraTensorPair {
30    pub target_module: String,
31    pub a_tensor: String,
32    pub b_tensor: String,
33    pub in_features: usize,
34    pub out_features: usize,
35    pub rank: usize,
36}
37
38#[derive(Debug, Clone)]
39pub struct StartupLoraAdapter {
40    pub name: String,
41    pub public_model_id: String,
42    pub path: PathBuf,
43    pub config: LoraAdapterConfig,
44    pub tensors: Vec<LoraTensorPair>,
45}
46
47#[derive(Debug, Clone)]
48pub struct StartupLoraSpec {
49    pub name: String,
50    pub path: PathBuf,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct ActiveLoraAdapter {
55    pub name: String,
56    pub path: PathBuf,
57}
58
59pub struct RuntimeLoraLinear<B: Backend> {
60    pub layer_index: Option<usize>,
61    pub target_module: String,
62    pub in_features: usize,
63    pub out_features: usize,
64    pub rank: usize,
65    pub scaling: f32,
66    a: DenseLinear<B>,
67    b: DenseLinear<B>,
68}
69
70pub struct RuntimeLoraAdapter<B: Backend> {
71    pub name: String,
72    pub path: PathBuf,
73    pub config: LoraAdapterConfig,
74    pub linears: Vec<RuntimeLoraLinear<B>>,
75}
76
77impl<B: Backend> RuntimeLoraAdapter<B> {
78    pub fn apply_projection(
79        &self,
80        ctx: &mut B::Context,
81        layer_index: usize,
82        target_module: &str,
83        input: &B::Buffer,
84        out: &mut B::Buffer,
85        m: usize,
86    ) -> Result<usize> {
87        let mut applied = 0usize;
88        for linear in self.linears.iter().filter(|linear| {
89            linear.target_module == target_module
90                && linear
91                    .layer_index
92                    .map(|idx| idx == layer_index)
93                    .unwrap_or(true)
94        }) {
95            linear.apply(ctx, input, out, m)?;
96            applied += 1;
97        }
98        Ok(applied)
99    }
100
101    pub fn supports_projection(
102        &self,
103        layer_index: usize,
104        target_module: &str,
105        in_features: usize,
106        out_features: usize,
107    ) -> bool {
108        self.linears.iter().any(|linear| {
109            linear.target_module == target_module
110                && linear.in_features == in_features
111                && linear.out_features == out_features
112                && linear
113                    .layer_index
114                    .map(|idx| idx == layer_index)
115                    .unwrap_or(true)
116        })
117    }
118}
119
120impl<B: Backend> RuntimeLoraLinear<B> {
121    fn apply(
122        &self,
123        ctx: &mut B::Context,
124        input: &B::Buffer,
125        out: &mut B::Buffer,
126        m: usize,
127    ) -> Result<()> {
128        let mut low_rank = B::alloc(m * self.rank);
129        let mut delta = B::alloc(m * self.out_features);
130        self.a.forward(ctx, input, &mut low_rank, m);
131        self.b.forward(ctx, &low_rank, &mut delta, m);
132        B::scaled_add_inplace(ctx, out, &delta, self.scaling, m * self.out_features);
133        Ok(())
134    }
135}
136
137const SUPPORTED_TARGET_MODULES: &[&str] = &[
138    "q_proj",
139    "k_proj",
140    "v_proj",
141    "o_proj",
142    "qkv_proj",
143    "gate_proj",
144    "up_proj",
145    "down_proj",
146    "gate_up_proj",
147    "linear",
148];
149
150pub fn default_lora_model_id(base_model_id: &str, name: &str) -> String {
151    format!("{base_model_id}:{name}")
152}
153
154pub fn render_lora_model_id(template: &str, base_model_id: &str, name: &str) -> String {
155    template
156        .replace("<base>", base_model_id)
157        .replace("<name>", name)
158}
159
160pub fn load_startup_lora_adapter(
161    name: impl Into<String>,
162    path: impl AsRef<Path>,
163    public_model_id: impl Into<String>,
164) -> Result<StartupLoraAdapter> {
165    let name = name.into();
166    validate_lora_name(&name)?;
167    let path = path.as_ref().to_path_buf();
168    if !path.is_dir() {
169        return Err(FerrumError::config(format!(
170            "LoRA adapter path does not exist or is not a directory: {}",
171            path.display()
172        )));
173    }
174
175    let config_path = path.join("adapter_config.json");
176    let weights_path = path.join("adapter_model.safetensors");
177    if !config_path.is_file() {
178        return Err(FerrumError::config(format!(
179            "LoRA adapter missing adapter_config.json: {}",
180            path.display()
181        )));
182    }
183    if !weights_path.is_file() {
184        return Err(FerrumError::config(format!(
185            "LoRA adapter missing adapter_model.safetensors: {}",
186            path.display()
187        )));
188    }
189
190    let config: LoraAdapterConfig = serde_json::from_slice(
191        &std::fs::read(&config_path)
192            .map_err(|e| FerrumError::config(format!("read adapter_config.json: {e}")))?,
193    )
194    .map_err(|e| FerrumError::config(format!("invalid adapter_config.json: {e}")))?;
195    validate_lora_config(&config)?;
196
197    let weights = std::fs::read(&weights_path)
198        .map_err(|e| FerrumError::config(format!("read adapter_model.safetensors: {e}")))?;
199    let tensors = SafeTensors::deserialize(&weights)
200        .map_err(|e| FerrumError::config(format!("invalid adapter_model.safetensors: {e}")))?;
201    let pairs = collect_lora_tensor_pairs(&config, &tensors)?;
202
203    Ok(StartupLoraAdapter {
204        name,
205        public_model_id: public_model_id.into(),
206        path,
207        config,
208        tensors: pairs,
209    })
210}
211
212pub fn load_startup_lora_adapters(
213    base_model_id: &str,
214    template: Option<&str>,
215    specs: &[StartupLoraSpec],
216) -> Result<Vec<StartupLoraAdapter>> {
217    let mut seen_names = HashSet::new();
218    let mut seen_ids = HashSet::new();
219    let template = template.unwrap_or("<base>:<name>");
220    let mut out = Vec::with_capacity(specs.len());
221    for spec in specs {
222        if !seen_names.insert(spec.name.clone()) {
223            return Err(FerrumError::config(format!(
224                "duplicate LoRA adapter name: {}",
225                spec.name
226            )));
227        }
228        let public_model_id = render_lora_model_id(template, base_model_id, &spec.name);
229        if !seen_ids.insert(public_model_id.clone()) {
230            return Err(FerrumError::config(format!(
231                "duplicate LoRA public model id: {public_model_id}"
232            )));
233        }
234        out.push(load_startup_lora_adapter(
235            spec.name.clone(),
236            &spec.path,
237            public_model_id,
238        )?);
239    }
240    Ok(out)
241}
242
243pub fn load_runtime_lora_adapter<B: Backend>(
244    adapter: &ActiveLoraAdapter,
245) -> Result<RuntimeLoraAdapter<B>> {
246    let startup = load_startup_lora_adapter(
247        adapter.name.clone(),
248        &adapter.path,
249        default_lora_model_id("base", &adapter.name),
250    )?;
251    let weights_path = adapter.path.join("adapter_model.safetensors");
252    let weights = std::fs::read(&weights_path)
253        .map_err(|e| FerrumError::config(format!("read adapter_model.safetensors: {e}")))?;
254    let tensors = SafeTensors::deserialize(&weights)
255        .map_err(|e| FerrumError::config(format!("invalid adapter_model.safetensors: {e}")))?;
256
257    let mut linears = Vec::with_capacity(startup.tensors.len());
258    for pair in &startup.tensors {
259        let a = tensors
260            .tensor(&pair.a_tensor)
261            .map_err(|e| FerrumError::config(format!("read LoRA tensor {}: {e}", pair.a_tensor)))?;
262        let b = tensors
263            .tensor(&pair.b_tensor)
264            .map_err(|e| FerrumError::config(format!("read LoRA tensor {}: {e}", pair.b_tensor)))?;
265        let a_data = tensor_data_to_f32(a.dtype(), a.data())?;
266        let b_data = tensor_data_to_f32(b.dtype(), b.data())?;
267        linears.push(RuntimeLoraLinear {
268            layer_index: parse_lora_layer_index(&pair.a_tensor),
269            target_module: pair.target_module.clone(),
270            in_features: pair.in_features,
271            out_features: pair.out_features,
272            rank: pair.rank,
273            scaling: startup.config.lora_alpha as f32 / startup.config.r as f32,
274            a: DenseLinear::<B>::from_rows(&a_data, pair.rank, pair.in_features),
275            b: DenseLinear::<B>::from_rows(&b_data, pair.out_features, pair.rank),
276        });
277    }
278
279    Ok(RuntimeLoraAdapter {
280        name: adapter.name.clone(),
281        path: adapter.path.clone(),
282        config: startup.config,
283        linears,
284    })
285}
286
287fn validate_lora_name(name: &str) -> Result<()> {
288    if name.is_empty()
289        || !name
290            .chars()
291            .all(|c| c.is_ascii_alphanumeric() || matches!(c, '_' | '-' | '.'))
292    {
293        return Err(FerrumError::config(format!(
294            "invalid LoRA adapter name {name:?}; use [A-Za-z0-9_.-]"
295        )));
296    }
297    Ok(())
298}
299
300fn validate_lora_config(config: &LoraAdapterConfig) -> Result<()> {
301    if config.r == 0 {
302        return Err(FerrumError::config("LoRA adapter rank r must be > 0"));
303    }
304    if config.target_modules.is_empty() {
305        return Err(FerrumError::config(
306            "LoRA adapter target_modules must not be empty",
307        ));
308    }
309    for target in &config.target_modules {
310        if !SUPPORTED_TARGET_MODULES.contains(&target.as_str()) {
311            return Err(FerrumError::config(format!(
312                "unsupported LoRA target module: {target}"
313            )));
314        }
315    }
316    Ok(())
317}
318
319fn collect_lora_tensor_pairs(
320    config: &LoraAdapterConfig,
321    tensors: &SafeTensors<'_>,
322) -> Result<Vec<LoraTensorPair>> {
323    let names: Vec<String> = tensors.names().iter().map(|s| (*s).to_string()).collect();
324    let name_set: HashSet<&str> = names.iter().map(String::as_str).collect();
325    let mut pairs = Vec::new();
326
327    for target in &config.target_modules {
328        for a_name in names
329            .iter()
330            .filter(|name| is_lora_a_for_target(name, target))
331        {
332            let b_name = a_name.replace(".lora_A.weight", ".lora_B.weight");
333            if !name_set.contains(b_name.as_str()) {
334                return Err(FerrumError::config(format!(
335                    "LoRA tensor pair missing B tensor for {a_name}"
336                )));
337            }
338            let a = tensors
339                .tensor(a_name)
340                .map_err(|e| FerrumError::config(format!("read LoRA tensor {a_name}: {e}")))?;
341            let b = tensors
342                .tensor(&b_name)
343                .map_err(|e| FerrumError::config(format!("read LoRA tensor {b_name}: {e}")))?;
344            validate_lora_dtype(a_name, a.dtype())?;
345            validate_lora_dtype(&b_name, b.dtype())?;
346            let a_shape = a.shape();
347            let b_shape = b.shape();
348            if a_shape.len() != 2 || b_shape.len() != 2 {
349                return Err(FerrumError::config(format!(
350                    "LoRA tensors must be 2-D: {a_name} shape={a_shape:?}, {b_name} shape={b_shape:?}"
351                )));
352            }
353            let rank = a_shape[0];
354            if rank != config.r || b_shape[1] != config.r {
355                return Err(FerrumError::config(format!(
356                    "LoRA rank mismatch for {target}: config r={}, A shape={a_shape:?}, B shape={b_shape:?}",
357                    config.r
358                )));
359            }
360            pairs.push(LoraTensorPair {
361                target_module: target.clone(),
362                a_tensor: a_name.clone(),
363                b_tensor: b_name,
364                in_features: a_shape[1],
365                out_features: b_shape[0],
366                rank,
367            });
368        }
369    }
370
371    if pairs.is_empty() {
372        return Err(FerrumError::config(
373            "LoRA adapter did not contain any supported target module tensor pairs",
374        ));
375    }
376
377    let mut unique = HashMap::<String, LoraTensorPair>::new();
378    for pair in pairs {
379        unique.entry(pair.a_tensor.clone()).or_insert(pair);
380    }
381    Ok(unique.into_values().collect())
382}
383
384fn is_lora_a_for_target(name: &str, target: &str) -> bool {
385    let suffix = format!(".{target}.lora_A.weight");
386    name.ends_with(&suffix) || name == format!("{target}.lora_A.weight")
387}
388
389fn validate_lora_dtype(name: &str, dtype: Dtype) -> Result<()> {
390    match dtype {
391        Dtype::F16 | Dtype::BF16 | Dtype::F32 => Ok(()),
392        other => Err(FerrumError::config(format!(
393            "unsupported LoRA tensor dtype for {name}: {other:?}"
394        ))),
395    }
396}
397
398fn tensor_data_to_f32(dtype: Dtype, data: &[u8]) -> Result<Vec<f32>> {
399    match dtype {
400        Dtype::F32 => {
401            if data.len() % 4 != 0 {
402                return Err(FerrumError::config(
403                    "F32 LoRA tensor byte length is not divisible by 4",
404                ));
405            }
406            Ok(data
407                .chunks_exact(4)
408                .map(|bytes| f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
409                .collect())
410        }
411        Dtype::F16 => {
412            if data.len() % 2 != 0 {
413                return Err(FerrumError::config(
414                    "F16 LoRA tensor byte length is not divisible by 2",
415                ));
416            }
417            Ok(data
418                .chunks_exact(2)
419                .map(|bytes| f16::from_le_bytes([bytes[0], bytes[1]]).to_f32())
420                .collect())
421        }
422        Dtype::BF16 => {
423            if data.len() % 2 != 0 {
424                return Err(FerrumError::config(
425                    "BF16 LoRA tensor byte length is not divisible by 2",
426                ));
427            }
428            Ok(data
429                .chunks_exact(2)
430                .map(|bytes| bf16::from_le_bytes([bytes[0], bytes[1]]).to_f32())
431                .collect())
432        }
433        other => Err(FerrumError::config(format!(
434            "unsupported LoRA tensor dtype for runtime load: {other:?}"
435        ))),
436    }
437}
438
439fn parse_lora_layer_index(name: &str) -> Option<usize> {
440    let marker = ".layers.";
441    let start = name.find(marker)? + marker.len();
442    let digits: String = name[start..]
443        .chars()
444        .take_while(|ch| ch.is_ascii_digit())
445        .collect();
446    (!digits.is_empty()).then(|| digits.parse().ok()).flatten()
447}