use crate::{Device, Dtype, Quantization};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub model_id: String,
#[serde(default)]
pub dtype: Dtype,
#[serde(default)]
pub quantization: Quantization,
#[serde(default)]
pub device: Device,
#[serde(default = "default_max_seq_len")]
pub max_seq_len: usize,
#[serde(default = "default_true")]
pub use_flash_attention: bool,
#[serde(default)]
pub trust_remote_code: bool,
#[serde(default)]
pub revision: Option<String>,
#[serde(default)]
pub hf_token: Option<String>,
}
impl Default for ModelConfig {
fn default() -> Self {
Self {
model_id: String::new(),
dtype: Dtype::default(),
quantization: Quantization::default(),
device: Device::default(),
max_seq_len: default_max_seq_len(),
use_flash_attention: true,
trust_remote_code: false,
revision: None,
hf_token: None,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LoraBias {
#[default]
None,
All,
LoraOnly,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoraConfig {
#[serde(default = "default_lora_r")]
pub r: usize,
#[serde(default = "default_lora_alpha")]
pub alpha: f32,
#[serde(default)]
pub dropout: f32,
#[serde(default = "default_target_modules")]
pub target_modules: Vec<String>,
#[serde(default)]
pub use_rslora: bool,
#[serde(default)]
pub use_dora: bool,
#[serde(default)]
pub bias: LoraBias,
#[serde(default = "default_true")]
pub init_lora_weights: bool,
}
impl Default for LoraConfig {
fn default() -> Self {
Self {
r: default_lora_r(),
alpha: default_lora_alpha(),
dropout: 0.0,
target_modules: default_target_modules(),
use_rslora: false,
use_dora: false,
bias: LoraBias::default(),
init_lora_weights: true,
}
}
}
impl LoraConfig {
#[must_use]
pub fn scaling(&self) -> f32 {
if self.use_rslora {
self.alpha / (self.r as f32).sqrt()
} else {
self.alpha / self.r as f32
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
#[serde(default = "default_lr")]
pub learning_rate: f64,
#[serde(default)]
pub embedding_learning_rate: Option<f64>,
#[serde(default = "default_batch_size")]
pub batch_size: usize,
#[serde(default = "default_gradient_accumulation_steps")]
pub gradient_accumulation_steps: usize,
#[serde(default = "default_epochs")]
pub num_epochs: usize,
#[serde(default)]
pub max_steps: Option<usize>,
#[serde(default = "default_warmup")]
pub warmup_steps: usize,
#[serde(default)]
pub warmup_ratio: Option<f64>,
#[serde(default = "default_weight_decay")]
pub weight_decay: f64,
#[serde(default = "default_grad_clip")]
pub max_grad_norm: f64,
#[serde(default)]
pub lr_scheduler: LrSchedulerType,
#[serde(default)]
pub gradient_checkpointing: CheckpointStrategy,
#[serde(default)]
pub optimizer: OptimizerType,
#[serde(default = "default_seed")]
pub seed: u64,
#[serde(default = "default_logging_steps")]
pub logging_steps: usize,
#[serde(default)]
pub eval_steps: Option<usize>,
#[serde(default)]
pub save_steps: Option<usize>,
#[serde(default = "default_output_dir")]
pub output_dir: String,
#[serde(default = "default_true")]
pub use_packing: bool,
#[serde(default = "default_max_seq_len")]
pub max_seq_len: usize,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
learning_rate: default_lr(),
embedding_learning_rate: None,
batch_size: default_batch_size(),
gradient_accumulation_steps: default_gradient_accumulation_steps(),
num_epochs: default_epochs(),
max_steps: None,
warmup_steps: default_warmup(),
warmup_ratio: None,
weight_decay: default_weight_decay(),
max_grad_norm: default_grad_clip(),
lr_scheduler: LrSchedulerType::default(),
gradient_checkpointing: CheckpointStrategy::default(),
optimizer: OptimizerType::default(),
seed: default_seed(),
logging_steps: default_logging_steps(),
eval_steps: None,
save_steps: None,
output_dir: default_output_dir(),
use_packing: true,
max_seq_len: default_max_seq_len(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum LrSchedulerType {
Constant,
Linear,
#[default]
Cosine,
CosineWithRestarts,
Polynomial,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum CheckpointStrategy {
#[default]
None,
EveryN(usize),
Smart,
SelectiveAttention,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum OptimizerType {
#[default]
AdamW,
Sgd,
Adafactor,
Lion,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetConfig {
pub dataset_id: String,
#[serde(default = "default_split")]
pub split: String,
#[serde(default = "default_text_column")]
pub text_column: String,
#[serde(default)]
pub max_samples: Option<usize>,
#[serde(default = "default_true")]
pub shuffle: bool,
#[serde(default = "default_seed")]
pub seed: u64,
}
impl Default for DatasetConfig {
fn default() -> Self {
Self {
dataset_id: String::new(),
split: default_split(),
text_column: default_text_column(),
max_samples: None,
shuffle: true,
seed: default_seed(),
}
}
}
fn default_max_seq_len() -> usize {
8192
}
fn default_true() -> bool {
true
}
fn default_lora_r() -> usize {
16
}
fn default_lora_alpha() -> f32 {
32.0
}
fn default_target_modules() -> Vec<String> {
vec![
"q_proj".into(),
"k_proj".into(),
"v_proj".into(),
"o_proj".into(),
]
}
fn default_lr() -> f64 {
2e-4
}
fn default_batch_size() -> usize {
1
}
fn default_gradient_accumulation_steps() -> usize {
4
}
fn default_epochs() -> usize {
3
}
fn default_warmup() -> usize {
100
}
fn default_weight_decay() -> f64 {
0.01
}
fn default_grad_clip() -> f64 {
1.0
}
fn default_seed() -> u64 {
42
}
fn default_logging_steps() -> usize {
10
}
fn default_output_dir() -> String {
"./output".into()
}
fn default_split() -> String {
"train".into()
}
fn default_text_column() -> String {
"text".into()
}