use std::sync::Arc;
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::prelude::*;
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::{GpuFloat, MatrixDesc, Transpose};
const GEMV_BLOCK_SIZE: u32 = 256;
#[allow(clippy::too_many_arguments)]
pub fn gemv<T: GpuFloat>(
handle: &BlasHandle,
trans: Transpose,
m: u32,
n: u32,
alpha: T,
a: &MatrixDesc<T>,
x: &DeviceBuffer<T>,
incx: i32,
beta: T,
y: &mut DeviceBuffer<T>,
incy: i32,
) -> BlasResult<()> {
if m == 0 || n == 0 {
return Ok(());
}
if incx <= 0 {
return Err(BlasError::InvalidArgument(
"incx must be positive".to_string(),
));
}
if incy <= 0 {
return Err(BlasError::InvalidArgument(
"incy must be positive".to_string(),
));
}
validate_gemv_dimensions(trans, m, n, a, x, incx, y, incy)?;
let (output_len, inner_len) = match trans {
Transpose::NoTrans => (m, n),
Transpose::Trans | Transpose::ConjTrans => (n, m),
};
let ptx = generate_gemv_ptx::<T>(handle.sm_version(), trans)?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, "gemv")?;
let block_size = GEMV_BLOCK_SIZE;
let grid_size = grid_size_for(output_len, block_size);
let params = LaunchParams::new(grid_size, block_size);
kernel.launch(
¶ms,
handle.stream(),
&(
a.ptr,
x.as_device_ptr(),
y.as_device_ptr(),
alpha.to_bits_u64(),
beta.to_bits_u64(),
m,
n,
a.ld,
incx as u32,
incy as u32,
output_len,
inner_len,
),
)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn validate_gemv_dimensions<T: GpuFloat>(
trans: Transpose,
m: u32,
n: u32,
a: &MatrixDesc<T>,
x: &DeviceBuffer<T>,
incx: i32,
y: &DeviceBuffer<T>,
incy: i32,
) -> BlasResult<()> {
if a.rows < m {
return Err(BlasError::InvalidDimension(format!(
"A.rows ({}) < m ({})",
a.rows, m
)));
}
if a.cols < n {
return Err(BlasError::InvalidDimension(format!(
"A.cols ({}) < n ({})",
a.cols, n
)));
}
let (x_len, y_len) = match trans {
Transpose::NoTrans => (n, m),
Transpose::Trans | Transpose::ConjTrans => (m, n),
};
let x_required = required_elements(x_len, incx);
if x.len() < x_required {
return Err(BlasError::BufferTooSmall {
expected: x_required,
actual: x.len(),
});
}
let y_required = required_elements(y_len, incy);
if y.len() < y_required {
return Err(BlasError::BufferTooSmall {
expected: y_required,
actual: y.len(),
});
}
Ok(())
}
fn generate_gemv_ptx<T: GpuFloat>(sm: SmVersion, trans: Transpose) -> BlasResult<String> {
let suffix = T::NAME;
let ptx_ty = T::PTX_TYPE;
let elem_bytes = T::size_u32();
let is_f64 = elem_bytes == 8;
let _kernel_name = format!("gemv_{suffix}_{}", trans_label(trans));
KernelBuilder::new("gemv")
.target(sm)
.param("a_ptr", PtxType::U64)
.param("x_ptr", PtxType::U64)
.param("y_ptr", PtxType::U64)
.param("alpha_bits", PtxType::U64)
.param("beta_bits", PtxType::U64)
.param("m", PtxType::U32)
.param("n", PtxType::U32)
.param("lda", PtxType::U32)
.param("incx", PtxType::U32)
.param("incy", PtxType::U32)
.param("output_len", PtxType::U32)
.param("inner_len", PtxType::U32)
.body(move |b| {
let gid = b.global_thread_id_x();
let output_len = b.load_param_u32("output_len");
let gid_inner = gid.clone();
b.if_lt_u32(gid, output_len, move |b| {
let gid = gid_inner;
let a_ptr = b.load_param_u64("a_ptr");
let x_ptr = b.load_param_u64("x_ptr");
let y_ptr = b.load_param_u64("y_ptr");
let inner_len = b.load_param_u32("inner_len");
let lda = b.load_param_u32("lda");
let incx = b.load_param_u32("incx");
let incy = b.load_param_u32("incy");
let alpha_bits = b.load_param_u64("alpha_bits");
let beta_bits = b.load_param_u64("beta_bits");
let alpha = if is_f64 {
let r = b.alloc_reg(PtxType::F64);
b.raw_ptx(&format!("mov.b64 {r}, {alpha_bits};"));
r
} else {
let lo32 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("cvt.u32.u64 {lo32}, {alpha_bits};"));
let r = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {r}, {lo32};"));
r
};
let beta = if is_f64 {
let r = b.alloc_reg(PtxType::F64);
b.raw_ptx(&format!("mov.b64 {r}, {beta_bits};"));
r
} else {
let lo32 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("cvt.u32.u64 {lo32}, {beta_bits};"));
let r = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {r}, {lo32};"));
r
};
let acc = b.alloc_reg(ptx_ty);
if is_f64 {
b.raw_ptx(&format!("mov.b64 {acc}, 0d0000000000000000;"));
} else {
b.raw_ptx(&format!("mov.b32 {acc}, 0f00000000;"));
}
let use_trans = matches!(trans, Transpose::Trans | Transpose::ConjTrans);
let row_base = if !use_trans {
let stride = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {stride}, {lda}, {};", elem_bytes));
let row_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {row_idx}, {gid}, {stride};"));
let row_off64 = b.cvt_u32_to_u64(row_idx);
b.add_u64(a_ptr, row_off64)
} else {
b.byte_offset_addr(a_ptr, gid.clone(), elem_bytes)
};
let loop_label = b.fresh_label("gemv_loop");
let done_label = b.fresh_label("gemv_done");
let k = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {k}, 0;"));
b.label(&loop_label);
let pred_loop = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.lo.u32 {pred_loop}, {k}, {inner_len};"));
b.raw_ptx(&format!("@!{pred_loop} bra {done_label};"));
let a_addr = if !use_trans {
b.byte_offset_addr(row_base.clone(), k.clone(), elem_bytes)
} else {
let stride_reg = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {stride_reg}, {lda}, {};", elem_bytes));
let k64 = b.cvt_u32_to_u64(k.clone());
let stride64 = b.cvt_u32_to_u64(stride_reg);
let off = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("mul.lo.u64 {off}, {k64}, {stride64};"));
b.add_u64(row_base.clone(), off)
};
let x_elem_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {x_elem_idx}, {k}, {incx};"));
let x_addr = b.byte_offset_addr(x_ptr.clone(), x_elem_idx, elem_bytes);
let a_val = if is_f64 {
b.load_global_f64(a_addr)
} else {
b.load_global_f32(a_addr)
};
let x_val = if is_f64 {
b.load_global_f64(x_addr)
} else {
b.load_global_f32(x_addr)
};
let new_acc = if is_f64 {
b.fma_f64(a_val, x_val, acc.clone())
} else {
b.fma_f32(a_val, x_val, acc.clone())
};
b.raw_ptx(&format!(
"mov.{} {acc}, {new_acc};",
if is_f64 { "f64" } else { "f32" }
));
b.raw_ptx(&format!("add.u32 {k}, {k}, 1;"));
b.branch(&loop_label);
b.label(&done_label);
let y_elem_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {y_elem_idx}, {gid}, {incy};"));
let y_addr = b.byte_offset_addr(y_ptr, y_elem_idx, elem_bytes);
let y_cur = if is_f64 {
b.load_global_f64(y_addr.clone())
} else {
b.load_global_f32(y_addr.clone())
};
let alpha_acc = if is_f64 {
let tmp = b.alloc_reg(PtxType::F64);
b.raw_ptx(&format!("mul.rn.f64 {tmp}, {alpha}, {acc};"));
tmp
} else {
let tmp = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {tmp}, {alpha}, {acc};"));
tmp
};
let beta_y = if is_f64 {
let tmp = b.alloc_reg(PtxType::F64);
b.raw_ptx(&format!("mul.rn.f64 {tmp}, {beta}, {y_cur};"));
tmp
} else {
let tmp = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {tmp}, {beta}, {y_cur};"));
tmp
};
let result = if is_f64 {
b.add_f64(alpha_acc, beta_y)
} else {
b.add_f32(alpha_acc, beta_y)
};
if is_f64 {
b.store_global_f64(y_addr, result);
} else {
b.store_global_f32(y_addr, result);
}
});
b.ret();
})
.build()
.map_err(|e| BlasError::PtxGeneration(e.to_string()))
}
fn trans_label(t: Transpose) -> &'static str {
match t {
Transpose::NoTrans => "n",
Transpose::Trans => "t",
Transpose::ConjTrans => "c",
}
}
fn required_elements(n: u32, inc: i32) -> usize {
if n == 0 {
return 0;
}
1 + (n as usize - 1) * inc.unsigned_abs() as usize
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn trans_label_values() {
assert_eq!(trans_label(Transpose::NoTrans), "n");
assert_eq!(trans_label(Transpose::Trans), "t");
assert_eq!(trans_label(Transpose::ConjTrans), "c");
}
#[test]
fn required_elements_basic() {
assert_eq!(required_elements(0, 1), 0);
assert_eq!(required_elements(1, 1), 1);
assert_eq!(required_elements(5, 1), 5);
assert_eq!(required_elements(5, 2), 9);
}
#[test]
fn gemv_ptx_generation_f32() {
let ptx = generate_gemv_ptx::<f32>(SmVersion::Sm80, Transpose::NoTrans);
assert!(ptx.is_ok());
let ptx = ptx.expect("test: PTX generation should succeed");
assert!(ptx.contains(".entry gemv"));
assert!(ptx.contains(".target sm_80"));
}
#[test]
fn gemv_ptx_generation_f64() {
let ptx = generate_gemv_ptx::<f64>(SmVersion::Sm80, Transpose::Trans);
assert!(ptx.is_ok());
let ptx = ptx.expect("test: PTX generation should succeed");
assert!(ptx.contains(".entry gemv"));
}
}