use proc_macro2::Span;
#[derive(Debug, Clone, PartialEq, Eq)]
#[allow(dead_code)] pub enum KernelType {
F32,
F64,
I32,
U32,
I64,
U64,
Bool,
SliceRef(Box<KernelType>),
SliceMutRef(Box<KernelType>),
}
#[allow(dead_code)] impl KernelType {
pub fn elem_type(&self) -> Option<&KernelType> {
match self {
KernelType::SliceRef(inner) | KernelType::SliceMutRef(inner) => Some(inner),
_ => None,
}
}
pub fn is_slice(&self) -> bool {
matches!(self, KernelType::SliceRef(_) | KernelType::SliceMutRef(_))
}
pub fn is_scalar(&self) -> bool {
!self.is_slice()
}
pub fn is_integer(&self) -> bool {
matches!(
self,
KernelType::I32 | KernelType::U32 | KernelType::I64 | KernelType::U64
)
}
pub fn is_mut_slice(&self) -> bool {
matches!(self, KernelType::SliceMutRef(_))
}
pub fn size_bytes(&self) -> usize {
match self {
KernelType::F32 | KernelType::I32 | KernelType::U32 => 4,
KernelType::F64 | KernelType::I64 | KernelType::U64 => 8,
KernelType::Bool => 1,
KernelType::SliceRef(_) | KernelType::SliceMutRef(_) => {
panic!("size_bytes() called on slice type")
}
}
}
pub fn ptx_type_token(&self) -> &'static str {
match self {
KernelType::F32 => "F32",
KernelType::F64 => "F64",
KernelType::I32 => "S32",
KernelType::U32 => "U32",
KernelType::I64 => "S64",
KernelType::U64 => "U64",
KernelType::Bool => "Pred",
KernelType::SliceRef(_) | KernelType::SliceMutRef(_) => {
panic!("ptx_type_token() called on slice type")
}
}
}
pub fn display_name(&self) -> String {
match self {
KernelType::F32 => "f32".to_string(),
KernelType::F64 => "f64".to_string(),
KernelType::I32 => "i32".to_string(),
KernelType::U32 => "u32".to_string(),
KernelType::I64 => "i64".to_string(),
KernelType::U64 => "u64".to_string(),
KernelType::Bool => "bool".to_string(),
KernelType::SliceRef(inner) => format!("&[{}]", inner.display_name()),
KernelType::SliceMutRef(inner) => format!("&mut [{}]", inner.display_name()),
}
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)] pub struct KernelParam {
pub name: String,
pub ty: KernelType,
pub span: Span,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct KernelConfig {
pub block_size: u32,
pub block_size_y: Option<u32>,
pub block_size_span: Span,
}
#[derive(Debug, Clone)]
#[allow(dead_code)] pub struct KernelSignature {
pub name: String,
pub params: Vec<KernelParam>,
pub config: KernelConfig,
pub name_span: Span,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn kernel_type_size_bytes() {
assert_eq!(KernelType::F32.size_bytes(), 4);
assert_eq!(KernelType::F64.size_bytes(), 8);
assert_eq!(KernelType::U32.size_bytes(), 4);
assert_eq!(KernelType::I64.size_bytes(), 8);
assert_eq!(KernelType::Bool.size_bytes(), 1);
}
#[test]
fn kernel_type_ptx_token() {
assert_eq!(KernelType::F32.ptx_type_token(), "F32");
assert_eq!(KernelType::I32.ptx_type_token(), "S32");
assert_eq!(KernelType::U64.ptx_type_token(), "U64");
}
#[test]
fn slice_type_properties() {
let slice = KernelType::SliceRef(Box::new(KernelType::F32));
assert!(slice.is_slice());
assert!(!slice.is_scalar());
assert!(!slice.is_mut_slice());
assert_eq!(slice.elem_type(), Some(&KernelType::F32));
let mut_slice = KernelType::SliceMutRef(Box::new(KernelType::F64));
assert!(mut_slice.is_slice());
assert!(mut_slice.is_mut_slice());
}
#[test]
fn display_names() {
assert_eq!(KernelType::F32.display_name(), "f32");
assert_eq!(
KernelType::SliceRef(Box::new(KernelType::F32)).display_name(),
"&[f32]"
);
assert_eq!(
KernelType::SliceMutRef(Box::new(KernelType::U32)).display_name(),
"&mut [u32]"
);
}
}