use std::sync::Arc;
use cudarc::curand::result::{LogNormalFill, NormalFill, UniformFill};
use cudarc::curand::sys;
use tokio::sync::oneshot;
use crate::completion::CompletionStrategy;
use crate::dtype::RngFloatSupported;
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use crate::kernel::dispatch::RngDispatch;
use crate::kernel::envelope;
use super::LIB;
pub enum Distribution<T: RngFloatSupported> {
Uniform {
lo: T::Scalar,
hi: T::Scalar,
},
Normal {
mean: T::Scalar,
std: T::Scalar,
},
LogNormal {
mean: T::Scalar,
std: T::Scalar,
},
Poisson {
lambda: f64,
},
Exponential {
lambda: T::Scalar,
},
Beta {
alpha: T::Scalar,
beta: T::Scalar,
},
Cauchy {
loc: T::Scalar,
scale: T::Scalar,
},
Gamma {
shape: T::Scalar,
scale: T::Scalar,
},
Discrete {
weights: GpuRef<f32>,
},
}
pub struct FillRequest<T: RngFloatSupported> {
pub buf: GpuRef<T>,
pub dist: Distribution<T>,
pub reply: oneshot::Sender<Result<(), GpuError>>,
}
impl RngDispatch for FillRequest<f32> {
fn fill(
self: Box<Self>,
gen: sys::curandGenerator_t,
stream: &Arc<cudarc::driver::CudaStream>,
completion: &Arc<dyn CompletionStrategy>,
) -> Result<(), GpuError> {
fill_float::<f32>(*self, gen, stream, completion)
}
}
impl RngDispatch for FillRequest<f64> {
fn fill(
self: Box<Self>,
gen: sys::curandGenerator_t,
stream: &Arc<cudarc::driver::CudaStream>,
completion: &Arc<dyn CompletionStrategy>,
) -> Result<(), GpuError> {
fill_float::<f64>(*self, gen, stream, completion)
}
}
fn fill_float<T>(
req: FillRequest<T>,
gen: sys::curandGenerator_t,
stream: &Arc<cudarc::driver::CudaStream>,
completion: &Arc<dyn CompletionStrategy>,
) -> Result<(), GpuError>
where
T: RngFloatSupported,
sys::curandGenerator_t: UniformFill<T> + NormalFill<T> + LogNormalFill<T>,
T::Scalar: Into<f64> + Copy,
T: NormalParam<T::Scalar>,
{
let FillRequest { buf, dist, reply } = req;
match dist {
Distribution::Uniform { lo, hi } => {
enqueue_uniform::<T>(gen, stream, completion, buf, reply, lo, hi)
}
Distribution::Normal { mean, std } => {
enqueue_normal::<T>(gen, stream, completion, buf, mean, std, reply)
}
Distribution::LogNormal { mean, std } => {
enqueue_log_normal::<T>(gen, stream, completion, buf, mean, std, reply)
}
Distribution::Poisson { lambda } => {
let _ = (gen, stream, completion, buf);
let _ = lambda;
let _ = reply.send(Err(GpuError::LibraryError {
lib: LIB,
msg:
"Poisson<T> not yet wired for floats (Phase 1: use FillRequest<u32> + Poisson)"
.into(),
}));
Ok(())
}
Distribution::Exponential { .. }
| Distribution::Beta { .. }
| Distribution::Cauchy { .. }
| Distribution::Gamma { .. }
| Distribution::Discrete { .. } => {
let _ = (gen, stream, completion, buf);
let _ = reply.send(Err(GpuError::LibraryError {
lib: LIB,
msg: "distribution not yet wired (Phase 1: needs custom kernel / NVRTC)".into(),
}));
Ok(())
}
}
}
fn enqueue_uniform<T>(
gen: sys::curandGenerator_t,
stream: &Arc<cudarc::driver::CudaStream>,
completion: &Arc<dyn CompletionStrategy>,
dst: GpuRef<T>,
reply: oneshot::Sender<Result<(), GpuError>>,
lo: T::Scalar,
hi: T::Scalar,
) -> Result<(), GpuError>
where
T: RngFloatSupported,
T::Scalar: Into<f64> + Copy,
sys::curandGenerator_t: UniformFill<T>,
{
let lo_f: f64 = lo.into();
let hi_f: f64 = hi.into();
let trivial = lo_f == 0.0 && hi_f == 1.0;
let dst_arc = match dst.access() {
Ok(s) => s.clone(),
Err(e) => {
let _ = reply.send(Err(e));
return Ok(());
}
};
let mut owned = match Arc::try_unwrap(dst_arc) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"RNG dst has multiple live references".into(),
)));
return Ok(());
}
};
if !trivial {
let _ = reply.send(Err(GpuError::LibraryError {
lib: LIB,
msg: format!(
"Uniform({lo_f},{hi_f}): non-(0,1] bounds need an affine transform kernel (Phase 1: not wired)"
),
}));
return Ok(());
}
dst.record_write(stream);
envelope::run_kernel(LIB, stream, completion, (), reply, move || {
let n = owned.len();
let res = unsafe {
let (ptr, _rec) = cudarc::driver::DevicePtrMut::device_ptr_mut(&mut owned, stream);
UniformFill::fill(gen, ptr as *mut T, n)
};
res.map(|_| (owned,)).map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("fill_uniform: {e}"),
})
});
Ok(())
}
fn enqueue_normal<T>(
gen: sys::curandGenerator_t,
stream: &Arc<cudarc::driver::CudaStream>,
completion: &Arc<dyn CompletionStrategy>,
dst: GpuRef<T>,
mean: T::Scalar,
std: T::Scalar,
reply: oneshot::Sender<Result<(), GpuError>>,
) -> Result<(), GpuError>
where
T: RngFloatSupported + NormalParam<T::Scalar>,
sys::curandGenerator_t: NormalFill<T>,
{
let dst_arc = match dst.access() {
Ok(s) => s.clone(),
Err(e) => {
let _ = reply.send(Err(e));
return Ok(());
}
};
let mut owned = match Arc::try_unwrap(dst_arc) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"RNG dst has multiple live references".into(),
)));
return Ok(());
}
};
let mean_t = T::from_scalar(mean);
let std_t = T::from_scalar(std);
dst.record_write(stream);
envelope::run_kernel(LIB, stream, completion, (), reply, move || {
let n = owned.len();
let res = unsafe {
let (ptr, _rec) = cudarc::driver::DevicePtrMut::device_ptr_mut(&mut owned, stream);
NormalFill::fill(gen, ptr as *mut T, n, mean_t, std_t)
};
res.map(|_| (owned,)).map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("fill_normal: {e}"),
})
});
Ok(())
}
fn enqueue_log_normal<T>(
gen: sys::curandGenerator_t,
stream: &Arc<cudarc::driver::CudaStream>,
completion: &Arc<dyn CompletionStrategy>,
dst: GpuRef<T>,
mean: T::Scalar,
std: T::Scalar,
reply: oneshot::Sender<Result<(), GpuError>>,
) -> Result<(), GpuError>
where
T: RngFloatSupported + NormalParam<T::Scalar>,
sys::curandGenerator_t: LogNormalFill<T>,
{
let dst_arc = match dst.access() {
Ok(s) => s.clone(),
Err(e) => {
let _ = reply.send(Err(e));
return Ok(());
}
};
let mut owned = match Arc::try_unwrap(dst_arc) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"RNG dst has multiple live references".into(),
)));
return Ok(());
}
};
let mean_t = T::from_scalar(mean);
let std_t = T::from_scalar(std);
dst.record_write(stream);
envelope::run_kernel(LIB, stream, completion, (), reply, move || {
let n = owned.len();
let res = unsafe {
let (ptr, _rec) = cudarc::driver::DevicePtrMut::device_ptr_mut(&mut owned, stream);
LogNormalFill::fill(gen, ptr as *mut T, n, mean_t, std_t)
};
res.map(|_| (owned,)).map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("fill_log_normal: {e}"),
})
});
Ok(())
}
pub trait NormalParam<S>: Sized {
fn from_scalar(s: S) -> Self;
}
impl NormalParam<f32> for f32 {
fn from_scalar(s: f32) -> Self {
s
}
}
impl NormalParam<f64> for f64 {
fn from_scalar(s: f64) -> Self {
s
}
}
pub(crate) fn fill_uniform_u32(
gen: sys::curandGenerator_t,
stream: &Arc<cudarc::driver::CudaStream>,
completion: &Arc<dyn CompletionStrategy>,
dst: GpuRef<u32>,
reply: oneshot::Sender<Result<(), GpuError>>,
) {
let dst_arc = match dst.access() {
Ok(s) => s.clone(),
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let mut owned = match Arc::try_unwrap(dst_arc) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"RNG dst has multiple live references".into(),
)));
return;
}
};
dst.record_write(stream);
envelope::run_kernel(LIB, stream, completion, (), reply, move || {
let n = owned.len();
let res = unsafe {
let (ptr, _rec) = cudarc::driver::DevicePtrMut::device_ptr_mut(&mut owned, stream);
UniformFill::fill(gen, ptr as *mut u32, n)
};
res.map(|_| (owned,)).map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("fill_uniform_u32: {e}"),
})
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn distribution_round_trip_f32_f64() {
let _: Distribution<f32> = Distribution::Uniform { lo: 0.0, hi: 1.0 };
let _: Distribution<f32> = Distribution::Normal {
mean: 0.0,
std: 1.0,
};
let _: Distribution<f32> = Distribution::LogNormal {
mean: 0.0,
std: 1.0,
};
let _: Distribution<f32> = Distribution::Poisson { lambda: 1.0 };
let _: Distribution<f32> = Distribution::Exponential { lambda: 1.0 };
let _: Distribution<f32> = Distribution::Beta {
alpha: 1.0,
beta: 1.0,
};
let _: Distribution<f32> = Distribution::Cauchy {
loc: 0.0,
scale: 1.0,
};
let _: Distribution<f32> = Distribution::Gamma {
shape: 1.0,
scale: 1.0,
};
let _: Distribution<f64> = Distribution::Uniform { lo: 0.0, hi: 1.0 };
let _: Distribution<f64> = Distribution::Normal {
mean: 0.0,
std: 1.0,
};
let _: Distribution<f64> = Distribution::LogNormal {
mean: 0.0,
std: 1.0,
};
let _: Distribution<f64> = Distribution::Poisson { lambda: 1.0 };
let _: Distribution<f64> = Distribution::Exponential { lambda: 1.0 };
let _: Distribution<f64> = Distribution::Beta {
alpha: 1.0,
beta: 1.0,
};
let _: Distribution<f64> = Distribution::Cauchy {
loc: 0.0,
scale: 1.0,
};
let _: Distribution<f64> = Distribution::Gamma {
shape: 1.0,
scale: 1.0,
};
}
#[test]
#[allow(deprecated)]
fn deprecated_fill_uniform_f32_still_works() {
let (tx, _rx) = tokio::sync::oneshot::channel::<Result<(), GpuError>>();
let _ = std::mem::ManuallyDrop::new(tx);
fn _assert<
F: FnOnce(GpuRef<f32>, oneshot::Sender<Result<(), GpuError>>) -> super::super::RngMsg,
>(
_f: F,
) {
}
_assert(|dst, reply| super::super::RngMsg::FillUniformF32 { dst, reply });
}
}