use super::helpers::*;
use crate::error::Result;
use crate::runtime::ensure_contiguous;
use crate::runtime::wgpu::shaders::elementwise;
use crate::runtime::wgpu::{WgpuClient, WgpuRuntime};
use crate::tensor::Tensor;
pub(crate) fn native_unary_op(
client: &WgpuClient,
op: &'static str,
a: &Tensor<WgpuRuntime>,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = a.dtype();
let a_contig = ensure_contiguous(a);
let numel = a.numel();
let out = alloc_output(client, a.shape(), dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let out_buf = get_tensor_buffer(&out)?;
let params = UnaryParams {
numel: numel as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
elementwise::launch_unary_op(
client.pipeline_cache(),
client.wgpu_queue(),
op,
&a_buf,
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}