use crate::types::PtxType;
#[derive(Debug, Clone)]
pub enum PtxParam {
Scalar {
name: String,
ptx_type: PtxType,
},
Pointer {
name: String,
elem_type: PtxType,
},
}
impl PtxParam {
pub fn scalar(name: &str, ptx_type: PtxType) -> Self {
Self::Scalar {
name: name.to_string(),
ptx_type,
}
}
pub fn pointer(name: &str, elem_type: PtxType) -> Self {
Self::Pointer {
name: name.to_string(),
elem_type,
}
}
pub fn name(&self) -> &str {
match self {
Self::Scalar { name, .. } | Self::Pointer { name, .. } => name,
}
}
pub fn ptx_decl(&self) -> String {
match self {
Self::Scalar { name, ptx_type } => {
format!(".param {} {}", ptx_type.ptx_suffix(), name)
}
Self::Pointer { name, .. } => {
format!(".param .u64 {}", name)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scalar_param() {
let p = PtxParam::scalar("n", PtxType::U32);
assert_eq!(p.name(), "n");
assert!(matches!(
p,
PtxParam::Scalar {
ptx_type: PtxType::U32,
..
}
));
}
#[test]
fn pointer_param() {
let p = PtxParam::pointer("a_ptr", PtxType::F32);
assert_eq!(p.name(), "a_ptr");
assert!(matches!(
p,
PtxParam::Pointer {
elem_type: PtxType::F32,
..
}
));
}
#[test]
fn scalar_ptx_decl() {
let p = PtxParam::scalar("n", PtxType::U32);
assert_eq!(p.ptx_decl(), ".param .u32 n");
}
#[test]
fn pointer_ptx_decl() {
let p = PtxParam::pointer("a_ptr", PtxType::F32);
assert_eq!(p.ptx_decl(), ".param .u64 a_ptr");
}
}