use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use crate::graph::{GraphOp, GraphRecordCtx};
use crate::kernel::record::{RecordMode, RngFillUniformOp as InnerRngFillUniformOp, RngRecorder};
pub struct RngFillUniformOp {
inner: Option<InnerRngFillUniformOp>,
}
impl RngFillUniformOp {
pub fn new(dst: GpuRef<f32>) -> Self {
Self {
inner: Some(InnerRngFillUniformOp { dst }),
}
}
}
impl GraphOp for RngFillUniformOp {
fn record(&mut self, ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
let stream = ctx.require_stream()?;
let rng = ctx.rng.ok_or_else(|| {
GpuError::Unrecoverable(
"RngFillUniformOp::record: cuRAND handle not available in ctx".into(),
)
})?;
let op = self.inner.take().ok_or_else(|| {
GpuError::Unrecoverable("RngFillUniformOp::record: already consumed".into())
})?;
let mut recorder = RngRecorder { rng };
recorder.enqueue_record(stream, op)
}
fn op_name(&self) -> &'static str {
"graph::rng_fill_uniform"
}
}