use super::helpers::*;
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::wgpu::shaders::index;
use crate::runtime::wgpu::{WgpuClient, WgpuRuntime};
use crate::runtime::{RuntimeClient, ensure_contiguous};
use crate::tensor::Tensor;
use wgpu::BufferUsages;
pub(crate) fn native_masked_fill(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
mask: &Tensor<WgpuRuntime>,
value: f64,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = a.dtype();
let numel = a.numel();
if mask.dtype() != DType::U32 {
return Err(Error::DTypeMismatch {
lhs: DType::U32,
rhs: mask.dtype(),
});
}
let mask_broadcast = mask
.broadcast_to(a.shape())
.map_err(|_| Error::ShapeMismatch {
expected: a.shape().to_vec(),
got: mask.shape().to_vec(),
})?;
let a_contig = ensure_contiguous(a);
let mask_contig = ensure_contiguous(&mask_broadcast);
let out = alloc_output(client, a.shape(), dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let mask_buf = get_tensor_buffer(&mask_contig)?;
let out_buf = get_tensor_buffer(&out)?;
let params = MaskedFillParams {
numel: numel as u32,
fill_value: value as f32,
};
let params_buf = create_params_buffer(client, ¶ms);
index::launch_masked_fill(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&mask_buf,
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
pub(crate) fn native_embedding_lookup(
client: &WgpuClient,
embeddings: &Tensor<WgpuRuntime>,
indices: &Tensor<WgpuRuntime>,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = embeddings.dtype();
let emb_shape = embeddings.shape();
if emb_shape.len() != 2 {
return Err(Error::ShapeMismatch {
expected: vec![0, 0], got: emb_shape.to_vec(),
});
}
let indices_i32 = ensure_i32_indices(client, indices)?;
if !matches!(dtype, DType::F32 | DType::I32 | DType::U32) {
return Err(Error::UnsupportedDType {
dtype,
op: "embedding_lookup",
});
}
let vocab_size = emb_shape[0];
let embedding_dim = emb_shape[1];
let num_indices = indices_i32.numel();
let mut out_shape = indices_i32.shape().to_vec();
out_shape.push(embedding_dim);
let emb_contig = ensure_contiguous(embeddings);
let idx_contig = ensure_contiguous(&indices_i32);
let out = alloc_output(client, &out_shape, dtype);
let emb_buf = get_tensor_buffer(&emb_contig)?;
let idx_buf = get_tensor_buffer(&idx_contig)?;
let out_buf = get_tensor_buffer(&out)?;
let params = EmbeddingLookupParams {
num_indices: num_indices as u32,
vocab_size: vocab_size as u32,
embedding_dim: embedding_dim as u32,
_pad0: 0,
};
let params_buf = create_params_buffer(client, ¶ms);
index::launch_embedding_lookup(
client.pipeline_cache(),
client.wgpu_queue(),
&emb_buf,
&idx_buf,
&out_buf,
¶ms_buf,
num_indices,
dtype,
)?;
Ok(out)
}
pub(crate) fn native_masked_select(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
mask: &Tensor<WgpuRuntime>,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = a.dtype();
let numel = a.numel();
if mask.dtype() != DType::U32 {
return Err(Error::DTypeMismatch {
lhs: DType::U32,
rhs: mask.dtype(),
});
}
let mask_broadcast = mask
.broadcast_to(a.shape())
.map_err(|_| Error::ShapeMismatch {
expected: a.shape().to_vec(),
got: mask.shape().to_vec(),
})?;
let a_contig = ensure_contiguous(a);
let mask_contig = ensure_contiguous(&mask_broadcast);
let a_buf = get_tensor_buffer(&a_contig)?;
let mask_buf = get_tensor_buffer(&mask_contig)?;
let count_buffer = client.wgpu_device.create_buffer(&wgpu::BufferDescriptor {
label: Some("masked_count_result"),
size: 4,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
client.queue.write_buffer(&count_buffer, 0, &[0u8; 4]);
let count_params = MaskedCountParams {
numel: numel as u32,
};
let count_params_buf = create_params_buffer(client, &count_params);
index::launch_masked_count(
client.pipeline_cache(),
client.wgpu_queue(),
&mask_buf,
&count_buffer,
&count_params_buf,
numel,
dtype,
)?;
let staging_buffer = client.wgpu_device.create_buffer(&wgpu::BufferDescriptor {
label: Some("count_staging"),
size: 4,
usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder = client
.wgpu_device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("copy_count"),
});
encoder.copy_buffer_to_buffer(&count_buffer, 0, &staging_buffer, 0, 4);
client.queue.submit(std::iter::once(encoder.finish()));
let slice = staging_buffer.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
sender.send(result).unwrap();
});
let _ = client.wgpu_device.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: Some(std::time::Duration::from_secs(60)),
});
receiver.recv().unwrap().unwrap();
let count = {
let data = slice.get_mapped_range();
u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize
};
drop(staging_buffer);
if count == 0 {
return Ok(Tensor::empty(&[0], dtype, client.device()));
}
let prefix_sum_buffer = client.wgpu_device.create_buffer(&wgpu::BufferDescriptor {
label: Some("prefix_sum"),
size: (numel * 4) as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let prefix_params = MaskedCountParams {
numel: numel as u32,
};
let prefix_params_buf = create_params_buffer(client, &prefix_params);
index::launch_masked_prefix_sum(
client.pipeline_cache(),
client.wgpu_queue(),
&mask_buf,
&prefix_sum_buffer,
&prefix_params_buf,
numel,
dtype,
)?;
let out = alloc_output(client, &[count], dtype);
let out_buf = get_tensor_buffer(&out)?;
let select_params = MaskedSelectParams {
numel: numel as u32,
};
let select_params_buf = create_params_buffer(client, &select_params);
index::launch_masked_select(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&mask_buf,
&prefix_sum_buffer,
&out_buf,
&select_params_buf,
numel,
dtype,
)?;
Ok(out)
}