use crate::lora::LoRAConfig;
use serde::{Deserialize, Serialize};
use super::teacher::default_revision;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StudentConfig {
pub model_id: String,
#[serde(default = "default_revision")]
pub revision: String,
pub lora: Option<LoRAYamlConfig>,
#[serde(default)]
pub load_in_4bit: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoRAYamlConfig {
pub rank: usize,
pub alpha: f32,
#[serde(default = "default_target_modules")]
pub target_modules: Vec<String>,
pub layers: Option<Vec<usize>>,
}
fn default_target_modules() -> Vec<String> {
vec!["q_proj".to_string(), "v_proj".to_string()]
}
impl From<&LoRAYamlConfig> for LoRAConfig {
fn from(yaml: &LoRAYamlConfig) -> Self {
let mut config = LoRAConfig::new(yaml.rank, yaml.alpha);
let modules: Vec<&str> = yaml.target_modules.iter().map(String::as_str).collect();
config = config.target_modules(&modules);
if let Some(ref layers) = yaml.layers {
config = config.target_layers(layers);
}
config
}
}