Skip to main content

TransformerTrainConfig

Struct TransformerTrainConfig 

Source
pub struct TransformerTrainConfig {
Show 23 fields pub base: TrainConfig, pub model_config: TransformerConfig, pub checkpoint_config: CheckpointConfig, pub precision_config: MixedPrecisionConfig, pub max_seq_len: usize, pub accumulation_steps: usize, pub warmup_steps: usize, pub lr: f32, pub max_steps: Option<usize>, pub use_cuda: bool, pub beta1: f32, pub beta2: f32, pub weight_decay: f32, pub distributed: Option<DistributedTrainConfig>, pub deterministic: bool, pub seed: u64, pub profile_interval: usize, pub lora_rank: Option<usize>, pub lora_alpha: Option<f32>, pub lora_target_modules: Option<Vec<String>>, pub lora_plus_ratio: f32, pub double_quantize: bool, pub quantize_nf4: bool,
}
Expand description

Configuration for transformer training

Fields§

§base: TrainConfig

Base training configuration

§model_config: TransformerConfig

Transformer architecture configuration

§checkpoint_config: CheckpointConfig

Checkpoint configuration for memory efficiency

§precision_config: MixedPrecisionConfig

Mixed-precision configuration

§max_seq_len: usize

Maximum sequence length

§accumulation_steps: usize

Accumulation steps for gradient accumulation

§warmup_steps: usize

Warmup steps for learning rate scheduler

§lr: f32

Learning rate

§max_steps: Option<usize>

Maximum training steps (stop after this many optimizer steps)

§use_cuda: bool

Use CUDA GPU training when available (default: true = auto-detect)

§beta1: f32

AdamW beta1 (default: 0.9)

§beta2: f32

AdamW beta2 (default: 0.999)

§weight_decay: f32

AdamW weight decay (default: 0.01)

§distributed: Option<DistributedTrainConfig>

Distributed training configuration (None = single-GPU)

§deterministic: bool

Enable bitwise deterministic training (CUBLAS_WORKSPACE_CONFIG, cuDNN deterministic) Contract: C-DETERM-001

§seed: u64

Random seed for reproducibility

§profile_interval: usize

KAIZEN-047: Step profiler report interval (0 = disabled, N = print every N steps)

§lora_rank: Option<usize>

LoRA rank (None = full fine-tuning, Some(r) = LoRA with rank r)

§lora_alpha: Option<f32>

LoRA alpha scaling factor (default: 2 * rank)

§lora_target_modules: Option<Vec<String>>

LoRA target modules (e.g., q_proj, v_proj)

§lora_plus_ratio: f32

LoRA+ ratio: LR multiplier for B matrices (ENT-LoRA-006) Default 1.0 = standard LoRA. 16.0 = LoRA+ (Hayou et al. ICML 2024)

§double_quantize: bool

Double quantization for QLoRA (ENT-LoRA-008) Quantizes FP32 absmax constants to 8-bit, saving ~0.37 bits/param

§quantize_nf4: bool

Quantize frozen base weights to NF4 (4-bit) for QLoRA pretraining (ENT-263)

When enabled with LoRA, uses CudaNf4TransformerBlock instead of CudaTransformerBlock, achieving ~8x VRAM compression on frozen weights. Only LoRA adapters and norm weights remain trainable in fp32.

Implementations§

Source§

impl TransformerTrainConfig

Source

pub fn new(model_config: TransformerConfig) -> Self

Create new config with defaults

Source

pub fn with_checkpointing(self, num_segments: usize) -> Self

Enable gradient checkpointing

Source

pub fn with_bf16(self) -> Self

Enable bf16 mixed precision

Source

pub fn with_fp16(self) -> Self

Enable fp16 mixed precision with dynamic loss scaling

Source

pub fn with_max_seq_len(self, len: usize) -> Self

Set maximum sequence length

Source

pub fn with_accumulation_steps(self, steps: usize) -> Self

Set gradient accumulation steps

Source

pub fn with_warmup_steps(self, steps: usize) -> Self

Set warmup steps

Source

pub fn with_lr(self, lr: f32) -> Self

Set learning rate

Source

pub fn with_grad_clip(self, clip: f32) -> Self

Set gradient clipping

Source

pub fn with_max_steps(self, steps: usize) -> Self

Set maximum training steps

Source

pub fn with_use_cuda(self, use_cuda: bool) -> Self

Enable or disable CUDA GPU training (default: true = auto-detect)

Source

pub fn with_beta2(self, beta2: f32) -> Self

Set AdamW beta2 (default: 0.999)

Source

pub fn with_weight_decay(self, wd: f32) -> Self

Set AdamW weight decay (default: 0.01)

Source

pub fn with_deterministic(self, deterministic: bool) -> Self

Enable bitwise deterministic training (C-DETERM-001)

Sets CUBLAS_WORKSPACE_CONFIG, cuDNN deterministic mode, and disables cuDNN benchmark. May reduce throughput but guarantees reproducibility.

Source

pub fn with_seed(self, seed: u64) -> Self

