1use std::fmt;
4
5use serde::{Deserialize, Serialize};
6
7use crate::requirement::AcceleratorRequirement;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
11#[non_exhaustive]
12pub enum TrainingMethod {
13 #[default]
14 FullFineTune,
15 LoRA,
16 QLoRA {
17 bits: u8,
18 },
19 Prefix,
20 DPO,
21 RLHF,
22 Distillation,
23}
24
25impl TrainingMethod {
26 #[must_use]
28 #[inline]
29 pub fn preferred_accelerator(&self) -> AcceleratorRequirement {
30 match self {
31 Self::LoRA | Self::QLoRA { .. } => AcceleratorRequirement::Gpu,
33 Self::FullFineTune | Self::DPO | Self::RLHF | Self::Distillation | Self::Prefix => {
35 AcceleratorRequirement::GpuOrTpu
36 }
37 }
38 }
39}
40
41impl fmt::Display for TrainingMethod {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 match self {
44 Self::FullFineTune => write!(f, "full"),
45 Self::LoRA => write!(f, "lora"),
46 Self::QLoRA { bits } => write!(f, "qlora-{}bit", bits),
47 Self::Prefix => write!(f, "prefix"),
48 Self::DPO => write!(f, "dpo"),
49 Self::RLHF => write!(f, "rlhf"),
50 Self::Distillation => write!(f, "distillation"),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
57#[non_exhaustive]
58pub enum TrainingTarget {
59 #[default]
60 Gpu,
61 Tpu,
62 Gaudi,
63 Cpu,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct MemoryEstimate {
69 pub model_gb: f64,
71 pub optimizer_gb: f64,
73 pub activation_gb: f64,
75 pub total_gb: f64,
77}
78
79impl fmt::Display for MemoryEstimate {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 write!(
82 f,
83 "Model: {:.1} GB, Optimizer: {:.1} GB, Activations: {:.1} GB (total: {:.1} GB)",
84 self.model_gb, self.optimizer_gb, self.activation_gb, self.total_gb
85 )
86 }
87}
88
89#[must_use]
110pub fn estimate_training_memory(
111 model_params_millions: u64,
112 method: TrainingMethod,
113 target: TrainingTarget,
114) -> MemoryEstimate {
115 let base_gb = (model_params_millions as f64
116 * crate::units::PARAMS_PER_MILLION
117 * crate::units::FP16_BYTES_PER_PARAM)
118 / crate::units::BYTES_PER_GIB;
119
120 match target {
121 TrainingTarget::Tpu => estimate_tpu_training(base_gb, method),
122 TrainingTarget::Gaudi => estimate_gaudi_training(base_gb, method),
123 TrainingTarget::Gpu | TrainingTarget::Cpu => estimate_gpu_training(base_gb, method),
124 }
125}
126
127fn estimate_gpu_training(base_gb: f64, method: TrainingMethod) -> MemoryEstimate {
128 let (model, optimizer, activation) = match method {
129 TrainingMethod::FullFineTune => (base_gb, base_gb * 2.0, base_gb * 1.0),
130 TrainingMethod::LoRA => (base_gb, base_gb * 0.1, base_gb * 0.1),
131 TrainingMethod::QLoRA { bits } => {
132 let qf = if bits <= 4 { 0.25 } else { 0.5 };
133 (base_gb * qf, base_gb * 0.1, base_gb * 0.1 * qf)
134 }
135 TrainingMethod::Prefix => (base_gb, base_gb * 0.05, base_gb * 0.05),
136 TrainingMethod::DPO | TrainingMethod::RLHF => {
137 (base_gb * 2.0, base_gb * 2.0, base_gb * 1.5)
139 }
140 TrainingMethod::Distillation => {
141 (base_gb * 1.5, base_gb * 1.0, base_gb * 0.8)
143 }
144 };
145 MemoryEstimate {
146 model_gb: model,
147 optimizer_gb: optimizer,
148 activation_gb: activation,
149 total_gb: model + optimizer + activation,
150 }
151}
152
153fn estimate_tpu_training(base_gb: f64, method: TrainingMethod) -> MemoryEstimate {
154 let (model, optimizer, activation) = match method {
156 TrainingMethod::FullFineTune => (base_gb, base_gb * 1.5, base_gb * 0.8),
157 TrainingMethod::LoRA => (base_gb, base_gb * 0.15, base_gb * 0.12),
158 TrainingMethod::QLoRA { bits } => {
159 let qf = if bits <= 4 { 0.4 } else { 0.6 };
160 (base_gb * qf, base_gb * 0.15, base_gb * 0.12 * qf)
161 }
162 TrainingMethod::Prefix => (base_gb, base_gb * 0.05, base_gb * 0.05),
163 TrainingMethod::DPO | TrainingMethod::RLHF => (base_gb * 2.0, base_gb * 1.5, base_gb * 1.2),
164 TrainingMethod::Distillation => (base_gb * 1.5, base_gb * 0.8, base_gb * 0.6),
165 };
166 MemoryEstimate {
167 model_gb: model,
168 optimizer_gb: optimizer,
169 activation_gb: activation,
170 total_gb: model + optimizer + activation,
171 }
172}
173
174fn estimate_gaudi_training(base_gb: f64, method: TrainingMethod) -> MemoryEstimate {
175 let (model, optimizer, activation) = match method {
178 TrainingMethod::FullFineTune => (base_gb, base_gb * 1.5, base_gb * 0.9),
179 TrainingMethod::LoRA => (base_gb, base_gb * 0.12, base_gb * 0.12),
180 TrainingMethod::QLoRA { bits } => {
181 let qf = if bits <= 4 { 0.35 } else { 0.55 };
182 (base_gb * qf, base_gb * 0.12, base_gb * 0.12 * qf)
183 }
184 TrainingMethod::Prefix => (base_gb, base_gb * 0.05, base_gb * 0.06),
185 TrainingMethod::DPO | TrainingMethod::RLHF => (base_gb * 2.0, base_gb * 1.5, base_gb * 1.3),
186 TrainingMethod::Distillation => (base_gb * 1.5, base_gb * 0.9, base_gb * 0.7),
187 };
188 MemoryEstimate {
189 model_gb: model,
190 optimizer_gb: optimizer,
191 activation_gb: activation,
192 total_gb: model + optimizer + activation,
193 }
194}