use crate::shared::Dialect;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
pub struct Cuda;
impl Dialect for Cuda {
fn include_f16(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("#include <cuda_fp16.h>\n")
}
fn include_bf16(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("#include <cuda_bfp16.h>\n")
}
fn include_wmma(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("#include <mma.h>\n")
}
fn include_runtime(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("#include <cuda_runtime.h>\n")
}
fn bfloat16_type_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("__nv_bfloat16")
}
fn bfloat162_type_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("__nv_bfloat162")
}
}