use crate::builder;
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use crate::symbols::Symbols;
use crate::target::SpirvTarget;
use crate::target_feature::TargetFeature;
use rspirv::dr::{Block, Builder, Module, Operand};
use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, StorageClass, Word};
use rspirv::{binary::Assemble, binary::Disassemble};
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_middle::bug;
use rustc_span::symbol::Symbol;
use rustc_span::{Span, DUMMY_SP};
use std::assert_matches::assert_matches;
use std::cell::{RefCell, RefMut};
use std::{fs::File, io::Write, path::Path};
#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub enum SpirvValueKind {
Def(Word),
IllegalConst(Word),
IllegalTypeUsed(Word),
FnAddr {
function: Word,
},
LogicalPtrCast {
original_ptr: Word,
original_pointee_ty: Word,
zombie_target_undef: Word,
},
}
#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct SpirvValue {
pub kind: SpirvValueKind,
pub ty: Word,
}
impl SpirvValue {
pub fn const_fold_load(self, cx: &CodegenCx<'_>) -> Option<Self> {
match self.kind {
SpirvValueKind::Def(id) | SpirvValueKind::IllegalConst(id) => {
let &entry = cx.builder.id_to_const.borrow().get(&id)?;
match entry.val {
SpirvConst::PtrTo { pointee } => {
let ty = match cx.lookup_type(self.ty) {
SpirvType::Pointer { pointee } => pointee,
ty => bug!("load called on value that wasn't a pointer: {:?}", ty),
};
let kind = if entry.legal.is_ok() {
SpirvValueKind::Def(pointee)
} else {
SpirvValueKind::IllegalConst(pointee)
};
Some(SpirvValue { kind, ty })
}
_ => None,
}
}
_ => None,
}
}
pub fn def(self, bx: &builder::Builder<'_, '_>) -> Word {
self.def_with_span(bx, bx.span())
}
pub fn def_cx(self, cx: &CodegenCx<'_>) -> Word {
self.def_with_span(cx, DUMMY_SP)
}
pub fn def_with_span(self, cx: &CodegenCx<'_>, span: Span) -> Word {
match self.kind {
SpirvValueKind::Def(id) => id,
SpirvValueKind::IllegalConst(id) => {
let entry = &cx.builder.id_to_const.borrow()[&id];
let msg = match entry.legal.unwrap_err() {
IllegalConst::Shallow(cause) => {
if let (
LeafIllegalConst::CompositeContainsPtrTo,
SpirvConst::Composite(_fields),
) = (cause, &entry.val)
{
}
cause.message()
}
IllegalConst::Indirect(cause) => cause.message(),
};
cx.zombie_even_in_user_code(id, span, msg);
id
}
SpirvValueKind::IllegalTypeUsed(id) => {
cx.tcx
.sess
.struct_span_err(span, "Can't use type as a value")
.note(&format!("Type: *{}", cx.debug_type(id)))
.emit();
id
}
SpirvValueKind::FnAddr { .. } => {
if cx.is_system_crate(span) {
cx.builder
.const_to_id
.borrow()
.get(&WithType {
ty: self.ty,
val: SpirvConst::ZombieUndefForFnAddr,
})
.expect("FnAddr didn't go through proper undef registration")
.val
} else {
cx.tcx.sess.span_err(
span,
"Cannot use this function pointer for anything other than calls",
);
0
}
}
SpirvValueKind::LogicalPtrCast {
original_ptr: _,
original_pointee_ty,
zombie_target_undef,
} => {
if cx.is_system_crate(span) {
cx.zombie_with_span(
zombie_target_undef,
span,
&format!(
"Cannot cast between pointer types. From: {}. To: {}.",
cx.debug_type(original_pointee_ty),
cx.debug_type(self.ty)
),
);
} else {
cx.tcx
.sess
.struct_span_err(span, "Cannot cast between pointer types")
.note(&format!("from: *{}", cx.debug_type(original_pointee_ty)))
.note(&format!("to: {}", cx.debug_type(self.ty)))
.emit();
}
zombie_target_undef
}
}
}
}
pub trait SpirvValueExt {
fn with_type(self, ty: Word) -> SpirvValue;
}
impl SpirvValueExt for Word {
fn with_type(self, ty: Word) -> SpirvValue {
SpirvValue {
kind: SpirvValueKind::Def(self),
ty,
}
}
}
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub enum SpirvConst<'tcx> {
U32(u32),
U64(u64),
F32(u32),
F64(u64),
Bool(bool),
Null,
Undef,
ZombieUndefForFnAddr,
Composite(&'tcx [Word]),
PtrTo {
pointee: Word,
},
}
impl SpirvConst<'_> {
fn tcx_arena_alloc_slices<'tcx>(self, cx: &CodegenCx<'tcx>) -> SpirvConst<'tcx> {
fn arena_alloc_slice<'tcx, T: Copy>(cx: &CodegenCx<'tcx>, xs: &[T]) -> &'tcx [T] {
if xs.is_empty() {
&[]
} else {
cx.tcx.arena.dropless.alloc_slice(xs)
}
}
match self {
SpirvConst::U32(v) => SpirvConst::U32(v),
SpirvConst::U64(v) => SpirvConst::U64(v),
SpirvConst::F32(v) => SpirvConst::F32(v),
SpirvConst::F64(v) => SpirvConst::F64(v),
SpirvConst::Bool(v) => SpirvConst::Bool(v),
SpirvConst::Null => SpirvConst::Null,
SpirvConst::Undef => SpirvConst::Undef,
SpirvConst::ZombieUndefForFnAddr => SpirvConst::ZombieUndefForFnAddr,
SpirvConst::PtrTo { pointee } => SpirvConst::PtrTo { pointee },
SpirvConst::Composite(fields) => SpirvConst::Composite(arena_alloc_slice(cx, fields)),
}
}
}
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
struct WithType<V> {
ty: Word,
val: V,
}
#[derive(Copy, Clone, Debug)]
enum LeafIllegalConst {
CompositeContainsPtrTo,
}
impl LeafIllegalConst {
fn message(&self) -> &'static str {
match *self {
Self::CompositeContainsPtrTo => {
"constant arrays/structs cannot contain pointers to other constants"
}
}
}
}
#[derive(Copy, Clone, Debug)]
enum IllegalConst {
Shallow(LeafIllegalConst),
Indirect(LeafIllegalConst),
}
#[derive(Copy, Clone, Debug)]
struct WithConstLegality<V> {
val: V,
legal: Result<(), IllegalConst>,
}
#[derive(Debug, Default, Copy, Clone)]
#[must_use = "BuilderCursor should usually be assigned to the Builder.cursor field"]
pub struct BuilderCursor {
pub function: Option<usize>,
pub block: Option<usize>,
}
pub struct BuilderSpirv<'tcx> {
builder: RefCell<Builder>,
const_to_id: RefCell<FxHashMap<WithType<SpirvConst<'tcx>>, WithConstLegality<Word>>>,
id_to_const: RefCell<FxHashMap<Word, WithConstLegality<SpirvConst<'tcx>>>>,
string_cache: RefCell<FxHashMap<String, Word>>,
enabled_capabilities: FxHashSet<Capability>,
enabled_extensions: FxHashSet<Symbol>,
}
impl<'tcx> BuilderSpirv<'tcx> {
pub fn new(sym: &Symbols, target: &SpirvTarget, features: &[TargetFeature]) -> Self {
let version = target.spirv_version();
let memory_model = target.memory_model();
let mut builder = Builder::new();
builder.set_version(version.0, version.1);
builder.module_mut().header.as_mut().unwrap().generator = 0x001B_0000;
let mut enabled_capabilities = FxHashSet::default();
let mut enabled_extensions = FxHashSet::default();
fn add_cap(
builder: &mut Builder,
enabled_capabilities: &mut FxHashSet<Capability>,
cap: Capability,
) {
builder.capability(cap);
enabled_capabilities.insert(cap);
}
fn add_ext(builder: &mut Builder, enabled_extensions: &mut FxHashSet<Symbol>, ext: Symbol) {
builder.extension(ext.as_str());
enabled_extensions.insert(ext);
}
for feature in features {
match *feature {
TargetFeature::Capability(cap) => {
add_cap(&mut builder, &mut enabled_capabilities, cap);
}
TargetFeature::Extension(ext) => {
add_ext(&mut builder, &mut enabled_extensions, ext);
}
}
}
add_cap(&mut builder, &mut enabled_capabilities, Capability::Shader);
if memory_model == MemoryModel::Vulkan {
if version < (1, 5) {
add_ext(
&mut builder,
&mut enabled_extensions,
sym.spv_khr_vulkan_memory_model,
);
}
add_cap(
&mut builder,
&mut enabled_capabilities,
Capability::VulkanMemoryModel,
);
}
add_cap(&mut builder, &mut enabled_capabilities, Capability::Linkage);
builder.memory_model(AddressingModel::Logical, memory_model);
Self {
builder: RefCell::new(builder),
const_to_id: Default::default(),
id_to_const: Default::default(),
string_cache: Default::default(),
enabled_capabilities,
enabled_extensions,
}
}
pub fn finalize(self) -> Module {
self.builder.into_inner().module()
}
pub fn dump_module_str(&self) -> String {
self.builder.borrow().module_ref().disassemble()
}
pub fn dump_module(&self, path: impl AsRef<Path>) {
let module = self.builder.borrow().module_ref().assemble();
File::create(path)
.unwrap()
.write_all(spirv_tools::binary::from_binary(&module))
.unwrap();
}
pub fn builder(&self, cursor: BuilderCursor) -> RefMut<'_, Builder> {
let mut builder = self.builder.borrow_mut();
if builder.selected_function() != cursor.function {
builder.select_function(cursor.function).unwrap();
}
if cursor.function.is_some() && builder.selected_block() != cursor.block {
builder.select_block(cursor.block).unwrap();
}
builder
}
pub fn has_capability(&self, capability: Capability) -> bool {
self.enabled_capabilities.contains(&capability)
}
pub fn has_extension(&self, extension: Symbol) -> bool {
self.enabled_extensions.contains(&extension)
}
pub fn select_function_by_id(&self, id: Word) -> BuilderCursor {
let mut builder = self.builder.borrow_mut();
for (index, func) in builder.module_ref().functions.iter().enumerate() {
if func.def.as_ref().and_then(|i| i.result_id) == Some(id) {
builder.select_function(Some(index)).unwrap();
return BuilderCursor {
function: Some(index),
block: None,
};
}
}
bug!("Function not found: {}", id);
}
pub(crate) fn def_constant_cx(
&self,
ty: Word,
val: SpirvConst<'_>,
cx: &CodegenCx<'tcx>,
) -> SpirvValue {
let val_with_type = WithType { ty, val };
let mut builder = self.builder(BuilderCursor::default());
if let Some(entry) = self.const_to_id.borrow().get(&val_with_type) {
let kind = if entry.legal.is_ok() {
SpirvValueKind::Def(entry.val)
} else {
SpirvValueKind::IllegalConst(entry.val)
};
return SpirvValue { kind, ty };
}
let val = val_with_type.val;
let id = match val {
SpirvConst::U32(v) => builder.constant_u32(ty, v),
SpirvConst::U64(v) => builder.constant_u64(ty, v),
SpirvConst::F32(v) => builder.constant_f32(ty, f32::from_bits(v)),
SpirvConst::F64(v) => builder.constant_f64(ty, f64::from_bits(v)),
SpirvConst::Bool(v) => {
if v {
builder.constant_true(ty)
} else {
builder.constant_false(ty)
}
}
SpirvConst::Null => builder.constant_null(ty),
SpirvConst::Undef | SpirvConst::ZombieUndefForFnAddr => builder.undef(ty, None),
SpirvConst::Composite(v) => builder.constant_composite(ty, v.iter().copied()),
SpirvConst::PtrTo { pointee } => {
builder.variable(ty, None, StorageClass::Private, Some(pointee))
}
};
#[allow(clippy::match_same_arms)]
let legal = match val {
SpirvConst::U32(_)
| SpirvConst::U64(_)
| SpirvConst::F32(_)
| SpirvConst::F64(_)
| SpirvConst::Bool(_) => Ok(()),
SpirvConst::Null => {
Ok(())
}
SpirvConst::Undef => {
Ok(())
}
SpirvConst::ZombieUndefForFnAddr => {
Ok(())
}
SpirvConst::Composite(v) => v.iter().fold(Ok(()), |composite_legal, field| {
let field_entry = &self.id_to_const.borrow()[field];
let field_legal_in_composite = field_entry.legal.and(
match field_entry.val {
SpirvConst::PtrTo { .. } => Err(IllegalConst::Shallow(
LeafIllegalConst::CompositeContainsPtrTo,
)),
_ => Ok(()),
},
);
match (composite_legal, field_legal_in_composite) {
(Ok(()), Ok(())) => Ok(()),
(Err(illegal), Ok(())) | (Ok(()), Err(illegal)) => Err(illegal),
(Err(illegal @ IllegalConst::Indirect(_)), Err(_))
| (Err(_), Err(illegal @ IllegalConst::Indirect(_)))
| (Err(illegal @ IllegalConst::Shallow(_)), Err(IllegalConst::Shallow(_))) => {
Err(illegal)
}
}
}),
SpirvConst::PtrTo { pointee } => match self.id_to_const.borrow()[&pointee].legal {
Ok(()) => Ok(()),
Err(IllegalConst::Shallow(cause) | IllegalConst::Indirect(cause)) => {
Err(IllegalConst::Indirect(cause))
}
},
};
let val = val.tcx_arena_alloc_slices(cx);
assert_matches!(
self.const_to_id
.borrow_mut()
.insert(WithType { ty, val }, WithConstLegality { val: id, legal }),
None
);
assert_matches!(
self.id_to_const
.borrow_mut()
.insert(id, WithConstLegality { val, legal }),
None
);
let kind = if legal.is_ok() {
SpirvValueKind::Def(id)
} else {
SpirvValueKind::IllegalConst(id)
};
SpirvValue { kind, ty }
}
pub fn lookup_const_by_id(&self, id: Word) -> Option<SpirvConst<'tcx>> {
Some(self.id_to_const.borrow().get(&id)?.val)
}
pub fn lookup_const(&self, def: SpirvValue) -> Option<SpirvConst<'tcx>> {
match def.kind {
SpirvValueKind::Def(id) | SpirvValueKind::IllegalConst(id) => {
self.lookup_const_by_id(id)
}
_ => None,
}
}
pub fn lookup_const_u64(&self, def: SpirvValue) -> Option<u64> {
match self.lookup_const(def)? {
SpirvConst::U32(v) => Some(v as u64),
SpirvConst::U64(v) => Some(v),
_ => None,
}
}
pub fn def_string(&self, s: String) -> Word {
use std::collections::hash_map::Entry;
match self.string_cache.borrow_mut().entry(s) {
Entry::Occupied(entry) => *entry.get(),
Entry::Vacant(entry) => {
let key = entry.key().clone();
*entry.insert(self.builder(Default::default()).string(key))
}
}
}
pub fn set_global_initializer(&self, global: Word, initializer: Word) {
let mut builder = self.builder.borrow_mut();
let module = builder.module_mut();
let index = module
.types_global_values
.iter()
.enumerate()
.find_map(|(index, inst)| {
if inst.result_id == Some(global) {
Some(index)
} else {
None
}
})
.expect("set_global_initializer global not found");
let mut inst = module.types_global_values.remove(index);
assert_eq!(inst.class.opcode, Op::Variable);
assert_eq!(
inst.operands.len(),
1,
"global already has initializer defined: {global}"
);
inst.operands.push(Operand::IdRef(initializer));
module.types_global_values.push(inst);
}
pub fn select_block_by_id(&self, id: Word) -> BuilderCursor {
fn block_matches(block: &Block, id: Word) -> bool {
block.label.as_ref().and_then(|b| b.result_id) == Some(id)
}
let mut builder = self.builder.borrow_mut();
let module = builder.module_ref();
if let Some(selected_function) = builder.selected_function() {
if let Some(selected_block) = builder.selected_block() {
let block = &module.functions[selected_function].blocks[selected_block];
if block_matches(block, id) {
return BuilderCursor {
function: Some(selected_function),
block: Some(selected_block),
};
}
}
for (index, block) in module.functions[selected_function]
.blocks
.iter()
.enumerate()
{
if block_matches(block, id) {
builder.select_block(Some(index)).unwrap();
return BuilderCursor {
function: Some(selected_function),
block: Some(index),
};
}
}
}
for (function_index, function) in module.functions.iter().enumerate() {
for (block_index, block) in function.blocks.iter().enumerate() {
if block_matches(block, id) {
builder.select_function(Some(function_index)).unwrap();
builder.select_block(Some(block_index)).unwrap();
return BuilderCursor {
function: Some(function_index),
block: Some(block_index),
};
}
}
}
bug!("Block not found: {}", id);
}
}