use super::helpers::*;
use crate::error::Result;
use crate::ops::broadcast_shape;
use crate::runtime::wgpu::shaders::{activation_launcher, where_launcher};
use crate::runtime::wgpu::{WgpuClient, WgpuRuntime};
use crate::runtime::{compute_broadcast_shape, ensure_contiguous};
use crate::tensor::Tensor;
pub(crate) fn native_clamp(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
min_val: f64,
max_val: f64,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = a.dtype();
let a_contig = ensure_contiguous(a);
let numel = a.numel();
let out = alloc_output(client, a.shape(), dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let out_buf = get_tensor_buffer(&out)?;
let params = ClampParams {
numel: numel as u32,
min_val: min_val as f32,
max_val: max_val as f32,
_pad0: 0,
};
let params_buf = create_params_buffer(client, ¶ms);
activation_launcher::launch_clamp_op(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
pub(crate) fn native_where_cond(
client: &WgpuClient,
cond: &Tensor<WgpuRuntime>,
x: &Tensor<WgpuRuntime>,
y: &Tensor<WgpuRuntime>,
) -> Result<Tensor<WgpuRuntime>> {
let cond_dtype = cond.dtype();
let out_dtype = x.dtype();
if x.dtype() != y.dtype() {
return Err(crate::error::Error::DTypeMismatch {
lhs: x.dtype(),
rhs: y.dtype(),
});
}
let xy_shape = compute_broadcast_shape(x, y)?;
let out_shape = broadcast_shape(cond.shape(), &xy_shape).ok_or_else(|| {
crate::error::Error::BroadcastError {
lhs: cond.shape().to_vec(),
rhs: xy_shape.clone(),
}
})?;
let numel: usize = out_shape.iter().product();
if cond.shape() == x.shape() && x.shape() == y.shape() {
let cond_contig = ensure_contiguous(cond);
let x_contig = ensure_contiguous(x);
let y_contig = ensure_contiguous(y);
let out = alloc_output(client, &out_shape, out_dtype);
let cond_buf = get_tensor_buffer(&cond_contig)?;
let x_buf = get_tensor_buffer(&x_contig)?;
let y_buf = get_tensor_buffer(&y_contig)?;
let out_buf = get_tensor_buffer(&out)?;
let params = WhereParams {
numel: numel as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
where_launcher::launch_where_generic_op(
client.pipeline_cache(),
client.wgpu_queue(),
&cond_buf,
&x_buf,
&y_buf,
&out_buf,
¶ms_buf,
numel,
cond_dtype,
out_dtype,
)?;
return Ok(out);
}
let cond_contig = ensure_contiguous(cond);
let x_contig = ensure_contiguous(x);
let y_contig = ensure_contiguous(y);
let cond_strides = compute_broadcast_strides(cond.shape(), &out_shape);
let x_strides = compute_broadcast_strides(x.shape(), &out_shape);
let y_strides = compute_broadcast_strides(y.shape(), &out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&s| s as u32).collect();
let out = alloc_output(client, &out_shape, out_dtype);
let cond_buf = get_tensor_buffer(&cond_contig)?;
let x_buf = get_tensor_buffer(&x_contig)?;
let y_buf = get_tensor_buffer(&y_contig)?;
let out_buf = get_tensor_buffer(&out)?;
let cond_strides_buf = create_storage_buffer(client, &cond_strides);
let x_strides_buf = create_storage_buffer(client, &x_strides);
let y_strides_buf = create_storage_buffer(client, &y_strides);
let shape_buf = create_storage_buffer(client, &shape_u32);
let params = WhereBroadcastParams {
numel: numel as u32,
ndim: out_shape.len() as u32,
_pad0: 0,
_pad1: 0,
};
let params_buf = create_params_buffer(client, ¶ms);
where_launcher::launch_where_broadcast_op(
client.pipeline_cache(),
client.wgpu_queue(),
&cond_buf,
&x_buf,
&y_buf,
&out_buf,
&cond_strides_buf,
&x_strides_buf,
&y_strides_buf,
&shape_buf,
¶ms_buf,
numel,
cond_dtype,
out_dtype,
)?;
Ok(out)
}