Skip to main content

peft_rs/adapters/
ia3.rs

1//! IA³ (Infused Adapter by Inhibiting and Amplifying Inner Activations) implementation.
2//!
3//! IA³ is an extremely parameter-efficient fine-tuning method that learns
4//! rescaling vectors for keys, values, and feedforward layers.
5//!
6//! Reference: <https://arxiv.org/abs/2205.05638>
7
8#![allow(clippy::doc_markdown)]
9#![allow(clippy::uninlined_format_args)]
10
11use candle_core::{DType, Device, Tensor};
12use candle_nn::VarMap;
13use serde::{Deserialize, Serialize};
14
15use crate::error::{PeftError, Result};
16use crate::traits::{Adapter, AdapterConfig, Mergeable, Trainable};
17
18/// Configuration for IA³ adapters.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Ia3Config {
21    /// Target modules to apply IA³ to.
22    #[serde(default = "default_target_modules")]
23    pub target_modules: Vec<String>,
24
25    /// Modules treated as feedforward (scaling applied to input).
26    /// Must be a subset of `target_modules`.
27    #[serde(default)]
28    pub feedforward_modules: Vec<String>,
29
30    /// Whether to initialize the vectors in IA³ layers to ones.
31    /// Setting this to false is discouraged.
32    #[serde(default = "default_true")]
33    pub init_ia3_weights: bool,
34
35    /// Set to true if the layer stores weight like (fan_in, fan_out).
36    #[serde(default)]
37    pub fan_in_fan_out: bool,
38}
39
40fn default_target_modules() -> Vec<String> {
41    vec!["k_proj".into(), "v_proj".into(), "down_proj".into()]
42}
43
44fn default_true() -> bool {
45    true
46}
47
48impl Default for Ia3Config {
49    fn default() -> Self {
50        Self {
51            target_modules: default_target_modules(),
52            feedforward_modules: vec!["down_proj".into()],
53            init_ia3_weights: true,
54            fan_in_fan_out: false,
55        }
56    }
57}
58
59impl AdapterConfig for Ia3Config {
60    fn validate(&self) -> Result<()> {
61        if self.target_modules.is_empty() {
62            return Err(PeftError::InvalidConfig(
63                "target_modules cannot be empty".into(),
64            ));
65        }
66        // Check that feedforward_modules is a subset of target_modules
67        for ff_module in &self.feedforward_modules {
68            if !self.target_modules.contains(ff_module) {
69                return Err(PeftError::InvalidConfig(format!(
70                    "feedforward_module '{}' must be in target_modules",
71                    ff_module
72                )));
73            }
74        }
75        Ok(())
76    }
77}
78
79/// IA³ layer implementing learned rescaling vectors.
80///
81/// For non-feedforward modules: `output = base_output * ia3_vector`
82/// For feedforward modules: `output = base_layer(input * ia3_vector)`
83pub struct Ia3Layer {
84    /// The learned scaling vector.
85    /// Shape: [out_features, 1] for non-feedforward, [1, in_features] for feedforward.
86    ia3_l: Tensor,
87    /// Configuration
88    config: Ia3Config,
89    /// Input dimension
90    in_features: usize,
91    /// Output dimension
92    out_features: usize,
93    /// Whether this is a feedforward layer
94    is_feedforward: bool,
95    /// Whether gradients are disabled
96    frozen: bool,
97}
98
99impl Ia3Layer {
100    /// Create a new IA³ layer.
101    ///
102    /// # Arguments
103    /// * `in_features` - Input dimension
104    /// * `out_features` - Output dimension
105    /// * `is_feedforward` - Whether this is a feedforward layer (scales input vs output)
106    /// * `config` - IA³ configuration
107    /// * `device` - Device to create tensors on
108    ///
109    /// # Errors
110    /// Returns error if configuration is invalid or tensor initialization fails.
111    pub fn new(
112        in_features: usize,
113        out_features: usize,
114        is_feedforward: bool,
115        config: Ia3Config,
116        device: &Device,
117    ) -> Result<Self> {
118        config.validate()?;
119
120        // Initialize the scaling vector
121        let ia3_l = if config.init_ia3_weights {
122            // Initialize to ones (identity transform)
123            if is_feedforward {
124                Tensor::ones((1, in_features), DType::F32, device)?
125            } else {
126                Tensor::ones((out_features, 1), DType::F32, device)?
127            }
128        } else {
129            // Random initialization
130            if is_feedforward {
131                Tensor::randn(0.0f32, 0.02, (1, in_features), device)?
132            } else {
133                Tensor::randn(0.0f32, 0.02, (out_features, 1), device)?
134            }
135        };
136
137        Ok(Self {
138            ia3_l,
139            config,
140            in_features,
141            out_features,
142            is_feedforward,
143            frozen: false,
144        })
145    }
146
147    /// Get the scaling vector.
148    #[must_use]
149    pub fn scaling_vector(&self) -> &Tensor {
150        &self.ia3_l
151    }
152
153    /// Check if this is a feedforward layer.
154    #[must_use]
155    pub fn is_feedforward(&self) -> bool {
156        self.is_feedforward
157    }
158
159    /// Apply IA³ scaling to input (for feedforward layers).
160    ///
161    /// # Arguments
162    /// * `input` - Input tensor [batch, seq_len, in_features]
163    ///
164    /// # Returns
165    /// Scaled input tensor
166    ///
167    /// # Errors
168    /// Returns error if called on non-feedforward layer or tensor operations fail.
169    pub fn scale_input(&self, input: &Tensor) -> Result<Tensor> {
170        if !self.is_feedforward {
171            return Err(PeftError::InvalidConfig(
172                "scale_input called on non-feedforward IA³ layer".into(),
173            ));
174        }
175        // ia3_l shape: [1, in_features], input shape: [batch, seq_len, in_features]
176        // Need to reshape for broadcasting
177        let scaling = self.ia3_l.reshape((1, 1, self.in_features))?;
178        Ok(input.broadcast_mul(&scaling)?)
179    }
180
181    /// Apply IA³ scaling to output (for non-feedforward layers).
182    ///
183    /// # Arguments
184    /// * `output` - Output tensor from base layer [batch, seq_len, out_features]
185    ///
186    /// # Returns
187    /// Scaled output tensor
188    ///
189    /// # Errors
190    /// Returns error if called on feedforward layer or tensor operations fail.
191    pub fn scale_output(&self, output: &Tensor) -> Result<Tensor> {
192        if self.is_feedforward {
193            return Err(PeftError::InvalidConfig(
194                "scale_output called on feedforward IA³ layer".into(),
195            ));
196        }
197        // ia3_l shape: [out_features, 1], output shape: [batch, seq_len, out_features]
198        // Reshape to [1, 1, out_features] for broadcasting
199        let scaling = self.ia3_l.reshape((1, 1, self.out_features))?;
200        Ok(output.broadcast_mul(&scaling)?)
201    }
202}
203
204impl Adapter for Ia3Layer {
205    type Config = Ia3Config;
206
207    fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
208        if self.is_feedforward {
209            // For feedforward: scale the input before passing to base layer
210            // The base layer computation should happen externally
211            self.scale_input(input)
212        } else {
213            // For non-feedforward: scale the output from base layer
214            match base_output {
215                Some(output) => self.scale_output(output),
216                None => Err(PeftError::InvalidConfig(
217                    "Non-feedforward IA³ requires base_output".into(),
218                )),
219            }
220        }
221    }
222
223    fn num_parameters(&self) -> usize {
224        if self.is_feedforward {
225            self.in_features
226        } else {
227            self.out_features
228        }
229    }
230
231    fn config(&self) -> &Self::Config {
232        &self.config
233    }
234}
235
236impl Mergeable for Ia3Layer {
237    fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
238        // For IA³, merging means multiplying base weights by the scaling vector
239        // Weight shape: [out_features, in_features]
240        // For feedforward: scale along in_features (column-wise)
241        // For non-feedforward: scale along out_features (row-wise)
242
243        if self.is_feedforward {
244            // ia3_l shape: [1, in_features]
245            // Broadcast multiply: each column scaled by corresponding element
246            Ok(base_weight.broadcast_mul(&self.ia3_l)?)
247        } else {
248            // ia3_l shape: [out_features, 1]
249            // Broadcast multiply: each row scaled by corresponding element
250            Ok(base_weight.broadcast_mul(&self.ia3_l)?)
251        }
252    }
253
254    fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
255        // Unmerging IA³ can be inaccurate due to potential division by values close to zero
256        // Add tolerance to avoid division by zero
257        let tolerance = 1e-8_f32;
258        let tolerance_tensor = Tensor::new(tolerance, self.ia3_l.device())?;
259        let safe_divisor = self.ia3_l.broadcast_add(&tolerance_tensor)?;
260
261        Ok(merged_weight.broadcast_div(&safe_divisor)?)
262    }
263}
264
265impl Trainable for Ia3Layer {
266    fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> Result<()> {
267        // Note: In the current design, tensors are created directly.
268        // For full training support, tensors should be created via VarBuilder
269        // during construction, which automatically registers them.
270        // This is a simplified implementation suitable for inference.
271        Ok(())
272    }
273
274    fn freeze(&mut self) {
275        self.frozen = true;
276    }
277
278    fn unfreeze(&mut self) {
279        self.frozen = false;
280    }
281
282    fn is_frozen(&self) -> bool {
283        self.frozen
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_ia3_config_default() {
293        let config = Ia3Config::default();
294        assert!(!config.target_modules.is_empty());
295        assert!(config.init_ia3_weights);
296        assert!(config.validate().is_ok());
297    }
298
299    #[test]
300    fn test_ia3_config_invalid_feedforward() {
301        let config = Ia3Config {
302            target_modules: vec!["q_proj".into()],
303            feedforward_modules: vec!["not_in_targets".into()],
304            ..Default::default()
305        };
306        assert!(config.validate().is_err());
307    }
308
309    #[test]
310    fn test_ia3_layer_creation_non_feedforward() {
311        let config = Ia3Config::default();
312        let device = Device::Cpu;
313        let layer = Ia3Layer::new(768, 768, false, config, &device);
314        assert!(layer.is_ok());
315
316        let layer = layer.unwrap();
317        assert!(!layer.is_feedforward());
318        // Non-feedforward: scaling vector has shape [out_features, 1]
319        assert_eq!(layer.scaling_vector().dims(), &[768, 1]);
320    }
321
322    #[test]
323    fn test_ia3_layer_creation_feedforward() {
324        let config = Ia3Config::default();
325        let device = Device::Cpu;
326        let layer = Ia3Layer::new(768, 3072, true, config, &device);
327        assert!(layer.is_ok());
328
329        let layer = layer.unwrap();
330        assert!(layer.is_feedforward());
331        // Feedforward: scaling vector has shape [1, in_features]
332        assert_eq!(layer.scaling_vector().dims(), &[1, 768]);
333    }
334
335    #[test]
336    fn test_ia3_num_parameters_non_feedforward() {
337        let config = Ia3Config::default();
338        let device = Device::Cpu;
339        let layer = Ia3Layer::new(768, 512, false, config, &device).unwrap();
340        // Non-feedforward uses out_features
341        assert_eq!(layer.num_parameters(), 512);
342    }
343
344    #[test]
345    fn test_ia3_num_parameters_feedforward() {
346        let config = Ia3Config::default();
347        let device = Device::Cpu;
348        let layer = Ia3Layer::new(768, 3072, true, config, &device).unwrap();
349        // Feedforward uses in_features
350        assert_eq!(layer.num_parameters(), 768);
351    }
352
353    #[test]
354    fn test_ia3_forward_non_feedforward() {
355        let config = Ia3Config::default();
356        let device = Device::Cpu;
357        let layer = Ia3Layer::new(768, 768, false, config, &device).unwrap();
358
359        let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
360        let base_output = Tensor::ones(&[1, 10, 768], DType::F32, &device).unwrap();
361
362        let output = layer.forward(&input, Some(&base_output)).unwrap();
363        assert_eq!(output.shape().dims(), &[1, 10, 768]);
364    }
365
366    #[test]
367    fn test_ia3_forward_feedforward() {
368        let config = Ia3Config::default();
369        let device = Device::Cpu;
370        let layer = Ia3Layer::new(768, 3072, true, config, &device).unwrap();
371
372        let input = Tensor::ones(&[1, 10, 768], DType::F32, &device).unwrap();
373
374        let output = layer.forward(&input, None).unwrap();
375        // For feedforward, output has same shape as input
376        assert_eq!(output.shape().dims(), &[1, 10, 768]);
377    }
378
379    #[test]
380    fn test_ia3_initialized_to_ones() {
381        let config = Ia3Config {
382            init_ia3_weights: true,
383            ..Default::default()
384        };
385        let device = Device::Cpu;
386        let layer = Ia3Layer::new(768, 768, false, config, &device).unwrap();
387
388        // With init_ia3_weights=true, scaling should be all ones
389        // So forward pass should return output unchanged
390        let base_output = Tensor::full(2.0f32, &[1, 10, 768], &device).unwrap();
391        let output = layer
392            .forward(
393                &Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap(),
394                Some(&base_output),
395            )
396            .unwrap();
397
398        // Output should equal base_output (scaled by 1)
399        let output_sum: f32 = output.sum_all().unwrap().to_scalar().unwrap();
400        let expected_sum = 2.0f32 * 1.0 * 10.0 * 768.0;
401        assert!((output_sum - expected_sum).abs() < 1e-3);
402    }
403}