1use super::classification::{
18 load_multi_label_corpus, load_safety_corpus, ClassificationHead, MultiLabelSafetySample,
19 SafetySample, TokenizedSample,
20};
21use crate::autograd::matmul;
22use crate::lora::LoRAConfig;
23use crate::lora::LoRALayer;
24use crate::optim::{clip_grad_norm_refs, AdamW, Optimizer};
25use crate::tokenizer::HfTokenizer;
26use crate::transformer::Transformer;
27use crate::transformer::TransformerConfig;
28use crate::Tensor;
29use std::path::{Path, PathBuf};
30
31#[cfg(feature = "cuda")]
32use crate::autograd::cuda_backward::pre_warm_lora_backward_kernels as pre_warm_backward_cache_kernels;
33#[cfg(feature = "cuda")]
34use crate::autograd::cuda_forward::{pre_warm_forward_kernels, pre_warm_lora_backward_kernels};
35#[cfg(feature = "cuda")]
36use crate::autograd::cuda_optim::pre_warm_lora_adamw_kernels;
37#[cfg(feature = "cuda")]
38use crate::autograd::cuda_training::{cuda_training_available, CudaTrainer};
39#[cfg(feature = "cuda")]
40use crate::gpu::guard::VramGuard;
41#[cfg(feature = "cuda")]
42use crate::transformer::{
43 CudaBlock, CudaBlockScratch, CudaGradWorkspace, CudaLoraGradWorkspace, CudaTransformerBlock,
44 GpuBlockOptimizerState, GpuLoraOptimizerState,
45};
46#[cfg(feature = "cuda")]
47use std::sync::Arc;
48#[cfg(feature = "cuda")]
49use trueno_gpu::driver::GpuBuffer;
50
51#[derive(Debug, Clone)]
53pub struct ClassifyConfig {
54 pub num_classes: usize,
56 pub lora_rank: usize,
58 pub lora_alpha: f32,
60 pub learning_rate: f32,
62 pub epochs: usize,
64 pub max_seq_len: usize,
66 pub log_interval: usize,
68 pub batch_size: usize,
73 pub accumulation_steps: usize,
78 pub gradient_clip_norm: Option<f32>,
83 pub class_weights: Option<Vec<f32>>,
91 pub quantize_nf4: bool,
96}
97
98impl Default for ClassifyConfig {
99 fn default() -> Self {
100 Self {
101 num_classes: 5,
102 lora_rank: 16,
103 lora_alpha: 16.0,
104 learning_rate: 1e-4,
105 epochs: 3,
106 max_seq_len: 512,
107 log_interval: 100,
108 batch_size: 32,
109 accumulation_steps: 1,
110 gradient_clip_norm: Some(1.0),
111 class_weights: None,
112 quantize_nf4: false,
113 }
114 }
115}
116
117#[derive(Debug, Clone)]
121pub struct HyperparamDiagnostic {
122 pub contract_id: &'static str,
123 pub severity: DiagSeverity,
124 pub message: String,
125 pub recommendation: String,
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum DiagSeverity {
131 Info,
133 Warn,
135 Error,
137}
138
139#[derive(Debug, Clone, Default)]
141pub struct HyperparamDiagnostics {
142 pub items: Vec<HyperparamDiagnostic>,
143}
144
145impl HyperparamDiagnostics {
146 pub fn has_warning(&self, contract_id: &str) -> bool {
148 self.items.iter().any(|d| {
149 d.contract_id == contract_id
150 && matches!(d.severity, DiagSeverity::Warn | DiagSeverity::Error)
151 })
152 }
153
154 pub fn has_errors(&self) -> bool {
156 self.items.iter().any(|d| matches!(d.severity, DiagSeverity::Error))
157 }
158
159 pub fn print_all(&self) {
161 for d in &self.items {
162 let prefix = match d.severity {
163 DiagSeverity::Info => "[HP-INFO]",
164 DiagSeverity::Warn => "[HP-WARN]",
165 DiagSeverity::Error => "[HP-ERROR]",
166 };
167 eprintln!("{prefix} {}: {} → {}", d.contract_id, d.message, d.recommendation);
168 }
169 }
170}
171
172pub struct DataStats {
176 pub p99_token_length: usize,
178 pub imbalance_ratio: f32,
180 pub minority_count: usize,
182}
183
184impl ClassifyConfig {
185 pub fn qlora_default(model_params: u64) -> Self {
200 let learning_rate = if model_params <= 13_000_000_000 { 2e-4 } else { 1e-4 };
202 let lora_rank = 16;
203 Self {
204 num_classes: 2,
205 lora_rank,
206 lora_alpha: (2 * lora_rank) as f32,
208 learning_rate,
209 epochs: 3,
211 max_seq_len: 256,
213 log_interval: 100,
214 batch_size: 16,
217 accumulation_steps: 1,
218 gradient_clip_norm: Some(1.0),
220 class_weights: None,
221 quantize_nf4: true,
222 }
223 }
224
225 pub fn validate_hyperparameters(&self, model_params: u64) -> HyperparamDiagnostics {
232 let mut diags = HyperparamDiagnostics::default();
233
234 if self.quantize_nf4 && model_params <= 13_000_000_000 && self.learning_rate < 1.5e-4 {
236 diags.items.push(HyperparamDiagnostic {
237 contract_id: "C-HP-001",
238 severity: DiagSeverity::Warn,
239 message: format!(
240 "lr={:.0e} too low for {}B model (Dettmers 2023: use 2e-4 for ≤13B)",
241 self.learning_rate,
242 model_params / 1_000_000_000
243 ),
244 recommendation: "learning_rate: 0.0002".to_string(),
245 });
246 }
247
248 let eff_batch = self.batch_size * self.accumulation_steps;
250 if eff_batch != 16 {
251 diags.items.push(HyperparamDiagnostic {
252 contract_id: "C-HP-002",
253 severity: DiagSeverity::Warn,
254 message: format!(
255 "effective_batch={eff_batch} ({}×{}), Dettmers 2023 recommends 16 for ≤13B",
256 self.batch_size, self.accumulation_steps
257 ),
258 recommendation: format!(
259 "batch_size: {}, accumulation_steps: {}",
260 self.batch_size,
261 16 / self.batch_size.max(1)
262 ),
263 });
264 }
265
266 let expected_alpha = 2.0 * self.lora_rank as f32;
268 if (self.lora_alpha - expected_alpha).abs() > 0.5 {
269 diags.items.push(HyperparamDiagnostic {
270 contract_id: "C-HP-003",
271 severity: DiagSeverity::Warn,
272 message: format!(
273 "lora_alpha={} with rank={} (ratio={:.1}), Lightning AI: alpha=2×rank={} optimal",
274 self.lora_alpha, self.lora_rank,
275 self.lora_alpha / self.lora_rank as f32,
276 expected_alpha
277 ),
278 recommendation: format!("lora_alpha: {expected_alpha}"),
279 });
280 }
281
282 if self.gradient_clip_norm.is_none() {
284 diags.items.push(HyperparamDiagnostic {
285 contract_id: "C-HP-006",
286 severity: DiagSeverity::Warn,
287 message: "No gradient clipping — SSC v2.2 saw grad norms up to 115.1".to_string(),
288 recommendation: "gradient_clip_norm: 1.0".to_string(),
289 });
290 }
291
292 if self.learning_rate <= 0.0 {
294 diags.items.push(HyperparamDiagnostic {
295 contract_id: "C-HP-001",
296 severity: DiagSeverity::Error,
297 message: "learning_rate must be > 0".to_string(),
298 recommendation: "learning_rate: 0.0002".to_string(),
299 });
300 }
301 if self.batch_size == 0 {
302 diags.items.push(HyperparamDiagnostic {
303 contract_id: "C-HP-002",
304 severity: DiagSeverity::Error,
305 message: "batch_size must be > 0".to_string(),
306 recommendation: "batch_size: 4".to_string(),
307 });
308 }
309
310 diags
311 }
312
313 pub fn validate_with_data(&self, stats: &DataStats) -> HyperparamDiagnostics {
319 let mut diags = HyperparamDiagnostics::default();
320
321 if self.max_seq_len > 2 * stats.p99_token_length && stats.p99_token_length > 0 {
323 diags.items.push(HyperparamDiagnostic {
324 contract_id: "C-HP-004",
325 severity: DiagSeverity::Warn,
326 message: format!(
327 "max_seq_len={} but p99(tokens)={} — attention is O(n²), wasting {:.0}× compute",
328 self.max_seq_len,
329 stats.p99_token_length,
330 (self.max_seq_len as f64 / stats.p99_token_length as f64).powi(2)
331 ),
332 recommendation: format!(
333 "max_seq_len: {} (next_pow2 of p99)",
334 stats.p99_token_length.next_power_of_two()
335 ),
336 });
337 }
338
339 if stats.imbalance_ratio > 5.0 && self.epochs < 2 {
341 let eff_batch = self.batch_size * self.accumulation_steps;
342 let updates_per_epoch = stats.minority_count / eff_batch.max(1);
343 diags.items.push(HyperparamDiagnostic {
344 contract_id: "C-HP-008",
345 severity: DiagSeverity::Warn,
346 message: format!(
347 "epochs={} with {:.1}:1 imbalance — minority gets only {} gradient updates",
348 self.epochs,
349 stats.imbalance_ratio,
350 updates_per_epoch * self.epochs
351 ),
352 recommendation: format!(
353 "epochs: 3 (minority gets {} updates)",
354 updates_per_epoch * 3
355 ),
356 });
357 }
358
359 diags
360 }
361}
362
363#[derive(Debug, Clone)]
365pub struct BatchResult {
366 pub avg_loss: f32,
368 pub correct: usize,
370 pub total: usize,
372 pub grad_norm: f32,
374}
375
376impl BatchResult {
377 #[must_use]
381 pub fn accuracy(&self) -> f32 {
382 contract_pre_accuracy!();
383 self.correct as f32 / self.total.max(1) as f32
384 }
385}
386
387#[cfg(feature = "cuda")]
398struct GpuTrainingState {
399 layer_inputs: Vec<GpuBuffer<f32>>,
401 final_norm_weight: GpuBuffer<f32>,
403 blocks_output: GpuBuffer<f32>,
405 grad_buf_a: GpuBuffer<f32>,
407 grad_buf_b: GpuBuffer<f32>,
409 grad_final_norm_weight: GpuBuffer<f32>,
411 optimizer_states: Vec<GpuBlockOptimizerState>,
413 step: u32,
415 output_scratch: GpuBuffer<f32>,
418 grad_upload_buf: GpuBuffer<f32>,
421 fwd_scratch_a: GpuBuffer<f32>,
424 fwd_scratch_b: GpuBuffer<f32>,
425 backward_cpu_staging: Vec<f32>,
429}
430
431pub struct ClassifyPipeline {
441 pub model: Transformer,
443 pub classifier: ClassificationHead,
445 pub lora_layers: Vec<LoRALayer>,
447 pub config: ClassifyConfig,
449 optimizer: AdamW,
451 tokenizer: Option<HfTokenizer>,
453 model_dir: Option<PathBuf>,
455 #[cfg(feature = "cuda")]
457 cuda_trainer: Option<CudaTrainer>,
458 #[cfg(feature = "cuda")]
460 cuda_blocks: Option<Vec<CudaBlock>>,
461 #[cfg(feature = "cuda")]
464 shared_scratch: Option<CudaBlockScratch>,
465 #[cfg(feature = "cuda")]
468 cuda_nan_count: usize,
469 #[cfg(feature = "cuda")]
472 gpu_training: Option<GpuTrainingState>,
473 #[cfg(feature = "cuda")]
476 cuda_grad_workspace: Option<CudaGradWorkspace>,
477 #[cfg(feature = "cuda")]
480 cuda_lora_grad_workspace: Option<CudaLoraGradWorkspace>,
481 #[cfg(feature = "cuda")]
483 cuda_lora_optimizer_states: Option<Vec<GpuLoraOptimizerState>>,
484 #[cfg(feature = "cuda")]
486 cuda_lora_grad_accum: Option<Vec<CudaLoraGradWorkspace>>,
487 #[cfg(feature = "cuda")]
489 nf4_lora_step: u32,
490 #[cfg(feature = "gpu")]
492 wgpu_forward_pass: Option<crate::transformer::WgpuForwardPass>,
493 #[cfg(feature = "cuda")]
496 #[allow(dead_code)]
497 vram_guard: Option<VramGuard>,
498}
499
500impl ClassifyPipeline {
501 pub fn new(model_config: &TransformerConfig, classify_config: ClassifyConfig) -> Self {
507 let model = Transformer::new(model_config);
508 let classifier =
509 ClassificationHead::new(model_config.hidden_size, classify_config.num_classes);
510 let mut lora_layers = Self::build_lora_layers(&model, model_config, &classify_config);
511
512 for lora in &mut lora_layers {
514 for param in lora.trainable_params() {
515 param.set_requires_grad(true);
516 }
517 }
518
519 let optimizer = AdamW::default_params(classify_config.learning_rate);
520
521 #[cfg(feature = "cuda")]
523 let (cuda_trainer, cuda_blocks, shared_scratch, vram_guard) =
524 Self::try_init_cuda(&model, model_config, &classify_config, &lora_layers);
525
526 #[cfg(feature = "cuda")]
528 let gpu_training = Self::try_init_gpu_training(
529 &model,
530 model_config,
531 classify_config.max_seq_len,
532 cuda_trainer.as_ref(),
533 cuda_blocks.as_ref(),
534 );
535
536 #[cfg(feature = "cuda")]
538 let cuda_grad_workspace = if classify_config.quantize_nf4 {
539 None
540 } else {
541 cuda_trainer.as_ref().and_then(|t| {
542 CudaGradWorkspace::new(t.context(), model_config)
543 .map_err(|e| eprintln!("[CUDA] Failed to allocate grad workspace: {e}"))
544 .ok()
545 })
546 };
547
548 #[cfg(feature = "cuda")]
550 let (cuda_lora_grad_workspace, cuda_lora_optimizer_states, cuda_lora_grad_accum) =
551 if classify_config.quantize_nf4 {
552 Self::try_init_nf4_lora_training(
553 cuda_trainer.as_ref(),
554 cuda_blocks.as_ref(),
555 model_config,
556 &classify_config,
557 )
558 } else {
559 (None, None, None)
560 };
561
562 #[cfg(feature = "gpu")]
564 let wgpu_forward_pass = {
565 #[cfg(feature = "cuda")]
566 let has_cuda = cuda_trainer.is_some();
567 #[cfg(not(feature = "cuda"))]
568 let has_cuda = false;
569
570 if has_cuda {
571 None } else {
573 match crate::transformer::WgpuForwardPass::with_resident_weights(&model) {
575 Ok(pass) => {
576 eprintln!("[wgpu] GPU forward pass initialized (resident weights)");
577 Some(pass)
578 }
579 Err(e) => {
580 eprintln!("[wgpu] GPU resident init failed, trying default: {e}");
581 match crate::transformer::WgpuForwardPass::new_default(model_config) {
582 Ok(pass) => {
583 eprintln!("[wgpu] GPU forward pass initialized (upload per call)");
584 Some(pass)
585 }
586 Err(e2) => {
587 eprintln!("[wgpu] GPU initialization failed, using CPU: {e2}");
588 None
589 }
590 }
591 }
592 }
593 }
594 };
595
596 Self {
597 model,
598 classifier,
599 lora_layers,
600 config: classify_config,
601 optimizer,
602 tokenizer: None,
603 model_dir: None,
604 #[cfg(feature = "cuda")]
605 cuda_trainer,
606 #[cfg(feature = "cuda")]
607 cuda_blocks,
608 #[cfg(feature = "cuda")]
609 shared_scratch,
610 #[cfg(feature = "cuda")]
611 cuda_nan_count: 0,
612 #[cfg(feature = "cuda")]
613 gpu_training,
614 #[cfg(feature = "cuda")]
615 cuda_grad_workspace,
616 #[cfg(feature = "cuda")]
617 cuda_lora_grad_workspace,
618 #[cfg(feature = "cuda")]
619 cuda_lora_optimizer_states,
620 #[cfg(feature = "cuda")]
621 cuda_lora_grad_accum,
622 #[cfg(feature = "cuda")]
623 nf4_lora_step: 0,
624 #[cfg(feature = "gpu")]
625 wgpu_forward_pass,
626 #[cfg(feature = "cuda")]
627 vram_guard,
628 }
629 }
630
631 pub fn from_pretrained(
644 model_dir: impl AsRef<Path>,
645 model_config: &TransformerConfig,
646 classify_config: ClassifyConfig,
647 ) -> crate::Result<Self> {
648 let model_dir = model_dir.as_ref();
649
650 let model = Transformer::from_safetensors(model_dir, model_config)?;
651 let classifier =
652 ClassificationHead::new(model_config.hidden_size, classify_config.num_classes);
653 let mut lora_layers = Self::build_lora_layers(&model, model_config, &classify_config);
654
655 for lora in &mut lora_layers {
656 for param in lora.trainable_params() {
657 param.set_requires_grad(true);
658 }
659 }
660
661 let tokenizer_path = model_dir.join("tokenizer.json");
663 let tokenizer = if tokenizer_path.exists() {
664 Some(
665 HfTokenizer::from_file(&tokenizer_path)
666 .map_err(|e| crate::Error::Io(format!("Failed to load tokenizer: {e}")))?,
667 )
668 } else {
669 return Err(crate::Error::ConfigError(format!(
670 "No tokenizer.json found in '{}'. Training requires a BPE tokenizer.",
671 model_dir.display(),
672 )));
673 };
674
675 let optimizer = AdamW::default_params(classify_config.learning_rate);
676
677 #[cfg(feature = "cuda")]
679 let (cuda_trainer, cuda_blocks, shared_scratch, vram_guard) =
680 Self::try_init_cuda(&model, model_config, &classify_config, &lora_layers);
681
682 #[cfg(feature = "cuda")]
684 let gpu_training = Self::try_init_gpu_training(
685 &model,
686 model_config,
687 classify_config.max_seq_len,
688 cuda_trainer.as_ref(),
689 cuda_blocks.as_ref(),
690 );
691
692 #[cfg(feature = "cuda")]
694 let cuda_grad_workspace = if classify_config.quantize_nf4 {
695 None } else {
697 cuda_trainer.as_ref().and_then(|t| {
698 CudaGradWorkspace::new(t.context(), model_config)
699 .map_err(|e| eprintln!("[CUDA] Failed to allocate grad workspace: {e}"))
700 .ok()
701 })
702 };
703
704 #[cfg(feature = "cuda")]
706 let (cuda_lora_grad_workspace, cuda_lora_optimizer_states, cuda_lora_grad_accum) =
707 if classify_config.quantize_nf4 {
708 Self::try_init_nf4_lora_training(
709 cuda_trainer.as_ref(),
710 cuda_blocks.as_ref(),
711 model_config,
712 &classify_config,
713 )
714 } else {
715 (None, None, None)
716 };
717
718 #[cfg(feature = "gpu")]
720 let wgpu_forward_pass = {
721 #[cfg(feature = "cuda")]
722 let has_cuda = cuda_trainer.is_some();
723 #[cfg(not(feature = "cuda"))]
724 let has_cuda = false;
725
726 if has_cuda {
727 None
728 } else {
729 match crate::transformer::WgpuForwardPass::with_resident_weights(&model) {
731 Ok(pass) => {
732 eprintln!(
733 "[wgpu] Batched forward pass initialized ({} layers, resident weights)",
734 model_config.num_hidden_layers
735 );
736 Some(pass)
737 }
738 Err(e) => {
739 eprintln!("[wgpu] Resident init failed, trying default: {e}");
740 match crate::transformer::WgpuForwardPass::new_default(model_config) {
741 Ok(pass) => {
742 eprintln!("[wgpu] Batched forward pass initialized ({} layers, upload per call)", model_config.num_hidden_layers);
743 Some(pass)
744 }
745 Err(e2) => {
746 eprintln!("[wgpu] GPU init failed, using CPU: {e2}");
747 None
748 }
749 }
750 }
751 }
752 }
753 };
754
755 Ok(Self {
756 model,
757 classifier,
758 lora_layers,
759 config: classify_config,
760 optimizer,
761 tokenizer,
762 model_dir: Some(model_dir.to_path_buf()),
763 #[cfg(feature = "cuda")]
764 cuda_trainer,
765 #[cfg(feature = "cuda")]
766 cuda_blocks,
767 #[cfg(feature = "cuda")]
768 shared_scratch,
769 #[cfg(feature = "cuda")]
770 cuda_nan_count: 0,
771 #[cfg(feature = "cuda")]
772 gpu_training,
773 #[cfg(feature = "cuda")]
774 cuda_grad_workspace,
775 #[cfg(feature = "cuda")]
776 cuda_lora_grad_workspace,
777 #[cfg(feature = "cuda")]
778 cuda_lora_optimizer_states,
779 #[cfg(feature = "cuda")]
780 cuda_lora_grad_accum,
781 #[cfg(feature = "cuda")]
782 nf4_lora_step: 0,
783 #[cfg(feature = "gpu")]
784 wgpu_forward_pass,
785 #[cfg(feature = "cuda")]
786 vram_guard,
787 })
788 }
789
790 pub fn from_apr(
799 apr_path: &Path,
800 model_config: &TransformerConfig,
801 classify_config: ClassifyConfig,
802 ) -> crate::Result<Self> {
803 let model = Transformer::from_apr(apr_path, model_config)?;
804 let classifier =
805 ClassificationHead::new(model_config.hidden_size, classify_config.num_classes);
806 let mut lora_layers = Self::build_lora_layers(&model, model_config, &classify_config);
807
808 for lora in &mut lora_layers {
809 for param in lora.trainable_params() {
810 param.set_requires_grad(true);
811 }
812 }
813
814 let tokenizer = {
816 let sibling = apr_path.file_stem().and_then(|stem| {
817 apr_path
818 .parent()
819 .map(|p| p.join(format!("{}.tokenizer.json", stem.to_str().unwrap_or(""))))
820 });
821
822 match sibling {
823 Some(ref path) if path.exists() => {
824 let tok = HfTokenizer::from_file(path).map_err(|e| {
825 crate::Error::ConfigError(format!(
826 "Failed to load tokenizer from '{}': {e}. \
827 Training requires a BPE tokenizer.",
828 path.display(),
829 ))
830 })?;
831 Some(tok)
832 }
833 _ => {
834 return Err(crate::Error::ConfigError(format!(
835 "No sibling tokenizer found for '{}'. Expected \
836 '{}.tokenizer.json' next to the .apr file. Training \
837 requires a BPE tokenizer.",
838 apr_path.display(),
839 apr_path.file_stem().unwrap_or_default().to_str().unwrap_or(""),
840 )));
841 }
842 }
843 };
844
845 let optimizer = AdamW::default_params(classify_config.learning_rate);
846
847 #[cfg(feature = "cuda")]
848 let (cuda_trainer, cuda_blocks, shared_scratch, vram_guard) =
849 Self::try_init_cuda(&model, model_config, &classify_config, &lora_layers);
850
851 #[cfg(feature = "cuda")]
852 let gpu_training = Self::try_init_gpu_training(
853 &model,
854 model_config,
855 classify_config.max_seq_len,
856 cuda_trainer.as_ref(),
857 cuda_blocks.as_ref(),
858 );
859
860 #[cfg(feature = "cuda")]
861 let cuda_grad_workspace = if classify_config.quantize_nf4 {
862 None
863 } else {
864 cuda_trainer.as_ref().and_then(|t| {
865 CudaGradWorkspace::new(t.context(), model_config)
866 .map_err(|e| eprintln!("[CUDA] Failed to allocate grad workspace: {e}"))
867 .ok()
868 })
869 };
870
871 #[cfg(feature = "cuda")]
872 let (cuda_lora_grad_workspace, cuda_lora_optimizer_states, cuda_lora_grad_accum) =
873 if classify_config.quantize_nf4 {
874 Self::try_init_nf4_lora_training(
875 cuda_trainer.as_ref(),
876 cuda_blocks.as_ref(),
877 model_config,
878 &classify_config,
879 )
880 } else {
881 (None, None, None)
882 };
883
884 #[cfg(feature = "gpu")]
886 let wgpu_forward_pass = {
887 #[cfg(feature = "cuda")]
888 let has_cuda = cuda_trainer.is_some();
889 #[cfg(not(feature = "cuda"))]
890 let has_cuda = false;
891
892 if has_cuda {
893 None
894 } else {
895 crate::transformer::WgpuForwardPass::with_resident_weights(&model)
897 .or_else(|e| {
898 eprintln!("[wgpu] Resident init failed: {e}, trying default");
899 crate::transformer::WgpuForwardPass::new_default(model_config)
900 })
901 .map_err(|e| eprintln!("[wgpu] GPU init failed: {e}"))
902 .ok()
903 }
904 };
905
906 Ok(Self {
907 model,
908 classifier,
909 lora_layers,
910 config: classify_config,
911 optimizer,
912 tokenizer,
913 model_dir: Some(apr_path.to_path_buf()),
914 #[cfg(feature = "cuda")]
915 cuda_trainer,
916 #[cfg(feature = "cuda")]
917 cuda_blocks,
918 #[cfg(feature = "cuda")]
919 shared_scratch,
920 #[cfg(feature = "cuda")]
921 cuda_nan_count: 0,
922 #[cfg(feature = "cuda")]
923 gpu_training,
924 #[cfg(feature = "cuda")]
925 cuda_grad_workspace,
926 #[cfg(feature = "cuda")]
927 cuda_lora_grad_workspace,
928 #[cfg(feature = "cuda")]
929 cuda_lora_optimizer_states,
930 #[cfg(feature = "cuda")]
931 cuda_lora_grad_accum,
932 #[cfg(feature = "cuda")]
933 nf4_lora_step: 0,
934 #[cfg(feature = "gpu")]
935 wgpu_forward_pass,
936 #[cfg(feature = "cuda")]
937 vram_guard,
938 })
939 }
940
941 pub(crate) fn tokenize(&self, text: &str) -> Vec<u32> {
949 let mut ids = match self.tokenizer.as_ref() {
950 Some(tok) => tok.encode(text),
951 None => {
952 text.bytes().map(u32::from).collect()
954 }
955 };
956 ids.truncate(self.config.max_seq_len);
957 if ids.is_empty() {
958 ids.push(0);
959 }
960 ids
961 }
962
963 pub fn pre_tokenize(&self, samples: &[SafetySample]) -> Vec<TokenizedSample> {
974 let has_tokenizer = self.tokenizer.is_some();
975 samples
976 .iter()
977 .map(|s| {
978 let token_ids = if has_tokenizer {
979 self.tokenize(&s.input)
980 } else {
981 let mut ids = s.input_ids();
983 ids.truncate(self.config.max_seq_len);
984 if ids.is_empty() {
985 ids.push(0);
986 }
987 ids
988 };
989 TokenizedSample { token_ids, label: s.label }
990 })
991 .collect()
992 }
993
994 pub fn train_batch_tokenized(&mut self, samples: &[TokenizedSample]) -> BatchResult {
999 if samples.is_empty() {
1000 return BatchResult { avg_loss: 0.0, correct: 0, total: 0, grad_norm: 0.0 };
1001 }
1002
1003 let batch_size = samples.len();
1004
1005 self.zero_all_gradients();
1007
1008 #[cfg(feature = "gpu")]
1010 let (total_loss, correct) = self
1011 .try_train_batch_wgpu_tokenized(samples)
1012 .unwrap_or_else(|| self.train_batch_per_sample_tokenized(samples));
1013
1014 #[cfg(not(feature = "gpu"))]
1015 let (total_loss, correct) = self.train_batch_per_sample_tokenized(samples);
1016
1017 self.scale_all_gradients(1.0 / batch_size as f32);
1019
1020 let grad_norm = if let Some(max_norm) = self.config.gradient_clip_norm {
1022 let mut params = self.trainable_parameters_mut();
1023 clip_grad_norm_refs(&mut params, max_norm)
1024 } else {
1025 self.compute_grad_norm()
1026 };
1027
1028 #[cfg(feature = "cuda")]
1030 {
1031 if self.gpu_training.is_some() && !self.config.quantize_nf4 {
1032 let lr = self.optimizer.lr();
1033 self.gpu_optimizer_step(lr);
1034 }
1035 }
1036
1037 #[cfg(feature = "cuda")]
1038 {
1039 if self.gpu_training.is_some() && self.config.quantize_nf4 {
1040 self.nf4_lora_batch_optimizer_step(batch_size);
1041 }
1042 }
1043
1044 let mut params: Vec<&mut Tensor> = Vec::new();
1045 if !self.config.quantize_nf4 {
1046 for lora in &mut self.lora_layers {
1047 params.extend(lora.trainable_params());
1048 }
1049 }
1050 params.extend(self.classifier.parameters_mut());
1051 self.optimizer.step_refs(&mut params);
1052
1053 BatchResult {
1054 avg_loss: total_loss / batch_size as f32,
1055 correct,
1056 total: batch_size,
1057 grad_norm,
1058 }
1059 }
1060
1061 fn train_batch_per_sample_tokenized(&mut self, samples: &[TokenizedSample]) -> (f32, usize) {
1063 let mut total_loss = 0.0f32;
1064 let mut correct = 0usize;
1065 for sample in samples {
1066 let (loss, predicted) = self.forward_backward_single(&sample.token_ids, sample.label);
1067 total_loss += loss;
1068 if predicted == sample.label {
1069 correct += 1;
1070 }
1071 }
1072 (total_loss, correct)
1073 }
1074
1075 #[cfg(feature = "gpu")]
1077 fn try_train_batch_wgpu_tokenized(
1078 &mut self,
1079 samples: &[TokenizedSample],
1080 ) -> Option<(f32, usize)> {
1081 self.wgpu_forward_pass.as_ref()?;
1082
1083 let batch_token_ids: Vec<Vec<u32>> = samples.iter().map(|s| s.token_ids.clone()).collect();
1084
1085 let lora_ref =
1086 if self.lora_layers.is_empty() { None } else { Some(self.lora_layers.as_slice()) };
1087
1088 let hiddens = self
1089 .wgpu_forward_pass
1090 .as_ref()
1091 .expect("checked is_none above")
1092 .forward_hidden_batch(&self.model, &batch_token_ids, lora_ref)
1093 .map_err(|e| {
1094 eprintln!("[wgpu] Batched forward failed, falling back to per-sample: {e}");
1095 })
1096 .ok()?;
1097
1098 let mut total_loss = 0.0f32;
1099 let mut correct = 0usize;
1100 for (i, hidden) in hiddens.iter().enumerate() {
1101 let (loss, predicted) = self.classify_backward_from_hidden(
1102 hidden,
1103 batch_token_ids[i].len(),
1104 samples[i].label,
1105 );
1106 total_loss += loss;
1107 if predicted == samples[i].label {
1108 correct += 1;
1109 }
1110 }
1111 Some((total_loss, correct))
1112 }
1113
1114 pub fn accumulate_gradients_tokenized(
1118 &mut self,
1119 micro_batch: &[TokenizedSample],
1120 ) -> BatchResult {
1121 if micro_batch.is_empty() {
1122 return BatchResult { avg_loss: 0.0, correct: 0, total: 0, grad_norm: 0.0 };
1123 }
1124
1125 let mut total_loss = 0.0f32;
1126 let mut correct = 0usize;
1127
1128 for sample in micro_batch {
1129 let (loss, predicted) = self.forward_backward_single(&sample.token_ids, sample.label);
1130 total_loss += loss;
1131 if predicted == sample.label {
1132 correct += 1;
1133 }
1134 }
1135
1136 BatchResult {
1137 avg_loss: total_loss / micro_batch.len() as f32,
1138 correct,
1139 total: micro_batch.len(),
1140 grad_norm: 0.0,
1141 }
1142 }
1143
1144 pub fn forward_only_tokenized(&mut self, token_ids: &[u32], label: usize) -> (f32, usize) {
1148 self.forward_only(token_ids, label)
1149 }
1150}
1151
1152include!("gpu.rs");
1154
1155include!("training.rs");
1157
1158#[cfg(test)]
1159#[allow(clippy::unwrap_used)]
1160mod tests;