use super::super::kernels::launch_scalar_op_half;
use super::super::kernels::{
AccumulationPrecision, launch_binary_op, launch_broadcast_binary_op,
launch_broadcast_compare_op, launch_compare_op, launch_gemv_kernel_bt_mr,
launch_matmul_batched_kernel, launch_matmul_bias_batched_kernel, launch_matmul_bias_kernel,
launch_matmul_kernel, launch_reduce_dim_op, launch_scalar_op_f32, launch_scalar_op_f64,
launch_semiring_matmul_batched_kernel, launch_semiring_matmul_kernel, launch_unary_op,
};
use super::super::kernels::{
launch_scalar_op_c64, launch_scalar_op_c128, launch_scalar_op_i32, launch_scalar_op_i64,
};
use super::super::{CudaClient, CudaRuntime};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{matmul_bias_output_shape, matmul_output_shape, reduce_output_shape};
use crate::runtime::{compute_broadcast_shape, ensure_contiguous, validate_binary_dtypes};
use crate::tensor::Tensor;
fn is_simple_transpose_2d(tensor: &Tensor<CudaRuntime>) -> bool {
let shape = tensor.shape();
let strides = tensor.strides();
if shape.len() != 2 {
return false;
}
strides[0] == 1 && strides[1] == shape[0] as isize
}
pub(crate) fn matmul_native(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
b: &Tensor<CudaRuntime>,
dtype: DType,
m: usize,
k: usize,
n: usize,
) -> Result<Tensor<CudaRuntime>> {
let out_shape = matmul_output_shape(a.shape(), b.shape()).ok_or(Error::ShapeMismatch {
expected: a.shape().to_vec(),
got: b.shape().to_vec(),
})?;
if m <= 16 && is_simple_transpose_2d(b) {
let a_contig = ensure_contiguous(a);
let out = Tensor::<CudaRuntime>::empty(&out_shape, dtype, &client.device);
unsafe {
launch_gemv_kernel_bt_mr(
&client.context,
&client.stream,
client.device.index,
dtype,
a_contig.ptr(),
b.ptr(), out.ptr(),
1, m,
n,
k,
1, 1, )?;
}
return Ok(out);
}
let a_contig = ensure_contiguous(a);
let b_contig = ensure_contiguous(b);
let out = Tensor::<CudaRuntime>::empty(&out_shape, dtype, &client.device);
unsafe {
launch_matmul_kernel(
&client.context,
&client.stream,
client.device.index,
dtype,
a_contig.ptr(),
b_contig.ptr(),
out.ptr(),
m,
n,
k,
)?;
}
Ok(out)
}
fn is_batched_transpose_last2(tensor: &Tensor<CudaRuntime>) -> bool {
let shape = tensor.shape();
let strides = tensor.strides();
if shape.len() != 3 {
return false;
}
let k = shape[1];
let n = shape[2];
strides[1] == 1 && strides[2] == k as isize && strides[0] == (n * k) as isize
}
fn compute_batch_counts(a_shape: &[usize], b_shape: &[usize]) -> (usize, usize) {
let a_batch: usize = a_shape
.iter()
.take(a_shape.len().saturating_sub(2))
.product();
let b_batch: usize = b_shape
.iter()
.take(b_shape.len().saturating_sub(2))
.product();
(a_batch.max(1), b_batch.max(1))
}
pub(crate) fn matmul_batched_native(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
b: &Tensor<CudaRuntime>,
dtype: DType,
batch: usize,
m: usize,
k: usize,
n: usize,
) -> Result<Tensor<CudaRuntime>> {
let out_shape = matmul_output_shape(a.shape(), b.shape()).ok_or(Error::ShapeMismatch {
expected: a.shape().to_vec(),
got: b.shape().to_vec(),
})?;
let (a_batch, b_batch) = compute_batch_counts(a.shape(), b.shape());
if m <= 16 && is_batched_transpose_last2(b) {
let a_contig = ensure_contiguous(a);
let out = Tensor::<CudaRuntime>::empty(&out_shape, dtype, &client.device);
unsafe {
launch_gemv_kernel_bt_mr(
&client.context,
&client.stream,
client.device.index,
dtype,
a_contig.ptr(),
b.ptr(),
out.ptr(),
batch,
m,
n,
k,
a_batch,
b_batch,
)?;
}
return Ok(out);
}
let a_contig = ensure_contiguous(a);
let b_contig = ensure_contiguous(b);
let out = Tensor::<CudaRuntime>::empty(&out_shape, dtype, &client.device);
unsafe {
launch_matmul_batched_kernel(
&client.context,
&client.stream,
client.device.index,
dtype,
a_contig.ptr(),
b_contig.ptr(),
out.ptr(),
batch,
m,
n,
k,
a_batch,
b_batch,
)?;
}
Ok(out)
}
pub(crate) fn matmul_bias_native(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
b: &Tensor<CudaRuntime>,
bias: &Tensor<CudaRuntime>,
dtype: DType,
m: usize,
k: usize,
n: usize,
) -> Result<Tensor<CudaRuntime>> {
let a_contig = ensure_contiguous(a);
let b_contig = ensure_contiguous(b);
let bias_contig = ensure_contiguous(bias);
let out_shape = matmul_bias_output_shape(a.shape(), b.shape(), bias.shape()).ok_or(
Error::ShapeMismatch {
expected: a.shape().to_vec(),
got: b.shape().to_vec(),
},
)?;
let out = Tensor::<CudaRuntime>::empty(&out_shape, dtype, &client.device);
unsafe {
launch_matmul_bias_kernel(
&client.context,
&client.stream,
client.device.index,
dtype,
a_contig.ptr(),
b_contig.ptr(),
bias_contig.ptr(),
out.ptr(),
m,
n,
k,
)?;
}
Ok(out)
}
pub(crate) fn matmul_bias_batched_native(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
b: &Tensor<CudaRuntime>,
bias: &Tensor<CudaRuntime>,
dtype: DType,
batch: usize,
m: usize,
k: usize,
n: usize,
) -> Result<Tensor<CudaRuntime>> {
let a_contig = ensure_contiguous(a);
let b_contig = ensure_contiguous(b);
let bias_contig = ensure_contiguous(bias);
let out_shape = matmul_bias_output_shape(a.shape(), b.shape(), bias.shape()).ok_or(
Error::ShapeMismatch {
expected: a.shape().to_vec(),
got: b.shape().to_vec(),
},
)?;
let (a_batch, b_batch) = compute_batch_counts(a.shape(), b.shape());
let out = Tensor::<CudaRuntime>::empty(&out_shape, dtype, &client.device);
unsafe {
launch_matmul_bias_batched_kernel(
&client.context,
&client.stream,
client.device.index,
dtype,
a_contig.ptr(),
b_contig.ptr(),
bias_contig.ptr(),
out.ptr(),
batch,
m,
n,
k,
a_batch,
b_batch,
)?;
}
Ok(out)
}
pub(crate) fn native_binary_op(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
b: &Tensor<CudaRuntime>,
op: &'static str,
) -> Result<Tensor<CudaRuntime>> {
let dtype = validate_binary_dtypes(a, b)?;
let out_shape = compute_broadcast_shape(a, b)?;
if a.shape() == b.shape() {
let a_contig = ensure_contiguous(a);
let b_contig = ensure_contiguous(b);
let out = Tensor::<CudaRuntime>::empty(&out_shape, dtype, &client.device);
unsafe {
launch_binary_op(
&client.context,
&client.stream,
client.device.index,
op,
dtype,
a_contig.ptr(),
b_contig.ptr(),
out.ptr(),
out.numel(),
)?;
}
return Ok(out);
}
let a_contig = ensure_contiguous(a);
let b_contig = ensure_contiguous(b);
let out = Tensor::<CudaRuntime>::empty(&out_shape, dtype, &client.device);
unsafe {
launch_broadcast_binary_op(
&client.context,
&client.stream,
client.device.index,
&client.device,
op,
dtype,
a_contig.ptr(),
b_contig.ptr(),
out.ptr(),
a.shape(),
b.shape(),
&out_shape,
)?;
}
Ok(out)
}
pub(crate) fn native_unary_op(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
op: &'static str,
) -> Result<Tensor<CudaRuntime>> {
let dtype = a.dtype();
let a_contig = ensure_contiguous(a);
let out = Tensor::<CudaRuntime>::empty(a.shape(), dtype, &client.device);
unsafe {
launch_unary_op(
&client.context,
&client.stream,
client.device.index,
op,
dtype,
a_contig.ptr(),
out.ptr(),
out.numel(),
)?;
}
Ok(out)
}
pub(crate) fn native_scalar_op(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
op: &'static str,
scalar: f64,
) -> Result<Tensor<CudaRuntime>> {
let dtype = a.dtype();
let a_contig = ensure_contiguous(a);
let out = Tensor::<CudaRuntime>::empty(a.shape(), dtype, &client.device);
if op == "pow_scalar" && matches!(dtype, DType::I32 | DType::I64) {
return Err(Error::UnsupportedDType {
dtype,
op: "pow_scalar",
});
}
unsafe {
match dtype {
DType::F32 => launch_scalar_op_f32(
&client.context,
&client.stream,
client.device.index,
op,
a_contig.ptr(),
scalar as f32,
out.ptr(),
out.numel(),
)?,
DType::F64 => launch_scalar_op_f64(
&client.context,
&client.stream,
client.device.index,
op,
a_contig.ptr(),
scalar,
out.ptr(),
out.numel(),
)?,
DType::I32 => launch_scalar_op_i32(
&client.context,
&client.stream,
client.device.index,
op,
a_contig.ptr(),
scalar as i32,
out.ptr(),
out.numel(),
)?,
DType::I64 => launch_scalar_op_i64(
&client.context,
&client.stream,
client.device.index,
op,
a_contig.ptr(),
scalar as i64,
out.ptr(),
out.numel(),
)?,
#[cfg(feature = "f16")]
DType::F16 | DType::BF16 => launch_scalar_op_half(
&client.context,
&client.stream,
client.device.index,
op,
dtype,
a_contig.ptr(),
scalar as f32,
out.ptr(),
out.numel(),
)?,
DType::FP8E4M3 | DType::FP8E5M2 => launch_scalar_op_half(
&client.context,
&client.stream,
client.device.index,
op,
dtype,
a_contig.ptr(),
scalar as f32,
out.ptr(),
out.numel(),
)?,
DType::Complex64 => launch_scalar_op_c64(
&client.context,
&client.stream,
client.device.index,
op,
a_contig.ptr(),
scalar as f32,
out.ptr(),
out.numel(),
)?,
DType::Complex128 => launch_scalar_op_c128(
&client.context,
&client.stream,
client.device.index,
op,
a_contig.ptr(),
scalar,
out.ptr(),
out.numel(),
)?,
_ => {
return Err(Error::UnsupportedDType { dtype, op });
}
}
}
Ok(out)
}
pub(crate) fn native_reduce_op(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
op: &'static str,
dims: &[usize],
keepdim: bool,
precision: Option<AccumulationPrecision>,
) -> Result<Tensor<CudaRuntime>> {
let dtype = a.dtype();
let out_shape = reduce_output_shape(a.shape(), dims, keepdim);
let acc_precision = precision.unwrap_or_default();
if dims.len() == 1 {
let dim = dims[0];
let shape = a.shape();
let outer_size: usize = shape[..dim].iter().product();
let reduce_size = shape[dim];
let inner_size: usize = shape[dim + 1..].iter().product();
let outer_size = outer_size.max(1);
let inner_size = inner_size.max(1);
let a_contig = ensure_contiguous(a);
let out = Tensor::<CudaRuntime>::empty(&out_shape, dtype, &client.device);
unsafe {
launch_reduce_dim_op(
&client.context,
&client.stream,
client.device.index,
op,
dtype,
a_contig.ptr(),
out.ptr(),
outer_size,
reduce_size,
inner_size,
acc_precision,
)?;
}
return Ok(out);
}
let mut sorted_dims: Vec<usize> = dims.to_vec();
sorted_dims.sort_unstable();
sorted_dims.reverse();
let mut current = a.clone();
for (i, &dim) in sorted_dims.iter().enumerate() {
let keep = if i == sorted_dims.len() - 1 {
keepdim
} else {
true
};
current = native_reduce_op(client, ¤t, op, &[dim], keep, precision)?;
}
if !keepdim && sorted_dims.len() > 1 {
}
Ok(current)
}
pub(crate) fn native_compare_op(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
b: &Tensor<CudaRuntime>,
op: &'static str,
) -> Result<Tensor<CudaRuntime>> {
let dtype = validate_binary_dtypes(a, b)?;
let out_shape = compute_broadcast_shape(a, b)?;
if a.shape() == b.shape() {
let a_contig = ensure_contiguous(a);
let b_contig = ensure_contiguous(b);
let out = Tensor::<CudaRuntime>::empty(&out_shape, dtype, &client.device);
unsafe {
launch_compare_op(
&client.context,
&client.stream,
client.device.index,
op,
dtype,
a_contig.ptr(),
b_contig.ptr(),
out.ptr(),
out.numel(),
)?;
}
return Ok(out);
}
let a_contig = ensure_contiguous(a);
let b_contig = ensure_contiguous(b);
let out = Tensor::<CudaRuntime>::empty(&out_shape, dtype, &client.device);
unsafe {
launch_broadcast_compare_op(
&client.context,
&client.stream,
client.device.index,
&client.device,
op,
dtype,
a_contig.ptr(),
b_contig.ptr(),
out.ptr(),
a.shape(),
b.shape(),
&out_shape,
)?;
}
Ok(out)
}
pub(crate) fn semiring_matmul_native(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
b: &Tensor<CudaRuntime>,
dtype: DType,
m: usize,
k: usize,
n: usize,
semiring_op: u32,
) -> Result<Tensor<CudaRuntime>> {
let a_contig = ensure_contiguous(a);
let b_contig = ensure_contiguous(b);
let out_shape = matmul_output_shape(a.shape(), b.shape()).ok_or(Error::ShapeMismatch {
expected: a.shape().to_vec(),
got: b.shape().to_vec(),
})?;
let out = Tensor::<CudaRuntime>::empty(&out_shape, dtype, &client.device);
unsafe {
launch_semiring_matmul_kernel(
&client.context,
&client.stream,
client.device.index,
dtype,
a_contig.ptr(),
b_contig.ptr(),
out.ptr(),
m,
n,
k,
semiring_op,
)?;
}
Ok(out)
}
pub(crate) fn semiring_matmul_batched_native(
client: &CudaClient,
a: &Tensor<CudaRuntime>,
b: &Tensor<CudaRuntime>,
dtype: DType,
batch: usize,
m: usize,
k: usize,
n: usize,
semiring_op: u32,
) -> Result<Tensor<CudaRuntime>> {
let a_contig = ensure_contiguous(a);
let b_contig = ensure_contiguous(b);
let out_shape = matmul_output_shape(a.shape(), b.shape()).ok_or(Error::ShapeMismatch {
expected: a.shape().to_vec(),
got: b.shape().to_vec(),
})?;
let (a_batch, b_batch) = compute_batch_counts(a.shape(), b.shape());
let out = Tensor::<CudaRuntime>::empty(&out_shape, dtype, &client.device);
unsafe {
launch_semiring_matmul_batched_kernel(
&client.context,
&client.stream,
client.device.index,
dtype,
a_contig.ptr(),
b_contig.ptr(),
out.ptr(),
batch,
m,
n,
k,
semiring_op,
a_batch,
b_batch,
)?;
}
Ok(out)
}