#[cfg(feature = "cuda")]
use axonml_core::Device;
#[cfg(feature = "cuda")]
use axonml_core::backends::cuda::get_cuda_backend;
#[cfg(feature = "cuda")]
use axonml_core::backends::cuda_pool::pool_alloc;
#[cfg(feature = "cuda")]
use axonml_core::error::Result;
#[cfg(feature = "cuda")]
use axonml_core::storage::Storage;
#[cfg(feature = "cuda")]
use crate::shape::{Shape, contiguous_strides};
#[cfg(feature = "cuda")]
use crate::tensor::Tensor;
#[cfg(feature = "cuda")]
impl Tensor<f32> {
pub(crate) fn add_cuda(&self, other: &Self) -> Result<Self> {
let a_data = self.contiguous_gpu();
let b_data = other.contiguous_gpu();
let len = a_data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let a_guard = a_data.storage.as_cuda_slice();
let b_guard = b_data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.add_f32(&mut out, a_guard.slice(), b_guard.slice(), len)
.expect("CUDA add_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Ok(Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
})
}
pub(crate) fn sub_cuda(&self, other: &Self) -> Result<Self> {
let a_data = self.contiguous_gpu();
let b_data = other.contiguous_gpu();
let len = a_data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let a_guard = a_data.storage.as_cuda_slice();
let b_guard = b_data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.sub_f32(&mut out, a_guard.slice(), b_guard.slice(), len)
.expect("CUDA sub_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Ok(Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
})
}
pub(crate) fn mul_cuda(&self, other: &Self) -> Result<Self> {
let a_data = self.contiguous_gpu();
let b_data = other.contiguous_gpu();
let len = a_data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let a_guard = a_data.storage.as_cuda_slice();
let b_guard = b_data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.mul_f32(&mut out, a_guard.slice(), b_guard.slice(), len)
.expect("CUDA mul_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Ok(Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
})
}
pub(crate) fn div_cuda(&self, other: &Self) -> Result<Self> {
let a_data = self.contiguous_gpu();
let b_data = other.contiguous_gpu();
let len = a_data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let a_guard = a_data.storage.as_cuda_slice();
let b_guard = b_data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.div_f32(&mut out, a_guard.slice(), b_guard.slice(), len)
.expect("CUDA div_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Ok(Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
})
}
pub(crate) fn broadcast_add_cuda(&self, other: &Self) -> Result<Self> {
let a = self.contiguous_gpu();
let b = other.contiguous_gpu();
let a_n = a.numel();
let b_n = b.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let result_shape = crate::shape::broadcast_shape(&self.shape, &other.shape)?;
let out_n = crate::shape::numel(&result_shape);
let mut out = pool_alloc(out_n).expect("GPU pool alloc failed");
let a_guard = a.storage.as_cuda_slice();
let b_guard = b.storage.as_cuda_slice();
if a_n >= b_n {
if a_n == out_n {
cuda.broadcast_add_f32(&mut out, a_guard.slice(), b_guard.slice(), out_n, b_n)
.expect("CUDA broadcast_add_f32 failed");
} else {
let a_bcast = a.broadcast_to(result_shape.as_slice()).contiguous_gpu();
let a2_guard = a_bcast.storage.as_cuda_slice();
cuda.broadcast_add_f32(&mut out, a2_guard.slice(), b_guard.slice(), out_n, b_n)
.expect("CUDA broadcast_add_f32 failed");
}
} else {
if b_n == out_n {
cuda.broadcast_add_rev_f32(&mut out, a_guard.slice(), b_guard.slice(), out_n, a_n)
.expect("CUDA broadcast_add_rev_f32 failed");
} else {
let b_bcast = b.broadcast_to(result_shape.as_slice()).contiguous_gpu();
let b2_guard = b_bcast.storage.as_cuda_slice();
cuda.broadcast_add_rev_f32(&mut out, a_guard.slice(), b2_guard.slice(), out_n, a_n)
.expect("CUDA broadcast_add_rev_f32 failed");
}
}
let storage = Storage::from_cuda_slice(out, out_n, self.device());
Ok(Self {
storage,
shape: result_shape,
strides: contiguous_strides(&crate::shape::broadcast_shape(&self.shape, &other.shape)?),
offset: 0,
})
}
pub(crate) fn broadcast_sub_cuda(&self, other: &Self) -> Result<Self> {
let a = self.contiguous_gpu();
let b = other.contiguous_gpu();
let a_n = a.numel();
let b_n = b.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let result_shape = crate::shape::broadcast_shape(&self.shape, &other.shape)?;
let out_n = crate::shape::numel(&result_shape);
let mut out = pool_alloc(out_n).expect("GPU pool alloc failed");
let a_guard = a.storage.as_cuda_slice();
let b_guard = b.storage.as_cuda_slice();
if a_n >= b_n {
if a_n == out_n {
cuda.broadcast_sub_f32(&mut out, a_guard.slice(), b_guard.slice(), out_n, b_n)
.expect("CUDA broadcast_sub_f32 failed");
} else {
let a_bcast = a.broadcast_to(result_shape.as_slice()).contiguous_gpu();
let a2_guard = a_bcast.storage.as_cuda_slice();
cuda.broadcast_sub_f32(&mut out, a2_guard.slice(), b_guard.slice(), out_n, b_n)
.expect("CUDA broadcast_sub_f32 failed");
}
} else {
if b_n == out_n {
cuda.broadcast_sub_rev_f32(&mut out, a_guard.slice(), b_guard.slice(), out_n, a_n)
.expect("CUDA broadcast_sub_rev_f32 failed");
} else {
let b_bcast = b.broadcast_to(result_shape.as_slice()).contiguous_gpu();
let b2_guard = b_bcast.storage.as_cuda_slice();
cuda.broadcast_sub_rev_f32(&mut out, a_guard.slice(), b2_guard.slice(), out_n, a_n)
.expect("CUDA broadcast_sub_rev_f32 failed");
}
}
let storage = Storage::from_cuda_slice(out, out_n, self.device());
Ok(Self {
storage,
shape: result_shape,
strides: contiguous_strides(&crate::shape::broadcast_shape(&self.shape, &other.shape)?),
offset: 0,
})
}
pub(crate) fn broadcast_mul_cuda(&self, other: &Self) -> Result<Self> {
let a = self.contiguous_gpu();
let b = other.contiguous_gpu();
let a_n = a.numel();
let b_n = b.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let result_shape = crate::shape::broadcast_shape(&self.shape, &other.shape)?;
let out_n = crate::shape::numel(&result_shape);
let mut out = pool_alloc(out_n).expect("GPU pool alloc failed");
let a_guard = a.storage.as_cuda_slice();
let b_guard = b.storage.as_cuda_slice();
if a_n >= b_n {
if a_n == out_n {
cuda.broadcast_mul_f32(&mut out, a_guard.slice(), b_guard.slice(), out_n, b_n)
.expect("CUDA broadcast_mul_f32 failed");
} else {
let a_bcast = a.broadcast_to(result_shape.as_slice()).contiguous_gpu();
let a2_guard = a_bcast.storage.as_cuda_slice();
cuda.broadcast_mul_f32(&mut out, a2_guard.slice(), b_guard.slice(), out_n, b_n)
.expect("CUDA broadcast_mul_f32 failed");
}
} else {
if b_n == out_n {
cuda.broadcast_mul_rev_f32(&mut out, a_guard.slice(), b_guard.slice(), out_n, a_n)
.expect("CUDA broadcast_mul_rev_f32 failed");
} else {
let b_bcast = b.broadcast_to(result_shape.as_slice()).contiguous_gpu();
let b2_guard = b_bcast.storage.as_cuda_slice();
cuda.broadcast_mul_rev_f32(&mut out, a_guard.slice(), b2_guard.slice(), out_n, a_n)
.expect("CUDA broadcast_mul_rev_f32 failed");
}
}
let storage = Storage::from_cuda_slice(out, out_n, self.device());
Ok(Self {
storage,
shape: result_shape,
strides: contiguous_strides(&crate::shape::broadcast_shape(&self.shape, &other.shape)?),
offset: 0,
})
}
pub(crate) fn broadcast_div_cuda(&self, other: &Self) -> Result<Self> {
let a = self.contiguous_gpu();
let b = other.contiguous_gpu();
let a_n = a.numel();
let b_n = b.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let result_shape = crate::shape::broadcast_shape(&self.shape, &other.shape)?;
let out_n = crate::shape::numel(&result_shape);
let mut out = pool_alloc(out_n).expect("GPU pool alloc failed");
let a_guard = a.storage.as_cuda_slice();
let b_guard = b.storage.as_cuda_slice();
if a_n >= b_n {
if a_n == out_n {
cuda.broadcast_div_f32(&mut out, a_guard.slice(), b_guard.slice(), out_n, b_n)
.expect("CUDA broadcast_div_f32 failed");
} else {
let a_bcast = a.broadcast_to(result_shape.as_slice()).contiguous_gpu();
let a2_guard = a_bcast.storage.as_cuda_slice();
cuda.broadcast_div_f32(&mut out, a2_guard.slice(), b_guard.slice(), out_n, b_n)
.expect("CUDA broadcast_div_f32 failed");
}
} else {
if b_n == out_n {
cuda.broadcast_div_rev_f32(&mut out, a_guard.slice(), b_guard.slice(), out_n, a_n)
.expect("CUDA broadcast_div_rev_f32 failed");
} else {
let b_bcast = b.broadcast_to(result_shape.as_slice()).contiguous_gpu();
let b2_guard = b_bcast.storage.as_cuda_slice();
cuda.broadcast_div_rev_f32(&mut out, a_guard.slice(), b2_guard.slice(), out_n, a_n)
.expect("CUDA broadcast_div_rev_f32 failed");
}
}
let storage = Storage::from_cuda_slice(out, out_n, self.device());
Ok(Self {
storage,
shape: result_shape,
strides: contiguous_strides(&crate::shape::broadcast_shape(&self.shape, &other.shape)?),
offset: 0,
})
}
pub(crate) fn neg_cuda(&self) -> Self {
let data = self.contiguous_gpu();
let len = data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.neg_f32(&mut out, src_guard.slice(), len)
.expect("CUDA neg_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub(crate) fn relu_cuda(&self) -> Self {
let data = self.contiguous_gpu();
let len = data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.relu_f32(&mut out, src_guard.slice(), len)
.expect("CUDA relu_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub(crate) fn sigmoid_cuda(&self) -> Self {
let data = self.contiguous_gpu();
let len = data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.sigmoid_f32(&mut out, src_guard.slice(), len)
.expect("CUDA sigmoid_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub(crate) fn tanh_cuda(&self) -> Self {
let data = self.contiguous_gpu();
let len = data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.tanh_f32(&mut out, src_guard.slice(), len)
.expect("CUDA tanh_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub(crate) fn exp_cuda(&self) -> Self {
let data = self.contiguous_gpu();
let len = data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.exp_f32(&mut out, src_guard.slice(), len)
.expect("CUDA exp_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub(crate) fn ln_cuda(&self) -> Self {
let data = self.contiguous_gpu();
let len = data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.log_f32(&mut out, src_guard.slice(), len)
.expect("CUDA log_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub(crate) fn sqrt_cuda(&self) -> Self {
let data = self.contiguous_gpu();
let len = data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.sqrt_f32(&mut out, src_guard.slice(), len)
.expect("CUDA sqrt_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub(crate) fn pow_cuda(&self, exp: f32) -> Self {
let data = self.contiguous_gpu();
let len = data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.pow_scalar_f32(&mut out, src_guard.slice(), exp, len)
.expect("CUDA pow_scalar_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub(crate) fn gelu_cuda(&self) -> Self {
let data = self.contiguous_gpu();
let len = data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.gelu_f32(&mut out, src_guard.slice(), len)
.expect("CUDA gelu_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub(crate) fn silu_cuda(&self) -> Self {
let data = self.contiguous_gpu();
let len = data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.silu_f32(&mut out, src_guard.slice(), len)
.expect("CUDA silu_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub(crate) fn mul_scalar_cuda(&self, scalar: f32) -> Self {
let data = self.contiguous_gpu();
let len = data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.broadcast_copy_f32(&mut out, src_guard.slice(), len, len)
.expect("CUDA broadcast_copy_f32 failed");
cuda.scale_f32(&mut out, scalar, len)
.expect("CUDA scale_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub(crate) fn add_scalar_cuda(&self, scalar: f32) -> Self {
let data = self.contiguous_gpu();
let len = data.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.add_scalar_f32(&mut out, src_guard.slice(), scalar, len)
.expect("CUDA add_scalar_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub(crate) fn softmax_cuda(&self, dim: i32) -> Result<Self> {
let data = self.contiguous_gpu();
let ndim = data.shape.len();
let total = data.numel();
let d = if dim < 0 { ndim as i32 + dim } else { dim } as usize;
if d == ndim - 1 {
let row_size = data.shape[ndim - 1];
let num_rows = total / row_size;
let cuda = get_cuda_backend().expect("CUDA backend not available");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(total).expect("GPU pool alloc failed");
cuda.broadcast_copy_f32(&mut out, src_guard.slice(), total, total)
.expect("CUDA broadcast_copy_f32 failed");
cuda.softmax_row_f32(&mut out, num_rows, row_size)
.expect("CUDA softmax_row_f32 failed");
let storage = Storage::from_cuda_slice(out, total, self.device());
Ok(Self {
storage,
shape: data.shape.clone(),
strides: contiguous_strides(&data.shape),
offset: 0,
})
} else {
let mut perm: Vec<usize> = (0..ndim).collect();
perm.swap(d, ndim - 1);
let transposed = data.permute(&perm)?;
let t_contig = transposed.contiguous_gpu();
let t_result = t_contig.softmax_cuda(ndim as i32 - 1)?;
Ok(t_result.permute(&perm)?.contiguous_gpu())
}
}
pub(crate) fn broadcast_to_cuda(&self, target_shape: &[usize]) -> Result<Self> {
let data = self.contiguous_gpu();
let src_len = data.numel();
let out_len = crate::shape::numel(target_shape);
let cuda = get_cuda_backend().expect("CUDA backend not available");
if out_len % src_len == 0 {
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(out_len).expect("GPU pool alloc failed");
cuda.broadcast_copy_f32(&mut out, src_guard.slice(), out_len, src_len)
.expect("CUDA broadcast_copy_f32 failed");
let storage = Storage::from_cuda_slice(out, out_len, self.device());
return Ok(Self {
storage,
shape: crate::shape::Shape::from_slice(target_shape),
strides: contiguous_strides(&crate::shape::Shape::from_slice(target_shape)),
offset: 0,
});
}
let result_shape: crate::shape::Shape = target_shape.into();
let src_strides =
crate::shape::broadcast_strides(&data.shape, &data.strides, &result_shape);
let indices: Vec<u32> = (0..out_len)
.map(|i| {
let coords = crate::shape::unravel_index(i, &result_shape);
let src_idx = data.offset + crate::shape::linear_index(&coords, &src_strides);
src_idx as u32
})
.collect();
let idx_gpu = cuda.htod_copy(&indices).expect("htod indices failed");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(out_len).expect("GPU pool alloc failed");
cuda.gather_contiguous_f32(&mut out, src_guard.slice(), &idx_gpu, out_len)
.expect("CUDA gather_contiguous_f32 failed");
let storage = Storage::from_cuda_slice(out, out_len, self.device());
Ok(Self {
storage,
shape: result_shape,
strides: contiguous_strides(&crate::shape::Shape::from_slice(target_shape)),
offset: 0,
})
}
pub(crate) fn matmul_cuda(&self, other: &Self) -> Result<Self> {
let cuda = get_cuda_backend().expect("CUDA backend not available");
fn is_last2_transposed(t: &Tensor<f32>) -> bool {
let nd = t.ndim();
if nd < 2 {
return false;
}
let strides = t.strides.as_slice();
strides[nd - 1] > strides[nd - 2]
}
fn batch_contiguous(t: &Tensor<f32>) -> bool {
let nd = t.ndim();
if nd <= 2 {
return true;
}
let strides = t.strides.as_slice();
let shape = t.shape.as_slice();
let mat_size = shape[nd - 2] * shape[nd - 1];
let mut expected = mat_size as isize;
for i in (0..nd - 2).rev() {
if strides[i] != expected {
return false;
}
expected *= shape[i] as isize;
}
true
}
let a_transposed = is_last2_transposed(self) && batch_contiguous(self) && self.offset == 0;
let b_transposed =
is_last2_transposed(other) && batch_contiguous(other) && other.offset == 0;
let a = if a_transposed {
self.clone()
} else {
self.contiguous_gpu()
};
let b = if b_transposed {
other.clone()
} else {
other.contiguous_gpu()
};
let m = a.shape[a.shape.len() - 2];
let k = a.shape[a.shape.len() - 1];
let n = b.shape[b.shape.len() - 1];
if m == 0 || k == 0 || n == 0 {
let out_shape: Vec<usize> = if a.shape.len() == 2 {
vec![m, n]
} else {
let mut s: Vec<usize> = a.shape[..a.shape.len() - 2].to_vec();
s.push(m);
s.push(n);
s
};
let total: usize = out_shape.iter().product();
return Ok(Self::from_vec(vec![0.0f32; total], &out_shape)?);
}
if a.shape.len() == 2 && b.shape.len() == 2 {
let a_guard = a.storage.as_cuda_slice();
let b_guard = b.storage.as_cuda_slice();
let mut c_gpu = pool_alloc(m * n).map_err(|e| crate::Error::InvalidOperation {
message: format!("GPU OOM in 2D matmul ({}x{}x{}): {}", m, k, n, e),
})?;
let (lda, op_a) = if a_transposed { (m, true) } else { (k, false) };
let (ldb, op_b) = if b_transposed { (k, true) } else { (n, false) };
let lda_min = if op_a { m } else { k };
let ldb_min = if op_b { k } else { n };
assert!(
lda >= lda_min.max(1),
"cuBLAS lda={} < min={} (m={}, k={}, op_a={})",
lda,
lda_min,
m,
k,
op_a
);
assert!(
ldb >= ldb_min.max(1),
"cuBLAS ldb={} < min={} (k={}, n={}, op_b={})",
ldb,
ldb_min,
k,
n,
op_b
);
assert!(n >= 1, "cuBLAS ldc=n={} must be >= 1", n);
cuda.gemm_f32(
op_b,
op_a,
n,
m,
k,
1.0,
b_guard.slice(),
ldb,
a_guard.slice(),
lda,
0.0,
&mut c_gpu,
n,
)
.expect("cuBLAS gemm failed");
let storage = Storage::from_cuda_slice(c_gpu, m * n, self.device());
return Ok(Self {
storage,
shape: Shape::from_slice(&[m, n]),
strides: contiguous_strides(&Shape::from_slice(&[m, n])),
offset: 0,
});
}
let batch_dims: Vec<usize> = a.shape[..a.shape.len() - 2].to_vec();
let batch_size: usize = batch_dims.iter().product();
if batch_size == 0 || m == 0 || k == 0 || n == 0 {
let mut out_shape = batch_dims.clone();
out_shape.push(m);
out_shape.push(n);
let total: usize = out_shape.iter().product();
return Ok(Self::from_vec(vec![0.0f32; total.max(1)], &out_shape)?);
}
let total = batch_size * m * n;
let a_guard = a.storage.as_cuda_slice();
let b_guard = b.storage.as_cuda_slice();
let (cublas_transa, cublas_lda) = if b_transposed {
(true, k) } else {
(false, n) };
let (cublas_transb, cublas_ldb) = if a_transposed {
(true, m) } else {
(false, k) };
let cublas_ldc = n;
let _stride_a = (k * n) as i64;
let _stride_b = (m * k) as i64;
let _stride_c = (m * n) as i64;
let a_mat_size = m * k;
let b_mat_size = k * n;
let c_mat_size = m * n;
let a_vec = a.to_vec();
let b_vec = b.to_vec();
let mut c_vec = vec![0.0f32; total];
for bi in 0..batch_size {
let a_slice = &a_vec[bi * a_mat_size..(bi + 1) * a_mat_size];
let b_slice = &b_vec[bi * b_mat_size..(bi + 1) * b_mat_size];
let a_gpu_i = cuda
.htod_copy(a_slice)
.map_err(|e| crate::Error::InvalidOperation {
message: format!("htod A batch {}: {:?}", bi, e),
})?;
let b_gpu_i = cuda
.htod_copy(b_slice)
.map_err(|e| crate::Error::InvalidOperation {
message: format!("htod B batch {}: {:?}", bi, e),
})?;
let mut c_gpu_i =
cuda.alloc::<f32>(c_mat_size)
.map_err(|e| crate::Error::InvalidOperation {
message: format!("alloc C batch {}: {:?}", bi, e),
})?;
cuda.gemm_f32(
cublas_transa,
cublas_transb,
n,
m,
k,
1.0,
&b_gpu_i,
cublas_lda,
&a_gpu_i,
cublas_ldb,
0.0,
&mut c_gpu_i,
cublas_ldc,
)
.map_err(|e| crate::Error::InvalidOperation {
message: format!("cuBLAS gemm batch {}/{} failed: {:?}", bi, batch_size, e),
})?;
let c_result =
cuda.dtoh_copy(&c_gpu_i)
.map_err(|e| crate::Error::InvalidOperation {
message: format!("dtoh C batch {}: {:?}", bi, e),
})?;
c_vec[bi * c_mat_size..(bi + 1) * c_mat_size].copy_from_slice(&c_result);
}
drop(a_guard);
drop(b_guard);
let c_gpu_src = cuda
.htod_copy(&c_vec)
.map_err(|e| crate::Error::InvalidOperation {
message: format!("htod C result: {:?}", e),
})?;
let mut c_gpu = pool_alloc(total).map_err(|e| crate::Error::InvalidOperation {
message: format!("pool alloc C result: {:?}", e),
})?;
cuda.broadcast_copy_f32(&mut c_gpu, &c_gpu_src, total, total)
.map_err(|e| crate::Error::InvalidOperation {
message: format!("copy C to pool: {:?}", e),
})?;
let mut output_shape = batch_dims;
output_shape.push(m);
output_shape.push(n);
let storage = Storage::from_cuda_slice(c_gpu, total, self.device());
Ok(Self {
storage,
shape: Shape::from_slice(&output_shape),
strides: contiguous_strides(&Shape::from_slice(&output_shape)),
offset: 0,
})
}
pub(crate) fn to_vec_gpu(&self) -> Vec<f32> {
self.storage.to_vec_f32()
}
pub(crate) fn contiguous_gpu(&self) -> Self {
if self.is_contiguous() && self.offset == 0 {
return self.clone();
}
let total = self.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let ndim = self.shape.len();
let offset = self.offset;
let shape = self.shape.as_slice();
let strides = self.strides.as_slice();
let shape_u32: Vec<u32> = shape.iter().map(|&s| s as u32).collect();
let strides_i64: Vec<i64> = strides.iter().map(|&s| s as i64).collect();
let shape_gpu = cuda.htod_copy(&shape_u32).expect("htod shape failed");
let strides_gpu = cuda.htod_copy(&strides_i64).expect("htod strides failed");
let src_guard = self.storage.as_cuda_slice();
let mut out = pool_alloc(total).expect("GPU pool alloc failed");
cuda.strided_gather_f32(
src_guard.slice(),
&mut out,
&strides_gpu,
&shape_gpu,
ndim,
offset,
total,
)
.expect("CUDA strided_gather_f32 failed");
let storage = Storage::from_cuda_slice(out, total, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub fn to_device_f32(&self, device: Device) -> Result<Self> {
if self.device() == device {
return Ok(self.clone());
}
let contig = if self.storage.is_gpu() {
self.contiguous_gpu()
} else {
self.contiguous()
};
let new_storage = contig.storage.to_device_f32(device)?;
Ok(Self {
storage: new_storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
})
}
pub fn layer_norm_cuda(
&self,
gamma: &Self,
beta: &Self,
norm_size: usize,
eps: f32,
) -> Result<Self> {
let input_data = self.contiguous_gpu();
let total_len = input_data.numel();
let num_rows = total_len / norm_size;
let cuda = get_cuda_backend().expect("CUDA backend not available");
let input_guard = input_data.storage.as_cuda_slice();
let gamma_guard = gamma.storage.as_cuda_slice();
let beta_guard = beta.storage.as_cuda_slice();
let mut out = pool_alloc(total_len).expect("GPU pool alloc failed for LayerNorm");
cuda.layer_norm_f32(
&mut out,
input_guard.slice(),
gamma_guard.slice(),
beta_guard.slice(),
norm_size,
eps,
num_rows,
)
.expect("CUDA layer_norm_f32 failed");
let storage = Storage::from_cuda_slice(out, total_len, self.device());
Ok(Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
})
}
pub fn embedding_gather_cuda(&self, gather_indices: &[u32], output_shape: &[usize]) -> Self {
let output_size = output_shape.iter().product::<usize>();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let idx_gpu = cuda
.htod_copy(gather_indices)
.expect("htod gather indices failed");
let weight_guard = self.storage.as_cuda_slice();
let mut out = pool_alloc(output_size).expect("GPU pool alloc failed");
cuda.gather_contiguous_f32(&mut out, weight_guard.slice(), &idx_gpu, output_size)
.expect("CUDA gather_contiguous_f32 failed");
let storage = Storage::from_cuda_slice(out, output_size, self.device());
Self {
storage,
shape: crate::shape::Shape::from_slice(output_shape),
strides: contiguous_strides(&crate::shape::Shape::from_slice(output_shape)),
offset: 0,
}
}
pub fn embedding_scatter_add_cuda(
&self,
indices: &[u32],
num_embeddings: usize,
emb_dim: usize,
) -> Self {
let cuda = get_cuda_backend().expect("CUDA backend not available");
let num_indices = indices.len();
let total_n = num_indices * emb_dim;
let idx_gpu = cuda.htod_copy(indices).expect("htod indices failed");
let out_size = num_embeddings * emb_dim;
let mut out = pool_alloc(out_size).expect("GPU pool alloc failed");
cuda.memset_zeros_f32(&mut out)
.expect("memset zeros failed");
let grad = self.contiguous_gpu();
let grad_guard = grad.storage.as_cuda_slice();
cuda.embedding_scatter_add_f32(grad_guard.slice(), &idx_gpu, &mut out, total_n, emb_dim)
.expect("CUDA embedding_scatter_add_f32 failed");
let shape = crate::shape::Shape::from_slice(&[num_embeddings, emb_dim]);
let storage = Storage::from_cuda_slice(out, out_size, self.device());
Self {
storage,
shape: shape.clone(),
strides: contiguous_strides(&shape),
offset: 0,
}
}
#[allow(clippy::too_many_arguments)]
pub fn adam_step_inplace(
&self,
grad: &Self,
exp_avg: &Self,
exp_avg_sq: &Self,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
bias_correction1: f32,
bias_correction2: f32,
) {
let n = self.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let mut param_guard = self.storage.as_cuda_slice_mut();
let grad_guard = grad.storage.as_cuda_slice();
let mut avg_guard = exp_avg.storage.as_cuda_slice_mut();
let mut sq_guard = exp_avg_sq.storage.as_cuda_slice_mut();
cuda.adam_step_f32(
param_guard.slice_mut(),
grad_guard.slice(),
avg_guard.slice_mut(),
sq_guard.slice_mut(),
n,
lr,
beta1,
beta2,
eps,
weight_decay,
bias_correction1,
bias_correction2,
)
.expect("CUDA adam_step_f32 failed");
}
pub fn clip_grad_norm_cuda(grads: &[Self], max_norm: f32) -> f32 {
if grads.is_empty() {
return 0.0;
}
let cuda = get_cuda_backend().expect("CUDA backend not available");
let mut acc = pool_alloc(1).expect("GPU pool alloc failed");
cuda.memset_zeros_f32(&mut acc).expect("memset failed");
for grad in grads {
let data = grad.contiguous_gpu();
let n = data.numel();
let guard = data.storage.as_cuda_slice();
cuda.grad_norm_sq_f32(guard.slice(), &mut acc, n)
.expect("CUDA grad_norm_sq_f32 failed");
}
let result = cuda.dtoh_copy(&acc).expect("dtoh failed");
let total_norm = result[0].sqrt();
if total_norm > max_norm {
let scale = max_norm / (total_norm + 1e-6);
for grad in grads {
let n = grad.numel();
let mut guard = grad.storage.as_cuda_slice_mut();
cuda.grad_scale_f32(guard.slice_mut(), n, scale)
.expect("CUDA grad_scale_f32 failed");
}
}
total_norm
}
pub fn grad_scale_inplace(&self, scale: f32) {
let n = self.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let mut guard = self.storage.as_cuda_slice_mut();
cuda.grad_scale_f32(guard.slice_mut(), n, scale)
.expect("CUDA grad_scale_f32 failed");
}
pub(crate) fn sum_dim_cuda(&self, dim: usize) -> Self {
let data = self.contiguous_gpu();
let ndim = data.shape.len();
let outer_size: usize = data.shape[..dim].iter().product();
let dim_size = data.shape[dim];
let inner_size: usize = data.shape[dim + 1..].iter().product();
let out_len = outer_size * inner_size;
let cuda = get_cuda_backend().expect("CUDA backend not available");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(out_len).expect("GPU pool alloc failed");
cuda.sum_dim_f32(
&mut out,
src_guard.slice(),
outer_size,
dim_size,
inner_size,
)
.expect("CUDA sum_dim_f32 failed");
let mut out_shape: Vec<usize> = Vec::with_capacity(ndim - 1);
for (i, &s) in data.shape.iter().enumerate() {
if i != dim {
out_shape.push(s);
}
}
if out_shape.is_empty() {
out_shape.push(1);
}
let shape = Shape::from_slice(&out_shape);
let storage = Storage::from_cuda_slice(out, out_len, self.device());
Self {
storage,
shape: shape.clone(),
strides: contiguous_strides(&shape),
offset: 0,
}
}
pub(crate) fn sum_dim_keepdim_cuda(&self, dim: usize) -> Self {
let data = self.contiguous_gpu();
let outer_size: usize = data.shape[..dim].iter().product();
let dim_size = data.shape[dim];
let inner_size: usize = data.shape[dim + 1..].iter().product();
let out_len = outer_size * inner_size;
let cuda = get_cuda_backend().expect("CUDA backend not available");
let src_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(out_len).expect("GPU pool alloc failed");
cuda.sum_dim_f32(
&mut out,
src_guard.slice(),
outer_size,
dim_size,
inner_size,
)
.expect("CUDA sum_dim_f32 failed");
let mut out_shape: Vec<usize> = data.shape.to_vec();
out_shape[dim] = 1;
let shape = Shape::from_slice(&out_shape);
let storage = Storage::from_cuda_slice(out, out_len, self.device());
Self {
storage,
shape: shape.clone(),
strides: contiguous_strides(&shape),
offset: 0,
}
}
pub fn relu_backward_cuda(&self, input: &Self) -> Self {
let grad = self.contiguous_gpu();
let inp = input.contiguous_gpu();
let len = grad.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let grad_guard = grad.storage.as_cuda_slice();
let inp_guard = inp.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.relu_backward_f32(&mut out, grad_guard.slice(), inp_guard.slice(), len)
.expect("CUDA relu_backward_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub fn sigmoid_backward_cuda(&self, output: &Self) -> Self {
let grad = self.contiguous_gpu();
let out_data = output.contiguous_gpu();
let len = grad.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let grad_guard = grad.storage.as_cuda_slice();
let out_guard = out_data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.sigmoid_backward_f32(&mut out, grad_guard.slice(), out_guard.slice(), len)
.expect("CUDA sigmoid_backward_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub fn softmax_backward_cuda(&self, softmax_output: &Self) -> Self {
let grad = self.contiguous_gpu();
let sout = softmax_output.contiguous_gpu();
let total = grad.numel();
let ndim = grad.shape.len();
let row_size = grad.shape[ndim - 1];
let num_rows = total / row_size;
let cuda = get_cuda_backend().expect("CUDA backend not available");
let grad_guard = grad.storage.as_cuda_slice();
let sout_guard = sout.storage.as_cuda_slice();
let mut out = pool_alloc(total).expect("GPU pool alloc failed");
cuda.softmax_backward_row_f32(
&mut out,
sout_guard.slice(),
grad_guard.slice(),
num_rows,
row_size,
)
.expect("CUDA softmax_backward_row_f32 failed");
let storage = Storage::from_cuda_slice(out, total, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub fn layer_norm_backward_dinput_cuda(
&self,
input: &Self,
gamma: &Self,
norm_size: usize,
eps: f32,
) -> Self {
let grad = self.contiguous_gpu();
let inp = input.contiguous_gpu();
let total = grad.numel();
let num_rows = total / norm_size;
let cuda = get_cuda_backend().expect("CUDA backend not available");
let grad_guard = grad.storage.as_cuda_slice();
let inp_guard = inp.storage.as_cuda_slice();
let gamma_guard = gamma.storage.as_cuda_slice();
let mut out = pool_alloc(total).expect("GPU pool alloc failed");
cuda.layer_norm_backward_dinput_f32(
&mut out,
grad_guard.slice(),
inp_guard.slice(),
gamma_guard.slice(),
norm_size,
eps,
num_rows,
)
.expect("CUDA layer_norm_backward_dinput_f32 failed");
let storage = Storage::from_cuda_slice(out, total, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub fn layer_norm_backward_dweight_dbias_cuda(
&self,
input: &Self,
norm_size: usize,
eps: f32,
) -> (Self, Self) {
let grad = self.contiguous_gpu();
let inp = input.contiguous_gpu();
let total = grad.numel();
let num_rows = total / norm_size;
let cuda = get_cuda_backend().expect("CUDA backend not available");
let grad_guard = grad.storage.as_cuda_slice();
let inp_guard = inp.storage.as_cuda_slice();
let mut d_weight = pool_alloc(norm_size).expect("GPU pool alloc failed");
let mut d_bias = pool_alloc(norm_size).expect("GPU pool alloc failed");
cuda.layer_norm_backward_dweight_dbias_f32(
&mut d_weight,
&mut d_bias,
grad_guard.slice(),
inp_guard.slice(),
norm_size,
eps,
num_rows,
)
.expect("CUDA layer_norm_backward_dweight_dbias_f32 failed");
let w_shape = Shape::from_slice(&[norm_size]);
let dw = Self {
storage: Storage::from_cuda_slice(d_weight, norm_size, self.device()),
shape: w_shape.clone(),
strides: contiguous_strides(&w_shape),
offset: 0,
};
let db = Self {
storage: Storage::from_cuda_slice(d_bias, norm_size, self.device()),
shape: w_shape.clone(),
strides: contiguous_strides(&w_shape),
offset: 0,
};
(dw, db)
}
pub fn tanh_backward_cuda(&self, output: &Self) -> Self {
let grad = self.contiguous_gpu();
let out_data = output.contiguous_gpu();
let len = grad.numel();
let cuda = get_cuda_backend().expect("CUDA backend not available");
let grad_guard = grad.storage.as_cuda_slice();
let out_guard = out_data.storage.as_cuda_slice();
let mut out = pool_alloc(len).expect("GPU pool alloc failed");
cuda.tanh_backward_f32(&mut out, grad_guard.slice(), out_guard.slice(), len)
.expect("CUDA tanh_backward_f32 failed");
let storage = Storage::from_cuda_slice(out, len, self.device());
Self {
storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
}
}
pub fn cross_entropy_fwd_cuda(&self, targets: &Self) -> (Self, Self) {
let logits = self.contiguous_gpu();
let tgt = targets.contiguous_gpu();
let batch_size = logits.shape[0];
let num_classes = logits.shape[1];
let cuda = get_cuda_backend().expect("CUDA backend not available");
let logits_guard = logits.storage.as_cuda_slice();
let tgt_guard = tgt.storage.as_cuda_slice();
let mut losses_gpu = pool_alloc(batch_size).expect("GPU pool alloc");
let mut softmax_gpu = pool_alloc(batch_size * num_classes).expect("GPU pool alloc");
cuda.cross_entropy_fwd_f32(
logits_guard.slice(),
tgt_guard.slice(),
&mut losses_gpu,
&mut softmax_gpu,
batch_size,
num_classes,
)
.expect("CUDA cross_entropy_fwd_f32 failed");
let loss_shape = Shape::from_slice(&[batch_size]);
let losses = Self {
storage: Storage::from_cuda_slice(losses_gpu, batch_size, self.device()),
shape: loss_shape.clone(),
strides: contiguous_strides(&loss_shape),
offset: 0,
};
let sm_shape = Shape::from_slice(&[batch_size, num_classes]);
let softmax = Self {
storage: Storage::from_cuda_slice(softmax_gpu, batch_size * num_classes, self.device()),
shape: sm_shape.clone(),
strides: contiguous_strides(&sm_shape),
offset: 0,
};
(losses, softmax)
}
pub fn cross_entropy_bwd_cuda(&self, targets: &Self, grad_output: &Self) -> Self {
let softmax = self.contiguous_gpu();
let tgt = targets.contiguous_gpu();
let grad_out = grad_output.contiguous_gpu();
let batch_size = softmax.shape[0];
let num_classes = softmax.shape[1];
let total = batch_size * num_classes;
let cuda = get_cuda_backend().expect("CUDA backend not available");
let sm_guard = softmax.storage.as_cuda_slice();
let tgt_guard = tgt.storage.as_cuda_slice();
let grad_guard = grad_out.storage.as_cuda_slice();
let mut grad_input = pool_alloc(total).expect("GPU pool alloc");
cuda.cross_entropy_bwd_f32(
sm_guard.slice(),
tgt_guard.slice(),
grad_guard.slice(),
&mut grad_input,
batch_size,
num_classes,
)
.expect("CUDA cross_entropy_bwd_f32 failed");
let out_shape = Shape::from_slice(&[batch_size, num_classes]);
Self {
storage: Storage::from_cuda_slice(grad_input, total, self.device()),
shape: out_shape.clone(),
strides: contiguous_strides(&out_shape),
offset: 0,
}
}
pub fn narrow_backward_cuda(&self, input_shape: &[usize], dim: usize, start: usize) -> Self {
let numel: usize = input_shape.iter().product();
let cuda = get_cuda_backend().expect("CUDA backend");
let mut dst = pool_alloc(numel).expect("GPU pool alloc for narrow_backward");
cuda.memset_zeros_f32(&mut dst)
.expect("CUDA memset_zeros failed");
let grad_contig = self.contiguous_gpu();
let src_guard = grad_contig.storage.as_cuda_slice();
let inner_size: usize = input_shape[dim + 1..].iter().product::<usize>().max(1);
let offset_elements = start * inner_size;
let outer_size: usize = input_shape[..dim].iter().product::<usize>().max(1);
let dim_full = input_shape[dim];
let dim_narrow = self.shape()[dim];
let block_src = dim_narrow * inner_size;
let block_dst = dim_full * inner_size;
if outer_size == 1 {
cuda.memcpy_dtod_f32(
&mut dst,
offset_elements,
src_guard.slice(),
0,
grad_contig.shape.iter().product::<usize>(),
)
.expect("CUDA memcpy_dtod failed");
} else {
for o in 0..outer_size {
let src_off = o * block_src;
let dst_off = o * block_dst + offset_elements;
cuda.memcpy_dtod_f32(&mut dst, dst_off, src_guard.slice(), src_off, block_src)
.expect("CUDA memcpy_dtod failed");
}
}
let out_shape = Shape::from_slice(input_shape);
Self {
storage: Storage::from_cuda_slice(dst, numel, self.device()),
shape: out_shape.clone(),
strides: contiguous_strides(&out_shape),
offset: 0,
}
}
pub fn mask_expand_cuda(
&self,
output_shape: &[usize],
batch_size: usize,
num_heads: usize,
tgt_len: usize,
src_len: usize,
) -> Option<Self> {
let cuda = get_cuda_backend()?;
let data = self.contiguous_gpu();
let mask_shape = &data.shape;
let total: usize = output_shape.iter().product();
let mask_guard = data.storage.as_cuda_slice();
let mut out = pool_alloc(total).ok()?;
let result = if mask_shape.len() == 2
&& mask_shape[0] == tgt_len
&& mask_shape[1] == src_len
{
cuda.mask_expand_causal_f32(mask_guard.slice(), &mut out, total, tgt_len, src_len)
} else if mask_shape.len() == 2 && mask_shape[0] == batch_size && mask_shape[1] == src_len {
cuda.mask_expand_padding_f32(
mask_guard.slice(),
&mut out,
total,
num_heads,
tgt_len,
src_len,
)
} else {
return None;
};
result.ok()?;
let out_shape = Shape::from_slice(output_shape);
Some(Self {
storage: Storage::from_cuda_slice(out, total, self.device()),
shape: out_shape.clone(),
strides: contiguous_strides(&out_shape),
offset: 0,
})
}
pub fn lstm_gates_fused(&self, c_prev: &Self, hidden_size: usize) -> Option<(Self, Self)> {
let batch_size = self.shape()[0];
let total = batch_size * hidden_size;
let cuda = get_cuda_backend()?;
let gates_contig = self.contiguous_gpu();
let c_contig = c_prev.contiguous_gpu();
let gates_guard = gates_contig.storage.as_cuda_slice();
let c_guard = c_contig.storage.as_cuda_slice();
let mut h_out = pool_alloc(total).ok()?;
let mut c_out = pool_alloc(total).ok()?;
cuda.lstm_gates_f32(
gates_guard.slice(),
c_guard.slice(),
&mut h_out,
&mut c_out,
hidden_size,
total,
)
.ok()?;
let h_storage = Storage::from_cuda_slice(h_out, total, self.device());
let c_storage = Storage::from_cuda_slice(c_out, total, self.device());
let sh = Shape::from_slice(&[batch_size, hidden_size]);
let h_tensor = Self {
storage: h_storage,
shape: sh.clone(),
strides: contiguous_strides(&sh),
offset: 0,
};
let c_tensor = Self {
storage: c_storage,
shape: sh.clone(),
strides: contiguous_strides(&sh),
offset: 0,
};
Some((h_tensor, c_tensor))
}
pub fn gru_gates_fused(
&self,
gates_hh: &Self,
h_prev: &Self,
hidden_size: usize,
) -> Option<Self> {
let batch_size = self.shape()[0];
let total = batch_size * hidden_size;
let cuda = get_cuda_backend()?;
let ih_contig = self.contiguous_gpu();
let hh_contig = gates_hh.contiguous_gpu();
let h_contig = h_prev.contiguous_gpu();
let ih_guard = ih_contig.storage.as_cuda_slice();
let hh_guard = hh_contig.storage.as_cuda_slice();
let h_guard = h_contig.storage.as_cuda_slice();
let mut h_out = pool_alloc(total).ok()?;
cuda.gru_gates_f32(
ih_guard.slice(),
hh_guard.slice(),
h_guard.slice(),
&mut h_out,
hidden_size,
total,
)
.ok()?;
let h_storage = Storage::from_cuda_slice(h_out, total, self.device());
let sh = Shape::from_slice(&[batch_size, hidden_size]);
Some(Self {
storage: h_storage,
shape: sh.clone(),
strides: contiguous_strides(&sh),
offset: 0,
})
}
pub fn lstm_gates_backward_fused(
&self,
c_prev: &Self,
c_new: &Self,
grad_h: &Self,
grad_c_next: &Self,
hidden_size: usize,
) -> Option<(Self, Self)> {
let batch_size = grad_h.shape()[0];
let total = batch_size * hidden_size;
let cuda = get_cuda_backend()?;
let gates_contig = self.contiguous_gpu();
let c_prev_contig = c_prev.contiguous_gpu();
let c_new_contig = c_new.contiguous_gpu();
let grad_h_contig = grad_h.contiguous_gpu();
let grad_c_contig = grad_c_next.contiguous_gpu();
let gates_guard = gates_contig.storage.as_cuda_slice();
let c_prev_guard = c_prev_contig.storage.as_cuda_slice();
let c_new_guard = c_new_contig.storage.as_cuda_slice();
let grad_h_guard = grad_h_contig.storage.as_cuda_slice();
let grad_c_guard = grad_c_contig.storage.as_cuda_slice();
let mut grad_gates_out = pool_alloc(batch_size * 4 * hidden_size).ok()?;
let mut grad_c_prev_out = pool_alloc(total).ok()?;
cuda.lstm_gates_backward_f32(
gates_guard.slice(),
c_prev_guard.slice(),
c_new_guard.slice(),
grad_h_guard.slice(),
grad_c_guard.slice(),
&mut grad_gates_out,
&mut grad_c_prev_out,
hidden_size,
total,
)
.ok()?;
let grad_gates_storage =
Storage::from_cuda_slice(grad_gates_out, batch_size * 4 * hidden_size, self.device());
let grad_c_prev_storage = Storage::from_cuda_slice(grad_c_prev_out, total, self.device());
let sh_gates = Shape::from_slice(&[batch_size, 4 * hidden_size]);
let sh_hidden = Shape::from_slice(&[batch_size, hidden_size]);
let grad_gates_tensor = Self {
storage: grad_gates_storage,
shape: sh_gates.clone(),
strides: contiguous_strides(&sh_gates),
offset: 0,
};
let grad_c_prev_tensor = Self {
storage: grad_c_prev_storage,
shape: sh_hidden.clone(),
strides: contiguous_strides(&sh_hidden),
offset: 0,
};
Some((grad_gates_tensor, grad_c_prev_tensor))
}
pub fn gru_gates_backward_fused(
&self,
gates_hh: &Self,
h_prev: &Self,
grad_h_new: &Self,
hidden_size: usize,
) -> Option<(Self, Self, Self)> {
let batch_size = grad_h_new.shape()[0];
let total = batch_size * hidden_size;
let cuda = get_cuda_backend()?;
let ih_contig = self.contiguous_gpu();
let hh_contig = gates_hh.contiguous_gpu();
let h_contig = h_prev.contiguous_gpu();
let grad_contig = grad_h_new.contiguous_gpu();
let ih_guard = ih_contig.storage.as_cuda_slice();
let hh_guard = hh_contig.storage.as_cuda_slice();
let h_guard = h_contig.storage.as_cuda_slice();
let grad_guard = grad_contig.storage.as_cuda_slice();
let mut grad_ih_out = pool_alloc(batch_size * 3 * hidden_size).ok()?;
let mut grad_hh_out = pool_alloc(batch_size * 3 * hidden_size).ok()?;
let mut grad_h_prev_out = pool_alloc(total).ok()?;
cuda.gru_gates_backward_f32(
ih_guard.slice(),
hh_guard.slice(),
h_guard.slice(),
grad_guard.slice(),
&mut grad_ih_out,
&mut grad_hh_out,
&mut grad_h_prev_out,
hidden_size,
total,
)
.ok()?;
let grad_ih_storage =
Storage::from_cuda_slice(grad_ih_out, batch_size * 3 * hidden_size, self.device());
let grad_hh_storage =
Storage::from_cuda_slice(grad_hh_out, batch_size * 3 * hidden_size, self.device());
let grad_h_prev_storage = Storage::from_cuda_slice(grad_h_prev_out, total, self.device());
let sh_3h = Shape::from_slice(&[batch_size, 3 * hidden_size]);
let sh_h = Shape::from_slice(&[batch_size, hidden_size]);
let grad_ih_tensor = Self {
storage: grad_ih_storage,
shape: sh_3h.clone(),
strides: contiguous_strides(&sh_3h),
offset: 0,
};
let grad_hh_tensor = Self {
storage: grad_hh_storage,
shape: sh_3h.clone(),
strides: contiguous_strides(&sh_3h),
offset: 0,
};
let grad_h_prev_tensor = Self {
storage: grad_h_prev_storage,
shape: sh_h.clone(),
strides: contiguous_strides(&sh_h),
offset: 0,
};
Some((grad_ih_tensor, grad_hh_tensor, grad_h_prev_tensor))
}
pub fn batchnorm_fused(
&self,
gamma: &Self,
beta: &Self,
eps: f32,
channels: usize,
spatial: usize,
) -> Option<(Self, Vec<f32>, Vec<f32>)> {
let cuda = get_cuda_backend()?;
let total = self.numel();
let n = total / (channels * spatial);
let input_contig = self.contiguous_gpu();
let gamma_contig = gamma.contiguous_gpu();
let beta_contig = beta.contiguous_gpu();
let input_guard = input_contig.storage.as_cuda_slice();
let gamma_guard = gamma_contig.storage.as_cuda_slice();
let beta_guard = beta_contig.storage.as_cuda_slice();
let zeros_c = vec![0.0f32; channels];
let mut sum_gpu = cuda.htod_copy(&zeros_c).ok()?;
let mut sum_sq_gpu = cuda.htod_copy(&zeros_c).ok()?;
cuda.batchnorm_stats_f32(
input_guard.slice(),
&mut sum_gpu,
&mut sum_sq_gpu,
n,
channels,
spatial,
)
.ok()?;
let sum_cpu = cuda.dtoh_copy::<f32>(&sum_gpu).ok()?;
let sum_sq_cpu = cuda.dtoh_copy::<f32>(&sum_sq_gpu).ok()?;
let n_per_ch = (n * spatial) as f32;
let mut mean_cpu = vec![0.0f32; channels];
let mut var_cpu = vec![0.0f32; channels];
for c in 0..channels {
mean_cpu[c] = sum_cpu[c] / n_per_ch;
var_cpu[c] = sum_sq_cpu[c] / n_per_ch - mean_cpu[c] * mean_cpu[c];
}
let mean_gpu = cuda.htod_copy(&mean_cpu).ok()?;
let var_gpu = cuda.htod_copy(&var_cpu).ok()?;
let mut out_gpu = pool_alloc(total).ok()?;
cuda.batchnorm_norm_f32(
input_guard.slice(),
&mean_gpu,
&var_gpu,
gamma_guard.slice(),
beta_guard.slice(),
&mut out_gpu,
eps,
channels,
spatial,
total,
)
.ok()?;
let out_storage = Storage::from_cuda_slice(out_gpu, total, self.device());
let out_tensor = Self {
storage: out_storage,
shape: self.shape.clone(),
strides: contiguous_strides(&self.shape),
offset: 0,
};
Some((out_tensor, mean_cpu, var_cpu))
}
pub fn conv2d_cuda(
&self,
weight: &Self,
bias: Option<&Self>,
stride: (usize, usize),
padding: (usize, usize),
) -> Option<Self> {
if !self.device().is_gpu() || !weight.device().is_gpu() {
return None;
}
if let Some(b) = bias {
if !b.device().is_gpu() {
return None;
}
}
let cuda = get_cuda_backend()?;
let batch_size = self.shape[0];
let in_channels = self.shape[1];
let in_height = self.shape[2];
let in_width = self.shape[3];
let out_channels = weight.shape[0];
let kernel_h = weight.shape[2];
let kernel_w = weight.shape[3];
let (stride_h, stride_w) = stride;
let (pad_h, pad_w) = padding;
let out_h = (in_height + 2 * pad_h - kernel_h) / stride_h + 1;
let out_w = (in_width + 2 * pad_w - kernel_w) / stride_w + 1;
let col_h = in_channels * kernel_h * kernel_w;
let col_w = out_h * out_w;
let col_n = col_h * col_w;
let spatial = out_h * out_w;
let out_per_batch = out_channels * spatial;
let in_per_batch = in_channels * in_height * in_width;
let input_data = self.contiguous_gpu();
let weight_data = weight.contiguous_gpu();
let input_guard = input_data.storage.as_cuda_slice();
let weight_guard = weight_data.storage.as_cuda_slice();
let im2col_params: [u32; 10] = [
in_height as u32,
in_width as u32,
kernel_h as u32,
kernel_w as u32,
pad_h as u32,
pad_w as u32,
stride_h as u32,
stride_w as u32,
out_h as u32,
out_w as u32,
];
let params_gpu = cuda.htod_copy(&im2col_params[..]).ok()?;
let bias_data = bias.map(|b| b.contiguous_gpu());
let bias_guard = bias_data.as_ref().map(|b| b.storage.as_cuda_slice());
let mut col_gpu = pool_alloc(col_n).ok()?;
let mut input_batch_gpu = pool_alloc(in_per_batch).ok()?;
let mut batch_out_gpu = pool_alloc(out_per_batch).ok()?;
let total_out = batch_size * out_per_batch;
let mut out_gpu = pool_alloc(total_out).ok()?;
for b in 0..batch_size {
cuda.memcpy_dtod_f32(
&mut input_batch_gpu,
0,
input_guard.slice(),
b * in_per_batch,
in_per_batch,
)
.ok()?;
cuda.im2col_f32(&input_batch_gpu, &mut col_gpu, ¶ms_gpu, col_n)
.ok()?;
cuda.gemm_f32(
false,
false,
col_w,
out_channels,
col_h,
1.0,
&col_gpu,
col_w,
weight_guard.slice(),
col_h,
0.0,
&mut batch_out_gpu,
col_w,
)
.ok()?;
if let Some(ref bg) = bias_guard {
cuda.bias_add_channels_f32(&mut batch_out_gpu, bg.slice(), spatial, out_per_batch)
.ok()?;
}
cuda.memcpy_dtod_f32(
&mut out_gpu,
b * out_per_batch,
&batch_out_gpu,
0,
out_per_batch,
)
.ok()?;
}
let out_shape = Shape::from_slice(&[batch_size, out_channels, out_h, out_w]);
Some(Self {
storage: Storage::from_cuda_slice(out_gpu, total_out, self.device()),
shape: out_shape.clone(),
strides: contiguous_strides(&out_shape),
offset: 0,
})
}
#[cfg(feature = "cudnn")]
pub fn conv2d_cudnn(
&self,
weight: &Self,
bias: Option<&Self>,
stride: (usize, usize),
padding: (usize, usize),
groups: usize,
) -> Option<Self> {
if !self.device().is_gpu() || !weight.device().is_gpu() {
return None;
}
if let Some(b) = bias {
if !b.device().is_gpu() {
return None;
}
}
let cuda = get_cuda_backend()?;
let cudnn_handle = cuda.cudnn()?;
let batch_size = self.shape[0];
let in_channels = self.shape[1];
let in_height = self.shape[2];
let in_width = self.shape[3];
let out_channels = weight.shape[0];
let kernel_h = weight.shape[2];
let kernel_w = weight.shape[3];
let (stride_h, stride_w) = stride;
let (pad_h, pad_w) = padding;
let out_h = (in_height + 2 * pad_h - kernel_h) / stride_h + 1;
let out_w = (in_width + 2 * pad_w - kernel_w) / stride_w + 1;
let input_contig = self.contiguous_gpu();
let weight_contig = weight.contiguous_gpu();
let input_guard = input_contig.storage.as_cuda_slice();
let weight_guard = weight_contig.storage.as_cuda_slice();
let bias_contig = bias.map(|b| b.contiguous_gpu());
let bias_guard = bias_contig.as_ref().map(|b| b.storage.as_cuda_slice());
let output_slice = axonml_core::backends::cudnn_ops::cudnn_conv2d_forward(
cudnn_handle,
cuda.stream(),
cuda,
input_guard.slice(),
weight_guard.slice(),
bias_guard.as_ref().map(|g| g.slice()),
batch_size,
in_channels,
in_height,
in_width,
out_channels,
kernel_h,
kernel_w,
stride,
padding,
groups,
)?;
let total_out = batch_size * out_channels * out_h * out_w;
let out_shape = Shape::from_slice(&[batch_size, out_channels, out_h, out_w]);
Some(Self {
storage: Storage::from_cuda_slice(output_slice, total_out, self.device()),
shape: out_shape.clone(),
strides: contiguous_strides(&out_shape),
offset: 0,
})
}
pub fn conv2d_grouped_cuda(
&self,
weight: &Self,
bias: Option<&Self>,
stride: (usize, usize),
padding: (usize, usize),
groups: usize,
) -> Option<Self> {
if !self.device().is_gpu() || !weight.device().is_gpu() {
return None;
}
if let Some(b) = bias {
if !b.device().is_gpu() {
return None;
}
}
let cuda = get_cuda_backend()?;
let batch_size = self.shape[0];
let in_channels = self.shape[1];
let in_height = self.shape[2];
let in_width = self.shape[3];
let out_channels = weight.shape[0];
let kernel_h = weight.shape[2];
let kernel_w = weight.shape[3];
let (stride_h, stride_w) = stride;
let (pad_h, pad_w) = padding;
let in_channels_per_group = in_channels / groups;
let out_channels_per_group = out_channels / groups;
let out_h = (in_height + 2 * pad_h - kernel_h) / stride_h + 1;
let out_w = (in_width + 2 * pad_w - kernel_w) / stride_w + 1;
let col_h = in_channels_per_group * kernel_h * kernel_w;
let col_w = out_h * out_w;
let col_n = col_h * col_w;
let spatial = out_h * out_w;
let in_spatial = in_height * in_width;
let out_per_batch = out_channels * spatial;
let input_data = self.contiguous_gpu();
let weight_data = weight.contiguous_gpu();
let input_guard = input_data.storage.as_cuda_slice();
let weight_guard = weight_data.storage.as_cuda_slice();
let params_arr: [u32; 10] = [
in_height as u32,
in_width as u32,
kernel_h as u32,
kernel_w as u32,
pad_h as u32,
pad_w as u32,
stride_h as u32,
stride_w as u32,
out_h as u32,
out_w as u32,
];
let params_gpu = cuda.htod_copy(¶ms_arr[..]).ok()?;
let bias_data = bias.map(|b| b.contiguous_gpu());
let bias_guard = bias_data.as_ref().map(|b| b.storage.as_cuda_slice());
let mut col_gpu = pool_alloc(col_n).ok()?;
let mut input_group_gpu = pool_alloc(in_channels_per_group * in_spatial).ok()?;
let mut group_out_gpu = pool_alloc(out_channels_per_group * spatial).ok()?;
let total_out = batch_size * out_per_batch;
let mut out_gpu = pool_alloc(total_out).ok()?;
for b in 0..batch_size {
for g in 0..groups {
let ic_start = g * in_channels_per_group;
let oc_start = g * out_channels_per_group;
let in_group_size = in_channels_per_group * in_spatial;
let in_offset = b * in_channels * in_spatial + ic_start * in_spatial;
cuda.memcpy_dtod_f32(
&mut input_group_gpu,
0,
input_guard.slice(),
in_offset,
in_group_size,
)
.ok()?;
cuda.im2col_f32(&input_group_gpu, &mut col_gpu, ¶ms_gpu, col_n)
.ok()?;
let w_offset = oc_start * in_channels_per_group * kernel_h * kernel_w;
let w_size = out_channels_per_group * col_h;
let mut weight_group_gpu = pool_alloc(w_size).ok()?;
cuda.memcpy_dtod_f32(
&mut weight_group_gpu,
0,
weight_guard.slice(),
w_offset,
w_size,
)
.ok()?;
cuda.gemm_f32(
false,
false,
col_w,
out_channels_per_group,
col_h,
1.0,
&col_gpu,
col_w,
&weight_group_gpu,
col_h,
0.0,
&mut group_out_gpu,
col_w,
)
.ok()?;
if let Some(ref bg) = bias_guard {
let mut bias_group = pool_alloc(out_channels_per_group).ok()?;
cuda.memcpy_dtod_f32(
&mut bias_group,
0,
bg.slice(),
oc_start,
out_channels_per_group,
)
.ok()?;
cuda.bias_add_channels_f32(
&mut group_out_gpu,
&bias_group,
spatial,
out_channels_per_group * spatial,
)
.ok()?;
}
let out_offset = b * out_per_batch + oc_start * spatial;
cuda.memcpy_dtod_f32(
&mut out_gpu,
out_offset,
&group_out_gpu,
0,
out_channels_per_group * spatial,
)
.ok()?;
}
}
let out_shape = Shape::from_slice(&[batch_size, out_channels, out_h, out_w]);
Some(Self {
storage: Storage::from_cuda_slice(out_gpu, total_out, self.device()),
shape: out_shape.clone(),
strides: contiguous_strides(&out_shape),
offset: 0,
})
}
pub fn conv2d_backward_cuda(
&self,
saved_input: &Self,
saved_weight: &Self,
input_shape: &[usize],
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
has_bias: bool,
) -> Option<(Self, Self, Option<Self>)> {
if !self.device().is_gpu()
|| !saved_input.device().is_gpu()
|| !saved_weight.device().is_gpu()
{
return None;
}
let cuda = get_cuda_backend()?;
let batch_size = input_shape[0];
let in_h = input_shape[2];
let in_w = input_shape[3];
let (kh, kw) = kernel_size;
let (sh, sw) = stride;
let (ph, pw) = padding;
let out_h = self.shape[2];
let out_w = self.shape[3];
let col_h = in_channels * kh * kw;
let col_w = out_h * out_w;
let col_n = col_h * col_w;
let spatial = out_h * out_w;
let in_per_batch = in_channels * in_h * in_w;
let out_per_batch = out_channels * spatial;
let grad_out_data = self.contiguous_gpu();
let input_data = saved_input.contiguous_gpu();
let weight_data = saved_weight.contiguous_gpu();
let grad_out_guard = grad_out_data.storage.as_cuda_slice();
let input_guard = input_data.storage.as_cuda_slice();
let weight_guard = weight_data.storage.as_cuda_slice();
let params_arr: [u32; 10] = [
in_h as u32,
in_w as u32,
kh as u32,
kw as u32,
ph as u32,
pw as u32,
sh as u32,
sw as u32,
out_h as u32,
out_w as u32,
];
let params_gpu = cuda.htod_copy(¶ms_arr[..]).ok()?;
let mut col_gpu = pool_alloc(col_n).ok()?;
let mut grad_out_batch = pool_alloc(out_per_batch).ok()?;
let mut input_batch = pool_alloc(in_per_batch).ok()?;
let weight_n = out_channels * col_h;
let mut grad_weight_gpu = pool_alloc(weight_n).ok()?;
let zeros_w = vec![0.0f32; weight_n];
let zeros_gpu = cuda.htod_copy(&zeros_w).ok()?;
cuda.memcpy_dtod_f32(&mut grad_weight_gpu, 0, &zeros_gpu, 0, weight_n)
.ok()?;
let total_input = batch_size * in_per_batch;
let mut grad_input_gpu = pool_alloc(total_input).ok()?;
let mut zero_batch = pool_alloc(in_per_batch).ok()?;
{
let zeros_in = vec![0.0f32; in_per_batch];
let z = cuda.htod_copy(&zeros_in).ok()?;
cuda.memcpy_dtod_f32(&mut zero_batch, 0, &z, 0, in_per_batch)
.ok()?;
}
let mut grad_bias_gpu = if has_bias {
let gb = pool_alloc(out_channels).ok()?;
Some(gb)
} else {
None
};
if let Some(ref mut gb) = grad_bias_gpu {
let zeros_b = vec![0.0f32; out_channels];
let zb = cuda.htod_copy(&zeros_b).ok()?;
cuda.memcpy_dtod_f32(gb, 0, &zb, 0, out_channels).ok()?;
}
for b in 0..batch_size {
cuda.memcpy_dtod_f32(
&mut grad_out_batch,
0,
grad_out_guard.slice(),
b * out_per_batch,
out_per_batch,
)
.ok()?;
cuda.memcpy_dtod_f32(
&mut input_batch,
0,
input_guard.slice(),
b * in_per_batch,
in_per_batch,
)
.ok()?;
cuda.gemm_f32(
false,
false,
spatial,
col_h,
out_channels,
1.0,
&grad_out_batch,
spatial,
weight_guard.slice(),
out_channels, 0.0,
&mut col_gpu,
spatial,
)
.ok()?;
let gi_offset = b * in_per_batch;
cuda.memcpy_dtod_f32(&mut grad_input_gpu, gi_offset, &zero_batch, 0, in_per_batch)
.ok()?;
let mut gi_batch = pool_alloc(in_per_batch).ok()?;
cuda.memcpy_dtod_f32(&mut gi_batch, 0, &zero_batch, 0, in_per_batch)
.ok()?;
cuda.col2im_f32(&col_gpu, &mut gi_batch, ¶ms_gpu, col_n)
.ok()?;
cuda.memcpy_dtod_f32(&mut grad_input_gpu, gi_offset, &gi_batch, 0, in_per_batch)
.ok()?;
cuda.im2col_f32(&input_batch, &mut col_gpu, ¶ms_gpu, col_n)
.ok()?;
cuda.gemm_f32(
true,
false,
col_h,
out_channels,
spatial,
1.0,
&col_gpu,
spatial,
&grad_out_batch,
spatial,
1.0,
&mut grad_weight_gpu,
col_h,
)
.ok()?;
if let Some(ref mut gb) = grad_bias_gpu {
let go_cpu = cuda.dtoh_copy(&grad_out_batch).ok()?;
let mut bias_acc = cuda.dtoh_copy(gb).ok()?;
for oc in 0..out_channels {
let mut sum = 0.0f32;
for s in 0..spatial {
sum += go_cpu[oc * spatial + s];
}
bias_acc[oc] += sum;
}
let ba_gpu = cuda.htod_copy(&bias_acc).ok()?;
cuda.memcpy_dtod_f32(gb, 0, &ba_gpu, 0, out_channels).ok()?;
}
}
let gi_shape = Shape::from_slice(input_shape);
let grad_input_t = Self {
storage: Storage::from_cuda_slice(grad_input_gpu, total_input, self.device()),
shape: gi_shape.clone(),
strides: contiguous_strides(&gi_shape),
offset: 0,
};
let gw_shape = Shape::from_slice(&[out_channels, in_channels, kh, kw]);
let grad_weight_t = Self {
storage: Storage::from_cuda_slice(grad_weight_gpu, weight_n, self.device()),
shape: gw_shape.clone(),
strides: contiguous_strides(&gw_shape),
offset: 0,
};
let grad_bias_t = grad_bias_gpu.map(|gb| {
let gb_shape = Shape::from_slice(&[out_channels]);
Self {
storage: Storage::from_cuda_slice(gb, out_channels, self.device()),
shape: gb_shape.clone(),
strides: contiguous_strides(&gb_shape),
offset: 0,
}
});
Some((grad_input_t, grad_weight_t, grad_bias_t))
}
pub fn maxpool2d_cuda(
&self,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
) -> Option<(Self, Vec<i32>)> {
if !self.device().is_gpu() {
return None;
}
let cuda = get_cuda_backend()?;
let batch = self.shape[0];
let channels = self.shape[1];
let in_h = self.shape[2];
let in_w = self.shape[3];
let (kh, kw) = kernel_size;
let (sh, sw) = stride;
let (ph, pw) = padding;
let out_h = (in_h + 2 * ph - kh) / sh + 1;
let out_w = (in_w + 2 * pw - kw) / sw + 1;
let total = batch * channels * out_h * out_w;
let input_data = self.contiguous_gpu();
let input_guard = input_data.storage.as_cuda_slice();
let params: [u32; 8] = [
in_h as u32,
in_w as u32,
kh as u32,
kw as u32,
sh as u32,
sw as u32,
ph as u32,
pw as u32,
];
let params_gpu = cuda.htod_copy(¶ms[..]).ok()?;
let mut output_gpu = pool_alloc(total).ok()?;
let mut indices_gpu = cuda.alloc::<i32>(total).ok()?;
cuda.maxpool2d_fwd_f32(
input_guard.slice(),
&mut output_gpu,
&mut indices_gpu,
¶ms_gpu,
channels,
out_h,
out_w,
total,
)
.ok()?;
let indices = cuda.dtoh_copy(&indices_gpu).ok()?;
let out_shape = Shape::from_slice(&[batch, channels, out_h, out_w]);
let output = Self {
storage: Storage::from_cuda_slice(output_gpu, total, self.device()),
shape: out_shape.clone(),
strides: contiguous_strides(&out_shape),
offset: 0,
};
Some((output, indices))
}
pub fn avgpool2d_cuda(
&self,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
count_include_pad: bool,
) -> Option<Self> {
if !self.device().is_gpu() {
return None;
}
let cuda = get_cuda_backend()?;
let batch = self.shape[0];
let channels = self.shape[1];
let in_h = self.shape[2];
let in_w = self.shape[3];
let (kh, kw) = kernel_size;
let (sh, sw) = stride;
let (ph, pw) = padding;
let out_h = (in_h + 2 * ph - kh) / sh + 1;
let out_w = (in_w + 2 * pw - kw) / sw + 1;
let total = batch * channels * out_h * out_w;
let input_data = self.contiguous_gpu();
let input_guard = input_data.storage.as_cuda_slice();
let params: [u32; 9] = [
in_h as u32,
in_w as u32,
kh as u32,
kw as u32,
sh as u32,
sw as u32,
ph as u32,
pw as u32,
count_include_pad as u32,
];
let params_gpu = cuda.htod_copy(¶ms[..]).ok()?;
let mut output_gpu = pool_alloc(total).ok()?;
cuda.avgpool2d_fwd_f32(
input_guard.slice(),
&mut output_gpu,
¶ms_gpu,
channels,
out_h,
out_w,
total,
)
.ok()?;
let out_shape = Shape::from_slice(&[batch, channels, out_h, out_w]);
Some(Self {
storage: Storage::from_cuda_slice(output_gpu, total, self.device()),
shape: out_shape.clone(),
strides: contiguous_strides(&out_shape),
offset: 0,
})
}
pub fn fused_attention_cuda(
&self,
k: &Self,
v: &Self,
scale: f32,
is_causal: bool,
) -> Option<Self> {
let cuda = get_cuda_backend()?;
let q_shape = self.shape();
assert!(q_shape.len() == 4, "Q must be [B, H, Tq, D]");
let batch_size = q_shape[0];
let num_heads = q_shape[1];
let tgt_len = q_shape[2];
let head_dim = q_shape[3];
let src_len = k.shape()[2];
let total_out = batch_size * num_heads * tgt_len * head_dim;
let q_contig = self.contiguous_gpu();
let k_contig = k.contiguous_gpu();
let v_contig = v.contiguous_gpu();
let q_guard = q_contig.storage.as_cuda_slice();
let k_guard = k_contig.storage.as_cuda_slice();
let v_guard = v_contig.storage.as_cuda_slice();
let mut out_gpu = pool_alloc(total_out).ok()?;
cuda.fused_attention_fwd_f32(
q_guard.slice(),
k_guard.slice(),
v_guard.slice(),
&mut out_gpu,
scale,
batch_size,
num_heads,
tgt_len,
src_len,
head_dim,
is_causal,
)
.ok()?;
let out_shape = Shape::from_slice(&[batch_size, num_heads, tgt_len, head_dim]);
Some(Self {
storage: Storage::from_cuda_slice(out_gpu, total_out, self.device()),
shape: out_shape.clone(),
strides: contiguous_strides(&out_shape),
offset: 0,
})
}
pub fn fused_attention_bwd_cuda(
&self,
k: &Self,
v: &Self,
output: &Self,
grad_output: &Self,
scale: f32,
is_causal: bool,
) -> Option<(Self, Self, Self)> {
let cuda = get_cuda_backend()?;
let q_shape = self.shape();
assert!(q_shape.len() == 4, "Q must be [B, H, Tq, D]");
let batch_size = q_shape[0];
let num_heads = q_shape[1];
let tgt_len = q_shape[2];
let head_dim = q_shape[3];
let src_len = k.shape()[2];
let total_q = batch_size * num_heads * tgt_len * head_dim;
let total_kv = batch_size * num_heads * src_len * head_dim;
let q_contig = self.contiguous_gpu();
let k_contig = k.contiguous_gpu();
let v_contig = v.contiguous_gpu();
let o_contig = output.contiguous_gpu();
let go_contig = grad_output.contiguous_gpu();
let q_guard = q_contig.storage.as_cuda_slice();
let k_guard = k_contig.storage.as_cuda_slice();
let v_guard = v_contig.storage.as_cuda_slice();
let o_guard = o_contig.storage.as_cuda_slice();
let go_guard = go_contig.storage.as_cuda_slice();
let mut gq_gpu = cuda.htod_copy(&vec![0.0f32; total_q]).ok()?;
let mut gk_gpu = cuda.htod_copy(&vec![0.0f32; total_kv]).ok()?;
let mut gv_gpu = cuda.htod_copy(&vec![0.0f32; total_kv]).ok()?;
cuda.fused_attention_bwd_f32(
q_guard.slice(),
k_guard.slice(),
v_guard.slice(),
o_guard.slice(),
go_guard.slice(),
&mut gq_gpu,
&mut gk_gpu,
&mut gv_gpu,
scale,
batch_size,
num_heads,
tgt_len,
src_len,
head_dim,
is_causal,
)
.ok()?;
let q_out_shape = Shape::from_slice(&[batch_size, num_heads, tgt_len, head_dim]);
let kv_out_shape = Shape::from_slice(&[batch_size, num_heads, src_len, head_dim]);
let grad_q = Self {
storage: Storage::from_cuda_slice(gq_gpu, total_q, self.device()),
shape: q_out_shape.clone(),
strides: contiguous_strides(&q_out_shape),
offset: 0,
};
let grad_k = Self {
storage: Storage::from_cuda_slice(gk_gpu, total_kv, self.device()),
shape: kv_out_shape.clone(),
strides: contiguous_strides(&kv_out_shape),
offset: 0,
};
let grad_v = Self {
storage: Storage::from_cuda_slice(gv_gpu, total_kv, self.device()),
shape: kv_out_shape.clone(),
strides: contiguous_strides(&kv_out_shape),
offset: 0,
};
Some((grad_q, grad_k, grad_v))
}
}