use std::fmt::Write as FmtWrite;
use std::sync::Arc;
use oxicuda_driver::Module;
use oxicuda_launch::{Dim3, Kernel, LaunchParams, grid_size_for};
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::{GpuFloat, Transpose};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::arch::SmVersion;
const CGEMM_BLOCK_X: u32 = 16;
const CGEMM_BLOCK_Y: u32 = 16;
const CGEMV_BLOCK: u32 = 256;
#[allow(clippy::too_many_arguments)]
pub fn complex_gemm<T: GpuFloat>(
handle: &BlasHandle,
transa: Transpose,
transb: Transpose,
m: usize,
n: usize,
k: usize,
alpha_re: T,
alpha_im: T,
a: &DeviceBuffer<T>,
lda: usize,
b: &DeviceBuffer<T>,
ldb: usize,
beta_re: T,
beta_im: T,
c: &mut DeviceBuffer<T>,
ldc: usize,
) -> BlasResult<()> {
if m == 0 || n == 0 || k == 0 {
return Err(BlasError::InvalidDimension(
"complex GEMM: all dimensions must be non-zero".into(),
));
}
let a_required = match transa {
Transpose::NoTrans => 2 * lda * k,
Transpose::Trans | Transpose::ConjTrans => 2 * lda * m,
};
if a.len() < a_required {
return Err(BlasError::BufferTooSmall {
expected: a_required,
actual: a.len(),
});
}
let b_required = match transb {
Transpose::NoTrans => 2 * ldb * n,
Transpose::Trans | Transpose::ConjTrans => 2 * ldb * k,
};
if b.len() < b_required {
return Err(BlasError::BufferTooSmall {
expected: b_required,
actual: b.len(),
});
}
let c_required = 2 * ldc * n;
if c.len() < c_required {
return Err(BlasError::BufferTooSmall {
expected: c_required,
actual: c.len(),
});
}
validate_complex_ld(transa, m, k, lda, "A")?;
validate_complex_ld(transb, k, n, ldb, "B")?;
if ldc < m {
return Err(BlasError::InvalidDimension(format!(
"complex GEMM: ldc ({ldc}) < m ({m})"
)));
}
let ptx = generate_complex_gemm_ptx::<T>(handle.sm_version(), transa, transb)?;
let module = Arc::new(Module::from_ptx(&ptx).map_err(BlasError::Cuda)?);
let kernel_name = complex_gemm_kernel_name::<T>(transa, transb);
let kernel = Kernel::from_module(module, &kernel_name).map_err(BlasError::Cuda)?;
let grid_x = grid_size_for(n as u32, CGEMM_BLOCK_X);
let grid_y = grid_size_for(m as u32, CGEMM_BLOCK_Y);
let grid = Dim3::xy(grid_x, grid_y);
let block = Dim3::xy(CGEMM_BLOCK_X, CGEMM_BLOCK_Y);
let params = LaunchParams::new(grid, block);
let args = (
a.as_device_ptr(),
b.as_device_ptr(),
c.as_device_ptr(),
m as u32,
n as u32,
k as u32,
lda as u32,
ldb as u32,
ldc as u32,
alpha_re,
alpha_im,
beta_re,
beta_im,
);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(BlasError::Cuda)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn complex_gemv<T: GpuFloat>(
handle: &BlasHandle,
trans: Transpose,
m: usize,
n: usize,
alpha_re: T,
alpha_im: T,
a: &DeviceBuffer<T>,
lda: usize,
x: &DeviceBuffer<T>,
incx: usize,
beta_re: T,
beta_im: T,
y: &mut DeviceBuffer<T>,
incy: usize,
) -> BlasResult<()> {
if m == 0 || n == 0 {
return Ok(());
}
if incx == 0 {
return Err(BlasError::InvalidArgument(
"complex GEMV: incx must be positive".into(),
));
}
if incy == 0 {
return Err(BlasError::InvalidArgument(
"complex GEMV: incy must be positive".into(),
));
}
let a_required = 2 * lda * n;
if a.len() < a_required {
return Err(BlasError::BufferTooSmall {
expected: a_required,
actual: a.len(),
});
}
let (x_len, y_len) = match trans {
Transpose::NoTrans => (n, m),
Transpose::Trans | Transpose::ConjTrans => (m, n),
};
let x_required = 2 * (1 + (x_len.saturating_sub(1)) * incx);
if x.len() < x_required {
return Err(BlasError::BufferTooSmall {
expected: x_required,
actual: x.len(),
});
}
let y_required = 2 * (1 + (y_len.saturating_sub(1)) * incy);
if y.len() < y_required {
return Err(BlasError::BufferTooSmall {
expected: y_required,
actual: y.len(),
});
}
if lda < m {
return Err(BlasError::InvalidDimension(format!(
"complex GEMV: lda ({lda}) < m ({m})"
)));
}
let (output_len, inner_len) = match trans {
Transpose::NoTrans => (m, n),
Transpose::Trans | Transpose::ConjTrans => (n, m),
};
let ptx = generate_complex_gemv_ptx::<T>(handle.sm_version(), trans)?;
let ta = trans_label(trans);
let gemv_kernel_name = format!("complex_gemv_{}_{ta}", T::NAME);
let module = Arc::new(Module::from_ptx(&ptx).map_err(BlasError::Cuda)?);
let kernel = Kernel::from_module(module, &gemv_kernel_name).map_err(BlasError::Cuda)?;
let grid = grid_size_for(output_len as u32, CGEMV_BLOCK);
let params = LaunchParams::new(grid, CGEMV_BLOCK);
let args = (
a.as_device_ptr(),
x.as_device_ptr(),
y.as_device_ptr(),
m as u32,
n as u32,
lda as u32,
incx as u32,
incy as u32,
alpha_re,
alpha_im,
beta_re,
beta_im,
output_len as u32,
inner_len as u32,
);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(BlasError::Cuda)?;
Ok(())
}
fn generate_complex_gemm_ptx<T: GpuFloat>(
sm: SmVersion,
transa: Transpose,
transb: Transpose,
) -> BlasResult<String> {
let byte_size = T::PTX_TYPE.size_bytes();
let kernel_name = complex_gemm_kernel_name::<T>(transa, transb);
let is_f64 = byte_size == 8;
let (fr, ld_ty) = if is_f64 { ("fd", "f64") } else { ("f", "f32") };
let zero_lit = if is_f64 {
"0d0000000000000000"
} else {
"0f00000000"
};
let mut p = String::with_capacity(8192);
wl(&mut p, &format!(".version {}", sm.ptx_version()))?;
wl(&mut p, &format!(".target {}", sm.as_ptx_str()))?;
wl(&mut p, ".address_size 64")?;
wl(&mut p, "")?;
wl(&mut p, &format!(".visible .entry {kernel_name}("))?;
wl(&mut p, " .param .u64 %param_a,")?;
wl(&mut p, " .param .u64 %param_b,")?;
wl(&mut p, " .param .u64 %param_c,")?;
wl(&mut p, " .param .u32 %param_m,")?;
wl(&mut p, " .param .u32 %param_n,")?;
wl(&mut p, " .param .u32 %param_k,")?;
wl(&mut p, " .param .u32 %param_lda,")?;
wl(&mut p, " .param .u32 %param_ldb,")?;
wl(&mut p, " .param .u32 %param_ldc,")?;
wl(&mut p, &format!(" .param .{ld_ty} %param_alpha_re,"))?;
wl(&mut p, &format!(" .param .{ld_ty} %param_alpha_im,"))?;
wl(&mut p, &format!(" .param .{ld_ty} %param_beta_re,"))?;
wl(&mut p, &format!(" .param .{ld_ty} %param_beta_im"))?;
wl(&mut p, ")")?;
wl(&mut p, "{")?;
wl(&mut p, " .reg .b32 %r<48>;")?;
wl(&mut p, " .reg .b64 %rd<24>;")?;
if is_f64 {
wl(&mut p, " .reg .f64 %fd<32>;")?;
} else {
wl(&mut p, " .reg .f32 %f<32>;")?;
}
wl(&mut p, " .reg .pred %p<8>;")?;
wl(&mut p, "")?;
wl(&mut p, " mov.u32 %r0, %tid.x;")?;
wl(&mut p, " mov.u32 %r1, %tid.y;")?;
wl(&mut p, " mov.u32 %r2, %ctaid.x;")?;
wl(&mut p, " mov.u32 %r3, %ctaid.y;")?;
wl(&mut p, " mov.u32 %r4, %ntid.x;")?;
wl(&mut p, " mov.u32 %r5, %ntid.y;")?;
wl(&mut p, " mad.lo.u32 %r6, %r2, %r4, %r0; // col")?;
wl(&mut p, " mad.lo.u32 %r7, %r3, %r5, %r1; // row")?;
wl(&mut p, "")?;
wl(&mut p, " ld.param.u64 %rd0, [%param_a];")?;
wl(&mut p, " ld.param.u64 %rd1, [%param_b];")?;
wl(&mut p, " ld.param.u64 %rd2, [%param_c];")?;
wl(&mut p, " ld.param.u32 %r8, [%param_m];")?;
wl(&mut p, " ld.param.u32 %r9, [%param_n];")?;
wl(&mut p, " ld.param.u32 %r10, [%param_k];")?;
wl(&mut p, " ld.param.u32 %r30, [%param_lda];")?;
wl(&mut p, " ld.param.u32 %r31, [%param_ldb];")?;
wl(&mut p, " ld.param.u32 %r32, [%param_ldc];")?;
wl(
&mut p,
&format!(" ld.param.{ld_ty} %{fr}20, [%param_alpha_re];"),
)?;
wl(
&mut p,
&format!(" ld.param.{ld_ty} %{fr}21, [%param_alpha_im];"),
)?;
wl(
&mut p,
&format!(" ld.param.{ld_ty} %{fr}22, [%param_beta_re];"),
)?;
wl(
&mut p,
&format!(" ld.param.{ld_ty} %{fr}23, [%param_beta_im];"),
)?;
wl(&mut p, "")?;
wl(&mut p, " setp.ge.u32 %p0, %r7, %r8;")?;
wl(&mut p, " setp.ge.u32 %p1, %r6, %r9;")?;
wl(&mut p, " @%p0 bra $CGEMM_DONE;")?;
wl(&mut p, " @%p1 bra $CGEMM_DONE;")?;
wl(&mut p, "")?;
wl(
&mut p,
&format!(" mov.{ld_ty} %{fr}0, {zero_lit}; // acc_re"),
)?;
wl(
&mut p,
&format!(" mov.{ld_ty} %{fr}1, {zero_lit}; // acc_im"),
)?;
wl(&mut p, " mov.u32 %r11, 0;")?;
wl(&mut p, "")?;
wl(&mut p, "$CGEMM_K_LOOP:")?;
wl(&mut p, " setp.ge.u32 %p2, %r11, %r10;")?;
wl(&mut p, " @%p2 bra $CGEMM_K_DONE;")?;
let (a_maj, a_min) = match transa {
Transpose::NoTrans => ("%r7", "%r11"),
Transpose::Trans | Transpose::ConjTrans => ("%r11", "%r7"),
};
wl(
&mut p,
&format!(" mad.lo.u32 %r12, {a_maj}, %r30, {a_min};"),
)?;
wl(&mut p, " shl.b32 %r12, %r12, 1; // *2 for complex")?;
wl(&mut p, " cvt.u64.u32 %rd3, %r12;")?;
wl(&mut p, &format!(" mul.lo.u64 %rd3, %rd3, {byte_size};"))?;
wl(&mut p, " add.u64 %rd4, %rd0, %rd3;")?;
wl(&mut p, &format!(" ld.global.{ld_ty} %{fr}2, [%rd4];"))?;
wl(
&mut p,
&format!(" ld.global.{ld_ty} %{fr}3, [%rd4+{byte_size}];"),
)?;
if transa == Transpose::ConjTrans {
wl(&mut p, &format!(" neg.{ld_ty} %{fr}3, %{fr}3;"))?;
}
let (b_maj, b_min) = match transb {
Transpose::NoTrans => ("%r11", "%r6"),
Transpose::Trans | Transpose::ConjTrans => ("%r6", "%r11"),
};
wl(
&mut p,
&format!(" mad.lo.u32 %r13, {b_maj}, %r31, {b_min};"),
)?;
wl(&mut p, " shl.b32 %r13, %r13, 1;")?;
wl(&mut p, " cvt.u64.u32 %rd5, %r13;")?;
wl(&mut p, &format!(" mul.lo.u64 %rd5, %rd5, {byte_size};"))?;
wl(&mut p, " add.u64 %rd6, %rd1, %rd5;")?;
wl(&mut p, &format!(" ld.global.{ld_ty} %{fr}4, [%rd6];"))?;
wl(
&mut p,
&format!(" ld.global.{ld_ty} %{fr}5, [%rd6+{byte_size}];"),
)?;
if transb == Transpose::ConjTrans {
wl(&mut p, &format!(" neg.{ld_ty} %{fr}5, %{fr}5;"))?;
}
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}0, %{fr}2, %{fr}4, %{fr}0; // acc_re += a_re*b_re"),
)?;
wl(&mut p, &format!(" neg.{ld_ty} %{fr}6, %{fr}3;"))?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}0, %{fr}6, %{fr}5, %{fr}0; // acc_re -= a_im*b_im"),
)?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}1, %{fr}2, %{fr}5, %{fr}1; // acc_im += a_re*b_im"),
)?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}1, %{fr}3, %{fr}4, %{fr}1; // acc_im += a_im*b_re"),
)?;
wl(&mut p, " add.u32 %r11, %r11, 1;")?;
wl(&mut p, " bra $CGEMM_K_LOOP;")?;
wl(&mut p, "$CGEMM_K_DONE:")?;
wl(&mut p, "")?;
wl(&mut p, " mad.lo.u32 %r14, %r7, %r32, %r6;")?;
wl(&mut p, " shl.b32 %r14, %r14, 1; // *2 for complex")?;
wl(&mut p, " cvt.u64.u32 %rd7, %r14;")?;
wl(&mut p, &format!(" mul.lo.u64 %rd7, %rd7, {byte_size};"))?;
wl(&mut p, " add.u64 %rd8, %rd2, %rd7;")?;
wl(
&mut p,
&format!(" ld.global.{ld_ty} %{fr}10, [%rd8]; // c_re"),
)?;
wl(
&mut p,
&format!(" ld.global.{ld_ty} %{fr}11, [%rd8+{byte_size}]; // c_im"),
)?;
wl(
&mut p,
&format!(" mul.rn.{ld_ty} %{fr}12, %{fr}20, %{fr}0; // alpha_re * acc_re"),
)?;
wl(&mut p, &format!(" neg.{ld_ty} %{fr}15, %{fr}21;"))?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}12, %{fr}15, %{fr}1, %{fr}12; // - alpha_im * acc_im"),
)?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}12, %{fr}22, %{fr}10, %{fr}12; // + beta_re * c_re"),
)?;
wl(&mut p, &format!(" neg.{ld_ty} %{fr}16, %{fr}23;"))?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}12, %{fr}16, %{fr}11, %{fr}12; // - beta_im * c_im"),
)?;
wl(
&mut p,
&format!(" mul.rn.{ld_ty} %{fr}13, %{fr}20, %{fr}1; // alpha_re * acc_im"),
)?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}13, %{fr}21, %{fr}0, %{fr}13; // + alpha_im * acc_re"),
)?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}13, %{fr}22, %{fr}11, %{fr}13; // + beta_re * c_im"),
)?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}13, %{fr}23, %{fr}10, %{fr}13; // + beta_im * c_re"),
)?;
wl(&mut p, &format!(" st.global.{ld_ty} [%rd8], %{fr}12;"))?;
wl(
&mut p,
&format!(" st.global.{ld_ty} [%rd8+{byte_size}], %{fr}13;"),
)?;
wl(&mut p, "")?;
wl(&mut p, "$CGEMM_DONE:")?;
wl(&mut p, " ret;")?;
wl(&mut p, "}")?;
Ok(p)
}
fn generate_complex_gemv_ptx<T: GpuFloat>(sm: SmVersion, trans: Transpose) -> BlasResult<String> {
let byte_size = T::PTX_TYPE.size_bytes();
let is_f64 = byte_size == 8;
let (fr, ld_ty) = if is_f64 { ("fd", "f64") } else { ("f", "f32") };
let zero_lit = if is_f64 {
"0d0000000000000000"
} else {
"0f00000000"
};
let ta = trans_label(trans);
let kernel_name = format!("complex_gemv_{}_{ta}", T::NAME);
let mut p = String::with_capacity(4096);
wl(&mut p, &format!(".version {}", sm.ptx_version()))?;
wl(&mut p, &format!(".target {}", sm.as_ptx_str()))?;
wl(&mut p, ".address_size 64")?;
wl(&mut p, "")?;
wl(&mut p, &format!(".visible .entry {kernel_name}("))?;
wl(&mut p, " .param .u64 %param_a,")?;
wl(&mut p, " .param .u64 %param_x,")?;
wl(&mut p, " .param .u64 %param_y,")?;
wl(&mut p, " .param .u32 %param_m,")?;
wl(&mut p, " .param .u32 %param_n,")?;
wl(&mut p, " .param .u32 %param_lda,")?;
wl(&mut p, " .param .u32 %param_incx,")?;
wl(&mut p, " .param .u32 %param_incy,")?;
wl(&mut p, &format!(" .param .{ld_ty} %param_alpha_re,"))?;
wl(&mut p, &format!(" .param .{ld_ty} %param_alpha_im,"))?;
wl(&mut p, &format!(" .param .{ld_ty} %param_beta_re,"))?;
wl(&mut p, &format!(" .param .{ld_ty} %param_beta_im,"))?;
wl(&mut p, " .param .u32 %param_output_len,")?;
wl(&mut p, " .param .u32 %param_inner_len")?;
wl(&mut p, ")")?;
wl(&mut p, "{")?;
wl(&mut p, " .reg .b32 %r<32>;")?;
wl(&mut p, " .reg .b64 %rd<16>;")?;
if is_f64 {
wl(&mut p, " .reg .f64 %fd<24>;")?;
} else {
wl(&mut p, " .reg .f32 %f<24>;")?;
}
wl(&mut p, " .reg .pred %p<4>;")?;
wl(&mut p, "")?;
wl(&mut p, " mov.u32 %r0, %tid.x;")?;
wl(&mut p, " mov.u32 %r1, %ctaid.x;")?;
wl(&mut p, " mov.u32 %r2, %ntid.x;")?;
wl(&mut p, " mad.lo.u32 %r3, %r1, %r2, %r0; // gid")?;
wl(&mut p, " ld.param.u32 %r4, [%param_output_len];")?;
wl(&mut p, " setp.ge.u32 %p0, %r3, %r4;")?;
wl(&mut p, " @%p0 bra $CGEMV_DONE;")?;
wl(&mut p, "")?;
wl(&mut p, " ld.param.u64 %rd0, [%param_a];")?;
wl(&mut p, " ld.param.u64 %rd1, [%param_x];")?;
wl(&mut p, " ld.param.u64 %rd2, [%param_y];")?;
wl(&mut p, " ld.param.u32 %r5, [%param_lda];")?;
wl(&mut p, " ld.param.u32 %r6, [%param_incx];")?;
wl(&mut p, " ld.param.u32 %r7, [%param_incy];")?;
wl(&mut p, " ld.param.u32 %r8, [%param_inner_len];")?;
wl(
&mut p,
&format!(" ld.param.{ld_ty} %{fr}16, [%param_alpha_re];"),
)?;
wl(
&mut p,
&format!(" ld.param.{ld_ty} %{fr}17, [%param_alpha_im];"),
)?;
wl(
&mut p,
&format!(" ld.param.{ld_ty} %{fr}18, [%param_beta_re];"),
)?;
wl(
&mut p,
&format!(" ld.param.{ld_ty} %{fr}19, [%param_beta_im];"),
)?;
wl(&mut p, "")?;
wl(
&mut p,
&format!(" mov.{ld_ty} %{fr}0, {zero_lit}; // acc_re"),
)?;
wl(
&mut p,
&format!(" mov.{ld_ty} %{fr}1, {zero_lit}; // acc_im"),
)?;
wl(&mut p, " mov.u32 %r9, 0; // k")?;
wl(&mut p, "")?;
wl(&mut p, "$CGEMV_LOOP:")?;
wl(&mut p, " setp.ge.u32 %p1, %r9, %r8;")?;
wl(&mut p, " @%p1 bra $CGEMV_LOOP_DONE;")?;
let use_trans = matches!(trans, Transpose::Trans | Transpose::ConjTrans);
if !use_trans {
wl(&mut p, " mad.lo.u32 %r10, %r3, %r5, %r9;")?;
} else {
wl(&mut p, " mad.lo.u32 %r10, %r9, %r5, %r3;")?;
}
wl(&mut p, " shl.b32 %r10, %r10, 1;")?;
wl(&mut p, " cvt.u64.u32 %rd3, %r10;")?;
wl(&mut p, &format!(" mul.lo.u64 %rd3, %rd3, {byte_size};"))?;
wl(&mut p, " add.u64 %rd4, %rd0, %rd3;")?;
wl(&mut p, &format!(" ld.global.{ld_ty} %{fr}2, [%rd4];"))?;
wl(
&mut p,
&format!(" ld.global.{ld_ty} %{fr}3, [%rd4+{byte_size}];"),
)?;
if trans == Transpose::ConjTrans {
wl(&mut p, &format!(" neg.{ld_ty} %{fr}3, %{fr}3;"))?;
}
wl(&mut p, " mul.lo.u32 %r11, %r9, %r6;")?;
wl(&mut p, " shl.b32 %r11, %r11, 1;")?;
wl(&mut p, " cvt.u64.u32 %rd5, %r11;")?;
wl(&mut p, &format!(" mul.lo.u64 %rd5, %rd5, {byte_size};"))?;
wl(&mut p, " add.u64 %rd6, %rd1, %rd5;")?;
wl(&mut p, &format!(" ld.global.{ld_ty} %{fr}4, [%rd6];"))?;
wl(
&mut p,
&format!(" ld.global.{ld_ty} %{fr}5, [%rd6+{byte_size}];"),
)?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}0, %{fr}2, %{fr}4, %{fr}0;"),
)?;
wl(&mut p, &format!(" neg.{ld_ty} %{fr}6, %{fr}3;"))?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}0, %{fr}6, %{fr}5, %{fr}0;"),
)?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}1, %{fr}2, %{fr}5, %{fr}1;"),
)?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}1, %{fr}3, %{fr}4, %{fr}1;"),
)?;
wl(&mut p, " add.u32 %r9, %r9, 1;")?;
wl(&mut p, " bra $CGEMV_LOOP;")?;
wl(&mut p, "$CGEMV_LOOP_DONE:")?;
wl(&mut p, "")?;
wl(&mut p, " mul.lo.u32 %r12, %r3, %r7;")?;
wl(&mut p, " shl.b32 %r12, %r12, 1;")?;
wl(&mut p, " cvt.u64.u32 %rd7, %r12;")?;
wl(&mut p, &format!(" mul.lo.u64 %rd7, %rd7, {byte_size};"))?;
wl(&mut p, " add.u64 %rd8, %rd2, %rd7;")?;
wl(&mut p, &format!(" ld.global.{ld_ty} %{fr}10, [%rd8];"))?;
wl(
&mut p,
&format!(" ld.global.{ld_ty} %{fr}11, [%rd8+{byte_size}];"),
)?;
wl(
&mut p,
&format!(" mul.rn.{ld_ty} %{fr}12, %{fr}16, %{fr}0;"),
)?;
wl(&mut p, &format!(" neg.{ld_ty} %{fr}14, %{fr}17;"))?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}12, %{fr}14, %{fr}1, %{fr}12;"),
)?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}12, %{fr}18, %{fr}10, %{fr}12;"),
)?;
wl(&mut p, &format!(" neg.{ld_ty} %{fr}15, %{fr}19;"))?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}12, %{fr}15, %{fr}11, %{fr}12;"),
)?;
wl(
&mut p,
&format!(" mul.rn.{ld_ty} %{fr}13, %{fr}16, %{fr}1;"),
)?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}13, %{fr}17, %{fr}0, %{fr}13;"),
)?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}13, %{fr}18, %{fr}11, %{fr}13;"),
)?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}13, %{fr}19, %{fr}10, %{fr}13;"),
)?;
wl(&mut p, &format!(" st.global.{ld_ty} [%rd8], %{fr}12;"))?;
wl(
&mut p,
&format!(" st.global.{ld_ty} [%rd8+{byte_size}], %{fr}13;"),
)?;
wl(&mut p, "")?;
wl(&mut p, "$CGEMV_DONE:")?;
wl(&mut p, " ret;")?;
wl(&mut p, "}")?;
Ok(p)
}
fn validate_complex_ld(
trans: Transpose,
rows: usize,
cols: usize,
ld: usize,
name: &str,
) -> BlasResult<()> {
let min_ld = match trans {
Transpose::NoTrans => rows,
Transpose::Trans | Transpose::ConjTrans => cols,
};
if ld < min_ld {
return Err(BlasError::InvalidDimension(format!(
"complex GEMM: ld{name} ({ld}) < required ({min_ld})"
)));
}
Ok(())
}
fn complex_gemm_kernel_name<T: GpuFloat>(transa: Transpose, transb: Transpose) -> String {
let ta = trans_label(transa);
let tb = trans_label(transb);
format!("complex_gemm_{}_{ta}_{tb}", T::NAME)
}
fn trans_label(t: Transpose) -> &'static str {
match t {
Transpose::NoTrans => "n",
Transpose::Trans => "t",
Transpose::ConjTrans => "c",
}
}
fn wl(ptx: &mut String, line: &str) -> BlasResult<()> {
writeln!(ptx, "{line}").map_err(|e| BlasError::PtxGeneration(format!("fmt error: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn complex_gemm_kernel_name_f32_nn() {
let name = complex_gemm_kernel_name::<f32>(Transpose::NoTrans, Transpose::NoTrans);
assert_eq!(name, "complex_gemm_f32_n_n");
}
#[test]
fn complex_gemm_kernel_name_f64_tc() {
let name = complex_gemm_kernel_name::<f64>(Transpose::Trans, Transpose::ConjTrans);
assert_eq!(name, "complex_gemm_f64_t_c");
}
#[test]
fn generate_complex_gemm_ptx_f32_nn() {
let ptx = generate_complex_gemm_ptx::<f32>(
SmVersion::Sm80,
Transpose::NoTrans,
Transpose::NoTrans,
);
let ptx = ptx.expect("PTX generation should succeed");
assert!(ptx.contains(".entry complex_gemm_f32_n_n"));
assert!(ptx.contains("fma.rn.f32"));
assert!(ptx.contains("$CGEMM_K_LOOP"));
assert!(ptx.contains("$CGEMM_K_DONE"));
assert!(ptx.contains("neg.f32"));
}
#[test]
fn generate_complex_gemm_ptx_f64_tt() {
let ptx =
generate_complex_gemm_ptx::<f64>(SmVersion::Sm80, Transpose::Trans, Transpose::Trans);
let ptx = ptx.expect("PTX generation should succeed");
assert!(ptx.contains(".entry complex_gemm_f64_t_t"));
assert!(ptx.contains("fma.rn.f64"));
assert!(ptx.contains(".target sm_80"));
}
#[test]
fn generate_complex_gemm_ptx_conj_trans() {
let ptx = generate_complex_gemm_ptx::<f32>(
SmVersion::Sm75,
Transpose::ConjTrans,
Transpose::NoTrans,
);
let ptx = ptx.expect("PTX generation should succeed");
assert!(ptx.contains("complex_gemm_f32_c_n"));
assert!(ptx.contains("neg.f32"));
}
#[test]
fn generate_complex_gemv_ptx_f32() {
let ptx = generate_complex_gemv_ptx::<f32>(SmVersion::Sm80, Transpose::NoTrans);
let ptx = ptx.expect("PTX generation should succeed");
assert!(ptx.contains(".entry complex_gemv_f32_n"));
assert!(ptx.contains("$CGEMV_LOOP"));
}
#[test]
fn generate_complex_gemv_ptx_f64_trans() {
let ptx = generate_complex_gemv_ptx::<f64>(SmVersion::Sm80, Transpose::Trans);
let ptx = ptx.expect("PTX generation should succeed");
assert!(ptx.contains(".entry complex_gemv_f64_t"));
assert!(ptx.contains("fma.rn.f64"));
}
#[test]
fn validate_complex_ld_ok() {
assert!(validate_complex_ld(Transpose::NoTrans, 64, 32, 64, "A").is_ok());
assert!(validate_complex_ld(Transpose::Trans, 64, 32, 32, "A").is_ok());
}
#[test]
fn validate_complex_ld_error() {
let err = validate_complex_ld(Transpose::NoTrans, 64, 32, 32, "A");
assert!(err.is_err());
}
#[test]
fn complex_gemm_zero_dim_error() {
let err =
BlasError::InvalidDimension("complex GEMM: all dimensions must be non-zero".into());
assert!(err.to_string().contains("non-zero"));
}
#[test]
fn trans_label_values() {
assert_eq!(trans_label(Transpose::NoTrans), "n");
assert_eq!(trans_label(Transpose::Trans), "t");
assert_eq!(trans_label(Transpose::ConjTrans), "c");
}
}