use cudarc::driver::CudaDevice;
use std::sync::Arc;
use crate::backend::{BackendError, BackendResult};
use crate::model::{LlamaModel, ModelConfig, TransformerLayer};
use super::gpu_model::*;
use super::gpu_ops::GpuOps;
pub struct GpuInference {
ops: GpuOps,
model: GpuModel,
pos: usize,
}
impl GpuInference {
pub fn from_model(model: &LlamaModel, max_seq_len: usize) -> BackendResult<Self> {
let config = model.config();
let device = Arc::new(CudaDevice::new(0)
.map_err(|e| BackendError::InitializationFailed(format!("CUDA init failed: {}", e)))?);
let ops = GpuOps::new(Arc::clone(&device))?;
eprintln!("Uploading model to GPU...");
let token_embedding = GpuWeight::from_tensor(&device, model.token_embedding())?;
eprintln!(" Token embedding: {:?}", model.token_embedding().shape());
let mut gpu_layers = Vec::with_capacity(config.num_layers);
for (i, layer) in model.layers().iter().enumerate() {
if i % 4 == 0 {
eprintln!(" Layer {}/{}", i + 1, config.num_layers);
}
gpu_layers.push(upload_transformer_layer(&device, layer, config)?);
}
let norm = GpuRMSNorm::from_rms_norm(&device, &model.norm().weight, model.norm().eps)?;
let output = GpuLinear::from_linear(&device, &model.output().weight, model.output().bias.as_ref())?;
let mut kv_caches = Vec::with_capacity(config.num_layers);
for _ in 0..config.num_layers {
kv_caches.push(GpuKVCache::new(
&device,
config.num_kv_heads,
max_seq_len,
config.head_dim,
)?);
}
let intermediate_size = config.intermediate_size;
let scratch = GpuScratchBuffers::new(
&device,
config.hidden_size,
intermediate_size,
config.num_heads,
config.num_kv_heads,
config.head_dim,
max_seq_len,
config.vocab_size,
)?;
let gpu_model = GpuModel {
device: Arc::clone(&device),
token_embedding,
layers: gpu_layers,
norm,
output,
kv_caches,
hidden_size: config.hidden_size,
num_heads: config.num_heads,
num_kv_heads: config.num_kv_heads,
head_dim: config.head_dim,
vocab_size: config.vocab_size,
num_layers: config.num_layers,
freq_base: config.rope_config.freq_base,
freq_scale: config.rope_config.freq_scale,
scratch,
};
eprintln!("Model uploaded to GPU");
Ok(Self {
ops,
model: gpu_model,
pos: 0,
})
}
pub fn forward(&mut self, token_id: u32) -> BackendResult<Vec<f32>> {
let hidden_size = self.model.hidden_size;
let num_heads = self.model.num_heads;
let num_kv_heads = self.model.num_kv_heads;
let head_dim = self.model.head_dim;
let freq_base = self.model.freq_base;
let freq_scale = self.model.freq_scale;
self.ops.embed_token(
token_id,
&self.model.token_embedding,
&mut self.model.scratch.hidden,
hidden_size,
)?;
self.ops.copy_gpu(&self.model.scratch.hidden, &mut self.model.scratch.residual)?;
for layer_idx in 0..self.model.num_layers {
let layer = &self.model.layers[layer_idx];
let kv_cache = &mut self.model.kv_caches[layer_idx];
self.ops.rms_norm_gpu(
&self.model.scratch.hidden,
&layer.attention_norm.weight,
&mut self.model.scratch.attn_out,
hidden_size,
layer.attention_norm.eps,
)?;
self.ops.linear_gpu(&self.model.scratch.attn_out, &layer.attention.wq, &mut self.model.scratch.q)?;
self.ops.linear_gpu(&self.model.scratch.attn_out, &layer.attention.wk, &mut self.model.scratch.k)?;
self.ops.linear_gpu(&self.model.scratch.attn_out, &layer.attention.wv, &mut self.model.scratch.v)?;
self.ops.rope_gpu(
&mut self.model.scratch.q,
&mut self.model.scratch.k,
num_heads,
head_dim,
self.pos,
freq_base,
freq_scale,
layer.attention.use_neox_rope,
)?;
let kv_offset = self.pos * num_kv_heads * head_dim;
self.ops.linear_gpu(&self.model.scratch.q, &layer.attention.wo, &mut self.model.scratch.attn_out)?;
self.ops.add_gpu(
&self.model.scratch.residual,
&self.model.scratch.attn_out,
&mut self.model.scratch.hidden,
hidden_size,
)?;
self.ops.copy_gpu(&self.model.scratch.hidden, &mut self.model.scratch.residual)?;
self.ops.rms_norm_gpu(
&self.model.scratch.hidden,
&layer.ffn_norm.weight,
&mut self.model.scratch.attn_out,
hidden_size,
layer.ffn_norm.eps,
)?;
self.ops.linear_gpu(&self.model.scratch.attn_out, &layer.ffn.w1, &mut self.model.scratch.ffn_gate)?;
self.ops.linear_gpu(&self.model.scratch.attn_out, &layer.ffn.w3, &mut self.model.scratch.ffn_up)?;
let intermediate_size = layer.ffn.w1.out_features;
{
let mut temp = self.model.device.alloc_zeros::<f32>(intermediate_size)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?;
self.ops.silu_gpu(&self.model.scratch.ffn_gate, &mut temp, intermediate_size)?;
self.ops.copy_gpu(&temp, &mut self.model.scratch.ffn_gate)?;
}
{
let mut temp = self.model.device.alloc_zeros::<f32>(intermediate_size)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?;
self.ops.mul_gpu(
&self.model.scratch.ffn_gate,
&self.model.scratch.ffn_up,
&mut temp,
intermediate_size,
)?;
self.ops.copy_gpu(&temp, &mut self.model.scratch.ffn_gate)?;
}
self.ops.linear_gpu(&self.model.scratch.ffn_gate, &layer.ffn.w2, &mut self.model.scratch.ffn_out)?;
self.ops.add_gpu(
&self.model.scratch.residual,
&self.model.scratch.ffn_out,
&mut self.model.scratch.hidden,
hidden_size,
)?;
self.ops.copy_gpu(&self.model.scratch.hidden, &mut self.model.scratch.residual)?;
}
self.ops.rms_norm_gpu(
&self.model.scratch.hidden,
&self.model.norm.weight,
&mut self.model.scratch.attn_out,
hidden_size,
self.model.norm.eps,
)?;
self.ops.linear_gpu(&self.model.scratch.attn_out, &self.model.output, &mut self.model.scratch.logits)?;
let logits = self.model.device.dtoh_sync_copy(&self.model.scratch.logits)
.map_err(|e| BackendError::OperationFailed(format!("{}", e)))?;
self.pos += 1;
Ok(logits)
}
pub fn reset(&mut self) {
self.pos = 0;
for cache in &mut self.model.kv_caches {
cache.reset();
}
}
pub fn position(&self) -> usize {
self.pos
}
}
fn upload_transformer_layer(
device: &Arc<CudaDevice>,
layer: &TransformerLayer,
config: &ModelConfig,
) -> BackendResult<GpuTransformerLayer> {
let attention_norm = GpuRMSNorm::from_rms_norm(device, &layer.attn_norm.weight, layer.attn_norm.eps)?;
let attention = GpuAttention {
wq: GpuLinear::from_linear(device, &layer.attention.wq.weight, layer.attention.wq.bias.as_ref())?,
wk: GpuLinear::from_linear(device, &layer.attention.wk.weight, layer.attention.wk.bias.as_ref())?,
wv: GpuLinear::from_linear(device, &layer.attention.wv.weight, layer.attention.wv.bias.as_ref())?,
wo: GpuLinear::from_linear(device, &layer.attention.wo.weight, layer.attention.wo.bias.as_ref())?,
num_heads: config.num_heads,
num_kv_heads: config.num_kv_heads,
head_dim: config.head_dim,
use_neox_rope: layer.attention.use_neox_rope,
};
let ffn_norm = GpuRMSNorm::from_rms_norm(device, &layer.ffn_norm.weight, layer.ffn_norm.eps)?;
let ffn = GpuFFN {
w1: GpuLinear::from_linear(device, &layer.ffn.w_gate.weight, layer.ffn.w_gate.bias.as_ref())?,
w2: GpuLinear::from_linear(device, &layer.ffn.w_down.weight, layer.ffn.w_down.bias.as_ref())?,
w3: GpuLinear::from_linear(device, &layer.ffn.w_up.weight, layer.ffn.w_up.bias.as_ref())?,
};
Ok(GpuTransformerLayer {
attention_norm,
attention,
ffn_norm,
ffn,
})
}