Skip to main content

ai_hwaccel/
training.rs

1//! Training method types and memory estimation.
2
3use std::fmt;
4
5use serde::{Deserialize, Serialize};
6
7use crate::requirement::AcceleratorRequirement;
8
9/// Fine-tuning / training method.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
11#[non_exhaustive]
12pub enum TrainingMethod {
13    #[default]
14    FullFineTune,
15    LoRA,
16    QLoRA {
17        bits: u8,
18    },
19    Prefix,
20    DPO,
21    RLHF,
22    Distillation,
23}
24
25impl TrainingMethod {
26    /// Preferred accelerator requirement for this training method.
27    #[must_use]
28    #[inline]
29    pub fn preferred_accelerator(&self) -> AcceleratorRequirement {
30        match self {
31            // LoRA/QLoRA benefit from custom CUDA kernels
32            Self::LoRA | Self::QLoRA { .. } => AcceleratorRequirement::Gpu,
33            // Full fine-tune, DPO, RLHF, distillation work well on GPU or TPU
34            Self::FullFineTune | Self::DPO | Self::RLHF | Self::Distillation | Self::Prefix => {
35                AcceleratorRequirement::GpuOrTpu
36            }
37        }
38    }
39}
40
41impl fmt::Display for TrainingMethod {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        match self {
44            Self::FullFineTune => write!(f, "full"),
45            Self::LoRA => write!(f, "lora"),
46            Self::QLoRA { bits } => write!(f, "qlora-{}bit", bits),
47            Self::Prefix => write!(f, "prefix"),
48            Self::DPO => write!(f, "dpo"),
49            Self::RLHF => write!(f, "rlhf"),
50            Self::Distillation => write!(f, "distillation"),
51        }
52    }
53}
54
55/// Target accelerator family for training memory estimation.
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
57#[non_exhaustive]
58pub enum TrainingTarget {
59    #[default]
60    Gpu,
61    Tpu,
62    Gaudi,
63    Cpu,
64}
65
66/// Estimated device memory breakdown for a training/fine-tuning job.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct MemoryEstimate {
69    /// Model weights in GB.
70    pub model_gb: f64,
71    /// Optimizer states in GB.
72    pub optimizer_gb: f64,
73    /// Activations / KV cache in GB.
74    pub activation_gb: f64,
75    /// Total device memory needed in GB.
76    pub total_gb: f64,
77}
78
79impl fmt::Display for MemoryEstimate {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        write!(
82            f,
83            "Model: {:.1} GB, Optimizer: {:.1} GB, Activations: {:.1} GB (total: {:.1} GB)",
84            self.model_gb, self.optimizer_gb, self.activation_gb, self.total_gb
85        )
86    }
87}
88
89/// Estimate device memory for a fine-tuning job on a specific accelerator family.
90///
91/// This is a heuristic approximation assuming:
92/// - FP16/BF16 model weights (2 bytes per parameter)
93/// - Adam optimizer states (2x model weights for full fine-tune)
94/// - Batch size 1, moderate sequence length
95///
96/// Real memory usage depends on batch size, sequence length, gradient
97/// checkpointing, and precision mixing. Use these estimates as lower-bound
98/// guidance, not exact predictions.
99///
100/// # Examples
101///
102/// ```rust
103/// use ai_hwaccel::{estimate_training_memory, TrainingMethod, TrainingTarget};
104///
105/// let est = estimate_training_memory(7000, TrainingMethod::LoRA, TrainingTarget::Gpu);
106/// assert!(est.total_gb > 0.0);
107/// assert!((est.model_gb + est.optimizer_gb + est.activation_gb - est.total_gb).abs() < 0.001);
108/// ```
109#[must_use]
110pub fn estimate_training_memory(
111    model_params_millions: u64,
112    method: TrainingMethod,
113    target: TrainingTarget,
114) -> MemoryEstimate {
115    let base_gb = (model_params_millions as f64
116        * crate::units::PARAMS_PER_MILLION
117        * crate::units::FP16_BYTES_PER_PARAM)
118        / crate::units::BYTES_PER_GIB;
119
120    match target {
121        TrainingTarget::Tpu => estimate_tpu_training(base_gb, method),
122        TrainingTarget::Gaudi => estimate_gaudi_training(base_gb, method),
123        TrainingTarget::Gpu | TrainingTarget::Cpu => estimate_gpu_training(base_gb, method),
124    }
125}
126
127fn estimate_gpu_training(base_gb: f64, method: TrainingMethod) -> MemoryEstimate {
128    let (model, optimizer, activation) = match method {
129        TrainingMethod::FullFineTune => (base_gb, base_gb * 2.0, base_gb * 1.0),
130        TrainingMethod::LoRA => (base_gb, base_gb * 0.1, base_gb * 0.1),
131        TrainingMethod::QLoRA { bits } => {
132            let qf = if bits <= 4 { 0.25 } else { 0.5 };
133            (base_gb * qf, base_gb * 0.1, base_gb * 0.1 * qf)
134        }
135        TrainingMethod::Prefix => (base_gb, base_gb * 0.05, base_gb * 0.05),
136        TrainingMethod::DPO | TrainingMethod::RLHF => {
137            // DPO/RLHF: two model copies + optimizer
138            (base_gb * 2.0, base_gb * 2.0, base_gb * 1.5)
139        }
140        TrainingMethod::Distillation => {
141            // Teacher + student
142            (base_gb * 1.5, base_gb * 1.0, base_gb * 0.8)
143        }
144    };
145    MemoryEstimate {
146        model_gb: model,
147        optimizer_gb: optimizer,
148        activation_gb: activation,
149        total_gb: model + optimizer + activation,
150    }
151}
152
153fn estimate_tpu_training(base_gb: f64, method: TrainingMethod) -> MemoryEstimate {
154    // TPU: BF16 native, XLA fuses activations, BF16 optimizer states (1.5x not 2x)
155    let (model, optimizer, activation) = match method {
156        TrainingMethod::FullFineTune => (base_gb, base_gb * 1.5, base_gb * 0.8),
157        TrainingMethod::LoRA => (base_gb, base_gb * 0.15, base_gb * 0.12),
158        TrainingMethod::QLoRA { bits } => {
159            let qf = if bits <= 4 { 0.4 } else { 0.6 };
160            (base_gb * qf, base_gb * 0.15, base_gb * 0.12 * qf)
161        }
162        TrainingMethod::Prefix => (base_gb, base_gb * 0.05, base_gb * 0.05),
163        TrainingMethod::DPO | TrainingMethod::RLHF => (base_gb * 2.0, base_gb * 1.5, base_gb * 1.2),
164        TrainingMethod::Distillation => (base_gb * 1.5, base_gb * 0.8, base_gb * 0.6),
165    };
166    MemoryEstimate {
167        model_gb: model,
168        optimizer_gb: optimizer,
169        activation_gb: activation,
170        total_gb: model + optimizer + activation,
171    }
172}
173
174fn estimate_gaudi_training(base_gb: f64, method: TrainingMethod) -> MemoryEstimate {
175    // Gaudi: BF16 native like TPU, but with different memory controller.
176    // Similar to TPU estimates but with slightly higher activation overhead.
177    let (model, optimizer, activation) = match method {
178        TrainingMethod::FullFineTune => (base_gb, base_gb * 1.5, base_gb * 0.9),
179        TrainingMethod::LoRA => (base_gb, base_gb * 0.12, base_gb * 0.12),
180        TrainingMethod::QLoRA { bits } => {
181            let qf = if bits <= 4 { 0.35 } else { 0.55 };
182            (base_gb * qf, base_gb * 0.12, base_gb * 0.12 * qf)
183        }
184        TrainingMethod::Prefix => (base_gb, base_gb * 0.05, base_gb * 0.06),
185        TrainingMethod::DPO | TrainingMethod::RLHF => (base_gb * 2.0, base_gb * 1.5, base_gb * 1.3),
186        TrainingMethod::Distillation => (base_gb * 1.5, base_gb * 0.9, base_gb * 0.7),
187    };
188    MemoryEstimate {
189        model_gb: model,
190        optimizer_gb: optimizer,
191        activation_gb: activation,
192        total_gb: model + optimizer + activation,
193    }
194}