use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::ComplexOps;
use crate::ops::common::{validate_complex_real_inputs, validate_make_complex_inputs};
use crate::runtime::cuda::kernels::{
launch_angle, launch_angle_real, launch_complex_div_real, launch_complex_mul_real, launch_conj,
launch_fill_with_f64, launch_from_real_imag, launch_imag, launch_real,
};
use crate::runtime::cuda::{CudaClient, CudaRuntime};
use crate::runtime::ensure_contiguous;
use crate::tensor::Tensor;
impl ComplexOps<CudaRuntime> for CudaClient {
fn conj(&self, a: &Tensor<CudaRuntime>) -> Result<Tensor<CudaRuntime>> {
let dtype = a.dtype();
if !dtype.is_complex() {
return Ok(a.clone());
}
let a_contig = ensure_contiguous(a);
let out = Tensor::<CudaRuntime>::empty(a.shape(), dtype, &self.device);
unsafe {
launch_conj(
&self.context,
&self.stream,
self.device.index,
dtype,
a_contig.ptr(),
out.ptr(),
a.numel(),
)?;
}
Ok(out)
}
fn real(&self, a: &Tensor<CudaRuntime>) -> Result<Tensor<CudaRuntime>> {
let dtype = a.dtype();
if !dtype.is_complex() {
return Ok(a.clone());
}
let out_dtype = match dtype {
DType::Complex64 => DType::F32,
DType::Complex128 => DType::F64,
_ => return Err(Error::UnsupportedDType { dtype, op: "real" }),
};
let a_contig = ensure_contiguous(a);
let out = Tensor::<CudaRuntime>::empty(a.shape(), out_dtype, &self.device);
unsafe {
launch_real(
&self.context,
&self.stream,
self.device.index,
dtype,
a_contig.ptr(),
out.ptr(),
a.numel(),
)?;
}
Ok(out)
}
fn imag(&self, a: &Tensor<CudaRuntime>) -> Result<Tensor<CudaRuntime>> {
let dtype = a.dtype();
let out_dtype = if dtype.is_complex() {
match dtype {
DType::Complex64 => DType::F32,
DType::Complex128 => DType::F64,
_ => return Err(Error::UnsupportedDType { dtype, op: "imag" }),
}
} else {
dtype
};
let out = Tensor::<CudaRuntime>::empty(a.shape(), out_dtype, &self.device);
if !dtype.is_complex() {
unsafe {
launch_fill_with_f64(
&self.context,
&self.stream,
self.device.index,
out_dtype,
0.0,
out.ptr(),
out.numel(),
)?;
}
return Ok(out);
}
let a_contig = ensure_contiguous(a);
unsafe {
launch_imag(
&self.context,
&self.stream,
self.device.index,
dtype,
a_contig.ptr(),
out.ptr(),
a.numel(),
)?;
}
Ok(out)
}
fn angle(&self, a: &Tensor<CudaRuntime>) -> Result<Tensor<CudaRuntime>> {
let dtype = a.dtype();
let out_dtype = if dtype.is_complex() {
match dtype {
DType::Complex64 => DType::F32,
DType::Complex128 => DType::F64,
_ => return Err(Error::UnsupportedDType { dtype, op: "angle" }),
}
} else {
dtype
};
let out = Tensor::<CudaRuntime>::empty(a.shape(), out_dtype, &self.device);
let a_contig = ensure_contiguous(a);
if !dtype.is_complex() {
match dtype {
DType::F32 | DType::F64 => unsafe {
launch_angle_real(
&self.context,
&self.stream,
self.device.index,
dtype,
a_contig.ptr(),
out.ptr(),
a.numel(),
)?;
},
_ => {
unsafe {
launch_fill_with_f64(
&self.context,
&self.stream,
self.device.index,
out_dtype,
0.0,
out.ptr(),
out.numel(),
)?;
}
}
}
return Ok(out);
}
unsafe {
launch_angle(
&self.context,
&self.stream,
self.device.index,
dtype,
a_contig.ptr(),
out.ptr(),
a.numel(),
)?;
}
Ok(out)
}
fn make_complex(
&self,
real: &Tensor<CudaRuntime>,
imag: &Tensor<CudaRuntime>,
) -> Result<Tensor<CudaRuntime>> {
validate_make_complex_inputs(real, imag)?;
let input_dtype = real.dtype();
let shape = real.shape();
let numel = real.numel();
let out_dtype = match input_dtype {
DType::F32 => DType::Complex64,
DType::F64 => DType::Complex128,
_ => unreachable!("validated above"),
};
let real_contig = ensure_contiguous(real);
let imag_contig = ensure_contiguous(imag);
let out = Tensor::<CudaRuntime>::empty(shape, out_dtype, &self.device);
if numel == 0 {
return Ok(out);
}
unsafe {
launch_from_real_imag(
&self.context,
&self.stream,
self.device.index,
input_dtype,
real_contig.ptr(),
imag_contig.ptr(),
out.ptr(),
numel,
)?;
}
Ok(out)
}
fn complex_mul_real(
&self,
complex: &Tensor<CudaRuntime>,
real: &Tensor<CudaRuntime>,
) -> Result<Tensor<CudaRuntime>> {
validate_complex_real_inputs(complex, real, "complex_mul_real")?;
let dtype = complex.dtype();
let shape = complex.shape();
let numel = complex.numel();
let complex_contig = ensure_contiguous(complex);
let real_contig = ensure_contiguous(real);
let out = Tensor::<CudaRuntime>::empty(shape, dtype, &self.device);
if numel == 0 {
return Ok(out);
}
unsafe {
launch_complex_mul_real(
&self.context,
&self.stream,
self.device.index,
dtype,
complex_contig.ptr(),
real_contig.ptr(),
out.ptr(),
numel,
)?;
}
Ok(out)
}
fn complex_div_real(
&self,
complex: &Tensor<CudaRuntime>,
real: &Tensor<CudaRuntime>,
) -> Result<Tensor<CudaRuntime>> {
validate_complex_real_inputs(complex, real, "complex_div_real")?;
let dtype = complex.dtype();
let shape = complex.shape();
let numel = complex.numel();
let complex_contig = ensure_contiguous(complex);
let real_contig = ensure_contiguous(real);
let out = Tensor::<CudaRuntime>::empty(shape, dtype, &self.device);
if numel == 0 {
return Ok(out);
}
unsafe {
launch_complex_div_real(
&self.context,
&self.stream,
self.device.index,
dtype,
complex_contig.ptr(),
real_contig.ptr(),
out.ptr(),
numel,
)?;
}
Ok(out)
}
}