Skip to main content

entrenar/lora/adapter/
peft_config.rs

1//! PEFT-compatible adapter_config.json generation
2//!
3//! Generates adapter configuration files compatible with HuggingFace PEFT library,
4//! enabling direct loading in `transformers` and `peft` Python packages.
5
6use crate::lora::LoRAConfig;
7use serde::{Deserialize, Serialize};
8
9/// PEFT adapter configuration matching the HuggingFace PEFT schema
10///
11/// This struct serializes to `adapter_config.json` format that can be loaded by
12/// `peft.PeftModel.from_pretrained()`.
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14pub struct PeftAdapterConfig {
15    /// PEFT method type (always "LORA" for LoRA adapters)
16    pub peft_type: String,
17    /// LoRA rank
18    pub r: usize,
19    /// LoRA alpha scaling parameter
20    pub lora_alpha: f32,
21    /// Target module names for LoRA adaptation
22    pub target_modules: Vec<String>,
23    /// LoRA dropout rate (0.0 if not used)
24    pub lora_dropout: f32,
25    /// Bias handling: "none", "all", or "lora_only"
26    pub bias: String,
27    /// Base model name or path (optional)
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub base_model_name_or_path: Option<String>,
30    /// Task type (e.g., "CAUSAL_LM", "SEQ_CLS")
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub task_type: Option<String>,
33    /// Fan-in/fan-out setting
34    #[serde(default)]
35    pub fan_in_fan_out: bool,
36    /// Inference mode
37    #[serde(default)]
38    pub inference_mode: bool,
39}
40
41impl PeftAdapterConfig {
42    /// Convert from entrenar's LoRAConfig to PEFT-compatible config
43    pub fn from_lora_config(config: &LoRAConfig, base_model: Option<&str>) -> Self {
44        let mut target_modules: Vec<String> = config.target_modules.iter().cloned().collect();
45        target_modules.sort();
46
47        Self {
48            peft_type: "LORA".to_string(),
49            r: config.rank,
50            lora_alpha: config.alpha,
51            target_modules,
52            lora_dropout: 0.0,
53            bias: "none".to_string(),
54            base_model_name_or_path: base_model.map(String::from),
55            task_type: None,
56            fan_in_fan_out: false,
57            inference_mode: false,
58        }
59    }
60
61    /// Set bias handling mode
62    pub fn with_bias(mut self, bias: impl Into<String>) -> Self {
63        self.bias = bias.into();
64        self
65    }
66
67    /// Set dropout rate
68    pub fn with_dropout(mut self, dropout: f32) -> Self {
69        self.lora_dropout = dropout;
70        self
71    }
72
73    /// Set task type
74    pub fn with_task_type(mut self, task_type: impl Into<String>) -> Self {
75        self.task_type = Some(task_type.into());
76        self
77    }
78
79    /// Set inference mode
80    pub fn with_inference_mode(mut self, inference_mode: bool) -> Self {
81        self.inference_mode = inference_mode;
82        self
83    }
84
85    /// Serialize to JSON string
86    pub fn to_json(&self) -> Result<String, serde_json::Error> {
87        serde_json::to_string_pretty(self)
88    }
89
90    /// Deserialize from JSON string
91    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
92        serde_json::from_str(json)
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    fn make_test_lora_config() -> LoRAConfig {
101        LoRAConfig::new(16, 32.0).target_attention_projections()
102    }
103
104    #[test]
105    fn test_from_lora_config() {
106        let lora_config = make_test_lora_config();
107        let peft = PeftAdapterConfig::from_lora_config(&lora_config, Some("meta-llama/Llama-2-7b"));
108
109        assert_eq!(peft.peft_type, "LORA");
110        assert_eq!(peft.r, 16);
111        assert_eq!(peft.lora_alpha, 32.0);
112        assert_eq!(peft.target_modules.len(), 4);
113        assert!(peft.target_modules.contains(&"q_proj".to_string()));
114        assert!(peft.target_modules.contains(&"k_proj".to_string()));
115        assert!(peft.target_modules.contains(&"v_proj".to_string()));
116        assert!(peft.target_modules.contains(&"o_proj".to_string()));
117        assert_eq!(peft.bias, "none");
118        assert_eq!(peft.base_model_name_or_path, Some("meta-llama/Llama-2-7b".to_string()));
119    }
120
121    #[test]
122    fn test_from_lora_config_no_base_model() {
123        let lora_config = LoRAConfig::new(8, 8.0).target_qv_projections();
124        let peft = PeftAdapterConfig::from_lora_config(&lora_config, None);
125
126        assert_eq!(peft.r, 8);
127        assert!(peft.base_model_name_or_path.is_none());
128        assert_eq!(peft.target_modules.len(), 2);
129    }
130
131    #[test]
132    fn test_json_roundtrip() {
133        let lora_config = make_test_lora_config();
134        let peft = PeftAdapterConfig::from_lora_config(&lora_config, Some("test/model"));
135
136        let json = peft.to_json().expect("operation should succeed");
137        let deserialized =
138            PeftAdapterConfig::from_json(&json).expect("deserialization should succeed");
139
140        assert_eq!(peft, deserialized);
141    }
142
143    #[test]
144    fn test_json_schema_keys() {
145        let lora_config = make_test_lora_config();
146        let peft = PeftAdapterConfig::from_lora_config(&lora_config, Some("test/model"));
147
148        let json = peft.to_json().expect("operation should succeed");
149
150        // Verify expected PEFT schema keys are present
151        assert!(json.contains("\"peft_type\""));
152        assert!(json.contains("\"r\""));
153        assert!(json.contains("\"lora_alpha\""));
154        assert!(json.contains("\"target_modules\""));
155        assert!(json.contains("\"lora_dropout\""));
156        assert!(json.contains("\"bias\""));
157        assert!(json.contains("\"base_model_name_or_path\""));
158    }
159
160    #[test]
161    fn test_json_no_base_model_omitted() {
162        let peft = PeftAdapterConfig::from_lora_config(&LoRAConfig::new(4, 4.0), None);
163        let json = peft.to_json().expect("operation should succeed");
164        // base_model_name_or_path should be omitted when None
165        assert!(!json.contains("base_model_name_or_path"));
166    }
167
168    #[test]
169    fn test_builder_methods() {
170        let config = LoRAConfig::new(8, 8.0).target_qv_projections();
171        let peft = PeftAdapterConfig::from_lora_config(&config, None)
172            .with_bias("lora_only")
173            .with_dropout(0.1)
174            .with_task_type("CAUSAL_LM")
175            .with_inference_mode(true);
176
177        assert_eq!(peft.bias, "lora_only");
178        assert_eq!(peft.lora_dropout, 0.1);
179        assert_eq!(peft.task_type, Some("CAUSAL_LM".to_string()));
180        assert!(peft.inference_mode);
181    }
182
183    #[test]
184    fn test_target_modules_sorted() {
185        let config = LoRAConfig::new(8, 8.0).target_attention_projections();
186        let peft = PeftAdapterConfig::from_lora_config(&config, None);
187
188        // target_modules should be sorted for deterministic output
189        let mut sorted = peft.target_modules.clone();
190        sorted.sort();
191        assert_eq!(peft.target_modules, sorted);
192    }
193
194    #[test]
195    fn test_empty_target_modules() {
196        let config = LoRAConfig::new(8, 8.0);
197        let peft = PeftAdapterConfig::from_lora_config(&config, None);
198
199        assert!(peft.target_modules.is_empty());
200    }
201}