use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
use super::Builder;
use crate::abi::ConvSpirvType;
use crate::builder::format_args_decompiler::{CodegenPanic, DecodedFormatArgs};
use crate::builder_spirv::{
SpirvBlockCursor, SpirvConst, SpirvValue, SpirvValueExt, SpirvValueKind,
};
use crate::codegen_cx::CodegenCx;
use crate::custom_insts::CustomInst;
use crate::spirv_type::SpirvType;
use itertools::Itertools;
use rspirv::dr::{InsertPoint, Instruction, Operand};
use rspirv::spirv::{Capability, MemoryModel, MemorySemantics, Op, Scope, StorageClass, Word};
use rustc_abi::{Align, BackendRepr, Scalar, Size, WrappingRange};
use rustc_apfloat::{Float, Round, Status, ieee};
use rustc_codegen_ssa::MemFlags;
use rustc_codegen_ssa::common::{
AtomicRmwBinOp, IntPredicate, RealPredicate, SynchronizationScope, TypeKind,
};
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::mir::place::PlaceRef;
use rustc_codegen_ssa::traits::{
BackendTypes, BaseTypeCodegenMethods, BuilderMethods, ConstCodegenMethods,
LayoutTypeCodegenMethods, OverflowOp,
};
use rustc_middle::bug;
use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs;
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::ty::{self, AtomicOrdering, Ty};
use rustc_span::Span;
use rustc_target::callconv::FnAbi;
use smallvec::SmallVec;
use std::iter::{self, empty};
use std::ops::{BitAnd, BitOr, BitXor, Not, RangeInclusive};
use tracing::{Level, instrument, span};
use tracing::{trace, warn};
enum ConstValue {
Unsigned(u128),
Signed(i128),
Bool(bool),
}
impl<'a, 'tcx> Builder<'a, 'tcx> {
fn try_get_const_value(&self, val: SpirvValue) -> Option<ConstValue> {
if let Some(const_val) = self.builder.lookup_const(val) {
let x = match const_val {
SpirvConst::Scalar(x) => x,
_ => return None,
};
match self.lookup_type(val.ty) {
SpirvType::Integer(bits, signed) => {
let size = Size::from_bits(bits);
if x == size.truncate(x) {
Some(if signed {
ConstValue::Signed(size.sign_extend(x))
} else {
ConstValue::Unsigned(size.truncate(x))
})
} else {
None
}
}
SpirvType::Bool => {
match x {
0 => Some(ConstValue::Bool(false)),
1 => Some(ConstValue::Bool(true)),
_ => None,
}
}
_ => None,
}
} else {
None
}
}
}
macro_rules! simple_op {
(
$func_name:ident
$(, int: $inst_int:ident)?
$(, uint: $inst_uint:ident)?
$(, sint: $inst_sint:ident)?
$(, float: $inst_float:ident)?
$(, bool: $inst_bool:ident)?
$(, fold_const {
$(int($int_lhs:ident, $int_rhs:ident) => $fold_int:expr;)?
$(uint($uint_lhs:ident, $uint_rhs:ident) => $fold_uint:expr;)?
$(sint($sint_lhs:ident, $sint_rhs:ident) => $fold_sint:expr;)?
$(bool($bool_lhs:ident, $bool_rhs:ident) => $fold_bool:expr;)?
})?
) => {
fn $func_name(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
assert_ty_eq!(self, lhs.ty, rhs.ty);
let result_type = lhs.ty;
$(
#[allow(unreachable_patterns, clippy::collapsible_match)]
if let Some(const_lhs) = self.try_get_const_value(lhs)
&& let Some(const_rhs) = self.try_get_const_value(rhs)
{
let result = (|| Some(match (const_lhs, const_rhs) {
$(
(ConstValue::Unsigned($int_lhs), ConstValue::Unsigned($int_rhs)) => $fold_int,
(ConstValue::Signed($int_lhs), ConstValue::Signed($int_rhs)) => $fold_int as u128,
)?
$((ConstValue::Unsigned($uint_lhs), ConstValue::Unsigned($uint_rhs)) => $fold_uint,)?
$((ConstValue::Signed($sint_lhs), ConstValue::Signed($sint_rhs)) => $fold_sint as u128, )?
$((ConstValue::Bool($bool_lhs), ConstValue::Bool($bool_rhs)) => ($fold_bool).into(), )?
_ => return None,
}))();
if let Some(result) = result {
return self.const_uint_big(result_type, result);
}
}
)?
match self.lookup_type(result_type) {
$(SpirvType::Integer(_, _) => {
self.emit()
.$inst_int(result_type, None, lhs.def(self), rhs.def(self))
})?
$(SpirvType::Integer(_, false) => {
self.emit()
.$inst_uint(result_type, None, lhs.def(self), rhs.def(self))
})?
$(SpirvType::Integer(_, true) => {
self.emit()
.$inst_sint(result_type, None, lhs.def(self), rhs.def(self))
})?
$(SpirvType::Float(_) => {
self.emit()
.$inst_float(result_type, None, lhs.def(self), rhs.def(self))
})?
$(SpirvType::Bool => {
self.emit()
.$inst_bool(result_type, None, lhs.def(self), rhs.def(self))
})?
o => self.fatal(format!(
concat!(stringify!($func_name), "() not implemented for type {}"),
o.debug(result_type, self)
)),
}
.unwrap()
.with_type(result_type)
}
};
}
macro_rules! simple_shift_op {
(
$func_name:ident
$(, int: $inst_int:ident)?
$(, uint: $inst_uint:ident)?
$(, sint: $inst_sint:ident)?
$(, fold_const {
$(int($int_lhs:ident, $int_rhs:ident) => $fold_int:expr;)?
$(uint($uint_lhs:ident, $uint_rhs:ident) => $fold_uint:expr;)?
$(sint($sint_lhs:ident, $sint_rhs:ident) => $fold_sint:expr;)?
})?
) => {
fn $func_name(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
let result_type = lhs.ty;
$(
#[allow(unreachable_patterns, clippy::collapsible_match)]
if let Some(const_lhs) = self.try_get_const_value(lhs)
&& let Some(const_rhs) = self.try_get_const_value(rhs)
{
let result = (|| Some(match (const_lhs, const_rhs) {
$(
(ConstValue::Unsigned($int_lhs), ConstValue::Unsigned($int_rhs)) => $fold_int,
(ConstValue::Unsigned($int_lhs), ConstValue::Signed($int_rhs)) => $fold_int,
(ConstValue::Signed($int_lhs), ConstValue::Unsigned($int_rhs)) => $fold_int as u128,
(ConstValue::Signed($int_lhs), ConstValue::Signed($int_rhs)) => $fold_int as u128,
)?
$(
(ConstValue::Unsigned($uint_lhs), ConstValue::Unsigned($uint_rhs)) => $fold_uint,
(ConstValue::Unsigned($uint_lhs), ConstValue::Signed($uint_rhs)) => $fold_uint,
)?
$(
(ConstValue::Signed($sint_lhs), ConstValue::Unsigned($sint_rhs)) => $fold_sint as u128,
(ConstValue::Signed($sint_lhs), ConstValue::Signed($sint_rhs)) => $fold_sint as u128,
)?
_ => return None,
}))();
if let Some(result) = result {
return self.const_uint_big(result_type, result);
}
}
)?
match self.lookup_type(result_type) {
$(SpirvType::Integer(_, _) => {
self.emit()
.$inst_int(result_type, None, lhs.def(self), rhs.def(self))
})?
$(SpirvType::Integer(_, false) => {
self.emit()
.$inst_uint(result_type, None, lhs.def(self), rhs.def(self))
})?
$(SpirvType::Integer(_, true) => {
self.emit()
.$inst_sint(result_type, None, lhs.def(self), rhs.def(self))
})?
o => self.fatal(format!(
concat!(stringify!($func_name), "() not implemented for type {}"),
o.debug(result_type, self)
)),
}
.unwrap()
.with_type(result_type)
}
};
}
macro_rules! simple_uni_op {
(
$func_name:ident
$(, int: $inst_int:ident)?
$(, uint: $inst_uint:ident)?
$(, sint: $inst_sint:ident)?
$(, float: $inst_float:ident)?
$(, bool: $inst_bool:ident)?
$(, fold_const {
$(int($int_val:ident) => $fold_int:expr;)?
$(uint($uint_val:ident) => $fold_uint:expr;)?
$(sint($sint_val:ident) => $fold_sint:expr;)?
$(bool($bool_val:ident) => $fold_bool:expr;)?
})?
) => {
fn $func_name(&mut self, val: Self::Value) -> Self::Value {
let result_type = val.ty;
$(
#[allow(unreachable_patterns, clippy::collapsible_match)]
if let Some(const_val) = self.try_get_const_value(val) {
let result = (|| Some(match (const_val) {
$(
ConstValue::Unsigned($int_val) => $fold_int,
ConstValue::Signed($int_val) => $fold_int as u128,
)?
$(ConstValue::Unsigned($uint_val) => $fold_uint, )?
$(ConstValue::Signed($sint_val) => $fold_sint as u128, )?
$(ConstValue::Bool($bool_val) => ($fold_bool).into(), )?
_ => return None,
}))();
if let Some(result) = result {
return self.const_uint_big(result_type, result);
}
}
)?
match self.lookup_type(result_type) {
$(SpirvType::Integer(_, _) => {
self.emit()
.$inst_int(result_type, None, val.def(self))
})?
$(SpirvType::Integer(_, false) => {
self.emit()
.$inst_uint(result_type, None, val.def(self))
})?
$(SpirvType::Integer(_, true) => {
self.emit()
.$inst_sint(result_type, None, val.def(self))
})?
$(SpirvType::Float(_) => {
self.emit()
.$inst_float(result_type, None, val.def(self))
})?
$(SpirvType::Bool => {
self.emit()
.$inst_bool(result_type, None, val.def(self))
})?
o => self.fatal(format!(
concat!(stringify!($func_name), "() not implemented for type {}"),
o.debug(result_type, self)
)),
}
.unwrap()
.with_type(result_type)
}
};
}
fn memset_fill_u16(b: u8) -> u16 {
b as u16 | ((b as u16) << 8)
}
fn memset_fill_u32(b: u8) -> u32 {
b as u32 | ((b as u32) << 8) | ((b as u32) << 16) | ((b as u32) << 24)
}
fn memset_fill_u64(b: u8) -> u64 {
b as u64
| ((b as u64) << 8)
| ((b as u64) << 16)
| ((b as u64) << 24)
| ((b as u64) << 32)
| ((b as u64) << 40)
| ((b as u64) << 48)
| ((b as u64) << 56)
}
fn memset_dynamic_scalar(
builder: &mut Builder<'_, '_>,
fill_var: Word,
byte_width: usize,
is_float: bool,
) -> Word {
let composite_type = SpirvType::simd_vector(
builder,
builder.span(),
SpirvType::Integer(8, false),
byte_width as u32,
)
.def(builder.span(), builder);
let composite = builder
.emit()
.composite_construct(composite_type, None, iter::repeat_n(fill_var, byte_width))
.unwrap();
let result_type = if is_float {
SpirvType::Float(byte_width as u32 * 8)
} else {
SpirvType::Integer(byte_width as u32 * 8, false)
};
builder
.emit()
.bitcast(result_type.def(builder.span(), builder), None, composite)
.unwrap()
}
impl<'a, 'tcx> Builder<'a, 'tcx> {
#[instrument(level = "trace", skip(self))]
fn ordering_to_semantics_def(&mut self, ordering: AtomicOrdering) -> SpirvValue {
let mut invalid_seq_cst = false;
let semantics = match ordering {
AtomicOrdering::Relaxed => MemorySemantics::NONE,
AtomicOrdering::Acquire => MemorySemantics::MAKE_VISIBLE | MemorySemantics::ACQUIRE,
AtomicOrdering::Release => MemorySemantics::MAKE_AVAILABLE | MemorySemantics::RELEASE,
AtomicOrdering::AcqRel => {
MemorySemantics::MAKE_AVAILABLE
| MemorySemantics::MAKE_VISIBLE
| MemorySemantics::ACQUIRE_RELEASE
}
AtomicOrdering::SeqCst => {
let builder = self.emit();
let memory_model = builder.module_ref().memory_model.as_ref().unwrap();
if memory_model.operands[1].unwrap_memory_model() == MemoryModel::Vulkan {
invalid_seq_cst = true;
}
MemorySemantics::MAKE_AVAILABLE
| MemorySemantics::MAKE_VISIBLE
| MemorySemantics::SEQUENTIALLY_CONSISTENT
}
};
let semantics = self.constant_u32(self.span(), semantics.bits());
if invalid_seq_cst {
self.zombie(
semantics.def(self),
"cannot use `AtomicOrdering::SeqCst` on Vulkan memory model \
(check if `AcqRel` fits your needs)",
);
}
semantics
}
#[instrument(level = "trace", skip(self))]
fn memset_const_pattern(&self, ty: &SpirvType<'tcx>, fill_byte: u8) -> Word {
match *ty {
SpirvType::Void => self.fatal("memset invalid on void pattern"),
SpirvType::Bool => self.fatal("memset invalid on bool pattern"),
SpirvType::Integer(width, false) => match width {
8 => self.constant_u8(self.span(), fill_byte).def(self),
16 => self
.constant_u16(self.span(), memset_fill_u16(fill_byte))
.def(self),
32 => self
.constant_u32(self.span(), memset_fill_u32(fill_byte))
.def(self),
64 => self
.constant_u64(self.span(), memset_fill_u64(fill_byte))
.def(self),
_ => self.fatal(format!(
"memset on integer width {width} not implemented yet"
)),
},
SpirvType::Integer(width, true) => match width {
8 => self.constant_i8(self.span(), fill_byte as i8).def(self),
16 => self
.constant_i16(self.span(), memset_fill_u16(fill_byte) as i16)
.def(self),
32 => self
.constant_i32(self.span(), memset_fill_u32(fill_byte) as i32)
.def(self),
64 => self
.constant_i64(self.span(), memset_fill_u64(fill_byte) as i64)
.def(self),
_ => self.fatal(format!(
"memset on integer width {width} not implemented yet"
)),
},
SpirvType::Float(width) => match width {
32 => self
.constant_f32(self.span(), f32::from_bits(memset_fill_u32(fill_byte)))
.def(self),
64 => self
.constant_f64(self.span(), f64::from_bits(memset_fill_u64(fill_byte)))
.def(self),
_ => self.fatal(format!("memset on float width {width} not implemented yet")),
},
SpirvType::Adt { .. } => self.fatal("memset on structs not implemented yet"),
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => {
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
self.constant_composite(
ty.def(self.span(), self),
iter::repeat_n(elem_pat, count as usize),
)
.def(self)
}
SpirvType::Array { element, count } => {
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
let count = self.builder.lookup_const_scalar(count).unwrap() as usize;
self.constant_composite(ty.def(self.span(), self), iter::repeat_n(elem_pat, count))
.def(self)
}
SpirvType::RuntimeArray { .. } => {
self.fatal("memset on runtime arrays not implemented yet")
}
SpirvType::Pointer { .. } => self.fatal("memset on pointers not implemented yet"),
SpirvType::Function { .. } => self.fatal("memset on functions not implemented yet"),
SpirvType::Image { .. } => self.fatal("cannot memset image"),
SpirvType::Sampler => self.fatal("cannot memset sampler"),
SpirvType::SampledImage { .. } => self.fatal("cannot memset sampled image"),
SpirvType::InterfaceBlock { .. } => self.fatal("cannot memset interface block"),
SpirvType::AccelerationStructureKhr => {
self.fatal("cannot memset acceleration structure")
}
SpirvType::RayQueryKhr => self.fatal("cannot memset ray query"),
}
}
#[instrument(level = "trace", skip(self))]
fn memset_dynamic_pattern(&mut self, ty: &SpirvType<'tcx>, fill_var: Word) -> Word {
match *ty {
SpirvType::Void => self.fatal("memset invalid on void pattern"),
SpirvType::Bool => self.fatal("memset invalid on bool pattern"),
SpirvType::Integer(width, _signedness) => match width {
8 => fill_var,
16 => memset_dynamic_scalar(self, fill_var, 2, false),
32 => memset_dynamic_scalar(self, fill_var, 4, false),
64 => memset_dynamic_scalar(self, fill_var, 8, false),
_ => self.fatal(format!(
"memset on integer width {width} not implemented yet"
)),
},
SpirvType::Float(width) => match width {
32 => memset_dynamic_scalar(self, fill_var, 4, true),
64 => memset_dynamic_scalar(self, fill_var, 8, true),
_ => self.fatal(format!("memset on float width {width} not implemented yet")),
},
SpirvType::Adt { .. } => self.fatal("memset on structs not implemented yet"),
SpirvType::Array { element, count } => {
let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var);
let count = self.builder.lookup_const_scalar(count).unwrap() as usize;
self.emit()
.composite_construct(
ty.def(self.span(), self),
None,
iter::repeat_n(elem_pat, count),
)
.unwrap()
}
SpirvType::Vector { element, count, .. } | SpirvType::Matrix { element, count } => {
let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var);
self.emit()
.composite_construct(
ty.def(self.span(), self),
None,
iter::repeat_n(elem_pat, count as usize),
)
.unwrap()
}
SpirvType::RuntimeArray { .. } => {
self.fatal("memset on runtime arrays not implemented yet")
}
SpirvType::Pointer { .. } => self.fatal("memset on pointers not implemented yet"),
SpirvType::Function { .. } => self.fatal("memset on functions not implemented yet"),
SpirvType::Image { .. } => self.fatal("cannot memset image"),
SpirvType::Sampler => self.fatal("cannot memset sampler"),
SpirvType::SampledImage { .. } => self.fatal("cannot memset sampled image"),
SpirvType::InterfaceBlock { .. } => self.fatal("cannot memset interface block"),
SpirvType::AccelerationStructureKhr => {
self.fatal("cannot memset acceleration structure")
}
SpirvType::RayQueryKhr => self.fatal("cannot memset ray query"),
}
}
#[instrument(level = "trace", skip(self))]
fn memset_constant_size(&mut self, ptr: SpirvValue, pat: SpirvValue, size_bytes: u64) {
let size_elem = self
.lookup_type(pat.ty)
.sizeof(self)
.expect("Memset on unsized values not supported");
let count = size_bytes / size_elem.bytes();
if count == 1 {
self.store(pat, ptr, Align::from_bytes(0).unwrap());
} else {
for index in 0..count {
let const_index = self.constant_u32(self.span(), index as u32);
let gep_ptr = self.inbounds_gep(pat.ty, ptr, &[const_index]);
self.store(pat, gep_ptr, Align::from_bytes(0).unwrap());
}
}
}
#[instrument(level = "trace", skip(self))]
fn memset_dynamic_size(&mut self, ptr: SpirvValue, pat: SpirvValue, size_bytes: SpirvValue) {
let size_elem = self
.lookup_type(pat.ty)
.sizeof(self)
.expect("Unable to memset a dynamic sized object");
let size_elem_const = self.constant_int(size_bytes.ty, size_elem.bytes().into());
let zero = self.constant_int(size_bytes.ty, 0);
let one = self.constant_int(size_bytes.ty, 1);
let zero_align = Align::from_bytes(0).unwrap();
let header_bb = self.append_sibling_block("memset_header");
let body_bb = self.append_sibling_block("memset_body");
let exit_bb = self.append_sibling_block("memset_exit");
let count = self.udiv(size_bytes, size_elem_const);
let index = self.alloca(self.lookup_type(count.ty).sizeof(self).unwrap(), zero_align);
self.store(zero, index, zero_align);
self.br(header_bb);
self.switch_to_block(header_bb);
let current_index = self.load(count.ty, index, zero_align);
let cond = self.icmp(IntPredicate::IntULT, current_index, count);
self.cond_br(cond, body_bb, exit_bb);
self.switch_to_block(body_bb);
let gep_ptr = self.gep(pat.ty, ptr, &[current_index]);
self.store(pat, gep_ptr, zero_align);
let current_index_plus_1 = self.add(current_index, one);
self.store(current_index_plus_1, index, zero_align);
self.br(header_bb);
self.switch_to_block(exit_bb);
}
#[instrument(level = "trace", skip(self))]
fn zombie_convert_ptr_to_u(&self, def: Word) {
self.zombie(def, "cannot convert pointers to integers");
}
#[instrument(level = "trace", skip(self))]
fn zombie_convert_u_to_ptr(&self, def: Word) {
self.zombie(def, "cannot convert integers to pointers");
}
#[instrument(level = "trace", skip(self))]
fn zombie_ptr_equal(&self, def: Word, inst: &str) {
if !self.builder.has_capability(Capability::VariablePointers) {
self.zombie(
def,
&format!("{inst} without OpCapability VariablePointers"),
);
}
}
#[instrument(level = "trace", skip(self), fields(ptr, ty = ?self.debug_type(ty)))]
fn adjust_pointer_for_typed_access(
&mut self,
ptr: SpirvValue,
ty: <Self as BackendTypes>::Type,
) -> (SpirvValue, <Self as BackendTypes>::Type) {
self.lookup_type(ty)
.sizeof(self)
.and_then(|size| self.adjust_pointer_for_sized_access(ptr, size))
.unwrap_or_else(|| (self.pointercast(ptr, self.type_ptr_to(ty)), ty))
}
#[instrument(level = "trace", skip(self))]
fn adjust_pointer_for_sized_access(
&mut self,
ptr: SpirvValue,
size: Size,
) -> Option<(SpirvValue, <Self as BackendTypes>::Type)> {
let ptr = ptr.strip_ptrcasts();
let mut leaf_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
other => self.fatal(format!("`ptr` is non-pointer type: {other:?}")),
};
trace!(
"before nested adjust_pointer_for_sized_access. `leaf_ty`: {}",
self.debug_type(leaf_ty)
);
let mut indices = SmallVec::<[_; 8]>::new();
while let Some((inner_indices, inner_ty)) = self.recover_access_chain_from_offset(
leaf_ty,
Size::ZERO,
Some(size)..=Some(size),
None,
) {
indices.extend(inner_indices);
leaf_ty = inner_ty;
}
trace!(
"after nested adjust_pointer_for_sized_access. `leaf_ty`: {}",
self.debug_type(leaf_ty)
);
let leaf_ptr_ty = (self.lookup_type(leaf_ty).sizeof(self) == Some(size))
.then(|| self.type_ptr_to(leaf_ty))?;
let leaf_ptr = if indices.is_empty() {
assert_ty_eq!(self, ptr.ty, leaf_ptr_ty);
ptr
} else {
let indices = indices
.into_iter()
.map(|idx| self.constant_u32(self.span(), idx).def(self))
.collect::<Vec<_>>();
self.emit()
.in_bounds_access_chain(leaf_ptr_ty, None, ptr.def(self), indices)
.unwrap()
.with_type(leaf_ptr_ty)
};
trace!(
"adjust_pointer_for_sized_access returning {} {}",
self.debug_type(leaf_ptr.ty),
self.debug_type(leaf_ty)
);
Some((leaf_ptr, leaf_ty))
}
#[instrument(level = "trace", skip(self), fields(ty = ?self.debug_type(ty), leaf_size_or_unsized_range, leaf_ty = ?leaf_ty))]
fn recover_access_chain_from_offset(
&self,
mut ty: <Self as BackendTypes>::Type,
mut offset: Size,
leaf_size_or_unsized_range: RangeInclusive<Option<Size>>,
leaf_ty: Option<<Self as BackendTypes>::Type>,
) -> Option<(SmallVec<[u32; 8]>, <Self as BackendTypes>::Type)> {
assert_ne!(Some(ty), leaf_ty);
if let Some(leaf_ty) = leaf_ty {
trace!(
"recovering access chain: leaf_ty: {:?}",
self.debug_type(leaf_ty)
);
} else {
trace!("recovering access chain: leaf_ty: None");
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum MaybeSized {
Sized(Size),
Unsized,
}
let leaf_size_range = {
let r = leaf_size_or_unsized_range;
let [start, end] =
[r.start(), r.end()].map(|x| x.map_or(MaybeSized::Unsized, MaybeSized::Sized));
start..=end
};
trace!("leaf_size_range: {:?}", leaf_size_range);
let mut ty_kind = self.lookup_type(ty);
let mut indices = SmallVec::new();
loop {
let ty_size;
match ty_kind {
SpirvType::Adt {
field_types,
field_offsets,
..
} => {
trace!("recovering access chain from ADT");
let (i, field_ty, field_ty_kind, field_ty_size, offset_in_field) = field_offsets
.iter()
.enumerate()
.find_map(|(i, &field_offset)| {
if field_offset > offset {
return None;
}
let field_ty = field_types[i];
let field_ty_kind = self.lookup_type(field_ty);
let field_ty_size = field_ty_kind
.sizeof(self).map_or(MaybeSized::Unsized, MaybeSized::Sized);
let offset_in_field = offset - field_offset;
if MaybeSized::Sized(offset_in_field) < field_ty_size
|| offset_in_field == Size::ZERO
&& leaf_size_range.contains(&MaybeSized::Sized(Size::ZERO)) && leaf_ty == Some(field_ty)
{
Some((i, field_ty, field_ty_kind, field_ty_size, offset_in_field))
} else {
None
}
})?;
ty = field_ty;
trace!("setting ty = field_ty: {:?}", self.debug_type(field_ty));
ty_kind = field_ty_kind;
trace!("setting ty_kind = field_ty_kind: {:?}", field_ty_kind);
ty_size = field_ty_size;
trace!("setting ty_size = field_ty_size: {:?}", field_ty_size);
indices.push(i as u32);
offset = offset_in_field;
trace!("setting offset = offset_in_field: {:?}", offset_in_field);
}
SpirvType::Vector { element, .. }
| SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element }
| SpirvType::Matrix { element, .. } => {
trace!("recovering access chain from Vector, Array, RuntimeArray, or Matrix");
ty = element;
trace!("setting ty = element: {:?}", self.debug_type(element));
ty_kind = self.lookup_type(ty);
trace!("looked up ty kind: {:?}", ty_kind);
let stride = ty_kind.sizeof(self)?;
ty_size = MaybeSized::Sized(stride);
if stride == Size::ZERO {
if offset != Size::ZERO {
trace!("zero-sized element with non-zero offset: {:?}", offset);
return None;
}
indices.push(0);
offset = Size::ZERO;
} else {
indices.push((offset.bytes() / stride.bytes()).try_into().ok()?);
offset = Size::from_bytes(offset.bytes() % stride.bytes());
}
}
_ => {
trace!("recovering access chain from SOMETHING ELSE, RETURNING NONE");
return None;
}
}
if ty_size < *leaf_size_range.start() {
trace!("avoiding digging beyond the point the leaf could actually fit");
return None;
}
if offset == Size::ZERO
&& leaf_size_range.contains(&ty_size)
&& leaf_ty.is_none_or(|leaf_ty| leaf_ty == ty)
{
trace!("successful recovery leaf type: {:?}", self.debug_type(ty));
trace!("successful recovery indices: {:?}", indices);
return Some((indices, ty));
}
}
}
#[instrument(level = "trace", skip(self), fields(ty = ?self.debug_type(ty), ptr, combined_indices = ?combined_indices.iter().map(|x| (self.debug_type(x.ty), x.kind)).collect::<Vec<_>>(), is_inbounds))]
fn maybe_inbounds_gep(
&mut self,
ty: Word,
ptr: SpirvValue,
combined_indices: &[SpirvValue],
is_inbounds: bool,
) -> SpirvValue {
let (&ptr_base_index, indices) = combined_indices.split_first().unwrap();
if !indices.is_empty() {
self.fatal(format!(
"[RUST-GPU BUG] `inbounds_gep` or `gep` called \
with {} combined indices (expected only 1)",
combined_indices.len(),
));
}
self.ptr_offset_strided(ptr, ty, ptr_base_index, is_inbounds)
}
#[instrument(level = "trace", skip(self), fields(ptr, stride_elem_ty = ?self.debug_type(stride_elem_ty), index, is_inbounds))]
fn ptr_offset_strided(
&mut self,
ptr: SpirvValue,
stride_elem_ty: Word,
index: SpirvValue,
is_inbounds: bool,
) -> SpirvValue {
let const_offset = self.builder.lookup_const_scalar(index).and_then(|idx| {
let idx_u64 = u64::try_from(idx).ok()?;
let stride = self.lookup_type(stride_elem_ty).sizeof(self)?;
Some(idx_u64 * stride)
});
if const_offset == Some(Size::ZERO) {
trace!("ptr_offset_strided: strategy 1 picked: offset 0 => noop");
return ptr;
}
if let Some(const_offset) = const_offset {
let original_ptr = ptr.strip_ptrcasts();
let original_pointee_ty = match self.lookup_type(original_ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
other => self.fatal(format!("pointer arithmetic on non-pointer type {other:?}")),
};
if let Some((const_indices, leaf_pointee_ty)) = self.recover_access_chain_from_offset(
original_pointee_ty,
const_offset,
Some(Size::ZERO)..=None,
None,
) {
trace!(
"ptr_offset_strided: strategy 2 picked: offset {const_offset:?} \
=> access chain w/ {const_indices:?}"
);
let leaf_ptr_ty = self.type_ptr_to(leaf_pointee_ty);
let original_ptr_id = original_ptr.def(self);
let const_indices_ids = const_indices
.into_iter()
.map(|idx| self.constant_u32(self.span(), idx).def(self))
.collect();
return self.emit_access_chain(
leaf_ptr_ty,
original_ptr_id,
None,
const_indices_ids,
is_inbounds,
);
}
}
let ptr = self.pointercast(ptr, self.type_ptr_to(stride_elem_ty));
let ptr_id = ptr.def(self);
let maybe_original_access_chain = {
let builder = self.emit();
let module = builder.module_ref();
let current_func_blocks = builder
.selected_function()
.and_then(|func_idx| Some(&module.functions.get(func_idx)?.blocks[..]))
.unwrap_or_default();
current_func_blocks
.iter()
.flat_map(|b| &b.instructions)
.rfind(|inst| inst.result_id == Some(ptr_id))
.filter(|inst| {
matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain)
})
.map(|inst| {
let base_ptr = inst.operands[0].unwrap_id_ref();
let indices = inst.operands[1..]
.iter()
.map(|op| op.unwrap_id_ref())
.collect::<Vec<_>>();
(base_ptr, indices)
})
};
if let Some((original_ptr, original_indices)) = maybe_original_access_chain {
trace!("ptr_offset_strided: strategy 3 picked: merging access chains");
let mut merged_indices = original_indices;
let last_index_id = merged_indices.last_mut().unwrap();
*last_index_id = self.add(last_index_id.with_type(index.ty), index).def(self);
return self.emit_access_chain(ptr.ty, original_ptr, None, merged_indices, is_inbounds);
}
trace!("ptr_offset_strided: falling back to (illegal) `OpPtrAccessChain`");
let result_ptr = if is_inbounds {
self.emit()
.in_bounds_ptr_access_chain(ptr.ty, None, ptr_id, index.def(self), vec![])
} else {
self.emit()
.ptr_access_chain(ptr.ty, None, ptr_id, index.def(self), vec![])
}
.unwrap();
self.zombie(
result_ptr,
"cannot offset a pointer to an arbitrary element",
);
result_ptr.with_type(ptr.ty)
}
#[instrument(
level = "trace",
skip(self),
fields(
result_type = ?self.debug_type(result_type),
pointer,
ptr_base_index,
indices,
is_inbounds
)
)]
fn emit_access_chain(
&mut self,
result_type: <Self as BackendTypes>::Type,
pointer: Word,
ptr_base_index: Option<SpirvValue>,
indices: Vec<Word>,
is_inbounds: bool,
) -> SpirvValue {
let mut builder = self.emit();
let non_zero_ptr_base_index =
ptr_base_index.filter(|&idx| self.builder.lookup_const_scalar(idx) != Some(0));
if let Some(ptr_base_index) = non_zero_ptr_base_index {
let result = if is_inbounds {
builder.in_bounds_ptr_access_chain(
result_type,
None,
pointer,
ptr_base_index.def(self),
indices,
)
} else {
builder.ptr_access_chain(
result_type,
None,
pointer,
ptr_base_index.def(self),
indices,
)
}
.unwrap();
self.zombie(result, "cannot offset a pointer to an arbitrary element");
result
} else {
if is_inbounds {
builder.in_bounds_access_chain(result_type, None, pointer, indices)
} else {
builder.access_chain(result_type, None, pointer, indices)
}
.unwrap()
}
.with_type(result_type)
}
#[instrument(level = "trace", skip(self))]
fn fptoint_sat(
&mut self,
signed: bool,
val: SpirvValue,
dest_ty: <Self as BackendTypes>::Type,
) -> SpirvValue {
let src_ty = self.cx.val_ty(val);
let (float_ty, int_ty) = if self.cx.type_kind(src_ty) == TypeKind::Vector {
assert_eq!(
self.cx.vector_length(src_ty),
self.cx.vector_length(dest_ty)
);
(self.cx.element_type(src_ty), self.cx.element_type(dest_ty))
} else {
(src_ty, dest_ty)
};
let int_width = self.cx().int_width(int_ty);
let float_width = self.cx().float_width(float_ty);
let int_max = |signed: bool, int_width: u64| -> u128 {
let shift_amount = 128 - int_width;
if signed {
i128::MAX as u128 >> shift_amount
} else {
u128::MAX >> shift_amount
}
};
let int_min = |signed: bool, int_width: u64| -> i128 {
if signed {
i128::MIN >> (128 - int_width)
} else {
0
}
};
let compute_clamp_bounds_single = |signed: bool, int_width: u64| -> (u128, u128) {
let rounded_min =
ieee::Single::from_i128_r(int_min(signed, int_width), Round::TowardZero);
assert_eq!(rounded_min.status, Status::OK);
let rounded_max =
ieee::Single::from_u128_r(int_max(signed, int_width), Round::TowardZero);
assert!(rounded_max.value.is_finite());
(rounded_min.value.to_bits(), rounded_max.value.to_bits())
};
let compute_clamp_bounds_double = |signed: bool, int_width: u64| -> (u128, u128) {
let rounded_min =
ieee::Double::from_i128_r(int_min(signed, int_width), Round::TowardZero);
assert_eq!(rounded_min.status, Status::OK);
let rounded_max =
ieee::Double::from_u128_r(int_max(signed, int_width), Round::TowardZero);
assert!(rounded_max.value.is_finite());
(rounded_min.value.to_bits(), rounded_max.value.to_bits())
};
let float_bits_to_llval = |bx: &mut Self, bits| {
let bits_llval = match float_width {
32 => bx.cx().const_u32(bits as u32),
64 => bx.cx().const_u64(bits as u64),
n => bug!("unsupported float width {}", n),
};
bx.bitcast(bits_llval, float_ty)
};
let (f_min, f_max) = match float_width {
32 => compute_clamp_bounds_single(signed, int_width),
64 => compute_clamp_bounds_double(signed, int_width),
n => bug!("unsupported float width {}", n),
};
let f_min = float_bits_to_llval(self, f_min);
let f_max = float_bits_to_llval(self, f_max);
let int_max = self.cx().const_uint_big(int_ty, int_max(signed, int_width));
let int_min = self
.cx()
.const_uint_big(int_ty, int_min(signed, int_width) as u128);
let zero = self.cx().const_uint(int_ty, 0);
let maybe_splat = |bx: &mut Self, val| {
if bx.cx().type_kind(dest_ty) == TypeKind::Vector {
bx.vector_splat(bx.vector_length(dest_ty), val)
} else {
val
}
};
let f_min = maybe_splat(self, f_min);
let f_max = maybe_splat(self, f_max);
let int_max = maybe_splat(self, int_max);
let int_min = maybe_splat(self, int_min);
let zero = maybe_splat(self, zero);
let fptosui_result = if signed {
self.fptosi(val, dest_ty)
} else {
self.fptoui(val, dest_ty)
};
let less_or_nan = self.fcmp(RealPredicate::RealULT, val, f_min);
let greater = self.fcmp(RealPredicate::RealOGT, val, f_max);
let s0 = self.select(less_or_nan, int_min, fptosui_result);
let s1 = self.select(greater, int_max, s0);
if signed {
let cmp = self.fcmp(RealPredicate::RealOEQ, val, val);
self.select(cmp, s1, zero)
} else {
s1
}
}
#[instrument(level = "trace", skip(self), fields(ty = ?self.debug_type(ty)))]
fn declare_func_local_var(
&mut self,
ty: <Self as BackendTypes>::Type,
_align: Align,
) -> SpirvValue {
let ptr_ty = self.type_ptr_to(ty);
let mut builder = self.emit();
builder.select_block(Some(0)).unwrap();
let index = {
let block = &builder.module_ref().functions[builder.selected_function().unwrap()]
.blocks[builder.selected_block().unwrap()];
block
.instructions
.iter()
.enumerate()
.find_map(|(index, inst)| {
if inst.class.opcode != Op::Variable {
Some(InsertPoint::FromBegin(index))
} else {
None
}
})
.unwrap_or(InsertPoint::End)
};
let result_id = builder.id();
let inst = Instruction::new(
Op::Variable,
Some(ptr_ty),
Some(result_id),
vec![Operand::StorageClass(StorageClass::Function)],
);
builder.insert_into_block(index, inst).unwrap();
result_id.with_type(ptr_ty)
}
}
impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
type CodegenCx = CodegenCx<'tcx>;
#[instrument(level = "trace", skip(cx))]
fn build(cx: &'a Self::CodegenCx, llbb: Self::BasicBlock) -> Self {
Self {
cx,
current_block: llbb,
current_span: Default::default(),
}
}
fn cx(&self) -> &Self::CodegenCx {
self.cx
}
fn llbb(&self) -> Self::BasicBlock {
unreachable!("dead code within `rustc_codegen_ssa`")
}
fn set_span(&mut self, span: Span) {
let span = span.ctxt().outer_expn().expansion_cause().unwrap_or(span);
let old_span = self.current_span.replace(span);
if false {
if old_span == Some(span) {
return;
}
}
let use_custom_insts = true;
if use_custom_insts {
let void_ty = SpirvType::Void.def(rustc_span::DUMMY_SP, self);
if span.is_dummy() {
self.custom_inst(void_ty, CustomInst::ClearDebugSrcLoc);
} else {
let (file, line_col_range) = self.builder.file_line_col_range_for_debuginfo(span);
let ((line_start, col_start), (line_end, col_end)) =
(line_col_range.start, line_col_range.end);
self.custom_inst(
void_ty,
CustomInst::SetDebugSrcLoc {
file: Operand::IdRef(file.file_name_op_string_id),
line_start: Operand::IdRef(self.const_u32(line_start).def(self)),
line_end: Operand::IdRef(self.const_u32(line_end).def(self)),
col_start: Operand::IdRef(self.const_u32(col_start).def(self)),
col_end: Operand::IdRef(self.const_u32(col_end).def(self)),
},
);
}
let mut builder = self.emit();
if let (Some(func_idx), Some(block_idx)) =
(builder.selected_function(), builder.selected_block())
{
let block = &mut builder.module_mut().functions[func_idx].blocks[block_idx];
match &block.instructions[..] {
[.., a, b]
if a.class.opcode == b.class.opcode
&& a.operands[..2] == b.operands[..2] =>
{
block.instructions.remove(block.instructions.len() - 2);
}
_ => {}
}
}
} else {
if span.is_dummy() {
self.emit().no_line();
} else {
let (file, line_col_range) = self.builder.file_line_col_range_for_debuginfo(span);
let (line, col) = line_col_range.start;
self.emit().line(file.file_name_op_string_id, line, col);
}
}
}
fn append_block(
cx: &'a Self::CodegenCx,
llfn: Self::Function,
_name: &str,
) -> Self::BasicBlock {
let mut builder = cx.builder.builder_for_fn(llfn);
let id = builder.begin_block(None).unwrap();
let index_in_builder = builder.selected_block().unwrap();
SpirvBlockCursor {
parent_fn: llfn,
id,
index_in_builder,
}
}
fn append_sibling_block(&mut self, name: &str) -> Self::BasicBlock {
Self::append_block(self.cx, self.current_block.parent_fn, name)
}
fn switch_to_block(&mut self, llbb: Self::BasicBlock) {
*self = Self::build(self.cx, llbb);
}
fn ret_void(&mut self) {
self.emit().ret().unwrap();
}
fn ret(&mut self, value: Self::Value) {
let func_ret_ty = {
let builder = self.emit();
let func = &builder.module_ref().functions[builder.selected_function().unwrap()];
func.def.as_ref().unwrap().result_type.unwrap()
};
let value = self.bitcast(value, func_ret_ty);
self.emit().ret_value(value.def(self)).unwrap();
}
fn br(&mut self, dest: Self::BasicBlock) {
self.emit().branch(dest.id).unwrap();
}
fn cond_br(
&mut self,
cond: Self::Value,
then_llbb: Self::BasicBlock,
else_llbb: Self::BasicBlock,
) {
let cond = cond.def(self);
match self.builder.lookup_const_by_id(cond) {
Some(SpirvConst::Scalar(1)) => self.br(then_llbb),
Some(SpirvConst::Scalar(0)) => self.br(else_llbb),
_ => {
self.emit()
.branch_conditional(cond, then_llbb.id, else_llbb.id, empty())
.unwrap();
}
}
}
fn switch(
&mut self,
v: Self::Value,
else_llbb: Self::BasicBlock,
cases: impl ExactSizeIterator<Item = (u128, Self::BasicBlock)>,
) {
fn construct_8(self_: &Builder<'_, '_>, signed: bool, v: u128) -> Operand {
if v > u8::MAX as u128 {
self_.fatal(format!(
"Switches to values above u8::MAX not supported: {v:?}"
))
} else if signed {
Operand::LiteralBit32(v as u8 as i8 as i32 as u32)
} else {
Operand::LiteralBit32(v as u8 as u32)
}
}
fn construct_16(self_: &Builder<'_, '_>, signed: bool, v: u128) -> Operand {
if v > u16::MAX as u128 {
self_.fatal(format!(
"Switches to values above u16::MAX not supported: {v:?}"
))
} else if signed {
Operand::LiteralBit32(v as u16 as i16 as i32 as u32)
} else {
Operand::LiteralBit32(v as u16 as u32)
}
}
fn construct_32(self_: &Builder<'_, '_>, _signed: bool, v: u128) -> Operand {
if v > u32::MAX as u128 {
self_.fatal(format!(
"Switches to values above u32::MAX not supported: {v:?}"
))
} else {
Operand::LiteralBit32(v as u32)
}
}
fn construct_64(self_: &Builder<'_, '_>, _signed: bool, v: u128) -> Operand {
if v > u64::MAX as u128 {
self_.fatal(format!(
"Switches to values above u64::MAX not supported: {v:?}"
))
} else {
Operand::LiteralBit64(v as u64)
}
}
let (signed, construct_case) = match self.lookup_type(v.ty) {
SpirvType::Integer(width, signed) => {
let construct_case = match width {
8 => construct_8,
16 => construct_16,
32 => construct_32,
64 => construct_64,
other => self.fatal(format!(
"switch selector cannot have width {other} (only 8, 16, 32, and 64 bits allowed)"
)),
};
(signed, construct_case)
}
other => self.fatal(format!(
"switch selector cannot have non-integer type {}",
other.debug(v.ty, self)
)),
};
let cases = cases
.map(|(i, b)| (construct_case(self, signed, i), b.id))
.collect::<Vec<_>>();
self.emit()
.switch(v.def(self), else_llbb.id, cases)
.unwrap();
}
fn invoke(
&mut self,
llty: Self::Type,
fn_attrs: Option<&CodegenFnAttrs>,
fn_abi: Option<&FnAbi<'tcx, Ty<'tcx>>>,
llfn: Self::Value,
args: &[Self::Value],
then: Self::BasicBlock,
_catch: Self::BasicBlock,
funclet: Option<&Self::Funclet>,
instance: Option<ty::Instance<'tcx>>,
) -> Self::Value {
let result = self.call(llty, fn_attrs, fn_abi, llfn, args, funclet, instance);
self.emit().branch(then.id).unwrap();
result
}
fn unreachable(&mut self) {
self.emit().unreachable().unwrap();
}
simple_op! {
add,
int: i_add,
fold_const {
int(a, b) => a.checked_add(b)?;
}
}
simple_op! {fadd, float: f_add}
simple_op! {fadd_fast, float: f_add} simple_op! {fadd_algebraic, float: f_add} simple_op! {
sub,
int: i_sub,
fold_const {
int(a, b) => a.checked_sub(b)?;
}
}
simple_op! {fsub, float: f_sub}
simple_op! {fsub_fast, float: f_sub} simple_op! {fsub_algebraic, float: f_sub} simple_op! {
mul,
int: i_mul,
fold_const {
int(a, b) => a.checked_mul(b)?;
}
}
simple_op! {fmul, float: f_mul}
simple_op! {fmul_fast, float: f_mul} simple_op! {fmul_algebraic, float: f_mul} simple_op! {
udiv,
uint: u_div,
fold_const {
uint(a, b) => a.checked_div(b)?;
}
}
simple_op! {
exactudiv,
uint: u_div,
fold_const {
uint(a, b) => a.checked_div(b)?;
}
}
simple_op! {
sdiv,
sint: s_div,
fold_const {
sint(a, b) => a.checked_div(b)?;
}
}
simple_op! {
exactsdiv,
sint: s_div,
fold_const {
sint(a, b) => a.checked_div(b)?;
}
}
simple_op! {fdiv, float: f_div}
simple_op! {fdiv_fast, float: f_div} simple_op! {fdiv_algebraic, float: f_div} simple_op! {
urem,
uint: u_mod,
fold_const {
uint(a, b) => a.checked_rem(b)?;
}
}
simple_op! {
srem,
sint: s_rem,
fold_const {
sint(a, b) => a.checked_rem(b)?;
}
}
simple_op! {frem, float: f_rem}
simple_op! {frem_fast, float: f_rem} simple_op! {frem_algebraic, float: f_rem} simple_shift_op! {
shl,
int: shift_left_logical,
fold_const {
int(a, b) => a.checked_shl(b as u32)?;
}
}
simple_shift_op! {
lshr,
uint: shift_right_logical,
fold_const {
uint(a, b) => a.checked_shr(b as u32)?;
}
}
simple_shift_op! {
ashr,
sint: shift_right_arithmetic,
fold_const {
sint(a, b) => a.checked_shr(b as u32)?;
}
}
simple_uni_op! {
neg,
sint: s_negate,
fold_const {
sint(a) => a.checked_neg()?;
}
}
simple_uni_op! {fneg, float: f_negate}
fn unchecked_sadd(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
self.add(lhs, rhs)
}
fn unchecked_uadd(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
self.add(lhs, rhs)
}
fn unchecked_ssub(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
self.sub(lhs, rhs)
}
fn unchecked_usub(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
self.sub(lhs, rhs)
}
fn unchecked_smul(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
self.mul(lhs, rhs)
}
fn unchecked_umul(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
self.mul(lhs, rhs)
}
simple_op! {
and,
int: bitwise_and,
bool: logical_and,
fold_const {
int(a, b) => a.bitand(b);
bool(a, b) => a.bitand(b);
}
}
simple_op! {
or,
int: bitwise_or,
bool: logical_or,
fold_const {
int(a, b) => a.bitor(b);
bool(a, b) => a.bitor(b);
}
}
simple_op! {
xor,
int: bitwise_xor,
bool: logical_not_equal,
fold_const {
int(a, b) => a.bitxor(b);
bool(a, b) => a.bitxor(b);
}
}
simple_uni_op! {
not,
int: not,
bool: logical_not,
fold_const {
int(a) => a.not();
bool(a) => a.not();
}
}
#[instrument(level = "trace", skip(self))]
fn checked_binop(
&mut self,
oop: OverflowOp,
ty: Ty<'_>,
lhs: Self::Value,
rhs: Self::Value,
) -> (Self::Value, Self::Value) {
let is_add = match oop {
OverflowOp::Add => true,
OverflowOp::Sub => false,
OverflowOp::Mul => {
let bool = SpirvType::Bool.def(self.span(), self);
let overflowed = self.undef(bool);
let result = (self.mul(lhs, rhs), overflowed);
self.zombie(result.1.def(self), "checked mul is not supported yet");
return result;
}
};
let signed = match ty.kind() {
ty::Int(_) => true,
ty::Uint(_) => false,
other => self.fatal(format!(
"Unexpected {} type: {other:#?}",
match oop {
OverflowOp::Add => "checked add",
OverflowOp::Sub => "checked sub",
OverflowOp::Mul => "checked mul",
}
)),
};
let result = if is_add {
self.add(lhs, rhs)
} else {
self.sub(lhs, rhs)
};
let overflowed = if signed {
let rhs_lt_zero = self.icmp(IntPredicate::IntSLT, rhs, self.constant_int(rhs.ty, 0));
let result_gt_lhs = self.icmp(
if is_add {
IntPredicate::IntSGT
} else {
IntPredicate::IntSLT
},
result,
lhs,
);
self.icmp(IntPredicate::IntEQ, rhs_lt_zero, result_gt_lhs)
} else {
self.icmp(
if is_add {
IntPredicate::IntULT
} else {
IntPredicate::IntUGT
},
result,
lhs,
)
};
(result, overflowed)
}
fn from_immediate(&mut self, val: Self::Value) -> Self::Value {
val
}
fn to_immediate_scalar(&mut self, val: Self::Value, _scalar: Scalar) -> Self::Value {
val
}
#[cfg(not(rustc_codegen_spirv_disable_pqp_cg_ssa))]
fn typed_alloca(&mut self, ty: Self::Type, align: Align) -> Self::Value {
self.declare_func_local_var(ty, align)
}
fn alloca(&mut self, size: Size, align: Align) -> Self::Value {
self.declare_func_local_var(self.type_array(self.type_i8(), size.bytes()), align)
}
fn load(&mut self, ty: Self::Type, ptr: Self::Value, _align: Align) -> Self::Value {
let (ptr, access_ty) = self.adjust_pointer_for_typed_access(ptr, ty);
let loaded_val = ptr.const_fold_load(self).unwrap_or_else(|| {
self.emit()
.load(access_ty, None, ptr.def(self), None, empty())
.unwrap()
.with_type(access_ty)
});
self.bitcast(loaded_val, ty)
}
fn volatile_load(&mut self, ty: Self::Type, ptr: Self::Value) -> Self::Value {
let result = self.load(ty, ptr, Align::from_bytes(0).unwrap());
self.zombie(result.def(self), "volatile load is not supported yet");
result
}
fn atomic_load(
&mut self,
ty: Self::Type,
ptr: Self::Value,
order: AtomicOrdering,
_size: Size,
) -> Self::Value {
let (ptr, access_ty) = self.adjust_pointer_for_typed_access(ptr, ty);
let memory = self.constant_u32(self.span(), Scope::Device as u32);
let semantics = self.ordering_to_semantics_def(order);
let result = self
.emit()
.atomic_load(
access_ty,
None,
ptr.def(self),
memory.def(self),
semantics.def(self),
)
.unwrap()
.with_type(access_ty);
self.validate_atomic(access_ty, result.def(self));
self.bitcast(result, ty)
}
fn load_operand(
&mut self,
place: PlaceRef<'tcx, Self::Value>,
) -> OperandRef<'tcx, Self::Value> {
if place.layout.is_zst() {
return OperandRef::zero_sized(place.layout);
}
let val = if place.val.llextra.is_some() {
OperandValue::Ref(place.val)
} else if self.cx.is_backend_immediate(place.layout) {
let llval = self.load(
place.layout.spirv_type(self.span(), self),
place.val.llval,
place.val.align,
);
OperandValue::Immediate(llval)
} else if let BackendRepr::ScalarPair(a, b) = place.layout.backend_repr {
let b_offset = a
.primitive()
.size(self)
.align_to(b.primitive().align(self).abi);
let mut load = |i, scalar: Scalar, align| {
let llptr = if i == 0 {
place.val.llval
} else {
self.inbounds_ptradd(place.val.llval, self.const_usize(b_offset.bytes()))
};
let load = self.load(
self.scalar_pair_element_backend_type(place.layout, i, false),
llptr,
align,
);
self.to_immediate_scalar(load, scalar)
};
OperandValue::Pair(
load(0, a, place.val.align),
load(1, b, place.val.align.restrict_for_offset(b_offset)),
)
} else {
OperandValue::Ref(place.val)
};
OperandRef {
val,
layout: place.layout,
move_annotation: None,
}
}
fn write_operand_repeatedly(
&mut self,
cg_elem: OperandRef<'tcx, Self::Value>,
count: u64,
dest: PlaceRef<'tcx, Self::Value>,
) {
let zero = self.const_usize(0);
let start = dest.project_index(self, zero).val.llval;
let elem_layout = dest.layout.field(self.cx(), 0);
let elem_ty = elem_layout.spirv_type(self.span(), self);
let align = dest.val.align.restrict_for_offset(elem_layout.size);
for i in 0..count {
let current = self.inbounds_gep(elem_ty, start, &[self.const_usize(i)]);
cg_elem.val.store(
self,
PlaceRef::new_sized_aligned(current, cg_elem.layout, align),
);
}
}
fn assume_integer_range(&mut self, _imm: Self::Value, _ty: Self::Type, _range: WrappingRange) {}
fn assume_nonnull(&mut self, _val: Self::Value) {}
fn range_metadata(&mut self, _load: Self::Value, _range: WrappingRange) {
}
fn nonnull_metadata(&mut self, _load: Self::Value) {
}
fn store(&mut self, val: Self::Value, ptr: Self::Value, _align: Align) -> Self::Value {
let (ptr, access_ty) = self.adjust_pointer_for_typed_access(ptr, val.ty);
let val = self.bitcast(val, access_ty);
self.emit()
.store(ptr.def(self), val.def(self), None, empty())
.unwrap();
val
}
fn store_with_flags(
&mut self,
val: Self::Value,
ptr: Self::Value,
align: Align,
flags: MemFlags,
) -> Self::Value {
if flags != MemFlags::empty() {
self.err(format!("store_with_flags is not supported yet: {flags:?}"));
}
self.store(val, ptr, align)
}
fn atomic_store(
&mut self,
val: Self::Value,
ptr: Self::Value,
order: AtomicOrdering,
_size: Size,
) {
let (ptr, access_ty) = self.adjust_pointer_for_typed_access(ptr, val.ty);
let val = self.bitcast(val, access_ty);
let memory = self.constant_u32(self.span(), Scope::Device as u32);
let semantics = self.ordering_to_semantics_def(order);
self.validate_atomic(val.ty, ptr.def(self));
self.emit()
.atomic_store(
ptr.def(self),
memory.def(self),
semantics.def(self),
val.def(self),
)
.unwrap();
}
fn gep(&mut self, ty: Self::Type, ptr: Self::Value, indices: &[Self::Value]) -> Self::Value {
self.maybe_inbounds_gep(ty, ptr, indices, false)
}
fn inbounds_gep(
&mut self,
ty: Self::Type,
ptr: Self::Value,
indices: &[Self::Value],
) -> Self::Value {
if let &[ptr_base_index, structured_index] = indices
&& self.builder.lookup_const_scalar(ptr_base_index) == Some(0)
&& let SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element, .. } =
self.lookup_type(ty)
{
return self.maybe_inbounds_gep(element, ptr, &[structured_index], true);
}
self.maybe_inbounds_gep(ty, ptr, indices, true)
}
fn trunc(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
self.intcast(val, dest_ty, false)
}
fn sext(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
self.intcast(val, dest_ty, true)
}
fn fptoui_sat(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
self.fptoint_sat(false, val, dest_ty)
}
fn fptosi_sat(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
self.fptoint_sat(true, val, dest_ty)
}
fn fptoui(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
if val.ty == dest_ty {
val
} else {
self.emit()
.convert_f_to_u(dest_ty, None, val.def(self))
.unwrap()
.with_type(dest_ty)
}
}
fn fptosi(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
if val.ty == dest_ty {
val
} else {
self.emit()
.convert_f_to_s(dest_ty, None, val.def(self))
.unwrap()
.with_type(dest_ty)
}
}
fn uitofp(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
if val.ty == dest_ty {
val
} else {
self.emit()
.convert_u_to_f(dest_ty, None, val.def(self))
.unwrap()
.with_type(dest_ty)
}
}
fn sitofp(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
if val.ty == dest_ty {
val
} else {
self.emit()
.convert_s_to_f(dest_ty, None, val.def(self))
.unwrap()
.with_type(dest_ty)
}
}
fn fptrunc(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
if val.ty == dest_ty {
val
} else {
self.emit()
.f_convert(dest_ty, None, val.def(self))
.unwrap()
.with_type(dest_ty)
}
}
fn fpext(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
if val.ty == dest_ty {
val
} else {
if let Some(const_val) = self.builder.lookup_const_scalar(val)
&& let (SpirvType::Float(src_width), SpirvType::Float(dst_width)) =
(self.lookup_type(val.ty), self.lookup_type(dest_ty))
&& src_width < dst_width
{
let float_val = match src_width {
32 => Some(f32::from_bits(const_val as u32) as f64),
64 => Some(f64::from_bits(const_val as u64)),
_ => None,
};
if let Some(val) = float_val {
return self.constant_float(dest_ty, val);
}
}
self.emit()
.f_convert(dest_ty, None, val.def(self))
.unwrap()
.with_type(dest_ty)
}
}
fn ptrtoint(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
match self.lookup_type(val.ty) {
SpirvType::Pointer { .. } => (),
other => self.fatal(format!(
"ptrtoint called on non-pointer source type: {other:?}"
)),
}
if val.ty == dest_ty {
val
} else {
let result = self
.emit()
.convert_ptr_to_u(dest_ty, None, val.def(self))
.unwrap()
.with_type(dest_ty);
self.zombie_convert_ptr_to_u(result.def(self));
result
}
}
fn inttoptr(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
match self.lookup_type(dest_ty) {
SpirvType::Pointer { .. } => (),
other => self.fatal(format!(
"inttoptr called on non-pointer dest type: {other:?}"
)),
}
if val.ty == dest_ty {
val
} else {
let result = self
.emit()
.convert_u_to_ptr(dest_ty, None, val.def(self))
.unwrap()
.with_type(dest_ty);
self.zombie_convert_u_to_ptr(result.def(self));
result
}
}
#[instrument(level = "trace", skip(self), fields(val_type = ?self.debug_type(val.ty), dest_ty = ?self.debug_type(dest_ty)))]
fn bitcast(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
if val.ty == dest_ty {
val
} else {
let val_ty_kind = self.lookup_type(val.ty);
let dest_ty_kind = self.lookup_type(dest_ty);
let unpack_newtype = |ty, kind| {
let span = span!(Level::DEBUG, "unpack_newtype");
let _guard = span.enter();
if !matches!(kind, SpirvType::Adt { .. } | SpirvType::Array { .. }) {
return None;
}
let size = kind.sizeof(self)?;
let mut leaf_ty = ty;
let mut indices = SmallVec::<[_; 8]>::new();
while let Some((inner_indices, inner_ty)) = self.recover_access_chain_from_offset(
leaf_ty,
Size::ZERO,
Some(size)..=Some(size),
None,
) {
indices.extend(inner_indices);
leaf_ty = inner_ty;
}
(!indices.is_empty()).then_some((indices, leaf_ty))
};
if let Some((indices, in_leaf_ty)) = unpack_newtype(val.ty, val_ty_kind) {
let in_leaf = self
.emit()
.composite_extract(in_leaf_ty, None, val.def(self), indices)
.unwrap()
.with_type(in_leaf_ty);
trace!(
"unpacked newtype. val: {} -> in_leaf_ty: {}",
self.debug_type(val.ty),
self.debug_type(in_leaf_ty),
);
return self.bitcast(in_leaf, dest_ty);
}
if let Some((indices, out_leaf_ty)) = unpack_newtype(dest_ty, dest_ty_kind) {
trace!(
"unpacked newtype: dest: {} -> out_leaf_ty: {}",
self.debug_type(dest_ty),
self.debug_type(out_leaf_ty),
);
let out_leaf = self.bitcast(val, out_leaf_ty);
let out_agg_undef = self.undef(dest_ty);
trace!("returning composite insert");
return self
.emit()
.composite_insert(
dest_ty,
None,
out_leaf.def(self),
out_agg_undef.def(self),
indices,
)
.unwrap()
.with_type(dest_ty);
}
let val_is_ptr = matches!(val_ty_kind, SpirvType::Pointer { .. });
let dest_is_ptr = matches!(dest_ty_kind, SpirvType::Pointer { .. });
if val_is_ptr && dest_is_ptr {
trace!("val and dest are both pointers");
return self.pointercast(val, dest_ty);
}
trace!(
"before emitting: val ty: {} -> dest ty: {}",
self.debug_type(val.ty),
self.debug_type(dest_ty)
);
let result = self
.emit()
.bitcast(dest_ty, None, val.def(self))
.unwrap()
.with_type(dest_ty);
if val_is_ptr || dest_is_ptr {
self.zombie(
result.def(self),
&format!(
"cannot cast between pointer and non-pointer types\
\nfrom `{}`\
\n to `{}`",
self.debug_type(val.ty),
self.debug_type(dest_ty)
),
);
}
result
}
}
fn intcast(&mut self, val: Self::Value, dest_ty: Self::Type, is_signed: bool) -> Self::Value {
if val.ty == dest_ty {
return val;
}
if let Some(const_val) = self.builder.lookup_const_scalar(val) {
let src_ty = self.lookup_type(val.ty);
let dst_ty_spv = self.lookup_type(dest_ty);
let optimized_result = match (src_ty, dst_ty_spv) {
(SpirvType::Integer(src_width, _), SpirvType::Integer(dst_width, _)) => {
if src_width < dst_width {
Some(self.constant_int(dest_ty, const_val))
} else {
None
}
}
(SpirvType::Bool, SpirvType::Integer(_, _)) => {
Some(self.constant_int(dest_ty, const_val))
}
(SpirvType::Integer(_, _), SpirvType::Bool) => {
Some(self.constant_bool(self.span(), const_val != 0))
}
_ => None,
};
if let Some(result) = optimized_result {
return result;
}
}
match (self.lookup_type(val.ty), self.lookup_type(dest_ty)) {
(
SpirvType::Integer(val_width, val_signedness),
SpirvType::Integer(dest_width, dest_signedness),
) if val_width == dest_width && val_signedness != dest_signedness => self
.emit()
.bitcast(dest_ty, None, val.def(self))
.unwrap()
.with_type(dest_ty),
(SpirvType::Integer(_, _), SpirvType::Integer(_, dest_signedness)) => {
if dest_signedness {
self.emit().s_convert(dest_ty, None, val.def(self))
} else {
self.emit().u_convert(dest_ty, None, val.def(self))
}
.unwrap()
.with_type(dest_ty)
}
(SpirvType::Bool, SpirvType::Integer(_, _)) => {
let if_true = self.constant_int(dest_ty, 1);
let if_false = self.constant_int(dest_ty, 0);
self.emit()
.select(
dest_ty,
None,
val.def(self),
if_true.def(self),
if_false.def(self),
)
.unwrap()
.with_type(dest_ty)
}
(SpirvType::Integer(_, _), SpirvType::Bool) => {
let zero = self.constant_int(val.ty, 0);
self.emit()
.i_not_equal(dest_ty, None, val.def(self), zero.def(self))
.unwrap()
.with_type(dest_ty)
}
(val_ty, dest_ty_spv) => self.fatal(format!(
"TODO: intcast not implemented yet: val={val:?} val.ty={val_ty:?} dest_ty={dest_ty_spv:?} is_signed={is_signed}"
)),
}
}
#[instrument(level = "trace", skip(self), fields(ptr, ptr_ty = ?self.debug_type(ptr.ty), dest_ty = ?self.debug_type(dest_ty)))]
fn pointercast(&mut self, ptr: Self::Value, dest_ty: Self::Type) -> Self::Value {
if let SpirvValueKind::IllegalConst(_) = ptr.kind {
trace!("illegal const");
return self.const_bitcast(ptr, dest_ty);
}
if ptr.ty == dest_ty {
trace!("ptr.ty == dest_ty");
return ptr;
}
let ptr = ptr.strip_ptrcasts();
trace!(
"ptr type after strippint pointer cases: {}",
self.debug_type(ptr.ty),
);
if ptr.ty == dest_ty {
return ptr;
}
let ptr_pointee = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer source type: {other:?}"
)),
};
let dest_pointee = match self.lookup_type(dest_ty) {
SpirvType::Pointer { pointee } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer dest type: {other:?}"
)),
};
let dest_pointee_size = self.lookup_type(dest_pointee).sizeof(self);
if let Some((indices, _)) = self.recover_access_chain_from_offset(
ptr_pointee,
Size::ZERO,
dest_pointee_size..=dest_pointee_size,
Some(dest_pointee),
) {
trace!("`recover_access_chain_from_offset` returned something");
trace!(
"ptr_pointee: {}, dest_pointee {}",
self.debug_type(ptr_pointee),
self.debug_type(dest_pointee),
);
let indices = indices
.into_iter()
.map(|idx| self.constant_u32(self.span(), idx).def(self))
.collect::<Vec<_>>();
self.emit()
.in_bounds_access_chain(dest_ty, None, ptr.def(self), indices)
.unwrap()
.with_type(dest_ty)
} else {
trace!("`recover_access_chain_from_offset` returned `None`");
trace!(
"ptr_pointee: {}, dest_pointee {}",
self.debug_type(ptr_pointee),
self.debug_type(dest_pointee),
);
let original_ptr = ptr.def(self);
SpirvValue {
kind: SpirvValueKind::LogicalPtrCast {
original_ptr,
original_ptr_ty: ptr.ty,
bitcast_result_id: self.emit().bitcast(dest_ty, None, original_ptr).unwrap(),
},
ty: dest_ty,
}
}
}
fn icmp(&mut self, op: IntPredicate, lhs: Self::Value, mut rhs: Self::Value) -> Self::Value {
use IntPredicate::*;
if lhs.ty != rhs.ty
&& [lhs, rhs].map(|v| matches!(self.lookup_type(v.ty), SpirvType::Pointer { .. }))
== [true, true]
{
rhs = self.pointercast(rhs, lhs.ty);
}
assert_ty_eq!(self, lhs.ty, rhs.ty);
let b = SpirvType::Bool.def(self.span(), self);
if let Some(const_lhs) = self.try_get_const_value(lhs)
&& let Some(const_rhs) = self.try_get_const_value(rhs)
{
let const_result = match self.lookup_type(lhs.ty) {
SpirvType::Integer(_, _) => match (const_lhs, const_rhs, op) {
(ConstValue::Unsigned(lhs), ConstValue::Unsigned(rhs), IntEQ) => {
Some(lhs.eq(&rhs))
}
(ConstValue::Signed(lhs), ConstValue::Signed(rhs), IntEQ) => Some(lhs.eq(&rhs)),
(ConstValue::Unsigned(lhs), ConstValue::Unsigned(rhs), IntNE) => {
Some(lhs.ne(&rhs))
}
(ConstValue::Signed(lhs), ConstValue::Signed(rhs), IntNE) => Some(lhs.ne(&rhs)),
(ConstValue::Unsigned(lhs), ConstValue::Unsigned(rhs), IntUGT) => {
Some(lhs.gt(&rhs))
}
(ConstValue::Unsigned(lhs), ConstValue::Unsigned(rhs), IntUGE) => {
Some(lhs.ge(&rhs))
}
(ConstValue::Unsigned(lhs), ConstValue::Unsigned(rhs), IntULT) => {
Some(lhs.lt(&rhs))
}
(ConstValue::Unsigned(lhs), ConstValue::Unsigned(rhs), IntULE) => {
Some(lhs.le(&rhs))
}
(ConstValue::Signed(lhs), ConstValue::Signed(rhs), IntUGT) => {
Some(lhs.gt(&rhs))
}
(ConstValue::Signed(lhs), ConstValue::Signed(rhs), IntUGE) => {
Some(lhs.ge(&rhs))
}
(ConstValue::Signed(lhs), ConstValue::Signed(rhs), IntULT) => {
Some(lhs.lt(&rhs))
}
(ConstValue::Signed(lhs), ConstValue::Signed(rhs), IntULE) => {
Some(lhs.le(&rhs))
}
(_, _, _) => None,
},
SpirvType::Bool => match (const_lhs, const_rhs, op) {
(ConstValue::Bool(lhs), ConstValue::Bool(rhs), IntEQ) => Some(lhs.eq(&rhs)),
(ConstValue::Bool(lhs), ConstValue::Bool(rhs), IntNE) => Some(lhs.ne(&rhs)),
(ConstValue::Bool(lhs), ConstValue::Bool(rhs), IntUGT) => Some(lhs.gt(&rhs)),
(ConstValue::Bool(lhs), ConstValue::Bool(rhs), IntUGE) => Some(lhs.ge(&rhs)),
(ConstValue::Bool(lhs), ConstValue::Bool(rhs), IntULT) => Some(lhs.lt(&rhs)),
(ConstValue::Bool(lhs), ConstValue::Bool(rhs), IntULE) => Some(lhs.le(&rhs)),
(_, _, _) => None,
},
_ => None,
};
if let Some(result) = const_result {
return self.const_bool(result);
}
}
match self.lookup_type(lhs.ty) {
SpirvType::Integer(_, _) => match op {
IntEQ => self.emit().i_equal(b, None, lhs.def(self), rhs.def(self)),
IntNE => self
.emit()
.i_not_equal(b, None, lhs.def(self), rhs.def(self)),
IntUGT => self
.emit()
.u_greater_than(b, None, lhs.def(self), rhs.def(self)),
IntUGE => self
.emit()
.u_greater_than_equal(b, None, lhs.def(self), rhs.def(self)),
IntULT => self
.emit()
.u_less_than(b, None, lhs.def(self), rhs.def(self)),
IntULE => self
.emit()
.u_less_than_equal(b, None, lhs.def(self), rhs.def(self)),
IntSGT => self
.emit()
.s_greater_than(b, None, lhs.def(self), rhs.def(self)),
IntSGE => self
.emit()
.s_greater_than_equal(b, None, lhs.def(self), rhs.def(self)),
IntSLT => self
.emit()
.s_less_than(b, None, lhs.def(self), rhs.def(self)),
IntSLE => self
.emit()
.s_less_than_equal(b, None, lhs.def(self), rhs.def(self)),
},
SpirvType::Pointer { .. } => match op {
IntEQ => {
if self.emit().version().unwrap() > (1, 3) {
self.emit()
.ptr_equal(b, None, lhs.def(self), rhs.def(self))
.inspect(|&result| {
self.zombie_ptr_equal(result, "OpPtrEqual");
})
} else {
let int_ty = self.type_usize();
let lhs = self
.emit()
.convert_ptr_to_u(int_ty, None, lhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(lhs);
let rhs = self
.emit()
.convert_ptr_to_u(int_ty, None, rhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(rhs);
self.emit().i_equal(b, None, lhs, rhs)
}
}
IntNE => {
if self.emit().version().unwrap() > (1, 3) {
self.emit()
.ptr_not_equal(b, None, lhs.def(self), rhs.def(self))
.inspect(|&result| {
self.zombie_ptr_equal(result, "OpPtrNotEqual");
})
} else {
let int_ty = self.type_usize();
let lhs = self
.emit()
.convert_ptr_to_u(int_ty, None, lhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(lhs);
let rhs = self
.emit()
.convert_ptr_to_u(int_ty, None, rhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(rhs);
self.emit().i_not_equal(b, None, lhs, rhs)
}
}
IntUGT => {
let int_ty = self.type_usize();
let lhs = self
.emit()
.convert_ptr_to_u(int_ty, None, lhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(lhs);
let rhs = self
.emit()
.convert_ptr_to_u(int_ty, None, rhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(rhs);
self.emit().u_greater_than(b, None, lhs, rhs)
}
IntUGE => {
let int_ty = self.type_usize();
let lhs = self
.emit()
.convert_ptr_to_u(int_ty, None, lhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(lhs);
let rhs = self
.emit()
.convert_ptr_to_u(int_ty, None, rhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(rhs);
self.emit().u_greater_than_equal(b, None, lhs, rhs)
}
IntULT => {
let int_ty = self.type_usize();
let lhs = self
.emit()
.convert_ptr_to_u(int_ty, None, lhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(lhs);
let rhs = self
.emit()
.convert_ptr_to_u(int_ty, None, rhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(rhs);
self.emit().u_less_than(b, None, lhs, rhs)
}
IntULE => {
let int_ty = self.type_usize();
let lhs = self
.emit()
.convert_ptr_to_u(int_ty, None, lhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(lhs);
let rhs = self
.emit()
.convert_ptr_to_u(int_ty, None, rhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(rhs);
self.emit().u_less_than_equal(b, None, lhs, rhs)
}
IntSGT => self.fatal("TODO: pointer operator IntSGT not implemented yet"),
IntSGE => self.fatal("TODO: pointer operator IntSGE not implemented yet"),
IntSLT => self.fatal("TODO: pointer operator IntSLT not implemented yet"),
IntSLE => self.fatal("TODO: pointer operator IntSLE not implemented yet"),
},
SpirvType::Bool => match op {
IntEQ => self
.emit()
.logical_equal(b, None, lhs.def(self), rhs.def(self)),
IntNE => self
.emit()
.logical_not_equal(b, None, lhs.def(self), rhs.def(self)),
IntUGT => {
let true_ = self.constant_bool(self.span(), true);
let rhs = self
.emit()
.logical_not_equal(b, None, rhs.def(self), true_.def(self))
.unwrap();
self.emit().logical_and(b, None, lhs.def(self), rhs)
}
IntUGE => {
let true_ = self.constant_bool(self.span(), true);
let rhs = self
.emit()
.logical_not_equal(b, None, rhs.def(self), true_.def(self))
.unwrap();
self.emit().logical_or(b, None, lhs.def(self), rhs)
}
IntULT => {
let true_ = self.constant_bool(self.span(), true);
let lhs = self
.emit()
.logical_not_equal(b, None, lhs.def(self), true_.def(self))
.unwrap();
self.emit().logical_and(b, None, lhs, rhs.def(self))
}
IntULE => {
let true_ = self.constant_bool(self.span(), true);
let lhs = self
.emit()
.logical_not_equal(b, None, lhs.def(self), true_.def(self))
.unwrap();
self.emit().logical_or(b, None, lhs, rhs.def(self))
}
IntSGT => self.fatal("TODO: boolean operator IntSGT not implemented yet"),
IntSGE => self.fatal("TODO: boolean operator IntSGE not implemented yet"),
IntSLT => self.fatal("TODO: boolean operator IntSLT not implemented yet"),
IntSLE => self.fatal("TODO: boolean operator IntSLE not implemented yet"),
},
other => self.fatal(format!(
"Int comparison not implemented on {}",
other.debug(lhs.ty, self)
)),
}
.unwrap()
.with_type(b)
}
fn fcmp(&mut self, op: RealPredicate, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
use RealPredicate::*;
assert_ty_eq!(self, lhs.ty, rhs.ty);
let b = SpirvType::Bool.def(self.span(), self);
match op {
RealPredicateFalse => return self.cx.constant_bool(self.span(), false),
RealPredicateTrue => return self.cx.constant_bool(self.span(), true),
RealOEQ => self
.emit()
.f_ord_equal(b, None, lhs.def(self), rhs.def(self)),
RealOGT => self
.emit()
.f_ord_greater_than(b, None, lhs.def(self), rhs.def(self)),
RealOGE => self
.emit()
.f_ord_greater_than_equal(b, None, lhs.def(self), rhs.def(self)),
RealOLT => self
.emit()
.f_ord_less_than(b, None, lhs.def(self), rhs.def(self)),
RealOLE => self
.emit()
.f_ord_less_than_equal(b, None, lhs.def(self), rhs.def(self)),
RealONE => self
.emit()
.f_ord_not_equal(b, None, lhs.def(self), rhs.def(self)),
RealORD => self.emit().ordered(b, None, lhs.def(self), rhs.def(self)),
RealUNO => self.emit().unordered(b, None, lhs.def(self), rhs.def(self)),
RealUEQ => self
.emit()
.f_unord_equal(b, None, lhs.def(self), rhs.def(self)),
RealUGT => self
.emit()
.f_unord_greater_than(b, None, lhs.def(self), rhs.def(self)),
RealUGE => {
self.emit()
.f_unord_greater_than_equal(b, None, lhs.def(self), rhs.def(self))
}
RealULT => self
.emit()
.f_unord_less_than(b, None, lhs.def(self), rhs.def(self)),
RealULE => self
.emit()
.f_unord_less_than_equal(b, None, lhs.def(self), rhs.def(self)),
RealUNE => self
.emit()
.f_unord_not_equal(b, None, lhs.def(self), rhs.def(self)),
}
.unwrap()
.with_type(b)
}
#[instrument(level = "trace", skip(self))]
fn memcpy(
&mut self,
dst: Self::Value,
_dst_align: Align,
src: Self::Value,
_src_align: Align,
size: Self::Value,
flags: MemFlags,
_tt: Option<rustc_ast::expand::typetree::FncTree>,
) {
if flags != MemFlags::empty() {
self.err(format!(
"memcpy with mem flags is not supported yet: {flags:?}"
));
}
let const_size = self
.builder
.lookup_const_scalar(size)
.and_then(|size| Some(Size::from_bytes(u64::try_from(size).ok()?)));
if const_size == Some(Size::ZERO) {
return;
}
let typed_copy_dst_src = const_size.and_then(|const_size| {
trace!(
"adjusting pointers: src: {} -> dst: {}",
self.debug_type(src.ty),
self.debug_type(dst.ty),
);
let dst_adj = self.adjust_pointer_for_sized_access(dst, const_size);
let src_adj = self.adjust_pointer_for_sized_access(src, const_size);
match (dst_adj, src_adj) {
(Some((dst, access_ty)), None) => {
trace!(
"DESTINATION adjusted memcpy calling pointercast: dst ty: {}, access ty: {}",
self.debug_type(dst.ty),
self.debug_type(access_ty)
);
Some((dst, self.pointercast(src, self.type_ptr_to(access_ty))))
}
(None, Some((src, access_ty))) => {
trace!(
"SOURCE adjusted memcpy calling pointercast: dst ty: {} -> access ty: {}, src ty: {}",
self.debug_type(dst.ty),
self.debug_type(access_ty),
self.debug_type(src.ty)
);
Some((self.pointercast(dst, self.type_ptr_to(access_ty)), src))
}
(Some((dst, dst_access_ty)), Some((src, src_access_ty)))
if dst_access_ty == src_access_ty =>
{
trace!("BOTH adjusted memcpy calling pointercast");
Some((dst, src))
}
(None, None) | (Some(_), Some(_)) => None,
}
});
if let Some((dst, src)) = typed_copy_dst_src {
if let Some(const_value) = src.const_fold_load(self) {
trace!("storing const value");
self.store(const_value, dst, Align::from_bytes(0).unwrap());
} else {
trace!("copying memory using OpCopyMemory");
self.emit()
.copy_memory(dst.def(self), src.def(self), None, None, empty())
.unwrap();
}
} else {
self.emit()
.copy_memory_sized(
dst.def(self),
src.def(self),
size.def(self),
None,
None,
empty(),
)
.unwrap();
self.zombie(dst.def(self), "cannot memcpy dynamically sized data");
}
}
fn memmove(
&mut self,
dst: Self::Value,
dst_align: Align,
src: Self::Value,
src_align: Align,
size: Self::Value,
flags: MemFlags,
) {
self.memcpy(dst, dst_align, src, src_align, size, flags, None);
}
fn memset(
&mut self,
ptr: Self::Value,
fill_byte: Self::Value,
size: Self::Value,
_align: Align,
flags: MemFlags,
) {
if flags != MemFlags::empty() {
self.err(format!(
"memset with mem flags is not supported yet: {flags:?}"
));
}
let const_size = self
.builder
.lookup_const_scalar(size)
.and_then(|size| Some(Size::from_bytes(u64::try_from(size).ok()?)));
let elem_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
_ => self.fatal(format!(
"memset called on non-pointer type: {}",
self.debug_type(ptr.ty)
)),
};
let elem_ty_spv = self.lookup_type(elem_ty);
let pat = match self.builder.lookup_const_scalar(fill_byte) {
Some(fill_byte) => self.memset_const_pattern(&elem_ty_spv, fill_byte as u8),
None => self.memset_dynamic_pattern(&elem_ty_spv, fill_byte.def(self)),
}
.with_type(elem_ty);
match const_size {
Some(size) => self.memset_constant_size(ptr, pat, size.bytes()),
None => self.memset_dynamic_size(ptr, pat, size),
}
}
fn select(
&mut self,
cond: Self::Value,
then_val: Self::Value,
else_val: Self::Value,
) -> Self::Value {
assert_ty_eq!(self, then_val.ty, else_val.ty);
let result_type = then_val.ty;
if let Some(ConstValue::Bool(b)) = self.try_get_const_value(cond) {
return if b { then_val } else { else_val };
}
self.emit()
.select(
result_type,
None,
cond.def(self),
then_val.def(self),
else_val.def(self),
)
.unwrap()
.with_type(result_type)
}
fn va_arg(&mut self, _list: Self::Value, _ty: Self::Type) -> Self::Value {
todo!()
}
fn extract_element(&mut self, vec: Self::Value, idx: Self::Value) -> Self::Value {
let result_type = match self.lookup_type(vec.ty) {
SpirvType::Vector { element, .. } => element,
other => self.fatal(format!("extract_element not implemented on type {other:?}")),
};
match self.builder.lookup_const_scalar(idx) {
Some(const_index) => self.emit().composite_extract(
result_type,
None,
vec.def(self),
[const_index as u32].iter().cloned(),
),
None => {
self.emit()
.vector_extract_dynamic(result_type, None, vec.def(self), idx.def(self))
}
}
.unwrap()
.with_type(result_type)
}
fn vector_splat(&mut self, num_elts: usize, elt: Self::Value) -> Self::Value {
let result_type =
SpirvType::simd_vector(self, self.span(), self.lookup_type(elt.ty), num_elts as u32)
.def(self.span(), self);
if self.builder.lookup_const(elt).is_some() {
self.constant_composite(result_type, iter::repeat_n(elt.def(self), num_elts))
} else {
self.emit()
.composite_construct(result_type, None, iter::repeat_n(elt.def(self), num_elts))
.unwrap()
.with_type(result_type)
}
}
fn extract_value(&mut self, agg_val: Self::Value, idx: u64) -> Self::Value {
let result_type = match self.lookup_type(agg_val.ty) {
SpirvType::Adt { field_types, .. } => field_types[idx as usize],
SpirvType::Array { element, .. }
| SpirvType::Vector { element, .. }
| SpirvType::Matrix { element, .. } => element,
other => self.fatal(format!(
"extract_value not implemented on type {}",
other.debug(agg_val.ty, self)
)),
};
self.emit()
.composite_extract(
result_type,
None,
agg_val.def(self),
[idx as u32].iter().cloned(),
)
.unwrap()
.with_type(result_type)
}
#[instrument(level = "trace", skip(self))]
fn insert_value(&mut self, agg_val: Self::Value, elt: Self::Value, idx: u64) -> Self::Value {
let field_type = match self.lookup_type(agg_val.ty) {
SpirvType::Adt { field_types, .. } => field_types[idx as usize],
SpirvType::Array { element, .. }
| SpirvType::Vector { element, .. }
| SpirvType::Matrix { element, .. } => element,
other => self.fatal(format!("insert_value not implemented on type {other:?}")),
};
let elt = self.bitcast(elt, field_type);
self.emit()
.composite_insert(
agg_val.ty,
None,
elt.def(self),
agg_val.def(self),
[idx as u32].iter().cloned(),
)
.unwrap()
.with_type(agg_val.ty)
}
fn set_personality_fn(&mut self, _personality: Self::Function) {
todo!()
}
fn cleanup_landing_pad(&mut self, _pers_fn: Self::Function) -> (Self::Value, Self::Value) {
todo!()
}
fn filter_landing_pad(&mut self, _pers_fn: Self::Function) {
todo!()
}
fn resume(&mut self, _exn0: Self::Value, _exn1: Self::Value) {
todo!()
}
fn cleanup_pad(
&mut self,
_parent: Option<Self::Value>,
_args: &[Self::Value],
) -> Self::Funclet {
bug!("Funclets are not supported")
}
fn cleanup_ret(&mut self, _funclet: &Self::Funclet, _unwind: Option<Self::BasicBlock>) {
bug!("Funclets are not supported")
}
fn catch_pad(&mut self, _parent: Self::Value, _args: &[Self::Value]) -> Self::Funclet {
bug!("Funclets are not supported")
}
fn catch_switch(
&mut self,
_parent: Option<Self::Value>,
_unwind: Option<Self::BasicBlock>,
_handlers: &[Self::BasicBlock],
) -> Self::Value {
bug!("Funclets are not supported")
}
fn atomic_cmpxchg(
&mut self,
dst: Self::Value,
cmp: Self::Value,
src: Self::Value,
order: AtomicOrdering,
failure_order: AtomicOrdering,
_weak: bool,
) -> (Self::Value, Self::Value) {
assert_ty_eq!(self, cmp.ty, src.ty);
let ty = src.ty;
let (dst, access_ty) = self.adjust_pointer_for_typed_access(dst, ty);
let cmp = self.bitcast(cmp, access_ty);
let src = self.bitcast(src, access_ty);
self.validate_atomic(access_ty, dst.def(self));
let memory = self.constant_u32(self.span(), Scope::Device as u32);
let semantics_equal = self.ordering_to_semantics_def(order);
let semantics_unequal = self.ordering_to_semantics_def(failure_order);
let result = self
.emit()
.atomic_compare_exchange(
access_ty,
None,
dst.def(self),
memory.def(self),
semantics_equal.def(self),
semantics_unequal.def(self),
src.def(self),
cmp.def(self),
)
.unwrap()
.with_type(access_ty);
let val = self.bitcast(result, ty);
let success = self.icmp(IntPredicate::IntEQ, val, cmp);
(val, success)
}
fn atomic_rmw(
&mut self,
op: AtomicRmwBinOp,
dst: Self::Value,
src: Self::Value,
order: AtomicOrdering,
_ret_ptr: bool,
) -> Self::Value {
let ty = src.ty;
let (dst, access_ty) = self.adjust_pointer_for_typed_access(dst, ty);
let src = self.bitcast(src, access_ty);
self.validate_atomic(access_ty, dst.def(self));
let memory = self
.constant_u32(self.span(), Scope::Device as u32)
.def(self);
let semantics = self.ordering_to_semantics_def(order).def(self);
use AtomicRmwBinOp::*;
let result = match op {
AtomicXchg => self.emit().atomic_exchange(
access_ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicAdd => self.emit().atomic_i_add(
access_ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicSub => self.emit().atomic_i_sub(
access_ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicAnd => self.emit().atomic_and(
access_ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicNand => self.fatal("atomic nand is not supported"),
AtomicOr => self.emit().atomic_or(
access_ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicXor => self.emit().atomic_xor(
access_ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicMax => self.emit().atomic_s_max(
access_ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicMin => self.emit().atomic_s_min(
access_ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicUMax => self.emit().atomic_u_max(
access_ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicUMin => self.emit().atomic_u_min(
access_ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
}
.unwrap()
.with_type(access_ty);
self.bitcast(result, ty)
}
fn atomic_fence(&mut self, order: AtomicOrdering, _scope: SynchronizationScope) {
let memory = self
.constant_u32(self.span(), Scope::Device as u32)
.def(self);
let semantics = self.ordering_to_semantics_def(order).def(self);
self.emit().memory_barrier(memory, semantics).unwrap();
}
fn set_invariant_load(&mut self, _load: Self::Value) {
}
fn lifetime_start(&mut self, _ptr: Self::Value, _size: Size) {
}
fn lifetime_end(&mut self, _ptr: Self::Value, _size: Size) {
}
#[tracing::instrument(
level = "debug",
skip(self, callee_ty, _fn_attrs, fn_abi, callee, args, funclet)
)]
fn call(
&mut self,
callee_ty: Self::Type,
_fn_attrs: Option<&CodegenFnAttrs>,
fn_abi: Option<&FnAbi<'tcx, Ty<'tcx>>>,
callee: Self::Value,
args: &[Self::Value],
funclet: Option<&Self::Funclet>,
instance: Option<ty::Instance<'tcx>>,
) -> Self::Value {
if funclet.is_some() {
self.fatal("TODO: Funclets are not supported");
}
let (callee_val, result_type, argument_types) = match self.lookup_type(callee.ty) {
SpirvType::Pointer { pointee } => {
let (pointee_is_function, result_type, argument_types) = match self
.lookup_type(pointee)
{
SpirvType::Function {
return_type,
arguments,
} => (true, return_type, arguments),
_ => {
if let SpirvType::Function {
return_type,
arguments,
} = self.lookup_type(callee_ty)
{
(false, return_type, arguments)
} else {
let Some(fn_abi) = fn_abi else {
bug!(
"call expected `fn` pointer to point to function type, got `{}`",
self.debug_type(pointee)
);
};
let fn_ty = fn_abi.spirv_type(self.span(), self);
match self.lookup_type(fn_ty) {
SpirvType::Function {
return_type,
arguments,
} => (false, return_type, arguments),
_ => bug!("call expected function ABI to lower to function type"),
}
}
}
};
let callee_val = if let SpirvValueKind::FnAddr { function } = callee.kind {
if pointee_is_function {
assert_ty_eq!(self, callee_ty, pointee);
}
function
}
else {
let fn_ptr_val = callee.def(self);
self.zombie(fn_ptr_val, "indirect calls are not supported in SPIR-V");
fn_ptr_val
};
(callee_val, result_type, argument_types)
}
_ => bug!(
"call expected `fn` pointer type, got `{}`",
self.debug_type(callee.ty)
),
};
let args: SmallVec<[_; 8]> = args
.iter()
.zip_eq(argument_types)
.map(|(&arg, &expected_type)| self.bitcast(arg, expected_type))
.collect();
let args = &args[..];
let instance_def_id = instance.map(|instance| instance.def_id());
let libm_intrinsic =
instance_def_id.and_then(|def_id| self.libm_intrinsics.borrow().get(&def_id).copied());
let num_traits_intrinsics = instance_def_id
.and_then(|def_id| self.num_traits_intrinsics.borrow().get(&def_id).copied());
let buffer_load_intrinsic = instance_def_id
.is_some_and(|def_id| self.buffer_load_intrinsics.borrow().contains(&def_id));
let buffer_store_intrinsic = instance_def_id
.is_some_and(|def_id| self.buffer_store_intrinsics.borrow().contains(&def_id));
let is_panic_entry_point = instance_def_id
.is_some_and(|def_id| self.panic_entry_points.borrow().contains(&def_id));
let from_trait_impl =
instance_def_id.and_then(|def_id| self.from_trait_impls.borrow().get(&def_id).copied());
if let Some(libm_intrinsic) = libm_intrinsic.or(num_traits_intrinsics) {
let result = self.call_libm_intrinsic(libm_intrinsic, result_type, args);
if result_type != result.ty {
bug!(
"Mismatched libm result type for {:?}: expected {}, got {}",
libm_intrinsic,
self.debug_type(result_type),
self.debug_type(result.ty),
);
}
return result;
}
if is_panic_entry_point {
return DecodedFormatArgs::try_decode_and_remove_format_args(
self,
args,
instance_def_id,
)
.codegen_panic(self, result_type);
}
if buffer_load_intrinsic {
return self.codegen_buffer_load_intrinsic(fn_abi, result_type, args);
}
if buffer_store_intrinsic {
self.codegen_buffer_store_intrinsic(fn_abi, args);
let void_ty = SpirvType::Void.def(rustc_span::DUMMY_SP, self);
return SpirvValue {
kind: SpirvValueKind::IllegalTypeUsed(void_ty),
ty: void_ty,
};
}
if let Some((source_ty, target_ty)) = from_trait_impl {
if let [arg] = args
&& let Some(const_val) = self.builder.lookup_const_scalar(*arg)
{
use rustc_middle::ty::FloatTy;
let optimized_result = match (source_ty.kind(), target_ty.kind()) {
(ty::Uint(_), ty::Uint(_)) | (ty::Int(_), ty::Int(_)) => {
Some(self.constant_int(result_type, const_val))
}
(ty::Float(FloatTy::F32), ty::Float(FloatTy::F64)) => {
let float_val = f32::from_bits(const_val as u32) as f64;
Some(self.constant_float(result_type, float_val))
}
_ => None,
};
if let Some(result) = optimized_result {
return result;
}
}
}
if instance_def_id.is_some_and(|did| {
self.tcx
.is_diagnostic_item(rustc_span::sym::maybe_uninit_uninit, did)
}) {
return self.undef(result_type);
}
let args = args.iter().map(|arg| arg.def(self)).collect::<Vec<_>>();
self.emit()
.function_call(result_type, None, callee_val, args)
.unwrap()
.with_type(result_type)
}
fn tail_call(
&mut self,
_llty: Self::Type,
_fn_attrs: Option<&CodegenFnAttrs>,
_fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
_llfn: Self::Value,
_args: &[Self::Value],
_funclet: Option<&Self::Funclet>,
_instance: Option<ty::Instance<'tcx>>,
) {
todo!()
}
fn zext(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
self.intcast(val, dest_ty, false)
}
fn apply_attrs_to_cleanup_callsite(&mut self, _llret: Self::Value) {
}
fn alloca_with_ty(&mut self, _layout: TyAndLayout<'tcx>) -> Self::Value {
bug!("scalable alloca is not supported in SPIR-V backend")
}
}