1#![allow(clippy::cast_precision_loss)]
2
3use std::{collections::HashSet, fmt::Debug, sync::Arc};
4
5use hanzo_ml::{quantized::QTensor, IndexOp, Result, Tensor, D};
6use hanzo_nn::{Linear, Module};
7use hanzo_quant::{QuantMethod, ShardedVarBuilder};
8use loralinear::LoraLinear;
9pub use qloralinear::QLoraLinear;
10use serde::Deserialize;
11
12mod loralinear;
13mod qloralinear;
14
15use std::collections::HashMap;
16
17use crate::layers;
18
19#[derive(Clone, Debug, Deserialize)]
20pub struct PreloadAdapter {
21 pub name: String,
22 pub adapter_model_id: String,
23}
24
25#[derive(Clone, Debug, Deserialize)]
26pub struct Ordering {
28 #[serde(rename = "order")]
29 pub adapters: Option<Vec<String>>,
30 pub layers: Option<HashMap<String, usize>>,
31 pub base_model_id: String,
32 pub preload_adapters: Option<Vec<PreloadAdapter>>,
33}
34
35#[derive(Clone, Debug)]
36pub struct LoraLinearConfig {
38 in_features: usize,
39 out_features: usize,
40}
41
42impl LoraLinearConfig {
43 pub fn new(in_features: usize, out_features: usize) -> Self {
44 LoraLinearConfig {
45 in_features,
46 out_features,
47 }
48 }
49}
50
51#[derive(Clone, Debug, Deserialize)]
52pub struct LoraConfig {
53 #[serde(rename = "r")]
54 rank: usize,
55 #[serde(rename = "lora_alpha")]
56 alpha: f64,
57 #[serde(rename = "lora_dropout")]
58 dropout: Option<f32>,
59 target_modules: HashSet<String>,
60}
61
62fn apply_scalings_to_x(x: Tensor, scalings_layer: &Tensor, adapter: usize) -> Result<Tensor> {
63 let scalings = scalings_layer.i((.., .., adapter))?.unsqueeze(D::Minus1)?;
64 let res = x.broadcast_mul(&scalings)?;
65 Ok(res)
66}
67
68#[derive(Debug)]
69struct Adapter {
70 a: Linear,
71 b: Linear,
72 scale: f64,
73}
74
75fn make_adapter(
76 a_vb: ShardedVarBuilder,
77 b_vb: ShardedVarBuilder,
78 cfg: &LoraConfig,
79 linear_cfg: &LoraLinearConfig,
80) -> Result<Adapter> {
81 assert!(a_vb.contains_tensor("weight"));
82 let a = a_vb.get((cfg.rank, linear_cfg.in_features), "weight")?;
83 assert!(b_vb.contains_tensor("weight"));
84 let b = b_vb.get((linear_cfg.out_features, cfg.rank), "weight")?;
85 let a = Linear::new(a, None);
86 let b = Linear::new(b, None);
87 let scale = if cfg.rank > 0 {
88 cfg.alpha / cfg.rank as f64
89 } else {
90 1.0
91 };
92 Ok(Adapter { a, b, scale })
93}
94
95pub trait LinearLayerLike: Merge {
97 fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod>;
98 fn is_lora(&self) -> bool;
99 fn weight(&self) -> &Tensor;
100 fn bias(&self) -> Option<&Tensor>;
101 fn lora_forward(
102 &self,
103 x: &Tensor,
104 scalings_layer: Option<Tensor>,
105 global_scaling_weight: f64,
106 is_scaling_pass: Option<f64>,
107 ) -> Result<Tensor>;
108}
109
110pub trait Merge {
111 fn get_delta_weight(&self, adapter: usize) -> Result<Tensor>;
113 fn merge_weights(&mut self) -> Result<()>;
115}
116
117impl Merge for Linear {
118 fn merge_weights(&mut self) -> Result<()> {
119 Ok(())
120 }
121 fn get_delta_weight(&self, _adapter: usize) -> Result<Tensor> {
122 unreachable!()
123 }
124}
125
126impl LinearLayerLike for Linear {
127 fn bias(&self) -> Option<&Tensor> {
128 self.bias()
129 }
130 fn quant_inner(&mut self) -> &mut Arc<dyn QuantMethod> {
131 unimplemented!("Linear layer has no reasonable quant inner!")
132 }
133 fn weight(&self) -> &Tensor {
134 self.weight()
135 }
136 fn lora_forward(
137 &self,
138 x: &Tensor,
139 _scalings_layer: Option<Tensor>,
140 _global_scaling_weight: f64,
141 _is_scaling_pass: Option<f64>,
142 ) -> Result<Tensor> {
143 self.forward(x)
144 }
145 fn is_lora(&self) -> bool {
146 false
147 }
148}
149
150#[allow(clippy::too_many_arguments)]
151pub fn linear(
152 d1: usize,
153 d2: usize,
154 base_vb: ShardedVarBuilder,
155 vb: ShardedVarBuilder,
156 lora_config: &[((String, String), LoraConfig)],
157 count: &mut usize,
158 ord: &Ordering,
159 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
160) -> Result<Arc<dyn LinearLayerLike + Send + Sync>> {
161 let prefix = vb.prefix();
162 let module = prefix.split('.').next_back().unwrap();
163
164 let linear_config = LoraLinearConfig::new(d1, d2);
165 let inner = layers::linear(d1, d2, base_vb.clone())?;
166
167 let target_modules = &lora_config.first().map(|c| &c.1.target_modules);
168 for (_, cfg) in lora_config {
169 if target_modules
170 .as_ref()
171 .is_some_and(|target_modules| &cfg.target_modules != *target_modules)
172 {
173 hanzo_ml::bail!("Expected all target modules to be the same.");
174 }
175 }
176
177 if !target_modules
178 .as_ref()
179 .is_some_and(|target_modules| target_modules.contains(module))
180 {
181 return Ok(Arc::new(inner));
182 }
183 let name = prefix.split("lora_A").last().unwrap();
184 let layer = if let Some(ref layers) = ord.layers {
185 *layers.get(name).unwrap()
186 } else {
187 0
188 };
189
190 let lorainner = LoraLinear::new(
191 &inner,
192 &linear_config,
193 lora_config,
194 &vb,
195 layer,
196 preload_adapters,
197 )?;
198 *count += 1;
199 Ok(Arc::new(lorainner))
200}
201
202#[allow(clippy::too_many_arguments)]
203pub fn linear_no_bias(
204 d1: usize,
205 d2: usize,
206 base_vb: ShardedVarBuilder,
207 vb: ShardedVarBuilder,
208 lora_config: &[((String, String), LoraConfig)],
209 count: &mut usize,
210 ord: &Ordering,
211 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
212) -> Result<Arc<dyn LinearLayerLike + Send + Sync>> {
213 let prefix = vb.prefix();
214 let module = prefix.split('.').next_back().unwrap();
215
216 let linear_config = LoraLinearConfig::new(d1, d2);
217 let inner = layers::linear_no_bias(d1, d2, base_vb.clone())?;
218
219 let target_modules = &lora_config.first().map(|c| &c.1.target_modules);
220 for (_, cfg) in lora_config {
221 if target_modules
222 .as_ref()
223 .is_some_and(|target_modules| &cfg.target_modules != *target_modules)
224 {
225 hanzo_ml::bail!("Expected all target modules to be the same.");
226 }
227 }
228
229 if !target_modules
230 .as_ref()
231 .is_some_and(|target_modules| target_modules.contains(module))
232 {
233 return Ok(Arc::new(inner));
234 }
235 let name = prefix.split("lora_A").last().unwrap();
236 let layer = if let Some(ref layers) = ord.layers {
237 *layers.get(name).unwrap()
238 } else {
239 0
240 };
241
242 let lorainner = LoraLinear::new(
243 &inner,
244 &linear_config,
245 lora_config,
246 &vb,
247 layer,
248 preload_adapters,
249 )?;
250 *count += 1;
251 Ok(Arc::new(lorainner))
252}
253
254fn get_maybe_topk_scalings(scalings: Tensor, layer: usize) -> Result<Tensor> {
255 scalings.i((.., .., layer, ..))
256}
257
258#[allow(clippy::too_many_arguments)]
259pub fn linear_b(
260 in_dim: usize,
261 out_dim: usize,
262 bias: bool,
263 base_vb: ShardedVarBuilder,
264 vb: ShardedVarBuilder,
265 lora_config: &[((String, String), LoraConfig)],
266 count: &mut usize,
267 ord: &Ordering,
268 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
269) -> Result<Arc<dyn LinearLayerLike + Send + Sync>> {
270 if bias {
271 linear(
272 in_dim,
273 out_dim,
274 base_vb,
275 vb,
276 lora_config,
277 count,
278 ord,
279 preload_adapters,
280 )
281 } else {
282 linear_no_bias(
283 in_dim,
284 out_dim,
285 base_vb,
286 vb,
287 lora_config,
288 count,
289 ord,
290 preload_adapters,
291 )
292 }
293}
294
295pub fn get_lora_cfg(tensor: &QTensor) -> LoraLinearConfig {
296 LoraLinearConfig::new(tensor.shape().dims()[1], tensor.shape().dims()[0])
297}