use std::ffi::c_void;
use std::ptr;
use cudarc::cublaslt::sys::{cublasLtMatmulDescAttributes_t, cublasLtMatmulDesc_t};
use crate::sys::cublaslt::set_desc_pointer_attr;
#[derive(Debug, Clone, Copy, Default)]
pub struct ScaleSet {
pub a: Option<*const f32>,
pub b: Option<*const f32>,
pub c: Option<*const f32>,
pub d: Option<*const f32>,
}
unsafe impl Send for ScaleSet {}
unsafe impl Sync for ScaleSet {}
impl ScaleSet {
pub const fn empty() -> Self {
Self {
a: None,
b: None,
c: None,
d: None,
}
}
pub fn with_a(mut self, ptr: *const f32) -> Self {
self.a = Some(ptr);
self
}
pub fn with_b(mut self, ptr: *const f32) -> Self {
self.b = Some(ptr);
self
}
pub fn with_c(mut self, ptr: *const f32) -> Self {
self.c = Some(ptr);
self
}
pub fn with_d(mut self, ptr: *const f32) -> Self {
self.d = Some(ptr);
self
}
pub fn is_empty(&self) -> bool {
self.a.is_none() && self.b.is_none() && self.c.is_none() && self.d.is_none()
}
pub unsafe fn apply(&self, desc: cublasLtMatmulDesc_t) -> Result<(), String> {
if let Some(p) = self.a {
unsafe {
set_desc_pointer_attr(
desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
p as *const c_void,
)?
};
}
if let Some(p) = self.b {
unsafe {
set_desc_pointer_attr(
desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
p as *const c_void,
)?
};
}
if let Some(p) = self.c {
unsafe {
set_desc_pointer_attr(
desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER,
p as *const c_void,
)?
};
}
if let Some(p) = self.d {
unsafe {
set_desc_pointer_attr(
desc,
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_D_SCALE_POINTER,
p as *const c_void,
)?
};
}
Ok(())
}
}
pub fn null_scale_ptr() -> *const f32 {
ptr::null()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scale_set_empty_default() {
let s = ScaleSet::default();
assert!(s.is_empty());
assert!(s.a.is_none());
}
#[test]
fn scale_set_builders() {
let a: f32 = 1.5;
let s = ScaleSet::empty()
.with_a(&a as *const f32)
.with_d(&a as *const f32);
assert!(!s.is_empty());
assert!(s.a.is_some());
assert!(s.b.is_none());
assert!(s.c.is_none());
assert!(s.d.is_some());
}
#[test]
fn scale_pointer_attribute_setting() {
use cudarc::cublaslt::sys::cublasLtMatmulDescAttributes_t as Attr;
assert_eq!(Attr::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER as u32, 17);
assert_eq!(Attr::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER as u32, 18);
assert_eq!(Attr::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER as u32, 19);
assert_eq!(Attr::CUBLASLT_MATMUL_DESC_D_SCALE_POINTER as u32, 20);
let a_scale: f32 = 1.0;
let b_scale: f32 = 2.0;
let c_scale: f32 = 3.0;
let d_scale: f32 = 4.0;
let s = ScaleSet::empty()
.with_a(&a_scale as *const f32)
.with_b(&b_scale as *const f32)
.with_c(&c_scale as *const f32)
.with_d(&d_scale as *const f32);
assert_eq!(s.a, Some(&a_scale as *const f32));
assert_eq!(s.b, Some(&b_scale as *const f32));
assert_eq!(s.c, Some(&c_scale as *const f32));
assert_eq!(s.d, Some(&d_scale as *const f32));
let empty = ScaleSet::empty();
assert!(empty.is_empty());
}
#[test]
fn null_scale_ptr_is_null() {
let p = null_scale_ptr();
assert!(p.is_null());
}
}