entrenar/lora/adapter/
peft_config.rs1use crate::lora::LoRAConfig;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14pub struct PeftAdapterConfig {
15 pub peft_type: String,
17 pub r: usize,
19 pub lora_alpha: f32,
21 pub target_modules: Vec<String>,
23 pub lora_dropout: f32,
25 pub bias: String,
27 #[serde(skip_serializing_if = "Option::is_none")]
29 pub base_model_name_or_path: Option<String>,
30 #[serde(skip_serializing_if = "Option::is_none")]
32 pub task_type: Option<String>,
33 #[serde(default)]
35 pub fan_in_fan_out: bool,
36 #[serde(default)]
38 pub inference_mode: bool,
39}
40
41impl PeftAdapterConfig {
42 pub fn from_lora_config(config: &LoRAConfig, base_model: Option<&str>) -> Self {
44 let mut target_modules: Vec<String> = config.target_modules.iter().cloned().collect();
45 target_modules.sort();
46
47 Self {
48 peft_type: "LORA".to_string(),
49 r: config.rank,
50 lora_alpha: config.alpha,
51 target_modules,
52 lora_dropout: 0.0,
53 bias: "none".to_string(),
54 base_model_name_or_path: base_model.map(String::from),
55 task_type: None,
56 fan_in_fan_out: false,
57 inference_mode: false,
58 }
59 }
60
61 pub fn with_bias(mut self, bias: impl Into<String>) -> Self {
63 self.bias = bias.into();
64 self
65 }
66
67 pub fn with_dropout(mut self, dropout: f32) -> Self {
69 self.lora_dropout = dropout;
70 self
71 }
72
73 pub fn with_task_type(mut self, task_type: impl Into<String>) -> Self {
75 self.task_type = Some(task_type.into());
76 self
77 }
78
79 pub fn with_inference_mode(mut self, inference_mode: bool) -> Self {
81 self.inference_mode = inference_mode;
82 self
83 }
84
85 pub fn to_json(&self) -> Result<String, serde_json::Error> {
87 serde_json::to_string_pretty(self)
88 }
89
90 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
92 serde_json::from_str(json)
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 fn make_test_lora_config() -> LoRAConfig {
101 LoRAConfig::new(16, 32.0).target_attention_projections()
102 }
103
104 #[test]
105 fn test_from_lora_config() {
106 let lora_config = make_test_lora_config();
107 let peft = PeftAdapterConfig::from_lora_config(&lora_config, Some("meta-llama/Llama-2-7b"));
108
109 assert_eq!(peft.peft_type, "LORA");
110 assert_eq!(peft.r, 16);
111 assert_eq!(peft.lora_alpha, 32.0);
112 assert_eq!(peft.target_modules.len(), 4);
113 assert!(peft.target_modules.contains(&"q_proj".to_string()));
114 assert!(peft.target_modules.contains(&"k_proj".to_string()));
115 assert!(peft.target_modules.contains(&"v_proj".to_string()));
116 assert!(peft.target_modules.contains(&"o_proj".to_string()));
117 assert_eq!(peft.bias, "none");
118 assert_eq!(peft.base_model_name_or_path, Some("meta-llama/Llama-2-7b".to_string()));
119 }
120
121 #[test]
122 fn test_from_lora_config_no_base_model() {
123 let lora_config = LoRAConfig::new(8, 8.0).target_qv_projections();
124 let peft = PeftAdapterConfig::from_lora_config(&lora_config, None);
125
126 assert_eq!(peft.r, 8);
127 assert!(peft.base_model_name_or_path.is_none());
128 assert_eq!(peft.target_modules.len(), 2);
129 }
130
131 #[test]
132 fn test_json_roundtrip() {
133 let lora_config = make_test_lora_config();
134 let peft = PeftAdapterConfig::from_lora_config(&lora_config, Some("test/model"));
135
136 let json = peft.to_json().expect("operation should succeed");
137 let deserialized =
138 PeftAdapterConfig::from_json(&json).expect("deserialization should succeed");
139
140 assert_eq!(peft, deserialized);
141 }
142
143 #[test]
144 fn test_json_schema_keys() {
145 let lora_config = make_test_lora_config();
146 let peft = PeftAdapterConfig::from_lora_config(&lora_config, Some("test/model"));
147
148 let json = peft.to_json().expect("operation should succeed");
149
150 assert!(json.contains("\"peft_type\""));
152 assert!(json.contains("\"r\""));
153 assert!(json.contains("\"lora_alpha\""));
154 assert!(json.contains("\"target_modules\""));
155 assert!(json.contains("\"lora_dropout\""));
156 assert!(json.contains("\"bias\""));
157 assert!(json.contains("\"base_model_name_or_path\""));
158 }
159
160 #[test]
161 fn test_json_no_base_model_omitted() {
162 let peft = PeftAdapterConfig::from_lora_config(&LoRAConfig::new(4, 4.0), None);
163 let json = peft.to_json().expect("operation should succeed");
164 assert!(!json.contains("base_model_name_or_path"));
166 }
167
168 #[test]
169 fn test_builder_methods() {
170 let config = LoRAConfig::new(8, 8.0).target_qv_projections();
171 let peft = PeftAdapterConfig::from_lora_config(&config, None)
172 .with_bias("lora_only")
173 .with_dropout(0.1)
174 .with_task_type("CAUSAL_LM")
175 .with_inference_mode(true);
176
177 assert_eq!(peft.bias, "lora_only");
178 assert_eq!(peft.lora_dropout, 0.1);
179 assert_eq!(peft.task_type, Some("CAUSAL_LM".to_string()));
180 assert!(peft.inference_mode);
181 }
182
183 #[test]
184 fn test_target_modules_sorted() {
185 let config = LoRAConfig::new(8, 8.0).target_attention_projections();
186 let peft = PeftAdapterConfig::from_lora_config(&config, None);
187
188 let mut sorted = peft.target_modules.clone();
190 sorted.sort();
191 assert_eq!(peft.target_modules, sorted);
192 }
193
194 #[test]
195 fn test_empty_target_modules() {
196 let config = LoRAConfig::new(8, 8.0);
197 let peft = PeftAdapterConfig::from_lora_config(&config, None);
198
199 assert!(peft.target_modules.is_empty());
200 }
201}