Skip to main content

Crate rlx_voxtral_tts_train

Crate rlx_voxtral_tts_train 

Source
Expand description

Native RLX training for Voxtral voice cloning.

Phase 1: codec encoder (reconstruction + VQ auxiliary losses). Phase 2: LoRA adapters on the 4B LM (embedding distillation). Export/inject weights into consolidated.safetensors for inference.

Re-exports§

pub use backward_prep::needs_portable_backward_prep;
pub use backward_prep::prepare_backward_for_device;
pub use checkpoint::export_encoder_weights;
pub use checkpoint::export_lora_weights;
pub use checkpoint::inject_weights;
pub use checkpoint::load_encoder_weights;
pub use checkpoint::load_lora_weights;
pub use compile::TrainSession;
pub use compile::backward_cpu_only_from_env;
pub use compile::compile_train_session;
pub use config::EncoderTrainConfig;
pub use config::LoraTrainConfig;
pub use config::TrainProfile;
pub use device::pick_auto_device;
pub use device::resolve_train_device;
pub use encoder_train::EncoderTrainResult;
pub use encoder_train::train_encoder;
pub use lora_train::train_lora;
pub use train_pipeline::TrainAllConfig;
pub use train_pipeline::TrainAllResult;
pub use train_pipeline::default_train_all;
pub use train_pipeline::train_all;
pub use weights::codec_has_encoder;
pub use weights::merge_codec_encoder_overlay;

Modules§

adam
Host-side AdamW optimizer (lifted from rlx-umap).
asr_loss
ASR auxiliary loss — mel proxy + optional Whisper CER.
audio_metrics
Lightweight audio similarity metrics for rig tests.
backward_prep
Lower training backward ops for Metal / MLX / wgpu / Vulkan before compile.
checkpoint
Safetensors export + merge into consolidated.safetensors.
codec_graph
Codec encoder + decoder as rlx_ir::Graph (trainable encoder, frozen decoder).
compile
Session compile with per-graph device selection (GPU forward + CPU backward when needed).
config
Training hyperparameters and LOW_VRAM profile.
dataset
WAV dataset + manifest writer.
device
Training device selection (--device, RLX_DEVICE, auto GPU pick).
discriminator
Multi-resolution STFT discriminators (hinge + feature matching).
distill_dataset
Rotating reference-audio + text prompts for LoRA distillation.
distill_text
Rotating distillation prompts — transcript variants + default pool.
early_stop
Epoch-level early stopping on eval (or train) metrics.
encoder_loss
Combined encoder training loss graph.
encoder_report
Training bench report — timing, loss curves, epoch checkpoints for ablation.
encoder_train
Phase 1 codec encoder training loop.
lm_lora_graph
LoRA adapter graph on the 4B Ministral backbone (Q/K/V/O + FFN gate/up/down).
lora_train
Phase 2 LoRA distillation training loop.
teacher
Frozen teacher hidden states for Phase 2 LoRA distillation.
train_pipeline
Full voice-clone training pipeline (encoder → LoRA → inject).
weights
Named parameter tensors for compiled training graphs.