Skip to main content

hanzo_engine/lora/
mod.rs

1#![allow(clippy::cast_precision_loss)]
2
3use std::{collections::HashSet, fmt::Debug, sync::Arc};
4
5use hanzo_ml::{quantized::QTensor, IndexOp, Result, Tensor, D};
6use hanzo_nn::{Linear, Module};
7use hanzo_quant::{QuantMethod, ShardedVarBuilder};
8use loralinear::LoraLinear;
9pub use qloralinear::QLoraLinear;
10use serde::Deserialize;
11
12mod loralinear;
13mod qloralinear;
14
15use std::collections::HashMap;
16
17use crate::layers;
18
19#[derive(Clone, Debug, Deserialize)]
20pub struct PreloadAdapter {
21    pub name: String,
22    pub adapter_model_id: String,
23}
24
25#[derive(Clone, Debug, Deserialize)]
26/// Adapter model ordering information.
27pub struct Ordering {
28    #[serde(rename = "order")]
29    pub adapters: Option<Vec<String>>,
30    pub layers: Option<HashMap<String, usize>>,
31    pub base_model_id: String,
32    pub preload_adapters: Option<Vec<PreloadAdapter>>,
33}
34
35#[derive(Clone, Debug)]
36/// Configuration for LoraLinear
37pub struct LoraLinearConfig {
38    in_features: usize,
39    out_features: usize,
40}
41
42impl LoraLinearConfig {
43    pub fn new(in_features: usize, out_features: usize) -> Self {
44        LoraLinearConfig {
45            in_features,
46            out_features,
47        }
48    }
49}
50
51#[derive(Clone, Debug, Deserialize)]
52pub struct LoraConfig {
53    #[serde(rename = "r")]
54    rank: usize,
55    #[serde(rename = "lora_alpha")]
56    alpha: f64,
57    #[serde(rename = "lora_dropout")]
58    dropout: Option<f32>,
59    target_modules: HashSet<String>,
60}
61
62fn apply_scalings_to_x(x: Tensor, scalings_layer: &Tensor, adapter: usize) -> Result<Tensor> {
63    let scalings = scalings_layer.i((.., .., adapter))?.unsqueeze(D::Minus1)?;
64    let res = x.broadcast_mul(&scalings)?;
65    Ok(res)
66}
67
68#[derive(Debug)]
69struct Adapter {
70    a: Linear,
71    b: Linear,
72    scale: f64,
73}
74
75fn make_adapter(
76    a_vb: ShardedVarBuilder,
77    b_vb: ShardedVarBuilder,
78    cfg: &LoraConfig,
79    linear_cfg: &LoraLinearConfig,
80) -> Result<Adapter> {
81    assert!(a_vb.contains_tensor("weight"));
82    let a = a_vb.get((cfg.rank, linear_cfg.in_features), "weight")?;
83    assert!(b_vb.contains_tensor("weight"));
84    let b = b_vb.get((linear_cfg.out_features, cfg.rank), "weight")?;
85    let a = Linear::new(a, None);
86    let b = Linear::new(b, None);
87    let scale = if cfg.rank > 0 {
88        cfg.alpha / cfg.rank as f64
89    } else {
90        1.0
91    };
92    Ok(Adapter { a, b, scale })
93}
94
95/// Any layer that is linear-like.
96pub trait LinearLayerLike: Merge {
97    fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod>;
98    fn is_lora(&self) -> bool;
99    fn weight(&self) -> &Tensor;
100    fn bias(&self) -> Option<&Tensor>;
101    fn lora_forward(
102        &self,
103        x: &Tensor,
104        scalings_layer: Option<Tensor>,
105        global_scaling_weight: f64,
106        is_scaling_pass: Option<f64>,
107    ) -> Result<Tensor>;
108}
109
110pub trait Merge {
111    /// Get the delta weight of the LoRA layer. This is meant to be an internal method.
112    fn get_delta_weight(&self, adapter: usize) -> Result<Tensor>;
113    /// Merge the LoRA weights.
114    fn merge_weights(&mut self) -> Result<()>;
115}
116
117impl Merge for Linear {
118    fn merge_weights(&mut self) -> Result<()> {
119        Ok(())
120    }
121    fn get_delta_weight(&self, _adapter: usize) -> Result<Tensor> {
122        unreachable!()
123    }
124}
125
126impl LinearLayerLike for Linear {
127    fn bias(&self) -> Option<&Tensor> {
128        self.bias()
129    }
130    fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod> {
131        unimplemented!("Linear layer has no reasonable quant inner!")
132    }
133    fn weight(&self) -> &Tensor {
134        self.weight()
135    }
136    fn lora_forward(
137        &self,
138        x: &Tensor,
139        _scalings_layer: Option<Tensor>,
140        _global_scaling_weight: f64,
141        _is_scaling_pass: Option<f64>,
142    ) -> Result<Tensor> {
143        self.forward(x)
144    }
145    fn is_lora(&self) -> bool {
146        false
147    }
148}
149
150#[allow(clippy::too_many_arguments)]
151pub fn linear(
152    d1: usize,
153    d2: usize,
154    base_vb: ShardedVarBuilder,
155    vb: ShardedVarBuilder,
156    lora_config: &[((String, String), LoraConfig)],
157    count: &mut usize,
158    ord: &Ordering,
159    preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
160) -> Result<Arc<dyn LinearLayerLike + Send + Sync>> {
161    let prefix = vb.prefix();
162    let module = prefix.split('.').next_back().unwrap();
163
164    let linear_config = LoraLinearConfig::new(d1, d2);
165    let inner = layers::linear(d1, d2, base_vb.clone())?;
166
167    let target_modules = &lora_config.first().map(|c| &c.1.target_modules);
168    for (_, cfg) in lora_config {
169        if target_modules
170            .as_ref()
171            .is_some_and(|target_modules| &cfg.target_modules != *target_modules)
172        {
173            hanzo_ml::bail!("Expected all target modules to be the same.");
174        }
175    }
176
177    if !target_modules
178        .as_ref()
179        .is_some_and(|target_modules| target_modules.contains(module))
180    {
181        return Ok(Arc::new(inner));
182    }
183    let name = prefix.split("lora_A").last().unwrap();
184    let layer = if let Some(ref layers) = ord.layers {
185        *layers.get(name).unwrap()
186    } else {
187        0
188    };
189
190    let lorainner = LoraLinear::new(
191        &inner,
192        &linear_config,
193        lora_config,
194        &vb,
195        layer,
196        preload_adapters,
197    )?;
198    *count += 1;
199    Ok(Arc::new(lorainner))
200}
201
202#[allow(clippy::too_many_arguments)]
203pub fn linear_no_bias(
204    d1: usize,
205    d2: usize,
206    base_vb: ShardedVarBuilder,
207    vb: ShardedVarBuilder,
208    lora_config: &[((String, String), LoraConfig)],
209    count: &mut usize,
210    ord: &Ordering,
211    preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
212) -> Result<Arc<dyn LinearLayerLike + Send + Sync>> {
213    let prefix = vb.prefix();
214    let module = prefix.split('.').next_back().unwrap();
215
216    let linear_config = LoraLinearConfig::new(d1, d2);
217    let inner = layers::linear_no_bias(d1, d2, base_vb.clone())?;
218
219    let target_modules = &lora_config.first().map(|c| &c.1.target_modules);
220    for (_, cfg) in lora_config {
221        if target_modules
222            .as_ref()
223            .is_some_and(|target_modules| &cfg.target_modules != *target_modules)
224        {
225            hanzo_ml::bail!("Expected all target modules to be the same.");
226        }
227    }
228
229    if !target_modules
230        .as_ref()
231        .is_some_and(|target_modules| target_modules.contains(module))
232    {
233        return Ok(Arc::new(inner));
234    }
235    let name = prefix.split("lora_A").last().unwrap();
236    let layer = if let Some(ref layers) = ord.layers {
237        *layers.get(name).unwrap()
238    } else {
239        0
240    };
241
242    let lorainner = LoraLinear::new(
243        &inner,
244        &linear_config,
245        lora_config,
246        &vb,
247        layer,
248        preload_adapters,
249    )?;
250    *count += 1;
251    Ok(Arc::new(lorainner))
252}
253
254fn get_maybe_topk_scalings(scalings: Tensor, layer: usize) -> Result<Tensor> {
255    scalings.i((.., .., layer, ..))
256}
257
258#[allow(clippy::too_many_arguments)]
259pub fn linear_b(
260    in_dim: usize,
261    out_dim: usize,
262    bias: bool,
263    base_vb: ShardedVarBuilder,
264    vb: ShardedVarBuilder,
265    lora_config: &[((String, String), LoraConfig)],
266    count: &mut usize,
267    ord: &Ordering,
268    preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
269) -> Result<Arc<dyn LinearLayerLike + Send + Sync>> {
270    if bias {
271        linear(
272            in_dim,
273            out_dim,
274            base_vb,
275            vb,
276            lora_config,
277            count,
278            ord,
279            preload_adapters,
280        )
281    } else {
282        linear_no_bias(
283            in_dim,
284            out_dim,
285            base_vb,
286            vb,
287            lora_config,
288            count,
289            ord,
290            preload_adapters,
291        )
292    }
293}
294
295pub fn get_lora_cfg(tensor: &QTensor) -> LoraLinearConfig {
296    LoraLinearConfig::new(tensor.shape().dims()[1], tensor.shape().dims()[0])
297}