use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
use crate::builder;
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use crate::symbols::Symbols;
use crate::target::{SpirvTarget, SpirvTargetVariant, SpirvVersion};
use crate::target_feature::TargetFeature;
use rspirv::dr::{Builder, Instruction, Module, Operand};
use rspirv::spirv::{
AddressingModel, Capability, MemoryModel, Op, SourceLanguage, StorageClass, Word,
};
use rspirv::{binary::Assemble, binary::Disassemble};
use rustc_abi::Size;
use rustc_arena::DroplessArena;
use rustc_codegen_ssa::traits::ConstCodegenMethods as _;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_middle::bug;
use rustc_middle::mir::interpret::ConstAllocation;
use rustc_middle::ty::TyCtxt;
use rustc_span::source_map::SourceMap;
use rustc_span::symbol::Symbol;
use rustc_span::{DUMMY_SP, SourceFile, Span};
use std::assert_matches;
use std::cell::{RefCell, RefMut};
use std::hash::{Hash, Hasher};
use std::iter;
use std::ops::Range;
use std::sync::Arc;
use std::{fs::File, io::Write, path::Path};
#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub enum SpirvValueKind {
Def(Word),
IllegalConst(Word),
IllegalTypeUsed(Word),
FnAddr {
function: Word,
},
LogicalPtrCast {
original_ptr: Word,
original_ptr_ty: Word,
bitcast_result_id: Word,
},
}
#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct SpirvValue {
pub kind: SpirvValueKind,
pub ty: Word,
}
impl SpirvValue {
pub fn strip_ptrcasts(self) -> Self {
match self.kind {
SpirvValueKind::LogicalPtrCast {
original_ptr,
original_ptr_ty,
bitcast_result_id: _,
} => original_ptr.with_type(original_ptr_ty),
_ => self,
}
}
pub fn const_fold_load(self, cx: &CodegenCx<'_>) -> Option<Self> {
match self.kind {
SpirvValueKind::Def(id) | SpirvValueKind::IllegalConst(id) => {
let &entry = cx.builder.id_to_const.borrow().get(&id)?;
match entry.val {
SpirvConst::PtrTo { pointee } => {
let ty = match cx.lookup_type(self.ty) {
SpirvType::Pointer { pointee } => pointee,
ty => bug!("load called on value that wasn't a pointer: {:?}", ty),
};
let kind = if entry.legal.is_ok() {
SpirvValueKind::Def(pointee)
} else {
SpirvValueKind::IllegalConst(pointee)
};
Some(SpirvValue { kind, ty })
}
_ => None,
}
}
_ => None,
}
}
pub fn def(self, bx: &builder::Builder<'_, '_>) -> Word {
self.def_with_span(bx, bx.span())
}
pub fn def_cx(self, cx: &CodegenCx<'_>) -> Word {
self.def_with_span(cx, DUMMY_SP)
}
pub fn def_with_span(self, cx: &CodegenCx<'_>, span: Span) -> Word {
match self.kind {
SpirvValueKind::Def(id) => id,
SpirvValueKind::IllegalConst(id) => {
let entry = &cx.builder.id_to_const.borrow()[&id];
let msg = match entry.legal.unwrap_err() {
IllegalConst::Shallow(cause) => {
if let (
LeafIllegalConst::CompositeContainsPtrTo,
SpirvConst::Composite(_fields),
) = (cause, &entry.val)
{
}
cause.message()
}
IllegalConst::Indirect(cause) => cause.message(),
};
cx.zombie_with_span(id, span, msg);
id
}
SpirvValueKind::IllegalTypeUsed(id) => {
cx.tcx
.dcx()
.struct_span_err(span, "Can't use type as a value")
.with_note(format!("Type: *{}", cx.debug_type(id)))
.emit();
id
}
SpirvValueKind::FnAddr { .. } => {
cx.builder
.const_to_id
.borrow()
.get(&WithType {
ty: self.ty,
val: SpirvConst::ZombieUndefForFnAddr,
})
.expect("FnAddr didn't go through proper undef registration")
.val
}
SpirvValueKind::LogicalPtrCast {
original_ptr: _,
original_ptr_ty,
bitcast_result_id,
} => {
cx.zombie_with_span(
bitcast_result_id,
span,
&format!(
"cannot cast between pointer types\
\nfrom `{}`\
\n to `{}`",
cx.debug_type(original_ptr_ty),
cx.debug_type(self.ty)
),
);
bitcast_result_id
}
}
}
}
pub trait SpirvValueExt {
fn with_type(self, ty: Word) -> SpirvValue;
}
impl SpirvValueExt for Word {
fn with_type(self, ty: Word) -> SpirvValue {
SpirvValue {
kind: SpirvValueKind::Def(self),
ty,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum SpirvConst<'a, 'tcx> {
Scalar(u128),
Null,
Undef,
ZombieUndefForFnAddr,
Composite(&'a [Word]),
PtrTo {
pointee: Word,
},
ConstDataFromAlloc(ConstAllocation<'tcx>),
}
impl<'tcx> SpirvConst<'_, 'tcx> {
fn tcx_arena_alloc_slices(self, cx: &CodegenCx<'tcx>) -> SpirvConst<'tcx, 'tcx> {
fn arena_alloc_slice<'tcx, T: Copy>(cx: &CodegenCx<'tcx>, xs: &[T]) -> &'tcx [T] {
if xs.is_empty() {
&[]
} else {
cx.tcx.arena.dropless.alloc_slice(xs)
}
}
match self {
SpirvConst::Scalar(v) => SpirvConst::Scalar(v),
SpirvConst::Null => SpirvConst::Null,
SpirvConst::Undef => SpirvConst::Undef,
SpirvConst::ZombieUndefForFnAddr => SpirvConst::ZombieUndefForFnAddr,
SpirvConst::PtrTo { pointee } => SpirvConst::PtrTo { pointee },
SpirvConst::Composite(fields) => SpirvConst::Composite(arena_alloc_slice(cx, fields)),
SpirvConst::ConstDataFromAlloc(alloc) => SpirvConst::ConstDataFromAlloc(alloc),
}
}
}
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
struct WithType<V> {
ty: Word,
val: V,
}
#[derive(Copy, Clone, Debug)]
enum LeafIllegalConst {
CompositeContainsPtrTo,
UntypedConstDataFromAlloc,
}
impl LeafIllegalConst {
fn message(&self) -> &'static str {
match *self {
Self::CompositeContainsPtrTo => {
"constant arrays/structs cannot contain pointers to other constants"
}
Self::UntypedConstDataFromAlloc => {
"`const_data_from_alloc` result wasn't passed through `static_addr_of`, \
then `const_bitcast` (which would've given it a type)"
}
}
}
}
#[derive(Copy, Clone, Debug)]
enum IllegalConst {
Shallow(LeafIllegalConst),
Indirect(LeafIllegalConst),
}
#[derive(Copy, Clone, Debug)]
struct WithConstLegality<V> {
val: V,
legal: Result<(), IllegalConst>,
}
struct DebugFileKey(Arc<SourceFile>);
impl PartialEq for DebugFileKey {
fn eq(&self, other: &Self) -> bool {
let (Self(self_sf), Self(other_sf)) = (self, other);
Arc::ptr_eq(self_sf, other_sf)
}
}
impl Eq for DebugFileKey {}
impl Hash for DebugFileKey {
fn hash<H: Hasher>(&self, state: &mut H) {
let Self(sf) = self;
sf.stable_id.hash(state);
sf.src_hash.hash(state);
}
}
#[derive(Copy, Clone)]
pub struct DebugFileSpirv<'tcx> {
pub file_name: &'tcx str,
pub file_name_op_string_id: Word,
}
#[derive(Copy, Clone, Debug)]
pub struct SpirvFunctionCursor {
pub ty: Word,
pub id: Word,
pub index_in_builder: usize,
}
#[derive(Copy, Clone, Debug)]
pub struct SpirvBlockCursor {
pub parent_fn: SpirvFunctionCursor,
pub id: Word,
pub index_in_builder: usize,
}
#[derive(Debug, Default, Copy, Clone)]
#[must_use = "BuilderCursor should usually be assigned to the Builder.cursor field"]
struct BuilderCursor {
fn_id_and_idx: Option<(Word, usize)>,
block_id_and_idx: Option<(Word, usize)>,
}
pub struct BuilderSpirv<'tcx> {
source_map: &'tcx SourceMap,
dropless_arena: &'tcx DroplessArena,
builder: RefCell<Builder>,
const_to_id: RefCell<FxHashMap<WithType<SpirvConst<'tcx, 'tcx>>, WithConstLegality<Word>>>,
id_to_const: RefCell<FxHashMap<Word, WithConstLegality<SpirvConst<'tcx, 'tcx>>>>,
debug_file_cache: RefCell<FxHashMap<DebugFileKey, DebugFileSpirv<'tcx>>>,
enabled_capabilities: FxHashSet<Capability>,
}
impl<'tcx> BuilderSpirv<'tcx> {
pub fn new(
tcx: TyCtxt<'tcx>,
sym: &Symbols,
target: &SpirvTarget,
features: &[TargetFeature],
) -> Self {
let version = target.spirv_version();
let memory_model = target.memory_model();
let mut builder = Builder::new();
builder.set_version(version.0, version.1);
builder.module_mut().header.as_mut().unwrap().generator = 0x001B_0000;
let mut enabled_capabilities = FxHashSet::default();
fn add_cap(
builder: &mut Builder,
enabled_capabilities: &mut FxHashSet<Capability>,
cap: Capability,
) {
builder.capability(cap);
enabled_capabilities.insert(cap);
}
fn add_ext(builder: &mut Builder, ext: Symbol) {
builder.extension(ext.as_str());
}
for feature in features {
match *feature {
TargetFeature::Capability(cap) => {
add_cap(&mut builder, &mut enabled_capabilities, cap);
}
TargetFeature::Extension(ext) => {
add_ext(&mut builder, ext);
}
}
}
add_cap(&mut builder, &mut enabled_capabilities, Capability::Shader);
if memory_model == MemoryModel::Vulkan {
if version < SpirvVersion::V1_5 {
add_ext(&mut builder, sym.spv_khr_vulkan_memory_model);
}
add_cap(
&mut builder,
&mut enabled_capabilities,
Capability::VulkanMemoryModel,
);
}
add_cap(&mut builder, &mut enabled_capabilities, Capability::Linkage);
builder.memory_model(AddressingModel::Logical, memory_model);
Self {
source_map: tcx.sess.source_map(),
dropless_arena: &tcx.arena.dropless,
builder: RefCell::new(builder),
const_to_id: Default::default(),
id_to_const: Default::default(),
debug_file_cache: Default::default(),
enabled_capabilities,
}
}
pub fn finalize(self) -> Module {
self.builder.into_inner().module()
}
pub fn dump_module_str(&self) -> String {
self.builder.borrow().module_ref().disassemble()
}
pub fn dump_module(&self, path: impl AsRef<Path>) {
let module = self.builder.borrow().module_ref().assemble();
File::create(path)
.unwrap()
.write_all(spirv_tools::binary::from_binary(&module))
.unwrap();
}
pub fn has_capability(&self, capability: Capability) -> bool {
self.enabled_capabilities.contains(&capability)
}
fn builder(&self, cursor: BuilderCursor) -> RefMut<'_, Builder> {
let mut builder = self.builder.borrow_mut();
let [maybe_fn_idx, maybe_block_idx] = [cursor.fn_id_and_idx, cursor.block_id_and_idx]
.map(|id_and_idx| id_and_idx.map(|(_, idx)| idx));
let fn_changed = builder.selected_function() != maybe_fn_idx;
if fn_changed {
builder.select_function(maybe_fn_idx).unwrap();
}
if let Some((fn_id, fn_idx)) = cursor.fn_id_and_idx
&& (fn_changed || builder.selected_block() != maybe_block_idx)
{
builder.select_block(maybe_block_idx).unwrap();
let function = &builder.module_ref().functions[fn_idx];
if fn_changed {
assert_eq!(function.def_id(), Some(fn_id));
}
if let Some((block_id, block_idx)) = cursor.block_id_and_idx {
assert_eq!(function.blocks[block_idx].label_id(), Some(block_id));
}
}
builder
}
pub fn global_builder(&self) -> RefMut<'_, Builder> {
self.builder(BuilderCursor::default())
}
pub fn builder_for_fn(&self, func: SpirvFunctionCursor) -> RefMut<'_, Builder> {
self.builder(BuilderCursor {
fn_id_and_idx: Some((func.id, func.index_in_builder)),
block_id_and_idx: None,
})
}
pub fn builder_for_block(&self, block: SpirvBlockCursor) -> RefMut<'_, Builder> {
self.builder(BuilderCursor {
fn_id_and_idx: Some((block.parent_fn.id, block.parent_fn.index_in_builder)),
block_id_and_idx: Some((block.id, block.index_in_builder)),
})
}
pub(crate) fn def_constant_cx(
&self,
ty: Word,
val: SpirvConst<'_, 'tcx>,
cx: &CodegenCx<'tcx>,
) -> SpirvValue {
let scalar_ty = match val {
SpirvConst::Scalar(_) => Some(cx.lookup_type(ty)),
_ => None,
};
let val = match (val, scalar_ty) {
(SpirvConst::Scalar(val), Some(SpirvType::Integer(bits, signed))) => {
let size = Size::from_bits(bits);
SpirvConst::Scalar(if signed {
size.sign_extend(val) as u128
} else {
size.truncate(val)
})
}
_ => val,
};
let val_with_type = WithType { ty, val };
if let Some(entry) = self.const_to_id.borrow().get(&val_with_type) {
let kind = if entry.legal.is_ok() {
SpirvValueKind::Def(entry.val)
} else {
SpirvValueKind::IllegalConst(entry.val)
};
return SpirvValue { kind, ty };
}
let val = val_with_type.val;
let const_op = |builder: &mut Builder, op, lhs, maybe_rhs: Option<_>| {
let spirt_has_const_op = false;
if !spirt_has_const_op {
let zombie = builder.undef(ty, None);
cx.zombie_with_span(
zombie,
DUMMY_SP,
&format!("unsupported constant of type `{}`", cx.debug_type(ty)),
);
return zombie;
}
let id = builder.id();
builder
.module_mut()
.types_global_values
.push(Instruction::new(
Op::SpecConstantOp,
Some(ty),
Some(id),
[
Operand::LiteralSpecConstantOpInteger(op),
Operand::IdRef(lhs),
]
.into_iter()
.chain(maybe_rhs.map(Operand::IdRef))
.collect(),
));
id
};
let mut builder = self.global_builder();
let id = match val {
SpirvConst::Scalar(v) => match scalar_ty.unwrap() {
SpirvType::Integer(..=32, _) | SpirvType::Float(..=32) => {
builder.constant_bit32(ty, v as u32)
}
SpirvType::Integer(64, _) | SpirvType::Float(64) => {
builder.constant_bit64(ty, v as u64)
}
SpirvType::Integer(128, false) => {
drop(builder);
let const_64_u32_id = cx.const_u32(64).def_cx(cx);
let [lo_id, hi_id] =
[v as u64, (v >> 64) as u64].map(|half| cx.const_u64(half).def_cx(cx));
builder = self.global_builder();
let mut const_op =
|op, lhs, maybe_rhs| const_op(&mut builder, op, lhs, maybe_rhs);
let [lo_u128_id, hi_shifted_u128_id] =
[(lo_id, None), (hi_id, Some(const_64_u32_id))].map(
|(half_u64_id, shift)| {
let mut half_u128_id = const_op(Op::UConvert, half_u64_id, None);
if let Some(shift_amount_id) = shift {
half_u128_id = const_op(
Op::ShiftLeftLogical,
half_u128_id,
Some(shift_amount_id),
);
}
half_u128_id
},
);
const_op(Op::BitwiseOr, lo_u128_id, Some(hi_shifted_u128_id))
}
SpirvType::Integer(128, true) | SpirvType::Float(128) => {
drop(builder);
let v_u128_id = cx.const_u128(v).def_cx(cx);
builder = self.global_builder();
const_op(&mut builder, Op::Bitcast, v_u128_id, None)
}
SpirvType::Bool => match v {
0 => builder.constant_false(ty),
1 => builder.constant_true(ty),
_ => cx
.tcx
.dcx()
.fatal(format!("invalid constant value for bool: {v}")),
},
other => cx.tcx.dcx().fatal(format!(
"SpirvConst::Scalar does not support type {}",
other.debug(ty, cx)
)),
},
SpirvConst::Null => builder.constant_null(ty),
SpirvConst::Undef
| SpirvConst::ZombieUndefForFnAddr
| SpirvConst::ConstDataFromAlloc(_) => builder.undef(ty, None),
SpirvConst::Composite(v) => builder.constant_composite(ty, v.iter().copied()),
SpirvConst::PtrTo { pointee } => {
builder.variable(ty, None, StorageClass::Private, Some(pointee))
}
};
#[allow(clippy::match_same_arms)]
let legal = match val {
SpirvConst::Scalar(_) => Ok(()),
SpirvConst::Null => {
Ok(())
}
SpirvConst::Undef => {
Ok(())
}
SpirvConst::ZombieUndefForFnAddr => {
Ok(())
}
SpirvConst::Composite(v) => v
.iter()
.map(|field| {
let field_entry = &self.id_to_const.borrow()[field];
field_entry.legal.and(
match field_entry.val {
SpirvConst::PtrTo { .. } => Err(IllegalConst::Shallow(
LeafIllegalConst::CompositeContainsPtrTo,
)),
_ => Ok(()),
},
)
})
.reduce(|a, b| {
match (a, b) {
(Ok(()), Ok(())) => Ok(()),
(Err(illegal), Ok(())) | (Ok(()), Err(illegal)) => Err(illegal),
(Err(illegal @ IllegalConst::Indirect(_)), Err(_))
| (Err(_), Err(illegal @ IllegalConst::Indirect(_)))
| (
Err(illegal @ IllegalConst::Shallow(_)),
Err(IllegalConst::Shallow(_)),
) => Err(illegal),
}
})
.unwrap_or(Ok(())),
SpirvConst::PtrTo { pointee } => match self.id_to_const.borrow()[&pointee].legal {
Ok(()) => Ok(()),
Err(IllegalConst::Shallow(cause) | IllegalConst::Indirect(cause)) => {
Err(IllegalConst::Indirect(cause))
}
},
SpirvConst::ConstDataFromAlloc(_) => Err(IllegalConst::Shallow(
LeafIllegalConst::UntypedConstDataFromAlloc,
)),
};
let val = val.tcx_arena_alloc_slices(cx);
assert_matches!(
self.const_to_id
.borrow_mut()
.insert(WithType { ty, val }, WithConstLegality { val: id, legal }),
None
);
assert_matches!(
self.id_to_const
.borrow_mut()
.insert(id, WithConstLegality { val, legal }),
None
);
let kind = if legal.is_ok() {
SpirvValueKind::Def(id)
} else {
SpirvValueKind::IllegalConst(id)
};
SpirvValue { kind, ty }
}
pub fn lookup_const_by_id(&self, id: Word) -> Option<SpirvConst<'tcx, 'tcx>> {
Some(self.id_to_const.borrow().get(&id)?.val)
}
pub fn lookup_const(&self, def: SpirvValue) -> Option<SpirvConst<'tcx, 'tcx>> {
match def.kind {
SpirvValueKind::Def(id) | SpirvValueKind::IllegalConst(id) => {
self.lookup_const_by_id(id)
}
_ => None,
}
}
pub fn lookup_const_scalar(&self, def: SpirvValue) -> Option<u128> {
match self.lookup_const(def)? {
SpirvConst::Scalar(v) => Some(v),
_ => None,
}
}
pub fn file_line_col_range_for_debuginfo(
&self,
span: Span,
) -> (DebugFileSpirv<'tcx>, Range<(u32, u32)>) {
let span = span.ctxt().outer_expn().expansion_cause().unwrap_or(span);
let (lo, hi) = (span.lo(), span.hi());
let lo_loc = self.source_map.lookup_char_pos(lo);
let lo_line_col = (lo_loc.line as u32, lo_loc.col_display as u32);
let hi_line_col = if lo <= hi {
let hi_loc = self.source_map.lookup_char_pos(hi);
if lo_loc.file.start_pos == hi_loc.file.start_pos {
(hi_loc.line as u32, hi_loc.col_display as u32)
} else {
lo_line_col
}
} else {
lo_line_col
};
(self.def_debug_file(lo_loc.file), lo_line_col..hi_line_col)
}
fn def_debug_file(&self, sf: Arc<SourceFile>) -> DebugFileSpirv<'tcx> {
*self
.debug_file_cache
.borrow_mut()
.entry(DebugFileKey(sf))
.or_insert_with_key(|DebugFileKey(sf)| {
let mut builder = self.global_builder();
let file_name = sf.name.prefer_remapped_unconditionally().to_string_lossy();
let file_name = self.dropless_arena.alloc_str(&file_name);
let file_name_op_string_id = builder.string(file_name.to_owned());
let file_contents = self
.source_map
.span_to_snippet(Span::with_root_ctxt(sf.start_pos, sf.end_position()))
.ok();
let op_source_and_continued_chunks = file_contents.as_ref().map(|contents| {
const MAX_OP_SOURCE_CONT_CONTENTS_LEN: usize = (0xffff - 1) * 4 - 1;
const MAX_OP_SOURCE_CONTENTS_LEN: usize =
MAX_OP_SOURCE_CONT_CONTENTS_LEN - 3 * 4;
let (op_source_str, mut all_op_source_continued_str) =
contents.split_at(contents.len().min(MAX_OP_SOURCE_CONTENTS_LEN));
let all_op_source_continued_str_chunks = iter::from_fn(move || {
let contents_rest = &mut all_op_source_continued_str;
if contents_rest.is_empty() {
return None;
}
let (cont_chunk, rest) = contents_rest
.split_at(contents_rest.len().min(MAX_OP_SOURCE_CONT_CONTENTS_LEN));
*contents_rest = rest;
Some(cont_chunk)
});
(op_source_str, all_op_source_continued_str_chunks)
});
if let Some((op_source_str, all_op_source_continued_str_chunks)) =
op_source_and_continued_chunks
{
builder.source(
SourceLanguage::Unknown,
0,
Some(file_name_op_string_id),
Some(op_source_str),
);
for cont_chunk in all_op_source_continued_str_chunks {
builder.source_continued(cont_chunk);
}
}
DebugFileSpirv {
file_name,
file_name_op_string_id,
}
})
}
pub fn set_global_initializer(&self, global: Word, initializer: Word) {
let mut builder = self.builder.borrow_mut();
let module = builder.module_mut();
let index = module
.types_global_values
.iter()
.enumerate()
.find_map(|(index, inst)| {
if inst.result_id == Some(global) {
Some(index)
} else {
None
}
})
.expect("set_global_initializer global not found");
let mut inst = module.types_global_values.remove(index);
assert_eq!(inst.class.opcode, Op::Variable);
assert_eq!(
inst.operands.len(),
1,
"global already has initializer defined: {global}"
);
inst.operands.push(Operand::IdRef(initializer));
module.types_global_values.push(inst);
}
}