use svod_dtype::{AddrSpace, DType, ScalarDType};
use svod_ir::ConstValue;
pub fn ldt(dtype: &DType) -> String {
match dtype {
DType::Vector { scalar, count } => {
format!("<{} x {}>", count, ldt_scalar(*scalar))
}
DType::Ptr { vcount, .. } if *vcount > 1 => {
format!("<{} x ptr>", vcount)
}
DType::Ptr { .. } | DType::Image { .. } => "ptr".to_string(),
DType::Scalar(s) => ldt_scalar(*s).to_string(),
}
}
fn ldt_scalar(s: ScalarDType) -> &'static str {
match s {
ScalarDType::Bool => "i1",
ScalarDType::Int8 | ScalarDType::UInt8 => "i8",
ScalarDType::Int16 | ScalarDType::UInt16 => "i16",
ScalarDType::Int32 | ScalarDType::UInt32 => "i32",
ScalarDType::Int64 | ScalarDType::UInt64 | ScalarDType::Index => "i64",
ScalarDType::Float16 => "half",
ScalarDType::BFloat16 => "bfloat",
ScalarDType::Float32 => "float",
ScalarDType::Float64 => "double",
ScalarDType::Void => "void",
ScalarDType::FP8E4M3 | ScalarDType::FP8E5M2 => "i8",
}
}
pub fn lconst(val: &ConstValue, dtype: &DType) -> String {
match val {
ConstValue::Int(i) => i.to_string(),
ConstValue::UInt(u) => (*u as i64).to_string(),
ConstValue::Float(f) => format_float(*f, dtype),
ConstValue::Bool(b) => if *b { "1" } else { "0" }.to_string(),
}
}
fn format_float(f: f64, dtype: &DType) -> String {
let scalar = dtype.base();
if f.is_nan() {
return match scalar {
ScalarDType::Float64 | ScalarDType::Float32 => "0x7FF8000000000000".to_string(),
ScalarDType::Float16 => "0xH7E00".to_string(),
ScalarDType::BFloat16 => "0xR7FC0".to_string(),
_ => "nan".to_string(),
};
}
if f.is_infinite() {
return match scalar {
ScalarDType::Float64 | ScalarDType::Float32 => {
format!("0x{:016X}", f.to_bits())
}
ScalarDType::Float16 => {
if f.is_sign_positive() { "0xH7C00".to_string() } else { "0xHFC00".to_string() }
}
ScalarDType::BFloat16 => {
if f.is_sign_positive() { "0xR7F80".to_string() } else { "0xRFF80".to_string() }
}
_ => {
if f.is_sign_positive() {
"inf".to_string()
} else {
"-inf".to_string()
}
}
};
}
match scalar {
ScalarDType::Float64 => {
format!("0x{:016X}", f.to_bits())
}
ScalarDType::Float32 => {
let f32_val = f as f32;
let f64_val = f32_val as f64;
format!("0x{:016X}", f64_val.to_bits())
}
ScalarDType::Float16 => {
let f32_val = f as f32;
let half_bits = f32_to_f16_bits(f32_val);
format!("0xH{:04X}", half_bits)
}
ScalarDType::BFloat16 => {
let f32_val = f as f32;
let bf16_bits = (f32_val.to_bits() >> 16) as u16;
format!("0xR{:04X}", bf16_bits)
}
_ => format!("{:e}", f),
}
}
fn f32_to_f16_bits(f: f32) -> u16 {
let bits = f.to_bits();
let sign = ((bits >> 16) & 0x8000) as u16;
let exp = ((bits >> 23) & 0xFF) as i32;
let mant = bits & 0x007FFFFF;
if exp == 255 {
if mant == 0 { sign | 0x7C00 } else { sign | 0x7E00 }
} else if exp > 142 {
sign | 0x7C00
} else if exp < 113 {
if exp < 103 {
sign
} else {
let mant = mant | 0x00800000;
let shift = 126 - exp;
sign | ((mant >> shift) as u16)
}
} else {
let new_exp = ((exp - 127 + 15) as u16) << 10;
let new_mant = (mant >> 13) as u16;
sign | new_exp | new_mant
}
}
pub fn lcast(from: &DType, to: &DType) -> &'static str {
let from_scalar = from.base();
let to_scalar = to.base();
debug_assert!(
!(from_scalar.is_fp8() || to_scalar.is_fp8()),
"lcast does not support FP8 (mapped to i8); decompose via devectorize fp8 patterns first"
);
if matches!(from, DType::Ptr { .. }) || matches!(to, DType::Ptr { .. }) {
return if matches!(from, DType::Ptr { .. }) && matches!(to, DType::Ptr { .. }) {
"bitcast"
} else if matches!(from, DType::Ptr { .. }) {
"ptrtoint"
} else {
"inttoptr"
};
}
if from_scalar.is_float() && to_scalar.is_float() {
return if to_scalar.bytes() > from_scalar.bytes() { "fpext" } else { "fptrunc" };
}
if (from_scalar.is_unsigned() || from_scalar.is_bool()) && to_scalar.is_float() {
return "uitofp";
}
if (from_scalar.is_signed() || from_scalar == ScalarDType::Index) && to_scalar.is_float() {
return "sitofp";
}
if from_scalar.is_float() && to_scalar.is_unsigned() {
return "fptoui";
}
if from_scalar.is_float() && (to_scalar.is_signed() || to_scalar == ScalarDType::Index) {
return "fptosi";
}
let from_bytes = from_scalar.bytes();
let to_bytes = to_scalar.bytes();
if from_scalar.is_bool() && !to_scalar.is_bool() {
return "zext";
}
if !from_scalar.is_bool() && to_scalar.is_bool() {
return "trunc";
}
if from_bytes == to_bytes {
return "bitcast";
}
if to_bytes < from_bytes {
return "trunc";
}
if from_scalar.is_unsigned() || from_scalar.is_bool() {
return "zext";
}
if from_scalar.is_signed() || from_scalar == ScalarDType::Index {
return "sext";
}
"bitcast"
}
pub fn addr_space_num(addrspace: AddrSpace) -> u32 {
match addrspace {
AddrSpace::Global => 0,
AddrSpace::Local => 3,
AddrSpace::Reg => 5,
}
}
#[cfg(test)]
#[path = "../../test/unit/llvm_common_types.rs"]
mod tests;