Skip to main content

entrenar/hf_pipeline/config/
student.rs

1//! Student model configuration with LoRA/QLoRA
2
3use crate::lora::LoRAConfig;
4use serde::{Deserialize, Serialize};
5
6use super::teacher::default_revision;
7
8/// Student model configuration with LoRA/QLoRA
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct StudentConfig {
11    /// Model ID (can be same as teacher or smaller model)
12    pub model_id: String,
13    /// Revision/branch
14    #[serde(default = "default_revision")]
15    pub revision: String,
16    /// LoRA configuration (if None, full fine-tuning)
17    pub lora: Option<LoRAYamlConfig>,
18    /// Use 4-bit quantization (QLoRA)
19    #[serde(default)]
20    pub load_in_4bit: bool,
21}
22
23/// LoRA configuration in YAML format
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct LoRAYamlConfig {
26    /// LoRA rank
27    pub rank: usize,
28    /// LoRA alpha (scaling factor)
29    pub alpha: f32,
30    /// Target modules
31    #[serde(default = "default_target_modules")]
32    pub target_modules: Vec<String>,
33    /// Target layers (optional)
34    pub layers: Option<Vec<usize>>,
35}
36
37fn default_target_modules() -> Vec<String> {
38    vec!["q_proj".to_string(), "v_proj".to_string()]
39}
40
41impl From<&LoRAYamlConfig> for LoRAConfig {
42    fn from(yaml: &LoRAYamlConfig) -> Self {
43        let mut config = LoRAConfig::new(yaml.rank, yaml.alpha);
44        let modules: Vec<&str> = yaml.target_modules.iter().map(String::as_str).collect();
45        config = config.target_modules(&modules);
46        if let Some(ref layers) = yaml.layers {
47            config = config.target_layers(layers);
48        }
49        config
50    }
51}