use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
use super::{CodegenCx, FmtArgsCtor};
use crate::abi::ConvSpirvType;
use crate::attr::AggregatedSpirvAttributes;
use crate::builder_spirv::{SpirvConst, SpirvFunctionCursor, SpirvValue, SpirvValueExt};
use crate::custom_decorations::{CustomDecoration, SrcLocDecoration};
use crate::spirv_type::SpirvType;
use itertools::Itertools;
use rspirv::spirv::{FunctionControl, LinkageType, StorageClass, Word};
use rustc_codegen_ssa::traits::{PreDefineCodegenMethods, StaticCodegenMethods};
use rustc_hir::attrs::{InlineAttr, Linkage};
use rustc_middle::bug;
use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs};
use rustc_middle::mir::interpret::ConstAllocation;
use rustc_middle::mono::{MonoItem, Visibility};
use rustc_middle::ty::layout::{FnAbiOf, LayoutOf};
use rustc_middle::ty::{self, Instance, TypeVisitableExt, TypingEnv};
use rustc_span::Span;
use rustc_span::def_id::DefId;
fn attrs_to_spirv(attrs: &CodegenFnAttrs) -> FunctionControl {
let mut control = FunctionControl::NONE;
match attrs.inline {
InlineAttr::None => (),
InlineAttr::Hint | InlineAttr::Always | InlineAttr::Force { .. } => {
control.insert(FunctionControl::INLINE);
}
InlineAttr::Never => control.insert(FunctionControl::DONT_INLINE),
}
if attrs.flags.contains(CodegenFnAttrFlags::FFI_PURE) {
control.insert(FunctionControl::PURE);
}
if attrs.flags.contains(CodegenFnAttrFlags::FFI_CONST) {
control.insert(FunctionControl::CONST);
}
control
}
impl<'tcx> CodegenCx<'tcx> {
pub(crate) fn static_addr_of_constant(&self, cv: SpirvValue) -> SpirvValue {
self.def_constant(
self.type_ptr_to(cv.ty),
SpirvConst::PtrTo {
pointee: cv.def_cx(self),
},
)
}
pub fn get_fn_ext(&self, instance: Instance<'tcx>) -> SpirvFunctionCursor {
assert!(!instance.args.has_infer());
assert!(!instance.args.has_escaping_bound_vars());
if let Some(&func) = self.fn_instances.borrow().get(&instance) {
return func;
}
let linkage = Some(LinkageType::Import);
let llfn = self.declare_fn_ext(instance, linkage);
self.fn_instances.borrow_mut().insert(instance, llfn);
llfn
}
fn declare_fn_ext(
&self,
instance: Instance<'tcx>,
linkage: Option<LinkageType>,
) -> SpirvFunctionCursor {
let def_id = instance.def_id();
let control = attrs_to_spirv(self.tcx.codegen_fn_attrs(def_id));
let fn_abi = self.fn_abi_of_instance(instance, ty::List::empty());
let span = self.tcx.def_span(def_id);
let function_type = fn_abi.spirv_type(span, self);
let (return_type, argument_types) = match self.lookup_type(function_type) {
SpirvType::Function {
return_type,
arguments,
} => (return_type, arguments),
other => bug!("fn_abi type {}", other.debug(function_type, self)),
};
let declared = {
let mut emit = self.emit_global();
let id = emit
.begin_function(return_type, None, control, function_type)
.unwrap();
let index_in_builder = emit.selected_function().unwrap();
if linkage != Some(LinkageType::Import) {
for &ty in argument_types {
emit.function_parameter(ty).unwrap();
}
}
emit.end_function().unwrap();
SpirvFunctionCursor {
ty: function_type,
id,
index_in_builder,
}
};
let fn_id = declared.id;
let src_loc_inst = SrcLocDecoration::from_rustc_span(
self.tcx.def_ident_span(def_id).unwrap_or(span),
&self.builder,
)
.map(|src_loc| src_loc.encode_to_inst(fn_id));
self.emit_global()
.module_mut()
.annotations
.extend(src_loc_inst);
let symbol_name = self.tcx.symbol_name(instance).name;
let demangled_symbol_name = format!("{:#}", rustc_demangle::demangle(symbol_name));
self.emit_global().name(fn_id, &demangled_symbol_name);
if let Some(linkage) = linkage {
self.set_linkage(fn_id, symbol_name.to_owned(), linkage);
}
let attrs = AggregatedSpirvAttributes::parse(
self,
self.tcx.get_attrs_by_path(
def_id,
&[self.sym.rust_gpu, self.sym.spirv_attr_with_version],
),
);
if let Some(entry) = attrs.entry.map(|attr| attr.value) {
self.fn_instances.borrow_mut().insert(instance, declared);
let entry_name = entry
.name
.as_ref()
.map_or_else(|| instance.to_string(), ToString::to_string);
self.entry_stub(instance, fn_abi, entry_name, entry);
}
if attrs.buffer_load_intrinsic.is_some() {
self.buffer_load_intrinsics.borrow_mut().insert(def_id);
}
if attrs.buffer_store_intrinsic.is_some() {
self.buffer_store_intrinsics.borrow_mut().insert(def_id);
}
if self.tcx.crate_name(def_id.krate) == self.sym.libm && !def_id.is_local() {
let item_name = self.tcx.item_name(def_id);
if let Some(&intrinsic) = self.sym.libm_intrinsics.get(&item_name) {
self.libm_intrinsics.borrow_mut().insert(def_id, intrinsic);
} else {
let message = format!("missing libm intrinsic {symbol_name}, which is {instance}");
self.tcx.dcx().err(message);
}
}
if self.tcx.crate_name(def_id.krate) == self.sym.num_traits && !def_id.is_local() {
let item_name = self.tcx.item_name(def_id);
if let Some(&intrinsic) = self.sym.num_traits_intrinsics.get(&item_name) {
self.num_traits_intrinsics
.borrow_mut()
.insert(def_id, intrinsic);
}
}
if let Some(impl_def_id) = self.tcx.impl_of_assoc(def_id)
&& let Some(trait_ref) = self.tcx.impl_opt_trait_ref(impl_def_id)
{
let trait_def_id = trait_ref.skip_binder().def_id;
let trait_path = self.tcx.def_path_str(trait_def_id);
if matches!(
trait_path.as_str(),
"core::convert::From" | "std::convert::From"
) {
let trait_args = trait_ref.skip_binder().args;
if let (Some(target_ty), Some(source_ty)) =
(trait_args.types().nth(0), trait_args.types().nth(1))
{
self.from_trait_impls
.borrow_mut()
.insert(def_id, (source_ty, target_ty));
}
}
}
if [
self.tcx.lang_items().panic_fn(),
self.tcx.lang_items().panic_fmt(),
self.tcx.lang_items().panic_nounwind(),
]
.contains(&Some(def_id))
{
self.panic_entry_points.borrow_mut().insert(def_id);
}
if let Some(name) = demangled_symbol_name.strip_prefix("core::panicking::")
&& (name == "panic_explicit" || name.starts_with("panic_"))
{
self.panic_entry_points.borrow_mut().insert(def_id);
}
if let Some(generics) = demangled_symbol_name
.strip_prefix("<core::fmt::Arguments>::new::<")
.and_then(|s| s.strip_suffix(">"))
{
let mut generics = generics.split(',').map(str::trim);
if let (Some(template_len), Some(rt_args_count), None) =
(generics.next(), generics.next(), generics.next())
{
self.fmt_args_new_fn_ids.borrow_mut().insert(
fn_id,
FmtArgsCtor::NewTemplate {
template_len: template_len.parse().unwrap(),
rt_args_count: rt_args_count.parse().unwrap(),
},
);
}
}
if let Some(pieces_len) = demangled_symbol_name
.strip_prefix("<core::fmt::Arguments>::new_const::<")
.and_then(|s| s.strip_suffix(">"))
{
self.fmt_args_new_fn_ids.borrow_mut().insert(
fn_id,
FmtArgsCtor::NewConst {
pieces_len: pieces_len.parse().unwrap(),
},
);
}
if let Some(generics) = demangled_symbol_name
.strip_prefix("<core::fmt::Arguments>::new_v1::<")
.and_then(|s| s.strip_suffix(">"))
{
let (pieces_len, rt_args_len) = generics.split_once(", ").unwrap();
self.fmt_args_new_fn_ids.borrow_mut().insert(
fn_id,
FmtArgsCtor::NewV1 {
pieces_len: pieces_len.parse().unwrap(),
rt_args_count: rt_args_len.parse().unwrap(),
},
);
}
if demangled_symbol_name == "<core::fmt::Arguments>::new_v1_formatted" {
self.fmt_args_new_fn_ids
.borrow_mut()
.insert(fn_id, FmtArgsCtor::NewV1FormattedDynamic);
}
if demangled_symbol_name == "<core::fmt::Arguments>::from_str"
|| demangled_symbol_name == "<core::fmt::Arguments>::from_str_nonconst"
{
self.fmt_args_new_fn_ids
.borrow_mut()
.insert(fn_id, FmtArgsCtor::FromStr);
}
if let Some(suffix) = demangled_symbol_name.strip_prefix("<core::fmt::rt::Argument>::new_")
{
let spec = suffix.split_once("::<").and_then(|(method_suffix, _)| {
Some(match method_suffix {
"display" => ' ',
"debug" => '?',
"octal" => 'o',
"lower_hex" => 'x',
"upper_hex" => 'X',
"pointer" => 'p',
"binary" => 'b',
"lower_exp" => 'e',
"upper_exp" => 'E',
_ => return None,
})
});
if let Some(spec) = spec
&& let Some((ty,)) = instance.args.types().collect_tuple()
{
self.fmt_rt_arg_new_fn_ids_to_ty_and_spec
.borrow_mut()
.insert(fn_id, (ty, spec));
}
}
if demangled_symbol_name == "<core::fmt::rt::Argument>::from_usize" {
self.fmt_rt_arg_new_fn_ids_to_ty_and_spec
.borrow_mut()
.insert(fn_id, (self.tcx.types.usize, '?'));
}
declared
}
pub fn get_static(&self, def_id: DefId) -> SpirvValue {
if let Some(&g) = self.statics.borrow().get(&def_id) {
return g;
}
let defined_in_current_codegen_unit = self
.codegen_unit
.items()
.contains_key(&MonoItem::Static(def_id));
assert!(
!defined_in_current_codegen_unit,
"get_static() should always hit the cache for statics defined in the same CGU, but did not for `{def_id:?}`"
);
let instance = Instance::mono(self.tcx, def_id);
let ty = instance.ty(self.tcx, TypingEnv::fully_monomorphized());
let sym = self.tcx.symbol_name(instance).name;
let span = self.tcx.def_span(def_id);
let g = self.declare_global(span, self.layout_of(ty).spirv_type(span, self));
self.statics.borrow_mut().insert(def_id, g);
self.set_linkage(g.def_cx(self), sym.to_string(), LinkageType::Import);
g
}
fn declare_global(&self, span: Span, ty: Word) -> SpirvValue {
let ptr_ty = SpirvType::Pointer { pointee: ty }.def(span, self);
let result = self
.emit_global()
.variable(ptr_ty, None, StorageClass::Private, None)
.with_type(ptr_ty);
self.zombie_with_span(result.def_cx(self), span, "globals are not supported yet");
result
}
}
impl<'tcx> PreDefineCodegenMethods<'tcx> for CodegenCx<'tcx> {
fn predefine_static(
&mut self,
def_id: DefId,
linkage: Linkage,
_visibility: Visibility,
symbol_name: &str,
) {
let instance = Instance::mono(self.tcx, def_id);
let ty = instance.ty(self.tcx, TypingEnv::fully_monomorphized());
let span = self.tcx.def_span(def_id);
let spvty = self.layout_of(ty).spirv_type(span, self);
let linkage = match linkage {
Linkage::External => Some(LinkageType::Export),
Linkage::Internal => None,
other => {
self.tcx.dcx().err(format!(
"TODO: Linkage type {other:?} not supported yet for static var symbol {symbol_name}"
));
None
}
};
let g = self.declare_global(span, spvty);
self.statics.borrow_mut().insert(def_id, g);
if let Some(linkage) = linkage {
self.set_linkage(g.def_cx(self), symbol_name.to_string(), linkage);
}
}
fn predefine_fn(
&mut self,
instance: Instance<'tcx>,
linkage: Linkage,
_visibility: Visibility,
symbol_name: &str,
) {
let linkage2 = match linkage {
Linkage::External | Linkage::WeakAny => Some(LinkageType::Export),
Linkage::Internal => None,
other => {
self.tcx.dcx().err(format!(
"TODO: Linkage type {other:?} not supported yet for function symbol {symbol_name}"
));
None
}
};
let declared = self.declare_fn_ext(instance, linkage2);
self.fn_instances.borrow_mut().insert(instance, declared);
}
}
impl<'tcx> StaticCodegenMethods for CodegenCx<'tcx> {
fn static_addr_of(&self, alloc: ConstAllocation<'_>, _kind: Option<&str>) -> Self::Value {
self.static_addr_of_constant(self.const_data_from_alloc(alloc))
}
fn codegen_static(&mut self, def_id: DefId) {
let g = self.get_static(def_id);
let alloc = match self.tcx.eval_static_initializer(def_id) {
Ok(alloc) => alloc,
Err(_) => return,
};
let value_ty = match self.lookup_type(g.ty) {
SpirvType::Pointer { pointee } => pointee,
other => self.tcx.dcx().fatal(format!(
"global had non-pointer type {}",
other.debug(g.ty, self)
)),
};
let v = self.try_read_from_const_alloc(alloc, value_ty).unwrap();
assert_ty_eq!(self, value_ty, v.ty);
self.builder
.set_global_initializer(g.def_cx(self), v.def_cx(self));
let attrs = self.tcx.codegen_fn_attrs(def_id);
let alloc = alloc.inner();
let align_override =
Some(alloc.align).filter(|&align| align != self.lookup_type(value_ty).alignof(self));
if let Some(_align) = align_override {
}
if attrs.flags.contains(CodegenFnAttrFlags::THREAD_LOCAL) {
}
if let Some(_section) = attrs.link_section {
}
if attrs.flags.contains(CodegenFnAttrFlags::USED_COMPILER) {
assert!(!attrs.flags.contains(CodegenFnAttrFlags::USED_LINKER));
self.add_compiler_used_global(g);
}
if attrs.flags.contains(CodegenFnAttrFlags::USED_LINKER) {
assert!(!attrs.flags.contains(CodegenFnAttrFlags::USED_COMPILER));
self.add_used_global(g);
}
}
}
impl CodegenCx<'_> {
fn add_used_global(&self, global: SpirvValue) {
let _unused = (self, global);
}
fn add_compiler_used_global(&self, global: SpirvValue) {
let _unused = (self, global);
}
}