pmetal-trainer 0.3.13

Training loops and optimization for PMetal
docs.rs failed to build pmetal-trainer-0.3.13
Please check the build logs for more information.
See Builds for ideas on how to fix a failed build, or Metadata for how to configure docs.rs builds.
If you believe this is docs.rs' fault, open an issue.

pmetal-trainer

Training loops and optimization strategies for LLM fine-tuning.

Overview

This crate provides the training infrastructure for PMetal, including various training methods, learning rate scheduling, checkpointing, and callback systems.

Training Methods

Method Description Use Case
SFT Supervised Fine-Tuning General instruction tuning
LoRA Low-Rank Adaptation Parameter-efficient fine-tuning
DPO Direct Preference Optimization Preference-based alignment
GRPO Group Relative Policy Optimization Efficient PPO alternative
GSPO Group Sequence Policy Optimization Fixes GRPO length bias
DAPO Decoupled Clip and Dynamic Sampling PO ByteDance's 4 GRPO improvements
PPO Proximal Policy Optimization RLHF with reward model
ORPO Odds Ratio Preference Optimization Reference-free alignment
SimPO Simple Preference Optimization Simplified preference learning
KTO Kahneman-Tversky Optimization Unpaired preference data
Online DPO Online Direct Preference Optimization DPO with online sampling
Distillation Knowledge distillation Teacher→student transfer
ANE Apple Neural Engine training Power-efficient on-device training
RLKD RL with Knowledge Distillation GRPO + teacher distillation
Embedding Sentence-transformer training InfoNCE, Triplet, CoSENT contrastive
Diffusion LLaDA-style diffusion training Experimental

Usage

Basic Training Loop

use pmetal_trainer::{TrainingLoop, TrainingConfig};

let config = TrainingConfig {
    batch_size: 4,
    gradient_accumulation_steps: 4,
    learning_rate: 2e-4,
    epochs: 1,
    max_grad_norm: 1.0,
    ..Default::default()
};

let mut trainer = TrainingLoop::new(model, optimizer, config)?;

// Train with optional callbacks
trainer.train(&dataloader, callbacks)?;

With Checkpointing

use pmetal_trainer::CheckpointManager;

let checkpoint_mgr = CheckpointManager::new("output/checkpoints");

// Resume from checkpoint if available
if let Some(ckpt) = checkpoint_mgr.latest()? {
    trainer.load_checkpoint(&ckpt)?;
}

// Save checkpoints during training
trainer.train_with_checkpoints(&dataloader, &checkpoint_mgr, save_every: 500)?;

Optimizers

Optimizer Description
AdamW Groups AdamW with per-parameter-group learning rates
Adam 8-bit Memory-efficient 8-bit Adam optimizer
Schedule-Free Optimizer without learning rate schedules
Metal Fused GPU-accelerated AdamW parameter updates

Learning Rate Schedulers

Scheduler Description
Constant Fixed learning rate
Linear Linear warmup and decay
Cosine Cosine annealing
Cosine with Restarts Cosine with periodic warm restarts
Polynomial Polynomial decay
WSD Warmup-Stable-Decay schedule

Modules

Module Description
training_loop Main training orchestration
sft Supervised fine-tuning trainer
lora_trainer LoRA-specific training
dpo Direct Preference Optimization
grpo Group Relative Policy Optimization
gspo Group Sequence Policy Optimization
dapo Decoupled Clip and Dynamic Sampling PO
ane_training ANE training loop (feature-gated: ane)
ppo Proximal Policy Optimization
orpo Odds Ratio Preference Optimization
simpo Simple Preference Optimization
kto Kahneman-Tversky Optimization
online_dpo Online DPO with sampling
distillation Knowledge distillation orchestration
rlkd Reinforcement Learning with Knowledge Distillation
embedding_trainer Sentence-transformer fine-tuning
contrastive_loss InfoNCE, Triplet, CoSENT loss functions
diffusion Diffusion-based training
orchestrator Unified training pipeline (shared across CLI/GUI/TUI/easy)
adamw_groups AdamW with parameter groups
adam8bit 8-bit Adam optimizer
schedule_free Schedule-free optimizer
metal_fused Metal-accelerated optimizer
adaptive_lr EMA-based adaptive learning rate control
checkpoint Checkpoint save/load
checkpointing Gradient checkpointing
scheduler Learning rate schedulers
callbacks Training callbacks (MetricsJsonCallback, StepMetrics)
param_groups Per-layer learning rates
distributed_bridge Distributed training sync (feature-gated: distributed)

Configuration

Parameter Description Default
batch_size Micro-batch size 4
gradient_accumulation_steps Accumulation steps 1
learning_rate Initial learning rate 2e-4
max_grad_norm Gradient clipping 1.0
warmup_steps LR warmup steps 0
weight_decay L2 regularization 0.0

License

MIT OR Apache-2.0