use std::sync::Arc;
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::templates::elementwise::{ElementwiseOp, ElementwiseTemplate};
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::GpuFloat;
const BLOCK_SIZE: u32 = 256;
pub fn fill<T: GpuFloat>(
handle: &BlasHandle,
dst: &mut DeviceBuffer<T>,
value: T,
n: u32,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
if dst.len() < n as usize {
return Err(BlasError::BufferTooSmall {
expected: n as usize,
actual: dst.len(),
});
}
let template = ElementwiseTemplate::new(ElementwiseOp::Fill, T::PTX_TYPE, handle.sm_version());
let kernel_name = template.kernel_name();
let ptx_source = template
.generate()
.map_err(|e| BlasError::PtxGeneration(format!("fill: {e}")))?;
let module = Arc::new(
Module::from_ptx(&ptx_source)
.map_err(|e| BlasError::LaunchFailed(format!("fill module load: {e}")))?,
);
let kernel = Kernel::from_module(module, &kernel_name)
.map_err(|e| BlasError::LaunchFailed(format!("fill kernel lookup: {e}")))?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let value_bits = value.to_bits_u64();
let args = (dst.as_device_ptr(), value_bits, n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("fill launch: {e}")))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn block_size_is_power_of_two() {
assert!(BLOCK_SIZE.is_power_of_two());
const { assert!(BLOCK_SIZE >= 32) };
}
#[test]
fn ptx_template_generates_fill_f32() {
let template = ElementwiseTemplate::new(
ElementwiseOp::Fill,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template.generate().expect("fill PTX generation failed");
assert!(ptx.contains("elementwise_fill_f32"), "wrong kernel name");
assert!(
ptx.contains("st.global.f32"),
"must contain store instruction"
);
assert!(ptx.contains("ld.param.f32"), "must load scalar from param");
}
#[test]
fn ptx_template_generates_fill_f64() {
let template = ElementwiseTemplate::new(
ElementwiseOp::Fill,
oxicuda_ptx::ir::PtxType::F64,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template.generate().expect("fill f64 PTX generation failed");
assert!(ptx.contains("elementwise_fill_f64"), "wrong kernel name");
assert!(ptx.contains("st.global.f64"), "must contain f64 store");
}
}