1#![allow(clippy::doc_markdown)]
9#![allow(clippy::uninlined_format_args)]
10
11use candle_core::{DType, Device, Tensor};
12use candle_nn::VarMap;
13use serde::{Deserialize, Serialize};
14
15use crate::error::{PeftError, Result};
16use crate::traits::{Adapter, AdapterConfig, Mergeable, Trainable};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Ia3Config {
21 #[serde(default = "default_target_modules")]
23 pub target_modules: Vec<String>,
24
25 #[serde(default)]
28 pub feedforward_modules: Vec<String>,
29
30 #[serde(default = "default_true")]
33 pub init_ia3_weights: bool,
34
35 #[serde(default)]
37 pub fan_in_fan_out: bool,
38}
39
40fn default_target_modules() -> Vec<String> {
41 vec!["k_proj".into(), "v_proj".into(), "down_proj".into()]
42}
43
44fn default_true() -> bool {
45 true
46}
47
48impl Default for Ia3Config {
49 fn default() -> Self {
50 Self {
51 target_modules: default_target_modules(),
52 feedforward_modules: vec!["down_proj".into()],
53 init_ia3_weights: true,
54 fan_in_fan_out: false,
55 }
56 }
57}
58
59impl AdapterConfig for Ia3Config {
60 fn validate(&self) -> Result<()> {
61 if self.target_modules.is_empty() {
62 return Err(PeftError::InvalidConfig(
63 "target_modules cannot be empty".into(),
64 ));
65 }
66 for ff_module in &self.feedforward_modules {
68 if !self.target_modules.contains(ff_module) {
69 return Err(PeftError::InvalidConfig(format!(
70 "feedforward_module '{}' must be in target_modules",
71 ff_module
72 )));
73 }
74 }
75 Ok(())
76 }
77}
78
79pub struct Ia3Layer {
84 ia3_l: Tensor,
87 config: Ia3Config,
89 in_features: usize,
91 out_features: usize,
93 is_feedforward: bool,
95 frozen: bool,
97}
98
99impl Ia3Layer {
100 pub fn new(
112 in_features: usize,
113 out_features: usize,
114 is_feedforward: bool,
115 config: Ia3Config,
116 device: &Device,
117 ) -> Result<Self> {
118 config.validate()?;
119
120 let ia3_l = if config.init_ia3_weights {
122 if is_feedforward {
124 Tensor::ones((1, in_features), DType::F32, device)?
125 } else {
126 Tensor::ones((out_features, 1), DType::F32, device)?
127 }
128 } else {
129 if is_feedforward {
131 Tensor::randn(0.0f32, 0.02, (1, in_features), device)?
132 } else {
133 Tensor::randn(0.0f32, 0.02, (out_features, 1), device)?
134 }
135 };
136
137 Ok(Self {
138 ia3_l,
139 config,
140 in_features,
141 out_features,
142 is_feedforward,
143 frozen: false,
144 })
145 }
146
147 #[must_use]
149 pub fn scaling_vector(&self) -> &Tensor {
150 &self.ia3_l
151 }
152
153 #[must_use]
155 pub fn is_feedforward(&self) -> bool {
156 self.is_feedforward
157 }
158
159 pub fn scale_input(&self, input: &Tensor) -> Result<Tensor> {
170 if !self.is_feedforward {
171 return Err(PeftError::InvalidConfig(
172 "scale_input called on non-feedforward IA³ layer".into(),
173 ));
174 }
175 let scaling = self.ia3_l.reshape((1, 1, self.in_features))?;
178 Ok(input.broadcast_mul(&scaling)?)
179 }
180
181 pub fn scale_output(&self, output: &Tensor) -> Result<Tensor> {
192 if self.is_feedforward {
193 return Err(PeftError::InvalidConfig(
194 "scale_output called on feedforward IA³ layer".into(),
195 ));
196 }
197 let scaling = self.ia3_l.reshape((1, 1, self.out_features))?;
200 Ok(output.broadcast_mul(&scaling)?)
201 }
202}
203
204impl Adapter for Ia3Layer {
205 type Config = Ia3Config;
206
207 fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
208 if self.is_feedforward {
209 self.scale_input(input)
212 } else {
213 match base_output {
215 Some(output) => self.scale_output(output),
216 None => Err(PeftError::InvalidConfig(
217 "Non-feedforward IA³ requires base_output".into(),
218 )),
219 }
220 }
221 }
222
223 fn num_parameters(&self) -> usize {
224 if self.is_feedforward {
225 self.in_features
226 } else {
227 self.out_features
228 }
229 }
230
231 fn config(&self) -> &Self::Config {
232 &self.config
233 }
234}
235
236impl Mergeable for Ia3Layer {
237 fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
238 if self.is_feedforward {
244 Ok(base_weight.broadcast_mul(&self.ia3_l)?)
247 } else {
248 Ok(base_weight.broadcast_mul(&self.ia3_l)?)
251 }
252 }
253
254 fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
255 let tolerance = 1e-8_f32;
258 let tolerance_tensor = Tensor::new(tolerance, self.ia3_l.device())?;
259 let safe_divisor = self.ia3_l.broadcast_add(&tolerance_tensor)?;
260
261 Ok(merged_weight.broadcast_div(&safe_divisor)?)
262 }
263}
264
265impl Trainable for Ia3Layer {
266 fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> Result<()> {
267 Ok(())
272 }
273
274 fn freeze(&mut self) {
275 self.frozen = true;
276 }
277
278 fn unfreeze(&mut self) {
279 self.frozen = false;
280 }
281
282 fn is_frozen(&self) -> bool {
283 self.frozen
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_ia3_config_default() {
293 let config = Ia3Config::default();
294 assert!(!config.target_modules.is_empty());
295 assert!(config.init_ia3_weights);
296 assert!(config.validate().is_ok());
297 }
298
299 #[test]
300 fn test_ia3_config_invalid_feedforward() {
301 let config = Ia3Config {
302 target_modules: vec!["q_proj".into()],
303 feedforward_modules: vec!["not_in_targets".into()],
304 ..Default::default()
305 };
306 assert!(config.validate().is_err());
307 }
308
309 #[test]
310 fn test_ia3_layer_creation_non_feedforward() {
311 let config = Ia3Config::default();
312 let device = Device::Cpu;
313 let layer = Ia3Layer::new(768, 768, false, config, &device);
314 assert!(layer.is_ok());
315
316 let layer = layer.unwrap();
317 assert!(!layer.is_feedforward());
318 assert_eq!(layer.scaling_vector().dims(), &[768, 1]);
320 }
321
322 #[test]
323 fn test_ia3_layer_creation_feedforward() {
324 let config = Ia3Config::default();
325 let device = Device::Cpu;
326 let layer = Ia3Layer::new(768, 3072, true, config, &device);
327 assert!(layer.is_ok());
328
329 let layer = layer.unwrap();
330 assert!(layer.is_feedforward());
331 assert_eq!(layer.scaling_vector().dims(), &[1, 768]);
333 }
334
335 #[test]
336 fn test_ia3_num_parameters_non_feedforward() {
337 let config = Ia3Config::default();
338 let device = Device::Cpu;
339 let layer = Ia3Layer::new(768, 512, false, config, &device).unwrap();
340 assert_eq!(layer.num_parameters(), 512);
342 }
343
344 #[test]
345 fn test_ia3_num_parameters_feedforward() {
346 let config = Ia3Config::default();
347 let device = Device::Cpu;
348 let layer = Ia3Layer::new(768, 3072, true, config, &device).unwrap();
349 assert_eq!(layer.num_parameters(), 768);
351 }
352
353 #[test]
354 fn test_ia3_forward_non_feedforward() {
355 let config = Ia3Config::default();
356 let device = Device::Cpu;
357 let layer = Ia3Layer::new(768, 768, false, config, &device).unwrap();
358
359 let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
360 let base_output = Tensor::ones(&[1, 10, 768], DType::F32, &device).unwrap();
361
362 let output = layer.forward(&input, Some(&base_output)).unwrap();
363 assert_eq!(output.shape().dims(), &[1, 10, 768]);
364 }
365
366 #[test]
367 fn test_ia3_forward_feedforward() {
368 let config = Ia3Config::default();
369 let device = Device::Cpu;
370 let layer = Ia3Layer::new(768, 3072, true, config, &device).unwrap();
371
372 let input = Tensor::ones(&[1, 10, 768], DType::F32, &device).unwrap();
373
374 let output = layer.forward(&input, None).unwrap();
375 assert_eq!(output.shape().dims(), &[1, 10, 768]);
377 }
378
379 #[test]
380 fn test_ia3_initialized_to_ones() {
381 let config = Ia3Config {
382 init_ia3_weights: true,
383 ..Default::default()
384 };
385 let device = Device::Cpu;
386 let layer = Ia3Layer::new(768, 768, false, config, &device).unwrap();
387
388 let base_output = Tensor::full(2.0f32, &[1, 10, 768], &device).unwrap();
391 let output = layer
392 .forward(
393 &Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap(),
394 Some(&base_output),
395 )
396 .unwrap();
397
398 let output_sum: f32 = output.sum_all().unwrap().to_scalar().unwrap();
400 let expected_sum = 2.0f32 * 1.0 * 10.0 * 768.0;
401 assert!((output_sum - expected_sum).abs() < 1e-3);
402 }
403}