use super::classification::{
load_multi_label_corpus, load_safety_corpus, ClassificationHead, MultiLabelSafetySample,
SafetySample, TokenizedSample,
};
use crate::autograd::matmul;
use crate::lora::LoRAConfig;
use crate::lora::LoRALayer;
use crate::optim::{clip_grad_norm_refs, AdamW, Optimizer};
use crate::tokenizer::HfTokenizer;
use crate::transformer::Transformer;
use crate::transformer::TransformerConfig;
use crate::Tensor;
use std::path::{Path, PathBuf};
#[cfg(feature = "cuda")]
use crate::autograd::cuda_backward::pre_warm_lora_backward_kernels as pre_warm_backward_cache_kernels;
#[cfg(feature = "cuda")]
use crate::autograd::cuda_forward::{pre_warm_forward_kernels, pre_warm_lora_backward_kernels};
#[cfg(feature = "cuda")]
use crate::autograd::cuda_optim::pre_warm_lora_adamw_kernels;
#[cfg(feature = "cuda")]
use crate::autograd::cuda_training::{cuda_training_available, CudaTrainer};
#[cfg(feature = "cuda")]
use crate::gpu::guard::VramGuard;
#[cfg(feature = "cuda")]
use crate::transformer::{
CudaBlock, CudaBlockScratch, CudaGradWorkspace, CudaLoraGradWorkspace, CudaTransformerBlock,
GpuBlockOptimizerState, GpuLoraOptimizerState,
};
#[cfg(feature = "cuda")]
use std::sync::Arc;
#[cfg(feature = "cuda")]
use trueno_gpu::driver::GpuBuffer;
#[derive(Debug, Clone)]
pub struct ClassifyConfig {
pub num_classes: usize,
pub lora_rank: usize,
pub lora_alpha: f32,
pub learning_rate: f32,
pub epochs: usize,
pub max_seq_len: usize,
pub log_interval: usize,
pub batch_size: usize,
pub accumulation_steps: usize,
pub gradient_clip_norm: Option<f32>,
pub class_weights: Option<Vec<f32>>,
pub quantize_nf4: bool,
}
impl Default for ClassifyConfig {
fn default() -> Self {
Self {
num_classes: 5,
lora_rank: 16,
lora_alpha: 16.0,
learning_rate: 1e-4,
epochs: 3,
max_seq_len: 512,
log_interval: 100,
batch_size: 32,
accumulation_steps: 1,
gradient_clip_norm: Some(1.0),
class_weights: None,
quantize_nf4: false,
}
}
}
#[derive(Debug, Clone)]
pub struct HyperparamDiagnostic {
pub contract_id: &'static str,
pub severity: DiagSeverity,
pub message: String,
pub recommendation: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DiagSeverity {
Info,
Warn,
Error,
}
#[derive(Debug, Clone, Default)]
pub struct HyperparamDiagnostics {
pub items: Vec<HyperparamDiagnostic>,
}
impl HyperparamDiagnostics {
pub fn has_warning(&self, contract_id: &str) -> bool {
self.items.iter().any(|d| {
d.contract_id == contract_id
&& matches!(d.severity, DiagSeverity::Warn | DiagSeverity::Error)
})
}
pub fn has_errors(&self) -> bool {
self.items.iter().any(|d| matches!(d.severity, DiagSeverity::Error))
}
pub fn print_all(&self) {
for d in &self.items {
let prefix = match d.severity {
DiagSeverity::Info => "[HP-INFO]",
DiagSeverity::Warn => "[HP-WARN]",
DiagSeverity::Error => "[HP-ERROR]",
};
eprintln!("{prefix} {}: {} → {}", d.contract_id, d.message, d.recommendation);
}
}
}
pub struct DataStats {
pub p99_token_length: usize,
pub imbalance_ratio: f32,
pub minority_count: usize,
}
impl ClassifyConfig {
pub fn qlora_default(model_params: u64) -> Self {
let learning_rate = if model_params <= 13_000_000_000 { 2e-4 } else { 1e-4 };
let lora_rank = 16;
Self {
num_classes: 2,
lora_rank,
lora_alpha: (2 * lora_rank) as f32,
learning_rate,
epochs: 3,
max_seq_len: 256,
log_interval: 100,
batch_size: 16,
accumulation_steps: 1,
gradient_clip_norm: Some(1.0),
class_weights: None,
quantize_nf4: true,
}
}
pub fn validate_hyperparameters(&self, model_params: u64) -> HyperparamDiagnostics {
let mut diags = HyperparamDiagnostics::default();
if self.quantize_nf4 && model_params <= 13_000_000_000 && self.learning_rate < 1.5e-4 {
diags.items.push(HyperparamDiagnostic {
contract_id: "C-HP-001",
severity: DiagSeverity::Warn,
message: format!(
"lr={:.0e} too low for {}B model (Dettmers 2023: use 2e-4 for ≤13B)",
self.learning_rate,
model_params / 1_000_000_000
),
recommendation: "learning_rate: 0.0002".to_string(),
});
}
let eff_batch = self.batch_size * self.accumulation_steps;
if eff_batch != 16 {
diags.items.push(HyperparamDiagnostic {
contract_id: "C-HP-002",
severity: DiagSeverity::Warn,
message: format!(
"effective_batch={eff_batch} ({}×{}), Dettmers 2023 recommends 16 for ≤13B",
self.batch_size, self.accumulation_steps
),
recommendation: format!(
"batch_size: {}, accumulation_steps: {}",
self.batch_size,
16 / self.batch_size.max(1)
),
});
}
let expected_alpha = 2.0 * self.lora_rank as f32;
if (self.lora_alpha - expected_alpha).abs() > 0.5 {
diags.items.push(HyperparamDiagnostic {
contract_id: "C-HP-003",
severity: DiagSeverity::Warn,
message: format!(
"lora_alpha={} with rank={} (ratio={:.1}), Lightning AI: alpha=2×rank={} optimal",
self.lora_alpha, self.lora_rank,
self.lora_alpha / self.lora_rank as f32,
expected_alpha
),
recommendation: format!("lora_alpha: {expected_alpha}"),
});
}
if self.gradient_clip_norm.is_none() {
diags.items.push(HyperparamDiagnostic {
contract_id: "C-HP-006",
severity: DiagSeverity::Warn,
message: "No gradient clipping — SSC v2.2 saw grad norms up to 115.1".to_string(),
recommendation: "gradient_clip_norm: 1.0".to_string(),
});
}
if self.learning_rate <= 0.0 {
diags.items.push(HyperparamDiagnostic {
contract_id: "C-HP-001",
severity: DiagSeverity::Error,
message: "learning_rate must be > 0".to_string(),
recommendation: "learning_rate: 0.0002".to_string(),
});
}
if self.batch_size == 0 {
diags.items.push(HyperparamDiagnostic {
contract_id: "C-HP-002",
severity: DiagSeverity::Error,
message: "batch_size must be > 0".to_string(),
recommendation: "batch_size: 4".to_string(),
});
}
diags
}
pub fn validate_with_data(&self, stats: &DataStats) -> HyperparamDiagnostics {
let mut diags = HyperparamDiagnostics::default();
if self.max_seq_len > 2 * stats.p99_token_length && stats.p99_token_length > 0 {
diags.items.push(HyperparamDiagnostic {
contract_id: "C-HP-004",
severity: DiagSeverity::Warn,
message: format!(
"max_seq_len={} but p99(tokens)={} — attention is O(n²), wasting {:.0}× compute",
self.max_seq_len,
stats.p99_token_length,
(self.max_seq_len as f64 / stats.p99_token_length as f64).powi(2)
),
recommendation: format!(
"max_seq_len: {} (next_pow2 of p99)",
stats.p99_token_length.next_power_of_two()
),
});
}
if stats.imbalance_ratio > 5.0 && self.epochs < 2 {
let eff_batch = self.batch_size * self.accumulation_steps;
let updates_per_epoch = stats.minority_count / eff_batch.max(1);
diags.items.push(HyperparamDiagnostic {
contract_id: "C-HP-008",
severity: DiagSeverity::Warn,
message: format!(
"epochs={} with {:.1}:1 imbalance — minority gets only {} gradient updates",
self.epochs,
stats.imbalance_ratio,
updates_per_epoch * self.epochs
),
recommendation: format!(
"epochs: 3 (minority gets {} updates)",
updates_per_epoch * 3
),
});
}
diags
}
}
#[derive(Debug, Clone)]
pub struct BatchResult {
pub avg_loss: f32,
pub correct: usize,
pub total: usize,
pub grad_norm: f32,
}
impl BatchResult {
#[must_use]
pub fn accuracy(&self) -> f32 {
self.correct as f32 / self.total.max(1) as f32
}
}
#[cfg(feature = "cuda")]
struct GpuTrainingState {
layer_inputs: Vec<GpuBuffer<f32>>,
final_norm_weight: GpuBuffer<f32>,
blocks_output: GpuBuffer<f32>,
grad_buf_a: GpuBuffer<f32>,
grad_buf_b: GpuBuffer<f32>,
grad_final_norm_weight: GpuBuffer<f32>,
optimizer_states: Vec<GpuBlockOptimizerState>,
step: u32,
output_scratch: GpuBuffer<f32>,
grad_upload_buf: GpuBuffer<f32>,
fwd_scratch_a: GpuBuffer<f32>,
fwd_scratch_b: GpuBuffer<f32>,
backward_cpu_staging: Vec<f32>,
}
pub struct ClassifyPipeline {
pub model: Transformer,
pub classifier: ClassificationHead,
pub lora_layers: Vec<LoRALayer>,
pub config: ClassifyConfig,
optimizer: AdamW,
tokenizer: Option<HfTokenizer>,
model_dir: Option<PathBuf>,
#[cfg(feature = "cuda")]
cuda_trainer: Option<CudaTrainer>,
#[cfg(feature = "cuda")]
cuda_blocks: Option<Vec<CudaBlock>>,
#[cfg(feature = "cuda")]
shared_scratch: Option<CudaBlockScratch>,
#[cfg(feature = "cuda")]
cuda_nan_count: usize,
#[cfg(feature = "cuda")]
gpu_training: Option<GpuTrainingState>,
#[cfg(feature = "cuda")]
cuda_grad_workspace: Option<CudaGradWorkspace>,
#[cfg(feature = "cuda")]
cuda_lora_grad_workspace: Option<CudaLoraGradWorkspace>,
#[cfg(feature = "cuda")]
cuda_lora_optimizer_states: Option<Vec<GpuLoraOptimizerState>>,
#[cfg(feature = "cuda")]
cuda_lora_grad_accum: Option<Vec<CudaLoraGradWorkspace>>,
#[cfg(feature = "cuda")]
nf4_lora_step: u32,
#[cfg(feature = "gpu")]
wgpu_forward_pass: Option<crate::transformer::WgpuForwardPass>,
#[cfg(feature = "cuda")]
#[allow(dead_code)]
vram_guard: Option<VramGuard>,
}
impl ClassifyPipeline {
pub fn new(model_config: &TransformerConfig, classify_config: ClassifyConfig) -> Self {
let model = Transformer::new(model_config);
let classifier =
ClassificationHead::new(model_config.hidden_size, classify_config.num_classes);
let mut lora_layers = Self::build_lora_layers(&model, model_config, &classify_config);
for lora in &mut lora_layers {
for param in lora.trainable_params() {
param.set_requires_grad(true);
}
}
let optimizer = AdamW::default_params(classify_config.learning_rate);
#[cfg(feature = "cuda")]
let (cuda_trainer, cuda_blocks, shared_scratch, vram_guard) =
Self::try_init_cuda(&model, model_config, &classify_config, &lora_layers);
#[cfg(feature = "cuda")]
let gpu_training = Self::try_init_gpu_training(
&model,
model_config,
classify_config.max_seq_len,
cuda_trainer.as_ref(),
cuda_blocks.as_ref(),
);
#[cfg(feature = "cuda")]
let cuda_grad_workspace = if classify_config.quantize_nf4 {
None
} else {
cuda_trainer.as_ref().and_then(|t| {
CudaGradWorkspace::new(t.context(), model_config)
.map_err(|e| eprintln!("[CUDA] Failed to allocate grad workspace: {e}"))
.ok()
})
};
#[cfg(feature = "cuda")]
let (cuda_lora_grad_workspace, cuda_lora_optimizer_states, cuda_lora_grad_accum) =
if classify_config.quantize_nf4 {
Self::try_init_nf4_lora_training(
cuda_trainer.as_ref(),
cuda_blocks.as_ref(),
model_config,
&classify_config,
)
} else {
(None, None, None)
};
#[cfg(feature = "gpu")]
let wgpu_forward_pass = {
#[cfg(feature = "cuda")]
let has_cuda = cuda_trainer.is_some();
#[cfg(not(feature = "cuda"))]
let has_cuda = false;
if has_cuda {
None } else {
match crate::transformer::WgpuForwardPass::with_resident_weights(&model) {
Ok(pass) => {
eprintln!("[wgpu] GPU forward pass initialized (resident weights)");
Some(pass)
}
Err(e) => {
eprintln!("[wgpu] GPU resident init failed, trying default: {e}");
match crate::transformer::WgpuForwardPass::new_default(model_config) {
Ok(pass) => {
eprintln!("[wgpu] GPU forward pass initialized (upload per call)");
Some(pass)
}
Err(e2) => {
eprintln!("[wgpu] GPU initialization failed, using CPU: {e2}");
None
}
}
}
}
}
};
Self {
model,
classifier,
lora_layers,
config: classify_config,
optimizer,
tokenizer: None,
model_dir: None,
#[cfg(feature = "cuda")]
cuda_trainer,
#[cfg(feature = "cuda")]
cuda_blocks,
#[cfg(feature = "cuda")]
shared_scratch,
#[cfg(feature = "cuda")]
cuda_nan_count: 0,
#[cfg(feature = "cuda")]
gpu_training,
#[cfg(feature = "cuda")]
cuda_grad_workspace,
#[cfg(feature = "cuda")]
cuda_lora_grad_workspace,
#[cfg(feature = "cuda")]
cuda_lora_optimizer_states,
#[cfg(feature = "cuda")]
cuda_lora_grad_accum,
#[cfg(feature = "cuda")]
nf4_lora_step: 0,
#[cfg(feature = "gpu")]
wgpu_forward_pass,
#[cfg(feature = "cuda")]
vram_guard,
}
}
pub fn from_pretrained(
model_dir: impl AsRef<Path>,
model_config: &TransformerConfig,
classify_config: ClassifyConfig,
) -> crate::Result<Self> {
let model_dir = model_dir.as_ref();
let model = Transformer::from_safetensors(model_dir, model_config)?;
let classifier =
ClassificationHead::new(model_config.hidden_size, classify_config.num_classes);
let mut lora_layers = Self::build_lora_layers(&model, model_config, &classify_config);
for lora in &mut lora_layers {
for param in lora.trainable_params() {
param.set_requires_grad(true);
}
}
let tokenizer_path = model_dir.join("tokenizer.json");
let tokenizer = if tokenizer_path.exists() {
Some(
HfTokenizer::from_file(&tokenizer_path)
.map_err(|e| crate::Error::Io(format!("Failed to load tokenizer: {e}")))?,
)
} else {
return Err(crate::Error::ConfigError(format!(
"No tokenizer.json found in '{}'. Training requires a BPE tokenizer.",
model_dir.display(),
)));
};
let optimizer = AdamW::default_params(classify_config.learning_rate);
#[cfg(feature = "cuda")]
let (cuda_trainer, cuda_blocks, shared_scratch, vram_guard) =
Self::try_init_cuda(&model, model_config, &classify_config, &lora_layers);
#[cfg(feature = "cuda")]
let gpu_training = Self::try_init_gpu_training(
&model,
model_config,
classify_config.max_seq_len,
cuda_trainer.as_ref(),
cuda_blocks.as_ref(),
);
#[cfg(feature = "cuda")]
let cuda_grad_workspace = if classify_config.quantize_nf4 {
None } else {
cuda_trainer.as_ref().and_then(|t| {
CudaGradWorkspace::new(t.context(), model_config)
.map_err(|e| eprintln!("[CUDA] Failed to allocate grad workspace: {e}"))
.ok()
})
};
#[cfg(feature = "cuda")]
let (cuda_lora_grad_workspace, cuda_lora_optimizer_states, cuda_lora_grad_accum) =
if classify_config.quantize_nf4 {
Self::try_init_nf4_lora_training(
cuda_trainer.as_ref(),
cuda_blocks.as_ref(),
model_config,
&classify_config,
)
} else {
(None, None, None)
};
#[cfg(feature = "gpu")]
let wgpu_forward_pass = {
#[cfg(feature = "cuda")]
let has_cuda = cuda_trainer.is_some();
#[cfg(not(feature = "cuda"))]
let has_cuda = false;
if has_cuda {
None
} else {
match crate::transformer::WgpuForwardPass::with_resident_weights(&model) {
Ok(pass) => {
eprintln!(
"[wgpu] Batched forward pass initialized ({} layers, resident weights)",
model_config.num_hidden_layers
);
Some(pass)
}
Err(e) => {
eprintln!("[wgpu] Resident init failed, trying default: {e}");
match crate::transformer::WgpuForwardPass::new_default(model_config) {
Ok(pass) => {
eprintln!("[wgpu] Batched forward pass initialized ({} layers, upload per call)", model_config.num_hidden_layers);
Some(pass)
}
Err(e2) => {
eprintln!("[wgpu] GPU init failed, using CPU: {e2}");
None
}
}
}
}
}
};
Ok(Self {
model,
classifier,
lora_layers,
config: classify_config,
optimizer,
tokenizer,
model_dir: Some(model_dir.to_path_buf()),
#[cfg(feature = "cuda")]
cuda_trainer,
#[cfg(feature = "cuda")]
cuda_blocks,
#[cfg(feature = "cuda")]
shared_scratch,
#[cfg(feature = "cuda")]
cuda_nan_count: 0,
#[cfg(feature = "cuda")]
gpu_training,
#[cfg(feature = "cuda")]
cuda_grad_workspace,
#[cfg(feature = "cuda")]
cuda_lora_grad_workspace,
#[cfg(feature = "cuda")]
cuda_lora_optimizer_states,
#[cfg(feature = "cuda")]
cuda_lora_grad_accum,
#[cfg(feature = "cuda")]
nf4_lora_step: 0,
#[cfg(feature = "gpu")]
wgpu_forward_pass,
#[cfg(feature = "cuda")]
vram_guard,
})
}
pub fn from_apr(
apr_path: &Path,
model_config: &TransformerConfig,
classify_config: ClassifyConfig,
) -> crate::Result<Self> {
let model = Transformer::from_apr(apr_path, model_config)?;
let classifier =
ClassificationHead::new(model_config.hidden_size, classify_config.num_classes);
let mut lora_layers = Self::build_lora_layers(&model, model_config, &classify_config);
for lora in &mut lora_layers {
for param in lora.trainable_params() {
param.set_requires_grad(true);
}
}
let tokenizer = {
let sibling = apr_path.file_stem().and_then(|stem| {
apr_path
.parent()
.map(|p| p.join(format!("{}.tokenizer.json", stem.to_str().unwrap_or(""))))
});
match sibling {
Some(ref path) if path.exists() => {
let tok = HfTokenizer::from_file(path).map_err(|e| {
crate::Error::ConfigError(format!(
"Failed to load tokenizer from '{}': {e}. \
Training requires a BPE tokenizer.",
path.display(),
))
})?;
Some(tok)
}
_ => {
return Err(crate::Error::ConfigError(format!(
"No sibling tokenizer found for '{}'. Expected \
'{}.tokenizer.json' next to the .apr file. Training \
requires a BPE tokenizer.",
apr_path.display(),
apr_path.file_stem().unwrap_or_default().to_str().unwrap_or(""),
)));
}
}
};
let optimizer = AdamW::default_params(classify_config.learning_rate);
#[cfg(feature = "cuda")]
let (cuda_trainer, cuda_blocks, shared_scratch, vram_guard) =
Self::try_init_cuda(&model, model_config, &classify_config, &lora_layers);
#[cfg(feature = "cuda")]
let gpu_training = Self::try_init_gpu_training(
&model,
model_config,
classify_config.max_seq_len,
cuda_trainer.as_ref(),
cuda_blocks.as_ref(),
);
#[cfg(feature = "cuda")]
let cuda_grad_workspace = if classify_config.quantize_nf4 {
None
} else {
cuda_trainer.as_ref().and_then(|t| {
CudaGradWorkspace::new(t.context(), model_config)
.map_err(|e| eprintln!("[CUDA] Failed to allocate grad workspace: {e}"))
.ok()
})
};
#[cfg(feature = "cuda")]
let (cuda_lora_grad_workspace, cuda_lora_optimizer_states, cuda_lora_grad_accum) =
if classify_config.quantize_nf4 {
Self::try_init_nf4_lora_training(
cuda_trainer.as_ref(),
cuda_blocks.as_ref(),
model_config,
&classify_config,
)
} else {
(None, None, None)
};
#[cfg(feature = "gpu")]
let wgpu_forward_pass = {
#[cfg(feature = "cuda")]
let has_cuda = cuda_trainer.is_some();
#[cfg(not(feature = "cuda"))]
let has_cuda = false;
if has_cuda {
None
} else {
crate::transformer::WgpuForwardPass::with_resident_weights(&model)
.or_else(|e| {
eprintln!("[wgpu] Resident init failed: {e}, trying default");
crate::transformer::WgpuForwardPass::new_default(model_config)
})
.map_err(|e| eprintln!("[wgpu] GPU init failed: {e}"))
.ok()
}
};
Ok(Self {
model,
classifier,
lora_layers,
config: classify_config,
optimizer,
tokenizer,
model_dir: Some(apr_path.to_path_buf()),
#[cfg(feature = "cuda")]
cuda_trainer,
#[cfg(feature = "cuda")]
cuda_blocks,
#[cfg(feature = "cuda")]
shared_scratch,
#[cfg(feature = "cuda")]
cuda_nan_count: 0,
#[cfg(feature = "cuda")]
gpu_training,
#[cfg(feature = "cuda")]
cuda_grad_workspace,
#[cfg(feature = "cuda")]
cuda_lora_grad_workspace,
#[cfg(feature = "cuda")]
cuda_lora_optimizer_states,
#[cfg(feature = "cuda")]
cuda_lora_grad_accum,
#[cfg(feature = "cuda")]
nf4_lora_step: 0,
#[cfg(feature = "gpu")]
wgpu_forward_pass,
#[cfg(feature = "cuda")]
vram_guard,
})
}
pub(crate) fn tokenize(&self, text: &str) -> Vec<u32> {
let mut ids = match self.tokenizer.as_ref() {
Some(tok) => tok.encode(text),
None => {
text.bytes().map(u32::from).collect()
}
};
ids.truncate(self.config.max_seq_len);
if ids.is_empty() {
ids.push(0);
}
ids
}
pub fn pre_tokenize(&self, samples: &[SafetySample]) -> Vec<TokenizedSample> {
let has_tokenizer = self.tokenizer.is_some();
samples
.iter()
.map(|s| {
let token_ids = if has_tokenizer {
self.tokenize(&s.input)
} else {
let mut ids = s.input_ids();
ids.truncate(self.config.max_seq_len);
if ids.is_empty() {
ids.push(0);
}
ids
};
TokenizedSample { token_ids, label: s.label }
})
.collect()
}
pub fn train_batch_tokenized(&mut self, samples: &[TokenizedSample]) -> BatchResult {
if samples.is_empty() {
return BatchResult { avg_loss: 0.0, correct: 0, total: 0, grad_norm: 0.0 };
}
let batch_size = samples.len();
self.zero_all_gradients();
#[cfg(feature = "gpu")]
let (total_loss, correct) = self
.try_train_batch_wgpu_tokenized(samples)
.unwrap_or_else(|| self.train_batch_per_sample_tokenized(samples));
#[cfg(not(feature = "gpu"))]
let (total_loss, correct) = self.train_batch_per_sample_tokenized(samples);
self.scale_all_gradients(1.0 / batch_size as f32);
let grad_norm = if let Some(max_norm) = self.config.gradient_clip_norm {
let mut params = self.trainable_parameters_mut();
clip_grad_norm_refs(&mut params, max_norm)
} else {
self.compute_grad_norm()
};
#[cfg(feature = "cuda")]
{
if self.gpu_training.is_some() && !self.config.quantize_nf4 {
let lr = self.optimizer.lr();
self.gpu_optimizer_step(lr);
}
}
#[cfg(feature = "cuda")]
{
if self.gpu_training.is_some() && self.config.quantize_nf4 {
self.nf4_lora_batch_optimizer_step(batch_size);
}
}
let mut params: Vec<&mut Tensor> = Vec::new();
if !self.config.quantize_nf4 {
for lora in &mut self.lora_layers {
params.extend(lora.trainable_params());
}
}
params.extend(self.classifier.parameters_mut());
self.optimizer.step_refs(&mut params);
BatchResult {
avg_loss: total_loss / batch_size as f32,
correct,
total: batch_size,
grad_norm,
}
}
fn train_batch_per_sample_tokenized(&mut self, samples: &[TokenizedSample]) -> (f32, usize) {
let mut total_loss = 0.0f32;
let mut correct = 0usize;
for sample in samples {
let (loss, predicted) = self.forward_backward_single(&sample.token_ids, sample.label);
total_loss += loss;
if predicted == sample.label {
correct += 1;
}
}
(total_loss, correct)
}
#[cfg(feature = "gpu")]
fn try_train_batch_wgpu_tokenized(
&mut self,
samples: &[TokenizedSample],
) -> Option<(f32, usize)> {
self.wgpu_forward_pass.as_ref()?;
let batch_token_ids: Vec<Vec<u32>> = samples.iter().map(|s| s.token_ids.clone()).collect();
let lora_ref =
if self.lora_layers.is_empty() { None } else { Some(self.lora_layers.as_slice()) };
let hiddens = self
.wgpu_forward_pass
.as_ref()
.expect("checked is_none above")
.forward_hidden_batch(&self.model, &batch_token_ids, lora_ref)
.map_err(|e| {
eprintln!("[wgpu] Batched forward failed, falling back to per-sample: {e}");
})
.ok()?;
let mut total_loss = 0.0f32;
let mut correct = 0usize;
for (i, hidden) in hiddens.iter().enumerate() {
let (loss, predicted) = self.classify_backward_from_hidden(
hidden,
batch_token_ids[i].len(),
samples[i].label,
);
total_loss += loss;
if predicted == samples[i].label {
correct += 1;
}
}
Some((total_loss, correct))
}
pub fn accumulate_gradients_tokenized(
&mut self,
micro_batch: &[TokenizedSample],
) -> BatchResult {
if micro_batch.is_empty() {
return BatchResult { avg_loss: 0.0, correct: 0, total: 0, grad_norm: 0.0 };
}
let mut total_loss = 0.0f32;
let mut correct = 0usize;
for sample in micro_batch {
let (loss, predicted) = self.forward_backward_single(&sample.token_ids, sample.label);
total_loss += loss;
if predicted == sample.label {
correct += 1;
}
}
BatchResult {
avg_loss: total_loss / micro_batch.len() as f32,
correct,
total: micro_batch.len(),
grad_norm: 0.0,
}
}
pub fn forward_only_tokenized(&mut self, token_ids: &[u32], label: usize) -> (f32, usize) {
self.forward_only(token_ids, label)
}
}
include!("gpu.rs");
include!("training.rs");
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests;