use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::impl_generic::{
DTypeSupport, MultinomialSamplingOps, dirichlet_impl, multinomial_samples_impl,
multivariate_normal_impl, wishart_impl,
};
use crate::ops::traits::multivariate::MultivariateRandomOps;
use crate::ops::{BinaryOps, CumulativeOps, RandomOps, ReduceOps};
use crate::runtime::RuntimeClient;
use crate::runtime::wgpu::shaders::{MultinomialCountParams, launch_multinomial_count};
use crate::runtime::wgpu::{WgpuClient, WgpuRuntime};
use crate::tensor::Tensor;
impl MultivariateRandomOps<WgpuRuntime> for WgpuClient {
fn multivariate_normal(
&self,
mean: &Tensor<WgpuRuntime>,
cov: &Tensor<WgpuRuntime>,
n_samples: usize,
) -> Result<Tensor<WgpuRuntime>> {
multivariate_normal_impl(self, mean, cov, n_samples, DTypeSupport::F32_ONLY)
}
fn wishart(
&self,
scale: &Tensor<WgpuRuntime>,
df: usize,
n_samples: usize,
) -> Result<Tensor<WgpuRuntime>> {
wishart_impl(self, scale, df, n_samples, DTypeSupport::F32_ONLY)
}
fn dirichlet(
&self,
alpha: &Tensor<WgpuRuntime>,
n_samples: usize,
) -> Result<Tensor<WgpuRuntime>> {
dirichlet_impl(self, alpha, n_samples)
}
fn multinomial_samples(
&self,
probs: &Tensor<WgpuRuntime>,
n_trials: usize,
n_samples: usize,
) -> Result<Tensor<WgpuRuntime>> {
multinomial_samples_impl(self, probs, n_trials, n_samples)
}
}
impl MultinomialSamplingOps<WgpuRuntime> for WgpuClient {
fn multinomial_sample_kernel(
&self,
probs: &Tensor<WgpuRuntime>,
n_trials: usize,
n_samples: usize,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = probs.dtype();
let k = probs.shape()[0];
if dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype,
op: "multinomial_samples (WebGPU only supports F32)",
});
}
let sum_probs = self.sum(probs, &[0], false)?;
let normalized = self.div(probs, &sum_probs)?;
let cdf = self.cumsum(&normalized, 0)?;
let uniforms = self.rand(&[n_samples, n_trials], dtype)?;
dispatch_multinomial_count_shader(self, &cdf, &uniforms, n_samples, n_trials, k)
}
}
fn dispatch_multinomial_count_shader(
client: &WgpuClient,
cdf: &Tensor<WgpuRuntime>,
uniforms: &Tensor<WgpuRuntime>,
n_samples: usize,
n_trials: usize,
k: usize,
) -> Result<Tensor<WgpuRuntime>> {
use crate::runtime::wgpu::client::get_buffer;
let output = Tensor::<WgpuRuntime>::empty(&[n_samples, k], DType::F32, client.device());
let cdf_buf =
get_buffer(cdf.ptr()).ok_or_else(|| Error::Internal("CDF buffer not found".to_string()))?;
let uniforms_buf = get_buffer(uniforms.ptr())
.ok_or_else(|| Error::Internal("Uniforms buffer not found".to_string()))?;
let output_buf = get_buffer(output.ptr())
.ok_or_else(|| Error::Internal("Output buffer not found".to_string()))?;
let params = MultinomialCountParams {
k: k as u32,
n_trials: n_trials as u32,
n_samples: n_samples as u32,
_pad: 0,
};
let params_buf = client.wgpu_device.create_buffer(&wgpu::BufferDescriptor {
label: Some("multinomial_count_params"),
size: std::mem::size_of::<MultinomialCountParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
client
.queue
.write_buffer(¶ms_buf, 0, bytemuck::bytes_of(¶ms));
launch_multinomial_count(
client.pipeline_cache(),
&client.queue,
&cdf_buf,
&uniforms_buf,
&output_buf,
¶ms_buf,
n_samples,
DType::F32,
)?;
Ok(output)
}