mod accessors;
mod backward;
mod constructors;
mod cuda_forward;
mod cuda_init;
mod generate;
mod training;
mod wgpu;
#[cfg(test)]
mod tests;
#[cfg(test)]
mod tests_cov3;
#[cfg(test)]
mod tests_cov3b;
use crate::lora::LoRALayer;
use crate::optim::{clip_grad_norm_refs, AdamW, Optimizer};
use crate::tokenizer::HfTokenizer;
use crate::train::transformer_trainer::step_profiler::StepProfiler;
use crate::transformer::{Transformer, TransformerConfig};
use crate::Tensor;
use std::path::{Path, PathBuf};
#[cfg(feature = "cuda")]
use crate::autograd::cuda_training::CudaTrainer;
#[cfg(feature = "cuda")]
use crate::gpu::guard::VramGuard;
#[cfg(feature = "cuda")]
use crate::transformer::{
CudaBlock, CudaBlockScratch, CudaLoraGradWorkspace, GpuLoraOptimizerState,
};
#[cfg(feature = "cuda")]
use trueno_gpu::driver::GpuBuffer;
#[derive(Debug, Clone)]
pub struct InstructConfig {
pub lora_rank: usize,
pub lora_alpha: f32,
pub learning_rate: f32,
pub epochs: usize,
pub max_seq_len: usize,
pub gradient_clip_norm: Option<f32>,
pub quantize_nf4: bool,
}
impl Default for InstructConfig {
fn default() -> Self {
Self {
lora_rank: 16,
lora_alpha: 32.0,
learning_rate: 2e-4,
epochs: 3,
max_seq_len: 512,
gradient_clip_norm: Some(1.0),
quantize_nf4: false,
}
}
}
#[derive(Debug, Clone)]
pub struct InstructStepResult {
pub loss: f32,
pub num_response_tokens: usize,
pub perplexity: f32,
}
#[derive(Debug, Clone)]
pub struct InstructBatchResult {
pub avg_loss: f32,
pub total_response_tokens: usize,
pub perplexity: f32,
pub grad_norm: f32,
}
#[cfg(feature = "cuda")]
pub(super) struct InstructGpuTrainingState {
layer_inputs: Vec<GpuBuffer<f32>>,
final_norm_weight: GpuBuffer<f32>,
blocks_output: GpuBuffer<f32>,
grad_buf_a: GpuBuffer<f32>,
grad_buf_b: GpuBuffer<f32>,
grad_final_norm_weight: GpuBuffer<f32>,
embed_transposed: GpuBuffer<f32>, embed_original: GpuBuffer<f32>, logits_buf: GpuBuffer<f32>,
grad_hidden_buf: GpuBuffer<f32>,
output_scratch: GpuBuffer<f32>,
grad_upload_buf: GpuBuffer<f32>,
fwd_scratch_a: GpuBuffer<f32>,
fwd_scratch_b: GpuBuffer<f32>,
lm_head_hidden_buf: GpuBuffer<f32>,
forward_graph_exec: Option<trueno_gpu::driver::CudaGraphExec>,
graph_cached_seq_len: usize,
backward_graph_state: Option<super::backward_graph::BackwardGraphState>,
cublas_workspace: Option<GpuBuffer<f32>>,
profiler_layer_fwd_us: Vec<u64>,
profiler_layer_bwd_us: Vec<u64>,
profiler_layer_start: Option<std::time::Instant>,
profiler_op_us: [u64; 16],
profiler_op_start: Option<std::time::Instant>,
}
pub struct InstructPipeline {
pub model: Transformer,
pub lora_layers: Vec<LoRALayer>,
pub config: InstructConfig,
optimizer: AdamW,
tokenizer: Option<HfTokenizer>,
model_dir: Option<PathBuf>,
pub profiler: StepProfiler,
#[cfg(feature = "cuda")]
cuda_trainer: Option<CudaTrainer>,
#[cfg(feature = "cuda")]
cuda_blocks: Option<Vec<CudaBlock>>,
#[cfg(feature = "cuda")]
shared_scratch: Option<CudaBlockScratch>,
#[cfg(feature = "cuda")]
#[allow(dead_code)]
cuda_nan_count: usize,
#[cfg(feature = "cuda")]
gpu_training: Option<InstructGpuTrainingState>,
#[cfg(feature = "cuda")]
cuda_lora_grad_workspace: Option<CudaLoraGradWorkspace>,
#[cfg(feature = "cuda")]
lora_fused_clip: Option<crate::autograd::cuda_optim::FusedClipState>,
#[cfg(feature = "cuda")]
cuda_lora_optimizer_states: Option<Vec<GpuLoraOptimizerState>>,
#[cfg(feature = "cuda")]
nf4_lora_step: u32,
#[cfg(feature = "cuda")]
#[allow(dead_code)]
vram_guard: Option<VramGuard>,
#[cfg(feature = "gpu")]
wgpu_training: Option<WgpuTrainingState>,
}
#[cfg(feature = "gpu")]
struct WgpuTrainingState {
fwd: trueno::backends::gpu::WgslForwardPass,
cross_entropy: crate::autograd::wgpu_cross_entropy::WgslCrossEntropy,
trainer: crate::autograd::wgpu_training::WgpuTrainer,
logits_buf: trueno::backends::gpu::wgpu::Buffer,
labels_buf: trueno::backends::gpu::wgpu::Buffer,
losses_buf: trueno::backends::gpu::wgpu::Buffer,
logsumexp_buf: trueno::backends::gpu::wgpu::Buffer,
lm_head_gpu: trueno::backends::gpu::wgpu::Buffer,
lm_head_t_gpu: trueno::backends::gpu::wgpu::Buffer,
num_layers: usize,
hidden_dim: usize,
vocab_size: usize,
}
#[derive(Debug, Clone)]
pub struct GenerateConfig {
pub max_new_tokens: usize,
pub temperature: f32,
pub top_k: usize,
pub stop_tokens: Vec<u32>,
}
fn sample_token(logits: &[f32], temperature: f32, top_k: usize) -> u32 {
if temperature <= 0.0 || top_k == 1 {
return logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(idx, _)| idx as u32);
}
let scaled: Vec<f32> = logits.iter().map(|&l| l / temperature).collect();
let mut indices_and_logits: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
indices_and_logits
.sort_unstable_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let k = if top_k > 0 && top_k < indices_and_logits.len() {
top_k
} else {
indices_and_logits.len()
};
let top = &indices_and_logits[..k];
let max_logit = top[0].1;
let exps: Vec<f32> = top.iter().map(|(_, l)| (l - max_logit).exp()).collect();
let sum: f32 = exps.iter().sum();
let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
let r: f32 = simple_random();
let mut cumulative = 0.0;
for (i, &p) in probs.iter().enumerate() {
cumulative += p;
if r < cumulative {
return top[i].0 as u32;
}
}
top[0].0 as u32
}
fn simple_random() -> f32 {
use std::cell::Cell;
thread_local! {
static STATE: Cell<u64> = Cell::new(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(42)
);
}
STATE.with(|s| {
let mut x = s.get();
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
s.set(x);
(x >> 40) as f32 / (1u64 << 24) as f32
})
}