use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::RandomOps;
use crate::runtime::RuntimeClient;
use crate::runtime::wgpu::WgpuClient;
use crate::runtime::wgpu::WgpuRuntime;
use crate::runtime::wgpu::ops::helpers::{
BernoulliParams, BetaDistParams, BinomialParams, ChiSquaredParams, ExponentialParams,
FDistributionParams, GammaDistParams, LaplaceParams, MultinomialWithReplacementParams,
MultinomialWithoutReplacementParams, PoissonParams, RandParams, RandintParamsI32,
RandintParamsU32, RandnParams, StudentTParams, alloc_output, create_params_buffer,
generate_wgpu_seed, get_tensor_buffer,
};
use crate::runtime::wgpu::shaders::{distributions, shape};
use crate::tensor::Tensor;
impl RandomOps<WgpuRuntime> for WgpuClient {
fn rand(&self, shape: &[usize], dtype: DType) -> Result<Tensor<WgpuRuntime>> {
if !matches!(dtype, DType::F32) {
return Err(Error::UnsupportedDType { dtype, op: "rand" });
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(Tensor::empty(shape, dtype, self.device()));
}
let out = alloc_output(self, shape, dtype);
let out_buf = get_tensor_buffer(&out)?;
use std::sync::atomic::{AtomicU32, Ordering};
static SEED_COUNTER: AtomicU32 = AtomicU32::new(0);
let counter = SEED_COUNTER.fetch_add(1, Ordering::Relaxed);
let time_seed = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u32)
.unwrap_or(12345u32);
let seed = time_seed.wrapping_add(counter);
let params = RandParams {
numel: numel as u32,
seed,
_pad1: 0,
_pad2: 0,
};
let params_buf = create_params_buffer(self, ¶ms);
shape::launch_rand(
self.pipeline_cache(),
self.wgpu_queue(),
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
fn rand_seeded(&self, shape: &[usize], dtype: DType, seed: u64) -> Result<Tensor<WgpuRuntime>> {
if !matches!(dtype, DType::F32) {
return Err(Error::UnsupportedDType {
dtype,
op: "rand_seeded",
});
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(Tensor::empty(shape, dtype, self.device()));
}
let out = alloc_output(self, shape, dtype);
let out_buf = get_tensor_buffer(&out)?;
let seed = seed as u32;
let params = RandParams {
numel: numel as u32,
seed,
_pad1: 0,
_pad2: 0,
};
let params_buf = create_params_buffer(self, ¶ms);
shape::launch_rand(
self.pipeline_cache(),
self.wgpu_queue(),
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
fn randn(&self, shape: &[usize], dtype: DType) -> Result<Tensor<WgpuRuntime>> {
if !matches!(dtype, DType::F32) {
return Err(Error::UnsupportedDType { dtype, op: "randn" });
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(Tensor::empty(shape, dtype, self.device()));
}
let out = alloc_output(self, shape, dtype);
let out_buf = get_tensor_buffer(&out)?;
use std::sync::atomic::{AtomicU32, Ordering};
static SEED_COUNTER: AtomicU32 = AtomicU32::new(0);
let counter = SEED_COUNTER.fetch_add(1, Ordering::Relaxed);
let time_seed = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u32)
.unwrap_or(12345u32);
let seed = time_seed.wrapping_add(counter);
let params = RandnParams {
numel: numel as u32,
seed,
_pad1: 0,
_pad2: 0,
};
let params_buf = create_params_buffer(self, ¶ms);
shape::launch_randn(
self.pipeline_cache(),
self.wgpu_queue(),
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
fn randint(
&self,
low: i64,
high: i64,
shape: &[usize],
dtype: DType,
) -> Result<Tensor<WgpuRuntime>> {
if !matches!(dtype, DType::I32 | DType::U32) {
return Err(Error::UnsupportedDType {
dtype,
op: "randint",
});
}
if high <= low {
return Err(Error::InvalidArgument {
arg: "high",
reason: format!(
"randint requires high > low, got low={}, high={}",
low, high
),
});
}
if matches!(dtype, DType::U32) && low < 0 {
return Err(Error::InvalidArgument {
arg: "low",
reason: format!(
"randint with unsigned dtype requires low >= 0, got low={}",
low
),
});
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(Tensor::empty(shape, dtype, self.device()));
}
let out = alloc_output(self, shape, dtype);
let out_buf = get_tensor_buffer(&out)?;
use std::sync::atomic::{AtomicU32, Ordering};
static SEED_COUNTER: AtomicU32 = AtomicU32::new(0);
let counter = SEED_COUNTER.fetch_add(1, Ordering::Relaxed);
let time_seed = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u32)
.unwrap_or(12345u32);
let seed = time_seed.wrapping_add(counter);
let range = (high - low) as u32;
let params_buf = match dtype {
DType::I32 => {
let params = RandintParamsI32 {
numel: numel as u32,
low: low as i32, range,
seed,
};
create_params_buffer(self, ¶ms)
}
DType::U32 => {
let params = RandintParamsU32 {
numel: numel as u32,
low: low as u32,
range,
seed,
};
create_params_buffer(self, ¶ms)
}
_ => unreachable!("randint only supports I32 and U32, validated above"),
};
shape::launch_randint(
self.pipeline_cache(),
self.wgpu_queue(),
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
fn multinomial(
&self,
probs: &Tensor<WgpuRuntime>,
num_samples: usize,
replacement: bool,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = probs.dtype();
if !matches!(dtype, DType::F32) {
return Err(Error::UnsupportedDType {
dtype,
op: "multinomial (WebGPU only supports F32 probabilities)",
});
}
if num_samples == 0 {
return Err(Error::InvalidArgument {
arg: "num_samples",
reason: "multinomial requires num_samples > 0".to_string(),
});
}
let shape = probs.shape();
let ndim = shape.len();
let (num_distributions, num_categories) = match ndim {
1 => (1usize, shape[0]),
2 => (shape[0], shape[1]),
_ => {
return Err(Error::InvalidArgument {
arg: "probs",
reason: format!(
"multinomial requires 1D or 2D probability tensor, got {}D",
ndim
),
});
}
};
if num_categories == 0 {
return Err(Error::InvalidArgument {
arg: "probs",
reason: "multinomial requires at least 1 category".to_string(),
});
}
if !replacement && num_samples > num_categories {
return Err(Error::InvalidArgument {
arg: "num_samples",
reason: format!(
"multinomial without replacement: num_samples ({}) cannot exceed num_categories ({})",
num_samples, num_categories
),
});
}
const MAX_CATEGORIES_WITHOUT_REPLACEMENT: usize = 1024;
if !replacement && num_categories > MAX_CATEGORIES_WITHOUT_REPLACEMENT {
return Err(Error::backend_limitation(
"WebGPU",
"multinomial",
format!(
"without replacement supports max {} categories, got {}",
MAX_CATEGORIES_WITHOUT_REPLACEMENT, num_categories
),
));
}
let probs_contig = if probs.is_contiguous() {
probs.clone()
} else {
probs.contiguous()
};
let out_dtype = DType::I32;
let out_shape = if ndim == 1 {
vec![num_samples]
} else {
vec![num_distributions, num_samples]
};
let out = alloc_output(self, &out_shape, out_dtype);
let out_buf = get_tensor_buffer(&out)?;
let probs_buf = get_tensor_buffer(&probs_contig)?;
use std::sync::atomic::{AtomicU32, Ordering};
static SEED_COUNTER: AtomicU32 = AtomicU32::new(0);
let counter = SEED_COUNTER.fetch_add(1, Ordering::Relaxed);
let time_seed = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u32)
.unwrap_or(54321u32);
let seed = time_seed.wrapping_add(counter);
if replacement {
let params = MultinomialWithReplacementParams {
num_distributions: num_distributions as u32,
num_categories: num_categories as u32,
num_samples: num_samples as u32,
seed,
};
let params_buf = create_params_buffer(self, ¶ms);
let total_samples = num_distributions * num_samples;
shape::launch_multinomial_with_replacement(
self.pipeline_cache(),
self.wgpu_queue(),
&probs_buf,
&out_buf,
¶ms_buf,
total_samples,
dtype,
)?;
} else {
let params = MultinomialWithoutReplacementParams {
num_distributions: num_distributions as u32,
num_categories: num_categories as u32,
num_samples: num_samples as u32,
seed,
};
let params_buf = create_params_buffer(self, ¶ms);
shape::launch_multinomial_without_replacement(
self.pipeline_cache(),
self.wgpu_queue(),
&probs_buf,
&out_buf,
¶ms_buf,
num_distributions,
dtype,
)?;
}
Ok(out)
}
fn bernoulli(&self, p: f64, shape: &[usize], dtype: DType) -> Result<Tensor<WgpuRuntime>> {
if !matches!(dtype, DType::F32) {
return Err(Error::UnsupportedDType {
dtype,
op: "bernoulli (WebGPU only supports F32)",
});
}
if !(0.0..=1.0).contains(&p) {
return Err(Error::InvalidArgument {
arg: "p",
reason: format!("bernoulli requires p in [0, 1], got {}", p),
});
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(Tensor::empty(shape, dtype, self.device()));
}
let out = alloc_output(self, shape, dtype);
let out_buf = get_tensor_buffer(&out)?;
let seed = generate_wgpu_seed();
let params = BernoulliParams {
numel: numel as u32,
seed,
p: p as f32,
_pad: 0,
};
let params_buf = create_params_buffer(self, ¶ms);
distributions::launch_bernoulli(
self.pipeline_cache(),
self.wgpu_queue(),
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
fn beta(
&self,
alpha: f64,
beta: f64,
shape: &[usize],
dtype: DType,
) -> Result<Tensor<WgpuRuntime>> {
if !matches!(dtype, DType::F32) {
return Err(Error::UnsupportedDType {
dtype,
op: "beta (WebGPU only supports F32)",
});
}
if alpha <= 0.0 {
return Err(Error::InvalidArgument {
arg: "alpha",
reason: format!("beta requires alpha > 0, got {}", alpha),
});
}
if beta <= 0.0 {
return Err(Error::InvalidArgument {
arg: "beta",
reason: format!("beta requires beta > 0, got {}", beta),
});
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(Tensor::empty(shape, dtype, self.device()));
}
let out = alloc_output(self, shape, dtype);
let out_buf = get_tensor_buffer(&out)?;
let seed = generate_wgpu_seed();
let params = BetaDistParams {
numel: numel as u32,
seed,
alpha: alpha as f32,
beta: beta as f32,
};
let params_buf = create_params_buffer(self, ¶ms);
distributions::launch_beta_dist(
self.pipeline_cache(),
self.wgpu_queue(),
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
fn gamma(
&self,
shape_param: f64,
scale: f64,
shape: &[usize],
dtype: DType,
) -> Result<Tensor<WgpuRuntime>> {
if !matches!(dtype, DType::F32) {
return Err(Error::UnsupportedDType {
dtype,
op: "gamma (WebGPU only supports F32)",
});
}
if shape_param <= 0.0 {
return Err(Error::InvalidArgument {
arg: "shape_param",
reason: format!("gamma requires shape_param > 0, got {}", shape_param),
});
}
if scale <= 0.0 {
return Err(Error::InvalidArgument {
arg: "scale",
reason: format!("gamma requires scale > 0, got {}", scale),
});
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(Tensor::empty(shape, dtype, self.device()));
}
let out = alloc_output(self, shape, dtype);
let out_buf = get_tensor_buffer(&out)?;
let seed = generate_wgpu_seed();
let params = GammaDistParams {
numel: numel as u32,
seed,
shape: shape_param as f32,
scale: scale as f32,
};
let params_buf = create_params_buffer(self, ¶ms);
distributions::launch_gamma_dist(
self.pipeline_cache(),
self.wgpu_queue(),
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
fn exponential(&self, rate: f64, shape: &[usize], dtype: DType) -> Result<Tensor<WgpuRuntime>> {
if !matches!(dtype, DType::F32) {
return Err(Error::UnsupportedDType {
dtype,
op: "exponential (WebGPU only supports F32)",
});
}
if rate <= 0.0 {
return Err(Error::InvalidArgument {
arg: "rate",
reason: format!("exponential requires rate > 0, got {}", rate),
});
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(Tensor::empty(shape, dtype, self.device()));
}
let out = alloc_output(self, shape, dtype);
let out_buf = get_tensor_buffer(&out)?;
let seed = generate_wgpu_seed();
let params = ExponentialParams {
numel: numel as u32,
seed,
rate: rate as f32,
_pad: 0,
};
let params_buf = create_params_buffer(self, ¶ms);
distributions::launch_exponential(
self.pipeline_cache(),
self.wgpu_queue(),
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
fn poisson(&self, lambda: f64, shape: &[usize], dtype: DType) -> Result<Tensor<WgpuRuntime>> {
if !matches!(dtype, DType::F32) {
return Err(Error::UnsupportedDType {
dtype,
op: "poisson (WebGPU only supports F32)",
});
}
if lambda <= 0.0 {
return Err(Error::InvalidArgument {
arg: "lambda",
reason: format!("poisson requires lambda > 0, got {}", lambda),
});
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(Tensor::empty(shape, dtype, self.device()));
}
let out = alloc_output(self, shape, dtype);
let out_buf = get_tensor_buffer(&out)?;
let seed = generate_wgpu_seed();
let params = PoissonParams {
numel: numel as u32,
seed,
lambda: lambda as f32,
_pad: 0,
};
let params_buf = create_params_buffer(self, ¶ms);
distributions::launch_poisson(
self.pipeline_cache(),
self.wgpu_queue(),
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
fn binomial(
&self,
n: u64,
p: f64,
shape: &[usize],
dtype: DType,
) -> Result<Tensor<WgpuRuntime>> {
if !matches!(dtype, DType::F32) {
return Err(Error::UnsupportedDType {
dtype,
op: "binomial (WebGPU only supports F32)",
});
}
if n == 0 {
return Err(Error::InvalidArgument {
arg: "n",
reason: "binomial requires n > 0".to_string(),
});
}
if !(0.0..=1.0).contains(&p) {
return Err(Error::InvalidArgument {
arg: "p",
reason: format!("binomial requires p in [0, 1], got {}", p),
});
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(Tensor::empty(shape, dtype, self.device()));
}
let out = alloc_output(self, shape, dtype);
let out_buf = get_tensor_buffer(&out)?;
let seed = generate_wgpu_seed();
let params = BinomialParams {
numel: numel as u32,
seed,
n_trials: n as u32, p: p as f32,
};
let params_buf = create_params_buffer(self, ¶ms);
distributions::launch_binomial(
self.pipeline_cache(),
self.wgpu_queue(),
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
fn laplace(
&self,
loc: f64,
scale: f64,
shape: &[usize],
dtype: DType,
) -> Result<Tensor<WgpuRuntime>> {
if !matches!(dtype, DType::F32) {
return Err(Error::UnsupportedDType {
dtype,
op: "laplace (WebGPU only supports F32)",
});
}
if scale <= 0.0 {
return Err(Error::InvalidArgument {
arg: "scale",
reason: format!("laplace requires scale > 0, got {}", scale),
});
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(Tensor::empty(shape, dtype, self.device()));
}
let out = alloc_output(self, shape, dtype);
let out_buf = get_tensor_buffer(&out)?;
let seed = generate_wgpu_seed();
let params = LaplaceParams {
numel: numel as u32,
seed,
loc: loc as f32,
scale: scale as f32,
};
let params_buf = create_params_buffer(self, ¶ms);
distributions::launch_laplace(
self.pipeline_cache(),
self.wgpu_queue(),
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
fn chi_squared(&self, df: f64, shape: &[usize], dtype: DType) -> Result<Tensor<WgpuRuntime>> {
if !matches!(dtype, DType::F32) {
return Err(Error::UnsupportedDType {
dtype,
op: "chi_squared (WebGPU only supports F32)",
});
}
if df <= 0.0 {
return Err(Error::InvalidArgument {
arg: "df",
reason: format!("chi_squared requires df > 0, got {}", df),
});
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(Tensor::empty(shape, dtype, self.device()));
}
let out = alloc_output(self, shape, dtype);
let out_buf = get_tensor_buffer(&out)?;
let seed = generate_wgpu_seed();
let params = ChiSquaredParams {
numel: numel as u32,
seed,
df: df as f32,
_pad: 0,
};
let params_buf = create_params_buffer(self, ¶ms);
distributions::launch_chi_squared(
self.pipeline_cache(),
self.wgpu_queue(),
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
fn student_t(&self, df: f64, shape: &[usize], dtype: DType) -> Result<Tensor<WgpuRuntime>> {
if !matches!(dtype, DType::F32) {
return Err(Error::UnsupportedDType {
dtype,
op: "student_t (WebGPU only supports F32)",
});
}
if df <= 0.0 {
return Err(Error::InvalidArgument {
arg: "df",
reason: format!("student_t requires df > 0, got {}", df),
});
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(Tensor::empty(shape, dtype, self.device()));
}
let out = alloc_output(self, shape, dtype);
let out_buf = get_tensor_buffer(&out)?;
let seed = generate_wgpu_seed();
let params = StudentTParams {
numel: numel as u32,
seed,
df: df as f32,
_pad: 0,
};
let params_buf = create_params_buffer(self, ¶ms);
distributions::launch_student_t(
self.pipeline_cache(),
self.wgpu_queue(),
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
fn randperm(&self, n: usize) -> Result<Tensor<WgpuRuntime>> {
crate::ops::impl_generic::randperm_impl(self, n)
}
fn f_distribution(
&self,
df1: f64,
df2: f64,
shape: &[usize],
dtype: DType,
) -> Result<Tensor<WgpuRuntime>> {
if !matches!(dtype, DType::F32) {
return Err(Error::UnsupportedDType {
dtype,
op: "f_distribution (WebGPU only supports F32)",
});
}
if df1 <= 0.0 {
return Err(Error::InvalidArgument {
arg: "df1",
reason: format!("f_distribution requires df1 > 0, got {}", df1),
});
}
if df2 <= 0.0 {
return Err(Error::InvalidArgument {
arg: "df2",
reason: format!("f_distribution requires df2 > 0, got {}", df2),
});
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(Tensor::empty(shape, dtype, self.device()));
}
let out = alloc_output(self, shape, dtype);
let out_buf = get_tensor_buffer(&out)?;
let seed = generate_wgpu_seed();
let params = FDistributionParams {
numel: numel as u32,
seed,
df1: df1 as f32,
df2: df2 as f32,
};
let params_buf = create_params_buffer(self, ¶ms);
distributions::launch_f_distribution(
self.pipeline_cache(),
self.wgpu_queue(),
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
Ok(out)
}
}