use std::fmt;
use serde::{Deserialize, Serialize};
use crate::requirement::AcceleratorRequirement;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[non_exhaustive]
pub enum TrainingMethod {
#[default]
FullFineTune,
LoRA,
QLoRA {
bits: u8,
},
Prefix,
DPO,
RLHF,
Distillation,
}
impl TrainingMethod {
#[must_use]
#[inline]
pub fn preferred_accelerator(&self) -> AcceleratorRequirement {
match self {
Self::LoRA | Self::QLoRA { .. } => AcceleratorRequirement::Gpu,
Self::FullFineTune | Self::DPO | Self::RLHF | Self::Distillation | Self::Prefix => {
AcceleratorRequirement::GpuOrTpu
}
}
}
}
impl fmt::Display for TrainingMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::FullFineTune => write!(f, "full"),
Self::LoRA => write!(f, "lora"),
Self::QLoRA { bits } => write!(f, "qlora-{}bit", bits),
Self::Prefix => write!(f, "prefix"),
Self::DPO => write!(f, "dpo"),
Self::RLHF => write!(f, "rlhf"),
Self::Distillation => write!(f, "distillation"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[non_exhaustive]
pub enum TrainingTarget {
#[default]
Gpu,
Tpu,
Gaudi,
Cpu,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEstimate {
pub model_gb: f64,
pub optimizer_gb: f64,
pub activation_gb: f64,
pub total_gb: f64,
}
impl fmt::Display for MemoryEstimate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Model: {:.1} GB, Optimizer: {:.1} GB, Activations: {:.1} GB (total: {:.1} GB)",
self.model_gb, self.optimizer_gb, self.activation_gb, self.total_gb
)
}
}
#[must_use]
pub fn estimate_training_memory(
model_params_millions: u64,
method: TrainingMethod,
target: TrainingTarget,
) -> MemoryEstimate {
let base_gb = (model_params_millions as f64
* crate::units::PARAMS_PER_MILLION
* crate::units::FP16_BYTES_PER_PARAM)
/ crate::units::BYTES_PER_GIB;
match target {
TrainingTarget::Tpu => estimate_tpu_training(base_gb, method),
TrainingTarget::Gaudi => estimate_gaudi_training(base_gb, method),
TrainingTarget::Gpu | TrainingTarget::Cpu => estimate_gpu_training(base_gb, method),
}
}
fn estimate_gpu_training(base_gb: f64, method: TrainingMethod) -> MemoryEstimate {
let (model, optimizer, activation) = match method {
TrainingMethod::FullFineTune => (base_gb, base_gb * 2.0, base_gb * 1.0),
TrainingMethod::LoRA => (base_gb, base_gb * 0.1, base_gb * 0.1),
TrainingMethod::QLoRA { bits } => {
let qf = if bits <= 4 { 0.25 } else { 0.5 };
(base_gb * qf, base_gb * 0.1, base_gb * 0.1 * qf)
}
TrainingMethod::Prefix => (base_gb, base_gb * 0.05, base_gb * 0.05),
TrainingMethod::DPO | TrainingMethod::RLHF => {
(base_gb * 2.0, base_gb * 2.0, base_gb * 1.5)
}
TrainingMethod::Distillation => {
(base_gb * 1.5, base_gb * 1.0, base_gb * 0.8)
}
};
MemoryEstimate {
model_gb: model,
optimizer_gb: optimizer,
activation_gb: activation,
total_gb: model + optimizer + activation,
}
}
fn estimate_tpu_training(base_gb: f64, method: TrainingMethod) -> MemoryEstimate {
let (model, optimizer, activation) = match method {
TrainingMethod::FullFineTune => (base_gb, base_gb * 1.5, base_gb * 0.8),
TrainingMethod::LoRA => (base_gb, base_gb * 0.15, base_gb * 0.12),
TrainingMethod::QLoRA { bits } => {
let qf = if bits <= 4 { 0.4 } else { 0.6 };
(base_gb * qf, base_gb * 0.15, base_gb * 0.12 * qf)
}
TrainingMethod::Prefix => (base_gb, base_gb * 0.05, base_gb * 0.05),
TrainingMethod::DPO | TrainingMethod::RLHF => (base_gb * 2.0, base_gb * 1.5, base_gb * 1.2),
TrainingMethod::Distillation => (base_gb * 1.5, base_gb * 0.8, base_gb * 0.6),
};
MemoryEstimate {
model_gb: model,
optimizer_gb: optimizer,
activation_gb: activation,
total_gb: model + optimizer + activation,
}
}
fn estimate_gaudi_training(base_gb: f64, method: TrainingMethod) -> MemoryEstimate {
let (model, optimizer, activation) = match method {
TrainingMethod::FullFineTune => (base_gb, base_gb * 1.5, base_gb * 0.9),
TrainingMethod::LoRA => (base_gb, base_gb * 0.12, base_gb * 0.12),
TrainingMethod::QLoRA { bits } => {
let qf = if bits <= 4 { 0.35 } else { 0.55 };
(base_gb * qf, base_gb * 0.12, base_gb * 0.12 * qf)
}
TrainingMethod::Prefix => (base_gb, base_gb * 0.05, base_gb * 0.06),
TrainingMethod::DPO | TrainingMethod::RLHF => (base_gb * 2.0, base_gb * 1.5, base_gb * 1.3),
TrainingMethod::Distillation => (base_gb * 1.5, base_gb * 0.9, base_gb * 0.7),
};
MemoryEstimate {
model_gb: model,
optimizer_gb: optimizer,
activation_gb: activation,
total_gb: model + optimizer + activation,
}
}