use crate::cuda_sys::cudaStream_t;
use crate::tensor::CudaTensor;
use crate::DType;
use tl_backend::{BackendError, BackendResult};
extern "C" {
fn launch_transpose_2d_kernel(
input: *const f32,
output: *mut f32,
rows: i32,
cols: i32,
stream: cudaStream_t,
);
fn launch_narrow_kernel(
input: *const f32,
output: *mut f32,
outer: i32,
inner: i32,
old_dim: i32,
new_dim: i32,
start: i32,
stream: cudaStream_t,
);
fn launch_cat_kernel(
a: *const f32,
b: *const f32,
output: *mut f32,
outer: i32,
inner: i32,
a_dim: i32,
b_dim: i32,
stream: cudaStream_t,
);
fn launch_broadcast_to_kernel(
input: *const f32,
output: *mut f32,
target_shape: *const i32,
src_shape: *const i32,
ndim: i32,
total: i32,
stream: cudaStream_t,
);
fn launch_transpose_nd_kernel(
input: *const f32,
output: *mut f32,
old_shape: *const i32,
old_strides: *const i32,
new_strides: *const i32,
ndim: i32,
total: i32,
dim0: i32,
dim1: i32,
stream: cudaStream_t,
);
}
impl CudaTensor {
pub fn reshape_impl(&self, new_shape: &[usize]) -> BackendResult<CudaTensor> {
let old_count = self.elem_count();
let new_count: usize = new_shape.iter().product();
if old_count != new_count {
return Err(BackendError::ShapeMismatch(format!(
"Cannot reshape from {:?} ({}) to {:?} ({})",
self.shape(),
old_count,
new_shape,
new_count
)));
}
Ok(self.view_with_shape(new_shape))
}
pub fn squeeze_impl(&self, dim: usize) -> BackendResult<CudaTensor> {
let shape = self.shape().to_vec();
if dim >= shape.len() || shape[dim] != 1 {
return self.clone_data();
}
let mut new_shape = shape;
new_shape.remove(dim);
if new_shape.is_empty() {
new_shape.push(1);
}
Ok(self.view_with_shape(&new_shape))
}
pub fn unsqueeze_impl(&self, dim: usize) -> BackendResult<CudaTensor> {
let mut new_shape = self.shape().to_vec();
if dim > new_shape.len() {
return Err(BackendError::ArgumentError(format!(
"unsqueeze dim {} out of range for ndim {}",
dim,
new_shape.len()
)));
}
new_shape.insert(dim, 1);
Ok(self.view_with_shape(&new_shape))
}
pub fn contiguous_impl(&self) -> BackendResult<CudaTensor> {
self.clone_data()
}
pub fn transpose_impl(&self, dim0: usize, dim1: usize) -> BackendResult<CudaTensor> {
let shape = self.shape().to_vec();
let ndim = shape.len();
if dim0 >= ndim || dim1 >= ndim {
return Err(BackendError::ArgumentError(format!(
"transpose dims ({}, {}) out of range for ndim {}",
dim0, dim1, ndim
)));
}
if dim0 == dim1 {
return self.clone_data();
}
let mut new_shape = shape.clone();
new_shape.swap(dim0, dim1);
if ndim == 2 {
let output = CudaTensor::uninit(&new_shape, DType::F32);
let stream = crate::stream::get_stream().raw();
unsafe {
launch_transpose_2d_kernel(
self.buffer.ptr() as *const f32,
output.buffer.ptr() as *mut f32,
shape[0] as i32,
shape[1] as i32,
stream,
);
}
crate::stream::sync_stream();
return Ok(output);
}
let total = self.elem_count();
let mut old_strides = vec![1i32; ndim];
for d in (0..ndim - 1).rev() {
old_strides[d] = old_strides[d + 1] * shape[d + 1] as i32;
}
let mut new_strides = vec![1i32; ndim];
for d in (0..ndim - 1).rev() {
new_strides[d] = new_strides[d + 1] * new_shape[d + 1] as i32;
}
let old_shape_i32: Vec<i32> = shape.iter().map(|&s| s as i32).collect();
let shape_gpu = CudaTensor::from_slice(
unsafe { std::slice::from_raw_parts(old_shape_i32.as_ptr() as *const f32, ndim) },
&[ndim],
DType::F32,
);
let old_strides_gpu = CudaTensor::from_slice(
unsafe { std::slice::from_raw_parts(old_strides.as_ptr() as *const f32, ndim) },
&[ndim],
DType::F32,
);
let new_strides_gpu = CudaTensor::from_slice(
unsafe { std::slice::from_raw_parts(new_strides.as_ptr() as *const f32, ndim) },
&[ndim],
DType::F32,
);
let output = CudaTensor::uninit(&new_shape, DType::F32);
let stream = crate::stream::get_stream().raw();
unsafe {
launch_transpose_nd_kernel(
self.buffer.ptr() as *const f32,
output.buffer.ptr() as *mut f32,
shape_gpu.buffer.ptr() as *const i32,
old_strides_gpu.buffer.ptr() as *const i32,
new_strides_gpu.buffer.ptr() as *const i32,
ndim as i32,
total as i32,
dim0 as i32,
dim1 as i32,
stream,
);
}
crate::stream::sync_stream();
Ok(output)
}
pub fn narrow_impl(&self, dim: usize, start: usize, len: usize) -> BackendResult<CudaTensor> {
let shape = self.shape().to_vec();
let ndim = shape.len();
if dim >= ndim {
return Err(BackendError::ArgumentError(format!(
"narrow dim {} >= ndim {}",
dim, ndim
)));
}
if start + len > shape[dim] {
return Err(BackendError::ArgumentError(format!(
"narrow: start {} + len {} > dim size {}",
start, len, shape[dim]
)));
}
let outer: usize = shape[..dim].iter().product::<usize>().max(1);
let inner: usize = shape[dim + 1..].iter().product::<usize>().max(1);
let old_dim = shape[dim];
let mut new_shape = shape.clone();
new_shape[dim] = len;
let output = CudaTensor::uninit(&new_shape, DType::F32);
let stream = crate::stream::get_stream().raw();
unsafe {
launch_narrow_kernel(
self.buffer.ptr() as *const f32,
output.buffer.ptr() as *mut f32,
outer as i32,
inner as i32,
old_dim as i32,
len as i32,
start as i32,
stream,
);
}
crate::stream::sync_stream();
Ok(output)
}
pub fn slice_impl(&self, dim: usize, start: usize, len: usize) -> BackendResult<CudaTensor> {
self.narrow_impl(dim, start, len)
}
pub fn cat_impl(&self, other: &CudaTensor, dim: usize) -> BackendResult<CudaTensor> {
let a_shape = self.shape().to_vec();
let b_shape = other.shape().to_vec();
if a_shape.len() != b_shape.len() {
return Err(BackendError::ShapeMismatch(format!(
"cat: rank mismatch {:?} vs {:?}",
a_shape, b_shape
)));
}
for (i, (&a, &b)) in a_shape.iter().zip(b_shape.iter()).enumerate() {
if i != dim && a != b {
return Err(BackendError::ShapeMismatch(format!(
"cat: dim {} mismatch: {} vs {}",
i, a, b
)));
}
}
let outer: usize = a_shape[..dim].iter().product::<usize>().max(1);
let inner: usize = a_shape[dim + 1..].iter().product::<usize>().max(1);
let a_dim = a_shape[dim];
let b_dim = b_shape[dim];
let mut out_shape = a_shape.clone();
out_shape[dim] = a_dim + b_dim;
let output = CudaTensor::uninit(&out_shape, DType::F32);
let stream = crate::stream::get_stream().raw();
unsafe {
launch_cat_kernel(
self.buffer.ptr() as *const f32,
other.buffer.ptr() as *const f32,
output.buffer.ptr() as *mut f32,
outer as i32,
inner as i32,
a_dim as i32,
b_dim as i32,
stream,
);
}
crate::stream::sync_stream();
Ok(output)
}
pub fn broadcast_to_impl(&self, target_shape: &[usize]) -> BackendResult<CudaTensor> {
let src_shape = self.shape();
let ndim = target_shape.len();
let out_count: usize = target_shape.iter().product();
let mut padded_src = vec![1i32; ndim];
let offset = ndim - src_shape.len();
for (i, &s) in src_shape.iter().enumerate() {
padded_src[offset + i] = s as i32;
}
let target_i32: Vec<i32> = target_shape.iter().map(|&s| s as i32).collect();
let target_gpu = CudaTensor::from_slice(
unsafe { std::slice::from_raw_parts(target_i32.as_ptr() as *const f32, ndim) },
&[ndim],
DType::F32,
);
let src_gpu = CudaTensor::from_slice(
unsafe { std::slice::from_raw_parts(padded_src.as_ptr() as *const f32, ndim) },
&[ndim],
DType::F32,
);
let output = CudaTensor::uninit(target_shape, DType::F32);
let stream = crate::stream::get_stream().raw();
unsafe {
launch_broadcast_to_kernel(
self.buffer.ptr() as *const f32,
output.buffer.ptr() as *mut f32,
target_gpu.buffer.ptr() as *const i32,
src_gpu.buffer.ptr() as *const i32,
ndim as i32,
out_count as i32,
stream,
);
}
crate::stream::sync_stream();
Ok(output)
}
}