Set random seed for reproducibility

Source

pub fn apply_deterministic_settings(&self)

Apply deterministic settings to the CUDA environment.

Must be called before any cuBLAS/cuDNN operations. Uses ReproducibilityConfig from finetune infrastructure.

§Contract (C-DETERM-001)

After calling this, CUBLAS_WORKSPACE_CONFIG=:4096:8 and CUDNN_DETERMINISTIC=1 are guaranteed set in the process environment.

Source

pub fn with_profile_interval(self, interval: usize) -> Self

Set step profiler report interval (0 = disabled, N = print every N steps)

Source

pub fn with_lora( self, rank: usize, alpha: f32, target_modules: Vec<String>, ) -> Self

Enable LoRA fine-tuning with rank, alpha, and target modules

When LoRA is enabled, only LoRA adapter weights (A, B matrices) and layer norms are trainable. Base model weights are frozen.

§Contract (ENT-LoRA-001)
  • Base weights frozen (requires_grad=false)
  • Only LoRA A/B + norms are optimizer targets
  • scale = alpha / rank
Source

pub fn with_lora_plus_ratio(self, ratio: f32) -> Self

Set LoRA+ ratio (ENT-LoRA-006)

LR multiplier for B matrices. Default 1.0 = standard LoRA. 16.0 = LoRA+ (Hayou et al. ICML 2024) — B learns 16x faster than A.

Source

pub fn with_double_quantize(self, enabled: bool) -> Self

Enable double quantization for QLoRA (ENT-LoRA-008)

Source

pub fn with_quantize_nf4(self, enabled: bool) -> Self

Enable NF4 quantization for QLoRA pretraining (ENT-263)

When enabled with LoRA, frozen base weights are quantized to 4-bit NF4, achieving ~8x VRAM compression. Only LoRA adapters and norm weights are trainable. Requires lora_rank to be set.

Source

pub fn is_nf4(&self) -> bool

Check if NF4 quantization is enabled for QLoRA

Source

pub fn is_lora(&self) -> bool

Check if LoRA fine-tuning is enabled

Source

pub fn with_distributed(self, config: DistributedTrainConfig) -> Self

Enable distributed training with the given configuration

Source

pub fn is_distributed(&self) -> bool

Check if distributed training is enabled

Source

pub fn world_size(&self) -> usize

Get world size (1 for single-GPU)

Source

pub fn rank(&self) -> usize

Get this worker’s rank (0 for single-GPU)

Trait Implementations§

Source§

impl Clone for TransformerTrainConfig

Source§

fn clone(&self) -> TransformerTrainConfig

Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for TransformerTrainConfig

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Allocation for T
where T: RefUnwindSafe + Send + Sync,

Source§

impl<T> Allocation for T
where T: RefUnwindSafe + Send + Sync,

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> Conv for T

Source§

fn conv<T>(self) -> T
where Self: Into<T>,

Converts self into T using Into<T>. Read more
Source§

impl<T> Downcast<T> for T

Source§

fn downcast(&self) -> &T

Source§

impl<T> FmtForward for T

Source§

fn fmt_binary(self) -> FmtBinary<Self>
where Self: Binary,

Causes self to use its Binary implementation when Debug-formatted.
Source§

fn fmt_display(self) -> FmtDisplay<Self>
where Self: Display,

Causes self to use its Display implementation when Debug-formatted.
Source§

fn fmt_lower_exp(self) -> FmtLowerExp<Self>
where Self: LowerExp,

Causes self to use its LowerExp implementation when Debug-formatted.
Source§

fn fmt_lower_hex(self) -> FmtLowerHex<Self>
where Self: LowerHex,

Causes self to use its LowerHex implementation when Debug-formatted.
Source§

fn fmt_octal(self) -> FmtOctal<Self>
where Self: Octal,

Causes self to use its Octal implementation when Debug-formatted.
Source§

fn fmt_pointer(self) -> FmtPointer<Self>
where Self: Pointer,

Causes self to use its Pointer implementation when Debug-formatted.
Source§

fn fmt_upper_exp(self) -> FmtUpperExp<Self>
where Self: UpperExp,

Causes self to use its UpperExp implementation when Debug-formatted.
Source§

fn fmt_upper_hex(self) -> FmtUpperHex<Self>
where Self: UpperHex,

Causes self to use its UpperHex implementation when Debug-formatted.
Source§

fn fmt_list(self) -> FmtList<Self>
where &'a Self: for<'a> IntoIterator,

Formats each item in a sequence. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> FromRef<T> for T
where T: Clone,

Source§

fn from_ref(input: &T) -> T

Converts to this type from a reference to the input type.
Source§

impl<T> FromRef<T> for T
where T: Clone,

Source§

fn from_ref(input: &T) -> T

Converts to this type from a reference to the input type.
Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pipe for T
where T: ?Sized,

Source§

fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> R
where Self: Sized,

Pipes by value. This is generally the method you want to use. Read more
Source§

fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> R
where R: 'a,

Borrows self and passes that borrow into the pipe function. Read more
Source§

fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> R
where R: 'a,

Mutably borrows self and passes that borrow into the pipe function. Read more
Source§

fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
where Self: Borrow<B>, B: 'a + ?Sized, R: 'a,

Borrows self, then passes self.borrow() into the pipe function. Read more
Source§

fn pipe_borrow_mut<'a, B, R>( &'a mut self, func: impl FnOnce(&'a mut B) -> R, ) -> R
where Self: BorrowMut<B>, B: 'a + ?Sized, R: 'a,

Mutably borrows self, then passes self.borrow_mut() into the pipe function. Read more
Source§

fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
where Self: AsRef<U>, U: 'a + ?Sized, R: 'a,

Borrows self, then passes self.as_ref() into the pipe function.
Source§

fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
where Self: AsMut<U>, U: 'a + ?Sized, R: 'a,

Mutably borrows self, then passes self.as_mut() into the pipe function.
Source§

fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
where Self: Deref<Target = T>, T: 'a + ?Sized, R: 'a,

Borrows self, then passes self.deref() into the pipe function.
Source§

fn pipe_deref_mut<'a, T, R>( &'a mut self, func: impl FnOnce(&'a mut T) -> R, ) -> R
where Self: DerefMut<Target = T> + Deref, T: 'a + ?Sized, R: 'a,

Mutably borrows self, then passes self.deref_mut() into the pipe function.
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> PolicyExt for T
where T: ?Sized,

Source§

fn and<P, B, E>(self, other: P) -> And<T, P>
where T: Sized + Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow only if self and other return Action::Follow. Read more
Source§

fn or<P, B, E>(self, other: P) -> Or<T, P>
where T: Sized + Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow if either self or other returns Action::Follow. Read more
Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<T> Tap for T

Source§

fn tap(self, func: impl FnOnce(&Self)) -> Self

Immutable access to a value. Read more
Source§

fn tap_mut(self, func: impl FnOnce(&mut Self)) -> Self

Mutable access to a value. Read more
Source§

fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
where Self: Borrow<B>, B: ?Sized,

Immutable access to the Borrow<B> of a value. Read more
Source§

fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
where Self: BorrowMut<B>, B: ?Sized,

Mutable access to the BorrowMut<B> of a value. Read more
Source§

fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
where Self: AsRef<R>, R: ?Sized,

Immutable access to the AsRef<R> view of a value. Read more
Source§

fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
where Self: AsMut<R>, R: ?Sized,

Mutable access to the AsMut<R> view of a value. Read more
Source§

fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
where Self: Deref<Target = T>, T: ?Sized,

Immutable access to the Deref::Target of a value. Read more
Source§

fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
where Self: DerefMut<Target = T> + Deref, T: ?Sized,

Mutable access to the Deref::Target of a value. Read more
Source§

fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self

Calls .tap() only in debug builds, and is erased in release builds.
Source§

fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self

Calls .tap_mut() only in debug builds, and is erased in release builds.
Source§

fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
where Self: Borrow<B>, B: ?Sized,

Calls .tap_borrow() only in debug builds, and is erased in release builds.
Source§

fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
where Self: BorrowMut<B>, B: ?Sized,

Calls .tap_borrow_mut() only in debug builds, and is erased in release builds.
Source§

fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
where Self: AsRef<R>, R: ?Sized,

Calls .tap_ref() only in debug builds, and is erased in release builds.
Source§

fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
where Self: AsMut<R>, R: ?Sized,

Calls .tap_ref_mut() only in debug builds, and is erased in release builds.
Source§

fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
where Self: Deref<Target = T>, T: ?Sized,

Calls .tap_deref() only in debug builds, and is erased in release builds.
Source§

fn tap_deref_mut_dbg<T>(self, func: impl FnOnce(&mut T)) -> Self
where Self: DerefMut<Target = T> + Deref, T: ?Sized,

Calls .tap_deref_mut() only in debug builds, and is erased in release builds.
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T> TryConv for T

Source§

fn try_conv<T>(self) -> Result<T, Self::Error>
where Self: TryInto<T>,

Attempts to convert self into T using TryInto<T>. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<T> Upcast<T> for T

Source§

fn upcast(&self) -> Option<&T>

Source§

impl<S, T> Upcast<T> for S
where T: UpcastFrom<S> + ?Sized, S: ?Sized,

Source§

fn upcast(&self) -> &T
where Self: ErasableGeneric, T: Sized + ErasableGeneric<Repr = Self::Repr>,

Perform a zero-cost type-safe upcast to a wider ref type within the Wasm bindgen generics type system. Read more
Source§

fn upcast_into(self) -> T
where Self: Sized + ErasableGeneric, T: Sized + ErasableGeneric<Repr = Self::Repr>,

Perform a zero-cost type-safe upcast to a wider type within the Wasm bindgen generics type system. Read more
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WasmNotSend for T
where T: Send,

Source§

impl<T> WasmNotSendSync for T

Source§

impl<T> WasmNotSync for T
where T: Sync,

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more