use crate::expr::{trace, ExprId};
use super::device::GpuDevice;
use super::kernel::KernelCache;
use super::tensor::GpuTensor;
pub fn map_elementwise(
device: &GpuDevice,
cache: &mut KernelCache,
inputs: &[&GpuTensor],
f: impl FnOnce(&[ExprId]) -> ExprId,
) -> GpuTensor {
let n_inputs = inputs.len();
let numel = inputs[0].numel();
for (i, t) in inputs.iter().enumerate().skip(1) {
assert_eq!(
t.numel(),
numel,
"input {i} has {} elements, expected {numel}",
t.numel()
);
}
let (mut graph, output) = trace(|| {
let vars: Vec<ExprId> = (0..n_inputs as u16).map(ExprId::var).collect();
f(&vars)
});
let output = graph.simplify(output);
let kernel = graph.to_wgsl(&[output], n_inputs);
let interleaved = interleave_inputs(device, cache, inputs, numel, n_inputs);
let out_tensor = GpuTensor::uninit(device, inputs[0].shape());
cache.dispatch(
device,
&kernel.source,
&interleaved,
&out_tensor.buffer,
numel as u32,
);
out_tensor
}
pub fn map_elementwise_multi(
device: &GpuDevice,
cache: &mut KernelCache,
inputs: &[&GpuTensor],
n_outputs: usize,
f: impl FnOnce(&[ExprId]) -> Vec<ExprId>,
) -> Vec<GpuTensor> {
let n_inputs = inputs.len();
let numel = inputs[0].numel();
let (mut graph, outputs) = trace(|| {
let vars: Vec<ExprId> = (0..n_inputs as u16).map(ExprId::var).collect();
f(&vars)
});
assert_eq!(outputs.len(), n_outputs);
let outputs: Vec<ExprId> = outputs.into_iter().map(|o| graph.simplify(o)).collect();
let kernel = graph.to_wgsl(&outputs, n_inputs);
let interleaved = interleave_inputs(device, cache, inputs, numel, n_inputs);
let out_buf = super::buffer::GpuBuffer::uninit(device, numel * n_outputs);
cache.dispatch(
device,
&kernel.source,
&interleaved,
&out_buf,
numel as u32,
);
let all_data = out_buf.to_vec_sync(device);
let shape = inputs[0].shape();
(0..n_outputs)
.map(|k| {
let data: Vec<f32> = (0..numel).map(|i| all_data[i * n_outputs + k]).collect();
GpuTensor::from_slice(device, &data, shape)
})
.collect()
}
fn interleave_inputs(
device: &GpuDevice,
cache: &mut KernelCache,
inputs: &[&GpuTensor],
numel: usize,
n_inputs: usize,
) -> super::buffer::GpuBuffer {
if n_inputs == 1 {
return inputs[0].buffer.clone_gpu_batched(device, cache);
}
if n_inputs == 2 {
let out = super::buffer::GpuBuffer::uninit(device, numel * 2);
let wgsl = r#"// Interleave 2 inputs: output[2*i+0] = a[i], output[2*i+1] = b[i]
struct Params {
count: u32,
_pad1: u32,
_pad2: u32,
_pad3: u32,
}
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= params.count) { return; }
output[idx * 2u + 0u] = a[idx];
output[idx * 2u + 1u] = b[idx];
}
"#;
cache.dispatch_rr_w(
device,
wgsl,
&inputs[0].buffer,
&inputs[1].buffer,
&out,
&[numel as u32, 0, 0, 0],
);
return out;
}
let input_data: Vec<Vec<f32>> = inputs
.iter()
.map(|t| t.buffer.to_vec_sync(device))
.collect();
let mut interleaved = Vec::with_capacity(numel * n_inputs);
for i in 0..numel {
for input in &input_data {
interleaved.push(input[i]);
}
}
super::buffer::GpuBuffer::from_slice(device, &interleaved)
}