pub struct TransformerTrainConfig {Show 23 fields
pub base: TrainConfig,
pub model_config: TransformerConfig,
pub checkpoint_config: CheckpointConfig,
pub precision_config: MixedPrecisionConfig,
pub max_seq_len: usize,
pub accumulation_steps: usize,
pub warmup_steps: usize,
pub lr: f32,
pub max_steps: Option<usize>,
pub use_cuda: bool,
pub beta1: f32,
pub beta2: f32,
pub weight_decay: f32,
pub distributed: Option<DistributedTrainConfig>,
pub deterministic: bool,
pub seed: u64,
pub profile_interval: usize,
pub lora_rank: Option<usize>,
pub lora_alpha: Option<f32>,
pub lora_target_modules: Option<Vec<String>>,
pub lora_plus_ratio: f32,
pub double_quantize: bool,
pub quantize_nf4: bool,
}Expand description
Configuration for transformer training
Fields§
§base: TrainConfigBase training configuration
model_config: TransformerConfigTransformer architecture configuration
checkpoint_config: CheckpointConfigCheckpoint configuration for memory efficiency
precision_config: MixedPrecisionConfigMixed-precision configuration
max_seq_len: usizeMaximum sequence length
accumulation_steps: usizeAccumulation steps for gradient accumulation
warmup_steps: usizeWarmup steps for learning rate scheduler
lr: f32Learning rate
max_steps: Option<usize>Maximum training steps (stop after this many optimizer steps)
use_cuda: boolUse CUDA GPU training when available (default: true = auto-detect)
beta1: f32AdamW beta1 (default: 0.9)
beta2: f32AdamW beta2 (default: 0.999)
weight_decay: f32AdamW weight decay (default: 0.01)
distributed: Option<DistributedTrainConfig>Distributed training configuration (None = single-GPU)
deterministic: boolEnable bitwise deterministic training (CUBLAS_WORKSPACE_CONFIG, cuDNN deterministic) Contract: C-DETERM-001
seed: u64Random seed for reproducibility
profile_interval: usizeKAIZEN-047: Step profiler report interval (0 = disabled, N = print every N steps)
lora_rank: Option<usize>LoRA rank (None = full fine-tuning, Some(r) = LoRA with rank r)
lora_alpha: Option<f32>LoRA alpha scaling factor (default: 2 * rank)
lora_target_modules: Option<Vec<String>>LoRA target modules (e.g., q_proj, v_proj)
lora_plus_ratio: f32LoRA+ ratio: LR multiplier for B matrices (ENT-LoRA-006) Default 1.0 = standard LoRA. 16.0 = LoRA+ (Hayou et al. ICML 2024)
double_quantize: boolDouble quantization for QLoRA (ENT-LoRA-008) Quantizes FP32 absmax constants to 8-bit, saving ~0.37 bits/param
quantize_nf4: boolQuantize frozen base weights to NF4 (4-bit) for QLoRA pretraining (ENT-263)
When enabled with LoRA, uses CudaNf4TransformerBlock instead of
CudaTransformerBlock, achieving ~8x VRAM compression on frozen weights.
Only LoRA adapters and norm weights remain trainable in fp32.
Implementations§
Source§impl TransformerTrainConfig
impl TransformerTrainConfig
Sourcepub fn new(model_config: TransformerConfig) -> Self
pub fn new(model_config: TransformerConfig) -> Self
Create new config with defaults
Sourcepub fn with_checkpointing(self, num_segments: usize) -> Self
pub fn with_checkpointing(self, num_segments: usize) -> Self
Enable gradient checkpointing
Sourcepub fn with_max_seq_len(self, len: usize) -> Self
pub fn with_max_seq_len(self, len: usize) -> Self
Set maximum sequence length
Sourcepub fn with_accumulation_steps(self, steps: usize) -> Self
pub fn with_accumulation_steps(self, steps: usize) -> Self
Set gradient accumulation steps
Sourcepub fn with_warmup_steps(self, steps: usize) -> Self
pub fn with_warmup_steps(self, steps: usize) -> Self
Set warmup steps
Sourcepub fn with_grad_clip(self, clip: f32) -> Self
pub fn with_grad_clip(self, clip: f32) -> Self
Set gradient clipping
Sourcepub fn with_max_steps(self, steps: usize) -> Self
pub fn with_max_steps(self, steps: usize) -> Self
Set maximum training steps
Sourcepub fn with_use_cuda(self, use_cuda: bool) -> Self
pub fn with_use_cuda(self, use_cuda: bool) -> Self
Enable or disable CUDA GPU training (default: true = auto-detect)
Sourcepub fn with_beta2(self, beta2: f32) -> Self
pub fn with_beta2(self, beta2: f32) -> Self
Set AdamW beta2 (default: 0.999)
Sourcepub fn with_weight_decay(self, wd: f32) -> Self
pub fn with_weight_decay(self, wd: f32) -> Self
Set AdamW weight decay (default: 0.01)
Sourcepub fn with_deterministic(self, deterministic: bool) -> Self
pub fn with_deterministic(self, deterministic: bool) -> Self
Enable bitwise deterministic training (C-DETERM-001)
Sets CUBLAS_WORKSPACE_CONFIG, cuDNN deterministic mode, and disables cuDNN benchmark. May reduce throughput but guarantees reproducibility.
Sourcepub fn apply_deterministic_settings(&self)
pub fn apply_deterministic_settings(&self)
Apply deterministic settings to the CUDA environment.
Must be called before any cuBLAS/cuDNN operations.
Uses ReproducibilityConfig from finetune infrastructure.
§Contract (C-DETERM-001)
After calling this, CUBLAS_WORKSPACE_CONFIG=:4096:8 and
CUDNN_DETERMINISTIC=1 are guaranteed set in the process environment.
Sourcepub fn with_profile_interval(self, interval: usize) -> Self
pub fn with_profile_interval(self, interval: usize) -> Self
Set step profiler report interval (0 = disabled, N = print every N steps)
Sourcepub fn with_lora(
self,
rank: usize,
alpha: f32,
target_modules: Vec<String>,
) -> Self
pub fn with_lora( self, rank: usize, alpha: f32, target_modules: Vec<String>, ) -> Self
Enable LoRA fine-tuning with rank, alpha, and target modules
When LoRA is enabled, only LoRA adapter weights (A, B matrices) and layer norms are trainable. Base model weights are frozen.
§Contract (ENT-LoRA-001)
- Base weights frozen (requires_grad=false)
- Only LoRA A/B + norms are optimizer targets
- scale = alpha / rank
Sourcepub fn with_lora_plus_ratio(self, ratio: f32) -> Self
pub fn with_lora_plus_ratio(self, ratio: f32) -> Self
Set LoRA+ ratio (ENT-LoRA-006)
LR multiplier for B matrices. Default 1.0 = standard LoRA. 16.0 = LoRA+ (Hayou et al. ICML 2024) — B learns 16x faster than A.
Sourcepub fn with_double_quantize(self, enabled: bool) -> Self
pub fn with_double_quantize(self, enabled: bool) -> Self
Enable double quantization for QLoRA (ENT-LoRA-008)
Sourcepub fn with_quantize_nf4(self, enabled: bool) -> Self
pub fn with_quantize_nf4(self, enabled: bool) -> Self
Enable NF4 quantization for QLoRA pretraining (ENT-263)
When enabled with LoRA, frozen base weights are quantized to 4-bit NF4,
achieving ~8x VRAM compression. Only LoRA adapters and norm weights are
trainable. Requires lora_rank to be set.
Sourcepub fn with_distributed(self, config: DistributedTrainConfig) -> Self
pub fn with_distributed(self, config: DistributedTrainConfig) -> Self
Enable distributed training with the given configuration
Sourcepub fn is_distributed(&self) -> bool
pub fn is_distributed(&self) -> bool
Check if distributed training is enabled
Sourcepub fn world_size(&self) -> usize
pub fn world_size(&self) -> usize
Get world size (1 for single-GPU)
Trait Implementations§
Source§impl Clone for TransformerTrainConfig
impl Clone for TransformerTrainConfig
Source§fn clone(&self) -> TransformerTrainConfig
fn clone(&self) -> TransformerTrainConfig
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreAuto Trait Implementations§
impl Freeze for TransformerTrainConfig
impl RefUnwindSafe for TransformerTrainConfig
impl Send for TransformerTrainConfig
impl Sync for TransformerTrainConfig
impl Unpin for TransformerTrainConfig
impl UnsafeUnpin for TransformerTrainConfig
impl UnwindSafe for TransformerTrainConfig
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> FmtForward for T
impl<T> FmtForward for T
Source§fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
self to use its Binary implementation when Debug-formatted.Source§fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
self to use its Display implementation when
Debug-formatted.Source§fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
self to use its LowerExp implementation when
Debug-formatted.Source§fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
self to use its LowerHex implementation when
Debug-formatted.Source§fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
self to use its Octal implementation when Debug-formatted.Source§fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
self to use its Pointer implementation when
Debug-formatted.Source§fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
self to use its UpperExp implementation when
Debug-formatted.Source§fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
self to use its UpperHex implementation when
Debug-formatted.Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pipe for Twhere
T: ?Sized,
impl<T> Pipe for Twhere
T: ?Sized,
Source§fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
Source§fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
Source§fn pipe_borrow_mut<'a, B, R>(
&'a mut self,
func: impl FnOnce(&'a mut B) -> R,
) -> R
fn pipe_borrow_mut<'a, B, R>( &'a mut self, func: impl FnOnce(&'a mut B) -> R, ) -> R
Source§fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
self, then passes self.as_ref() into the pipe function.Source§fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
self, then passes self.as_mut() into the pipe
function.Source§fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
self, then passes self.deref() into the pipe function.Source§impl<T> Pointable for T
impl<T> Pointable for T
Source§impl<T> PolicyExt for Twhere
T: ?Sized,
impl<T> PolicyExt for Twhere
T: ?Sized,
Source§impl<T> Tap for T
impl<T> Tap for T
Source§fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
Borrow<B> of a value. Read moreSource§fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
BorrowMut<B> of a value. Read moreSource§fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
AsRef<R> view of a value. Read moreSource§fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
AsMut<R> view of a value. Read moreSource§fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
.tap() only in debug builds, and is erased in release builds.Source§fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
.tap_mut() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
.tap_borrow() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
.tap_borrow_mut() only in debug builds, and is erased in release
builds.Source§fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
.tap_ref() only in debug builds, and is erased in release
builds.Source§fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
.tap_ref_mut() only in debug builds, and is erased in release
builds.Source§fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
.tap_deref() only in debug builds, and is erased in release
builds.