use super::device::GpuDevice;
use super::kernel::KernelCache;
use super::module::{GpuAdam, GpuTrainModule};
use super::nn::{relu, relu_backward};
use super::realize::map_elementwise;
use super::reduce::reduce_sum_all;
use super::tensor::GpuTensor;
pub struct GpuTanhLayer {
cached_output: Option<GpuTensor>,
}
impl GpuTanhLayer {
pub fn new() -> Self {
Self { cached_output: None }
}
}
impl GpuTrainModule for GpuTanhLayer {
fn forward_train(
&mut self,
device: &GpuDevice,
cache: &mut KernelCache,
input: &GpuTensor,
) -> GpuTensor {
let output = map_elementwise(device, cache, &[input], |args| {
use crate::Scalar;
args[0].tanh()
});
self.cached_output = Some(output.clone_gpu_batched(device, cache));
output
}
fn backward(
&mut self,
device: &GpuDevice,
cache: &mut KernelCache,
grad_output: &GpuTensor,
) -> GpuTensor {
let output = self.cached_output.as_ref().expect("must call forward_train before backward");
map_elementwise(device, cache, &[grad_output, output], |args| {
use crate::Scalar;
use crate::expr::ExprId;
let grad = args[0];
let out = args[1];
grad * (ExprId::from_f64(1.0) - out * out)
})
}
fn parameters(&self) -> Vec<&GpuTensor> {
vec![]
}
fn parameters_mut(&mut self) -> Vec<&mut GpuTensor> {
vec![]
}
fn gradients(&self) -> Vec<Option<&GpuTensor>> {
vec![]
}
fn zero_grad(&mut self) {
self.cached_output = None;
}
}
pub struct GpuReLULayer {
cached_input: Option<GpuTensor>,
}
impl GpuReLULayer {
pub fn new() -> Self {
Self { cached_input: None }
}
}
impl GpuTrainModule for GpuReLULayer {
fn forward_train(
&mut self,
device: &GpuDevice,
cache: &mut KernelCache,
input: &GpuTensor,
) -> GpuTensor {
self.cached_input = Some(input.clone_gpu_batched(device, cache));
relu(device, cache, input)
}
fn backward(
&mut self,
device: &GpuDevice,
cache: &mut KernelCache,
grad_output: &GpuTensor,
) -> GpuTensor {
let input = self.cached_input.as_ref().expect("must call forward_train before backward");
relu_backward(device, cache, input, grad_output)
}
fn parameters(&self) -> Vec<&GpuTensor> {
vec![]
}
fn parameters_mut(&mut self) -> Vec<&mut GpuTensor> {
vec![]
}
fn gradients(&self) -> Vec<Option<&GpuTensor>> {
vec![]
}
fn zero_grad(&mut self) {
self.cached_input = None;
}
}
pub struct GpuSequential {
layers: Vec<Box<dyn GpuTrainModule>>,
}
impl GpuSequential {
pub fn new(layers: Vec<Box<dyn GpuTrainModule>>) -> Self {
Self { layers }
}
}
impl GpuTrainModule for GpuSequential {
fn forward_train(
&mut self,
device: &GpuDevice,
cache: &mut KernelCache,
input: &GpuTensor,
) -> GpuTensor {
let mut x = input.clone_gpu_batched(device, cache);
for layer in &mut self.layers {
x = layer.forward_train(device, cache, &x);
}
x
}
fn backward(
&mut self,
device: &GpuDevice,
cache: &mut KernelCache,
grad_output: &GpuTensor,
) -> GpuTensor {
let mut grad = grad_output.clone_gpu_batched(device, cache);
for layer in self.layers.iter_mut().rev() {
grad = layer.backward(device, cache, &grad);
}
grad
}
fn parameters(&self) -> Vec<&GpuTensor> {
self.layers.iter().flat_map(|l| l.parameters()).collect()
}
fn parameters_mut(&mut self) -> Vec<&mut GpuTensor> {
self.layers.iter_mut().flat_map(|l| l.parameters_mut()).collect()
}
fn gradients(&self) -> Vec<Option<&GpuTensor>> {
self.layers.iter().flat_map(|l| l.gradients()).collect()
}
fn zero_grad(&mut self) {
for layer in &mut self.layers {
layer.zero_grad();
}
}
}
pub fn gpu_mse_loss(
device: &GpuDevice,
cache: &mut KernelCache,
pred: &GpuTensor,
target: &GpuTensor,
) -> (GpuTensor, GpuTensor) {
let n = pred.numel();
assert_eq!(n, target.numel());
let diff = map_elementwise(device, cache, &[pred, target], |args| args[0] - args[1]);
let sq = map_elementwise(device, cache, &[&diff], |args| args[0] * args[0]);
let sum = reduce_sum_all(device, cache, &sq);
let loss = sum.scale(device, cache, 1.0 / n as f32);
let scale = 2.0 / n as f32;
let grad = diff.scale(device, cache, scale);
(loss, grad)
}
pub fn gpu_cross_entropy_loss(
device: &GpuDevice,
cache: &mut KernelCache,
pred: &GpuTensor,
target: &GpuTensor,
) -> (GpuTensor, GpuTensor) {
cache.flush(device);
let logits_data = pred.buffer.to_vec_sync(device);
let target_data = target.buffer.to_vec_sync(device);
let (batch, num_classes) = if pred.ndim() == 2 {
(pred.shape()[0], pred.shape()[1])
} else {
(1, pred.shape()[0])
};
let mut total_loss = 0.0f32;
let mut grad = vec![0.0f32; batch * num_classes];
for b in 0..batch {
let offset = b * num_classes;
let row = &logits_data[offset..offset + num_classes];
let target_idx = target_data[b] as usize;
let max_val = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp: Vec<f32> = row.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f32 = exp.iter().sum();
let probs: Vec<f32> = exp.iter().map(|&e| e / sum).collect();
total_loss += -probs[target_idx].max(1e-12).ln();
for c in 0..num_classes {
grad[offset + c] = probs[c];
}
grad[offset + target_idx] -= 1.0;
}
let avg_loss = total_loss / batch as f32;
let inv_batch = 1.0 / batch as f32;
for g in &mut grad {
*g *= inv_batch;
}
let loss = GpuTensor::from_slice(device, &[avg_loss], &[1]);
let grad_tensor = GpuTensor::from_slice(device, &grad, pred.shape());
(loss, grad_tensor)
}
pub struct GpuDataLoader {
inputs: Vec<f32>,
targets: Vec<f32>,
input_dim: usize,
target_dim: usize,
n_samples: usize,
batch_size: usize,
position: usize,
}
impl GpuDataLoader {
pub fn new(
inputs: Vec<f32>,
targets: Vec<f32>,
input_dim: usize,
target_dim: usize,
batch_size: usize,
) -> Self {
let n_samples = inputs.len() / input_dim;
assert_eq!(inputs.len(), n_samples * input_dim);
assert_eq!(targets.len(), n_samples * target_dim);
Self {
inputs,
targets,
input_dim,
target_dim,
n_samples,
batch_size,
position: 0,
}
}
pub fn reset(&mut self) {
self.position = 0;
}
pub fn next_batch(&mut self, device: &GpuDevice) -> Option<(GpuTensor, GpuTensor)> {
if self.position >= self.n_samples {
return None;
}
let end = (self.position + self.batch_size).min(self.n_samples);
let batch = end - self.position;
let in_start = self.position * self.input_dim;
let in_end = end * self.input_dim;
let tgt_start = self.position * self.target_dim;
let tgt_end = end * self.target_dim;
let in_shape = if batch == 1 {
vec![self.input_dim]
} else {
vec![batch, self.input_dim]
};
let tgt_shape = if batch == 1 {
vec![self.target_dim]
} else {
vec![batch, self.target_dim]
};
let input = GpuTensor::from_slice(device, &self.inputs[in_start..in_end], &in_shape);
let target = GpuTensor::from_slice(device, &self.targets[tgt_start..tgt_end], &tgt_shape);
self.position = end;
Some((input, target))
}
pub fn len(&self) -> usize {
self.n_samples
}
pub fn is_empty(&self) -> bool {
self.n_samples == 0
}
pub fn n_batches(&self) -> usize {
(self.n_samples + self.batch_size - 1) / self.batch_size
}
}
pub struct GpuTrainer {
optimizer: GpuAdam,
loss_fn: fn(&GpuDevice, &mut KernelCache, &GpuTensor, &GpuTensor) -> (GpuTensor, GpuTensor),
num_epochs: usize,
}
impl GpuTrainer {
pub fn new(lr: f32, num_epochs: usize) -> Self {
Self {
optimizer: GpuAdam::new(lr),
loss_fn: gpu_mse_loss,
num_epochs,
}
}
pub fn with_loss_fn(
mut self,
f: fn(&GpuDevice, &mut KernelCache, &GpuTensor, &GpuTensor) -> (GpuTensor, GpuTensor),
) -> Self {
self.loss_fn = f;
self
}
pub fn fit(
&mut self,
device: &GpuDevice,
cache: &mut KernelCache,
model: &mut dyn GpuTrainModule,
loader: &mut GpuDataLoader,
) -> Vec<f32> {
let mut epoch_losses = Vec::with_capacity(self.num_epochs);
cache.begin_batch();
for _ in 0..self.num_epochs {
loader.reset();
let mut batch_losses: Vec<GpuTensor> = Vec::new();
while let Some((input, target)) = loader.next_batch(device) {
model.zero_grad();
let pred = model.forward_train(device, cache, &input);
let (loss_tensor, grad) = (self.loss_fn)(device, cache, &pred, &target);
batch_losses.push(loss_tensor);
model.backward(device, cache, &grad);
let grads: Vec<GpuTensor> = model
.gradients()
.into_iter()
.map(|g| {
let g = g.expect("gradient missing after backward");
g.clone_gpu_batched(device, cache)
})
.collect();
let mut params = model.parameters_mut();
self.optimizer.step(device, cache, &mut params, &grads);
}
cache.flush(device);
let n_batches = batch_losses.len().max(1);
let total_loss: f32 = batch_losses
.iter()
.map(|t| t.to_vec_sync(device)[0])
.sum();
epoch_losses.push(total_loss / n_batches as f32);
cache.begin_batch();
}
epoch_losses
}
}
pub struct GradScaler {
scale: f32,
growth_factor: f32,
backoff_factor: f32,
growth_interval: u32,
growth_step: u32,
found_inf: bool,
}
impl GradScaler {
pub fn new() -> Self {
Self {
scale: 65536.0,
growth_factor: 2.0,
backoff_factor: 0.5,
growth_interval: 2000,
growth_step: 0,
found_inf: false,
}
}
pub fn with_scale(initial_scale: f32) -> Self {
Self {
scale: initial_scale,
..Self::new()
}
}
pub fn get_scale(&self) -> f32 {
self.scale
}
pub fn scale_loss(
&self,
device: &GpuDevice,
cache: &mut KernelCache,
loss: &GpuTensor,
) -> GpuTensor {
loss.scale(device, cache, self.scale)
}
pub fn unscale_and_step(
&mut self,
device: &GpuDevice,
cache: &mut KernelCache,
optimizer: &mut GpuAdam,
params: &mut [&mut GpuTensor],
grads: &[GpuTensor],
) -> bool {
let inv_scale = 1.0 / self.scale;
let mut unscaled_grads = Vec::with_capacity(grads.len());
let mut has_inf = false;
for grad in grads {
let unscaled = grad.scale(device, cache, inv_scale);
cache.flush(device);
let data = unscaled.to_vec_sync(device);
if data.iter().any(|&v| v.is_infinite() || v.is_nan()) {
has_inf = true;
break;
}
unscaled_grads.push(unscaled);
}
self.found_inf = has_inf;
if has_inf {
return false;
}
optimizer.step(device, cache, params, &unscaled_grads);
true
}
pub fn update(&mut self) {
if self.found_inf {
self.scale *= self.backoff_factor;
self.growth_step = 0;
} else {
self.growth_step += 1;
if self.growth_step >= self.growth_interval {
self.scale *= self.growth_factor;
self.growth_step = 0;
}
}
}
}
impl Default for GradScaler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::module::{GpuLinear, GpuModule};
fn get_device() -> GpuDevice {
GpuDevice::new_sync().expect("GPU device required for tests")
}
#[test]
fn linear_forward_backward_gradient_check() {
let device = get_device();
let mut cache = KernelCache::new();
let mut linear = GpuLinear::new(
&device,
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
&[0.1, 0.2, 0.3],
2,
3,
);
let input = GpuTensor::from_slice(&device, &[1.0, 0.5], &[2]);
let output = linear.forward_train(&device, &mut cache, &input);
let out_data = output.to_vec_sync(&device);
assert!((out_data[0] - 2.1).abs() < 0.01, "out[0]={}", out_data[0]);
assert!((out_data[1] - 5.2).abs() < 0.01, "out[1]={}", out_data[1]);
assert!((out_data[2] - 8.3).abs() < 0.01, "out[2]={}", out_data[2]);
let grad_out = GpuTensor::from_slice(&device, &[1.0, 1.0, 1.0], &[3]);
let grad_input = linear.backward(&device, &mut cache, &grad_out);
let eps = 1e-3;
let in_data = vec![1.0f32, 0.5];
let gi_data = grad_input.to_vec_sync(&device);
for dim in 0..2 {
let mut plus = in_data.clone();
plus[dim] += eps;
let mut minus = in_data.clone();
minus[dim] -= eps;
let out_p = linear.forward(
&device,
&mut cache,
&GpuTensor::from_slice(&device, &plus, &[2]),
);
let out_m = linear.forward(
&device,
&mut cache,
&GpuTensor::from_slice(&device, &minus, &[2]),
);
let p_data = out_p.to_vec_sync(&device);
let m_data = out_m.to_vec_sync(&device);
let numerical_grad: f32 =
p_data.iter().zip(m_data.iter()).map(|(p, m)| (p - m) / (2.0 * eps)).sum();
assert!(
(gi_data[dim] - numerical_grad).abs() < 0.05,
"input grad[{dim}]: analytical={}, numerical={}",
gi_data[dim],
numerical_grad
);
}
}
#[test]
fn sequential_forward_backward_shapes() {
let device = get_device();
let mut cache = KernelCache::new();
let mut model = GpuSequential::new(vec![
Box::new(GpuLinear::kaiming(&device, 2, 4, 42)),
Box::new(GpuReLULayer::new()),
Box::new(GpuLinear::kaiming(&device, 4, 1, 43)),
]);
let input = GpuTensor::from_slice(&device, &[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let output = model.forward_train(&device, &mut cache, &input);
assert_eq!(output.shape(), &[2, 1]);
let grad = GpuTensor::from_slice(&device, &[1.0, 1.0], &[2, 1]);
let grad_input = model.backward(&device, &mut cache, &grad);
assert_eq!(grad_input.shape(), &[2, 2]);
}
#[test]
fn data_loader_batching() {
let device = get_device();
let inputs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let targets = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let mut loader = GpuDataLoader::new(inputs, targets, 2, 1, 2);
assert_eq!(loader.len(), 5);
assert_eq!(loader.n_batches(), 3);
let (inp, tgt) = loader.next_batch(&device).unwrap();
assert_eq!(inp.shape(), &[2, 2]);
assert_eq!(tgt.shape(), &[2, 1]);
let (inp, tgt) = loader.next_batch(&device).unwrap();
assert_eq!(inp.shape(), &[2, 2]);
assert_eq!(tgt.shape(), &[2, 1]);
let (inp, tgt) = loader.next_batch(&device).unwrap();
assert_eq!(inp.shape(), &[2]);
assert_eq!(tgt.shape(), &[1]);
assert!(loader.next_batch(&device).is_none());
loader.reset();
assert!(loader.next_batch(&device).is_some());
}
#[test]
fn xor_training_convergence() {
let device = get_device();
let mut cache = KernelCache::new();
let inputs = vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0];
let targets = vec![0.0, 1.0, 1.0, 0.0];
let mut loader = GpuDataLoader::new(inputs, targets, 2, 1, 4);
let mut model = GpuSequential::new(vec![
Box::new(GpuLinear::kaiming(&device, 2, 8, 123)),
Box::new(GpuReLULayer::new()),
Box::new(GpuLinear::kaiming(&device, 8, 1, 456)),
]);
let mut trainer = GpuTrainer::new(0.01, 500);
let losses = trainer.fit(&device, &mut cache, &mut model, &mut loader);
let first = losses[0];
let last = *losses.last().unwrap();
assert!(
last < first * 0.1,
"XOR training did not converge: first_loss={first}, last_loss={last}"
);
}
}