use crate::dtype::DType;
use crate::error::Result;
use crate::ops::{BinaryOps, UtilityOps};
use crate::runtime::cuda::kernels::{
launch_arange, launch_eye, launch_fill_with_f64, launch_linspace,
};
use crate::runtime::cuda::{CudaClient, CudaRuntime};
use crate::runtime::{validate_arange, validate_eye};
use crate::tensor::Tensor;
impl UtilityOps<CudaRuntime> for CudaClient {
fn clamp(
&self,
a: &Tensor<CudaRuntime>,
min_val: f64,
max_val: f64,
) -> Result<Tensor<CudaRuntime>> {
let min_scalar = self.fill(&[], min_val, a.dtype())?;
let max_scalar = self.fill(&[], max_val, a.dtype())?;
let clamped_low = self.maximum(a, &min_scalar)?;
self.minimum(&clamped_low, &max_scalar)
}
fn fill(&self, shape: &[usize], value: f64, dtype: DType) -> Result<Tensor<CudaRuntime>> {
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(Tensor::<CudaRuntime>::empty(shape, dtype, &self.device));
}
let out = Tensor::<CudaRuntime>::empty(shape, dtype, &self.device);
unsafe {
launch_fill_with_f64(
&self.context,
&self.stream,
self.device.index,
dtype,
value,
out.ptr(),
numel,
)?;
}
Ok(out)
}
fn arange(
&self,
start: f64,
stop: f64,
step: f64,
dtype: DType,
) -> Result<Tensor<CudaRuntime>> {
let numel = validate_arange(start, stop, step)?;
if numel == 0 {
return Ok(Tensor::<CudaRuntime>::empty(&[0], dtype, &self.device));
}
let out = Tensor::<CudaRuntime>::empty(&[numel], dtype, &self.device);
unsafe {
launch_arange(
&self.context,
&self.stream,
self.device.index,
dtype,
start,
step,
out.ptr(),
numel,
)?;
}
Ok(out)
}
fn linspace(
&self,
start: f64,
stop: f64,
steps: usize,
dtype: DType,
) -> Result<Tensor<CudaRuntime>> {
if steps == 0 {
return Ok(Tensor::<CudaRuntime>::empty(&[0], dtype, &self.device));
}
if steps == 1 {
return self.fill(&[1], start, dtype);
}
let out = Tensor::<CudaRuntime>::empty(&[steps], dtype, &self.device);
unsafe {
launch_linspace(
&self.context,
&self.stream,
self.device.index,
dtype,
start,
stop,
out.ptr(),
steps,
)?;
}
Ok(out)
}
fn one_hot(
&self,
indices: &Tensor<CudaRuntime>,
num_classes: usize,
) -> Result<Tensor<CudaRuntime>> {
crate::ops::impl_generic::one_hot_impl(self, indices, num_classes)
}
fn meshgrid(
&self,
tensors: &[&Tensor<CudaRuntime>],
indexing: crate::ops::MeshgridIndexing,
) -> Result<Vec<Tensor<CudaRuntime>>> {
crate::ops::impl_generic::meshgrid_impl(tensors, indexing)
}
fn eye(&self, n: usize, m: Option<usize>, dtype: DType) -> Result<Tensor<CudaRuntime>> {
let (rows, cols) = validate_eye(n, m);
if rows == 0 || cols == 0 {
return Ok(Tensor::<CudaRuntime>::empty(
&[rows, cols],
dtype,
&self.device,
));
}
let out = Tensor::<CudaRuntime>::empty(&[rows, cols], dtype, &self.device);
unsafe {
launch_eye(
&self.context,
&self.stream,
self.device.index,
dtype,
rows,
cols,
out.ptr(),
)?;
}
Ok(out)
}
}