entrenar/hf_pipeline/config/
student.rs1use crate::lora::LoRAConfig;
4use serde::{Deserialize, Serialize};
5
6use super::teacher::default_revision;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct StudentConfig {
11 pub model_id: String,
13 #[serde(default = "default_revision")]
15 pub revision: String,
16 pub lora: Option<LoRAYamlConfig>,
18 #[serde(default)]
20 pub load_in_4bit: bool,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct LoRAYamlConfig {
26 pub rank: usize,
28 pub alpha: f32,
30 #[serde(default = "default_target_modules")]
32 pub target_modules: Vec<String>,
33 pub layers: Option<Vec<usize>>,
35}
36
37fn default_target_modules() -> Vec<String> {
38 vec!["q_proj".to_string(), "v_proj".to_string()]
39}
40
41impl From<&LoRAYamlConfig> for LoRAConfig {
42 fn from(yaml: &LoRAYamlConfig) -> Self {
43 let mut config = LoRAConfig::new(yaml.rank, yaml.alpha);
44 let modules: Vec<&str> = yaml.target_modules.iter().map(String::as_str).collect();
45 config = config.target_modules(&modules);
46 if let Some(ref layers) = yaml.layers {
47 config = config.target_layers(layers);
48 }
49 config
50 }
51}