use crate::linker::ipo::CallGraph;
use crate::spirv_type_constraints::{self, InstSig, StorageClassPat, TyListPat, TyPat};
use indexmap::{IndexMap, IndexSet};
use rspirv::dr::{Builder, Function, Instruction, Module, Operand};
use rspirv::spirv::{Op, StorageClass, Word};
use rustc_data_structures::captures::Captures;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use smallvec::SmallVec;
use std::collections::{BTreeMap, VecDeque};
use std::convert::{TryFrom, TryInto};
use std::ops::{Range, RangeTo};
use std::{fmt, io, iter, mem, slice};
struct FmtBy<F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result>(F);
impl<F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result> fmt::Debug for FmtBy<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0(f)
}
}
impl<F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result> fmt::Display for FmtBy<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0(f)
}
}
pub trait Specialization {
fn specialize_operand(&self, operand: &Operand) -> bool;
fn concrete_fallback(&self) -> Operand;
}
pub struct SimpleSpecialization<SO: Fn(&Operand) -> bool> {
pub specialize_operand: SO,
pub concrete_fallback: Operand,
}
impl<SO: Fn(&Operand) -> bool> Specialization for SimpleSpecialization<SO> {
fn specialize_operand(&self, operand: &Operand) -> bool {
(self.specialize_operand)(operand)
}
fn concrete_fallback(&self) -> Operand {
self.concrete_fallback.clone()
}
}
pub fn specialize(
opts: &super::Options,
module: Module,
specialization: impl Specialization,
) -> Module {
let debug = opts.specializer_debug;
let dump_instances = &opts.specializer_dump_instances;
let mut debug_names = FxHashMap::default();
if debug || dump_instances.is_some() {
debug_names = module
.debug_names
.iter()
.filter(|inst| inst.class.opcode == Op::Name)
.map(|inst| {
(
inst.operands[0].unwrap_id_ref(),
inst.operands[1].unwrap_literal_string().to_string(),
)
})
.collect();
}
let mut specializer = Specializer {
specialization,
debug,
debug_names,
generics: IndexMap::new(),
int_consts: FxHashMap::default(),
};
specializer.collect_generics(&module);
let mut interface_concrete_instances = IndexSet::new();
for inst in &module.entry_points {
for interface_operand in &inst.operands[3..] {
let interface_id = interface_operand.unwrap_id_ref();
if let Some(generic) = specializer.generics.get(&interface_id) {
if let Some(param_values) = &generic.param_values {
if param_values.iter().all(|v| matches!(v, Value::Known(_))) {
interface_concrete_instances.insert(Instance {
generic_id: interface_id,
generic_args: param_values
.iter()
.copied()
.map(|v| match v {
Value::Known(v) => v,
_ => unreachable!(),
})
.collect(),
});
}
}
}
}
}
let call_graph = CallGraph::collect(&module);
let mut non_generic_replacements = vec![];
for func_idx in call_graph.post_order() {
if let Some(replacements) = specializer.infer_function(&module.functions[func_idx]) {
non_generic_replacements.push((func_idx, replacements));
}
}
let mut expander = Expander::new(&specializer, module);
for interface_instance in interface_concrete_instances {
expander.alloc_instance_id(interface_instance);
}
if debug {
eprintln!("non-generic replacements:");
}
for (func_idx, replacements) in non_generic_replacements {
let mut func = mem::replace(
&mut expander.builder.module_mut().functions[func_idx],
Function::new(),
);
if debug {
let empty = replacements.with_instance.is_empty()
&& replacements.with_concrete_or_param.is_empty();
if !empty {
eprintln!(" in %{}:", func.def_id().unwrap());
}
}
for (loc, operand) in
replacements.to_concrete(&[], |instance| expander.alloc_instance_id(instance))
{
if debug {
eprintln!(" {operand} -> {loc:?}");
}
func.index_set(loc, operand.into());
}
expander.builder.module_mut().functions[func_idx] = func;
}
expander.propagate_instances();
if let Some(path) = dump_instances {
expander
.dump_instances(&mut std::fs::File::create(path).unwrap())
.unwrap();
}
expander.expand_module()
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
enum CopyOperand {
IdRef(Word),
StorageClass(StorageClass),
}
#[derive(Debug)]
struct NotSupportedAsCopyOperand(Operand);
impl TryFrom<&Operand> for CopyOperand {
type Error = NotSupportedAsCopyOperand;
fn try_from(operand: &Operand) -> Result<Self, Self::Error> {
match *operand {
Operand::IdRef(id) => Ok(Self::IdRef(id)),
Operand::StorageClass(s) => Ok(Self::StorageClass(s)),
_ => Err(NotSupportedAsCopyOperand(operand.clone())),
}
}
}
impl From<CopyOperand> for Operand {
fn from(op: CopyOperand) -> Self {
match op {
CopyOperand::IdRef(id) => Self::IdRef(id),
CopyOperand::StorageClass(s) => Self::StorageClass(s),
}
}
}
impl fmt::Display for CopyOperand {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::IdRef(id) => write!(f, "%{id}"),
Self::StorageClass(s) => write!(f, "{s:?}"),
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
enum Value<T> {
Unknown,
Known(CopyOperand),
SameAs(T),
}
impl<T> Value<T> {
fn map_var<U>(self, f: impl FnOnce(T) -> U) -> Value<U> {
match self {
Value::Unknown => Value::Unknown,
Value::Known(o) => Value::Known(o),
Value::SameAs(var) => Value::SameAs(f(var)),
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct Param(u32);
impl fmt::Display for Param {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "${}", self.0)
}
}
impl Param {
fn range_iter(range: &Range<Self>) -> impl Iterator<Item = Self> + Clone {
(range.start.0..range.end.0).map(Self)
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct Instance<GA> {
generic_id: Word,
generic_args: GA,
}
impl<GA> Instance<GA> {
fn as_ref(&self) -> Instance<&GA> {
Instance {
generic_id: self.generic_id,
generic_args: &self.generic_args,
}
}
fn map_generic_args<T, U, GA2>(self, f: impl FnMut(T) -> U) -> Instance<GA2>
where
GA: IntoIterator<Item = T>,
GA2: std::iter::FromIterator<U>,
{
Instance {
generic_id: self.generic_id,
generic_args: self.generic_args.into_iter().map(f).collect(),
}
}
fn display<'a, T: fmt::Display, GAI: Iterator<Item = T> + Clone>(
&'a self,
f: impl FnOnce(&'a GA) -> GAI,
) -> impl fmt::Display {
let &Self {
generic_id,
ref generic_args,
} = self;
let generic_args_iter = f(generic_args);
FmtBy(move |f| {
write!(f, "%{generic_id}<")?;
for (i, arg) in generic_args_iter.clone().enumerate() {
if i != 0 {
write!(f, ", ")?;
}
write!(f, "{arg}")?;
}
write!(f, ">")
})
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum InstructionLocation {
Module,
FnParam(usize),
FnBody {
block_idx: usize,
inst_idx: usize,
},
}
trait OperandIndexGetSet<I> {
fn index_get(&self, index: I) -> Operand;
fn index_set(&mut self, index: I, operand: Operand);
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum OperandIdx {
ResultType,
Input(usize),
}
impl OperandIndexGetSet<OperandIdx> for Instruction {
fn index_get(&self, idx: OperandIdx) -> Operand {
match idx {
OperandIdx::ResultType => Operand::IdRef(self.result_type.unwrap()),
OperandIdx::Input(i) => self.operands[i].clone(),
}
}
fn index_set(&mut self, idx: OperandIdx, operand: Operand) {
match idx {
OperandIdx::ResultType => self.result_type = Some(operand.unwrap_id_ref()),
OperandIdx::Input(i) => self.operands[i] = operand,
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
struct OperandLocation {
inst_loc: InstructionLocation,
operand_idx: OperandIdx,
}
impl OperandIndexGetSet<OperandLocation> for Instruction {
fn index_get(&self, loc: OperandLocation) -> Operand {
assert_eq!(loc.inst_loc, InstructionLocation::Module);
self.index_get(loc.operand_idx)
}
fn index_set(&mut self, loc: OperandLocation, operand: Operand) {
assert_eq!(loc.inst_loc, InstructionLocation::Module);
self.index_set(loc.operand_idx, operand);
}
}
impl OperandIndexGetSet<OperandLocation> for Function {
fn index_get(&self, loc: OperandLocation) -> Operand {
let inst = match loc.inst_loc {
InstructionLocation::Module => self.def.as_ref().unwrap(),
InstructionLocation::FnParam(i) => &self.parameters[i],
InstructionLocation::FnBody {
block_idx,
inst_idx,
} => &self.blocks[block_idx].instructions[inst_idx],
};
inst.index_get(loc.operand_idx)
}
fn index_set(&mut self, loc: OperandLocation, operand: Operand) {
let inst = match loc.inst_loc {
InstructionLocation::Module => self.def.as_mut().unwrap(),
InstructionLocation::FnParam(i) => &mut self.parameters[i],
InstructionLocation::FnBody {
block_idx,
inst_idx,
} => &mut self.blocks[block_idx].instructions[inst_idx],
};
inst.index_set(loc.operand_idx, operand);
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
enum ConcreteOrParam {
Concrete(CopyOperand),
Param(Param),
}
impl ConcreteOrParam {
fn apply_generic_args(self, generic_args: &[CopyOperand]) -> CopyOperand {
match self {
Self::Concrete(x) => x,
Self::Param(Param(i)) => generic_args[i as usize],
}
}
}
#[derive(Debug)]
struct Replacements {
with_instance: IndexMap<Instance<SmallVec<[ConcreteOrParam; 4]>>, Vec<OperandLocation>>,
with_concrete_or_param: Vec<(OperandLocation, ConcreteOrParam)>,
}
impl Replacements {
fn to_concrete<'a>(
&'a self,
generic_args: &'a [CopyOperand],
mut concrete_instance_id: impl FnMut(Instance<SmallVec<[CopyOperand; 4]>>) -> Word + 'a,
) -> impl Iterator<Item = (OperandLocation, CopyOperand)> + 'a {
self.with_instance
.iter()
.flat_map(move |(instance, locations)| {
let concrete = CopyOperand::IdRef(concrete_instance_id(
instance
.as_ref()
.map_generic_args(|x| x.apply_generic_args(generic_args)),
));
locations.iter().map(move |&loc| (loc, concrete))
})
.chain(
self.with_concrete_or_param
.iter()
.map(move |&(loc, x)| (loc, x.apply_generic_args(generic_args))),
)
}
}
struct Generic {
param_count: u32,
def: Instruction,
param_values: Option<Vec<Value<Param>>>,
replacements: Replacements,
}
struct Specializer<S: Specialization> {
specialization: S,
debug: bool,
debug_names: FxHashMap<Word, String>,
generics: IndexMap<Word, Generic>,
int_consts: FxHashMap<Word, u32>,
}
impl<S: Specialization> Specializer<S> {
fn params_needed_by(&self, operand: &Operand) -> (u32, Option<&Generic>) {
if self.specialization.specialize_operand(operand) {
(1, None)
} else if let Operand::IdRef(id) = operand {
self.generics
.get(id)
.map_or((0, None), |generic| (generic.param_count, Some(generic)))
} else {
(0, None)
}
}
fn collect_generics(&mut self, module: &Module) {
let types_global_values_and_functions = module
.types_global_values
.iter()
.chain(module.functions.iter().filter_map(|f| f.def.as_ref()));
let mut forward_declared_pointers = FxHashSet::default();
for inst in types_global_values_and_functions {
let result_id = if inst.class.opcode == Op::TypeForwardPointer {
forward_declared_pointers.insert(inst.operands[0].unwrap_id_ref());
inst.operands[0].unwrap_id_ref()
} else {
let result_id = inst.result_id.unwrap_or_else(|| {
unreachable!(
"Op{:?} is in `types_global_values` but not have a result ID",
inst.class.opcode
);
});
if forward_declared_pointers.remove(&result_id) {
assert_eq!(inst.class.opcode, Op::TypePointer);
continue;
}
result_id
};
if inst.class.opcode == Op::Constant {
if let Operand::LiteralInt32(x) = inst.operands[0] {
self.int_consts.insert(result_id, x);
}
}
let (param_count, param_values, replacements) = {
let mut infer_cx = InferCx::new(self);
infer_cx.instantiate_instruction(inst, InstructionLocation::Module);
let param_count = infer_cx.infer_var_values.len() as u32;
let param_values = infer_cx
.infer_var_values
.iter()
.map(|v| v.map_var(|InferVar(i)| Param(i)));
let param_values = if param_values.clone().any(|v| v != Value::Unknown) {
Some(param_values.collect())
} else {
None
};
(
param_count,
param_values,
infer_cx.into_replacements(..Param(param_count)),
)
};
if param_count > 0 {
self.generics.insert(
result_id,
Generic {
param_count,
def: inst.clone(),
param_values,
replacements,
},
);
}
}
}
fn infer_function(&mut self, func: &Function) -> Option<Replacements> {
let func_id = func.def_id().unwrap();
let param_count = self
.generics
.get(&func_id)
.map_or(0, |generic| generic.param_count);
let (param_values, replacements) = {
let mut infer_cx = InferCx::new(self);
infer_cx.instantiate_function(func);
let param_values = infer_cx.infer_var_values[..param_count as usize]
.iter()
.map(|v| v.map_var(|InferVar(i)| Param(i)));
let param_values = if param_values.clone().any(|v| v != Value::Unknown) {
Some(param_values.collect())
} else {
None
};
(
param_values,
infer_cx.into_replacements(..Param(param_count)),
)
};
if let Some(generic) = self.generics.get_mut(&func_id) {
assert!(generic.param_values.is_none());
generic.param_values = param_values;
generic.replacements = replacements;
None
} else {
Some(replacements)
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct InferVar(u32);
impl fmt::Display for InferVar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "?{}", self.0)
}
}
impl InferVar {
fn range_iter(range: &Range<Self>) -> impl Iterator<Item = Self> + Clone {
(range.start.0..range.end.0).map(Self)
}
}
struct InferCx<'a, S: Specialization> {
specializer: &'a Specializer<S>,
infer_var_values: Vec<Value<InferVar>>,
type_of_result: IndexMap<Word, InferOperand>,
instantiated_operands: Vec<(OperandLocation, Instance<Range<InferVar>>)>,
inferred_operands: Vec<(OperandLocation, InferVar)>,
}
impl<'a, S: Specialization> InferCx<'a, S> {
fn new(specializer: &'a Specializer<S>) -> Self {
InferCx {
specializer,
infer_var_values: vec![],
type_of_result: IndexMap::new(),
instantiated_operands: vec![],
inferred_operands: vec![],
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
enum InferOperand {
Unknown,
Var(InferVar),
Concrete(CopyOperand),
Instance(Instance<Range<InferVar>>),
}
impl InferOperand {
fn from_operand_and_generic_args(
operand: &Operand,
generic_args: Range<InferVar>,
cx: &InferCx<'_, impl Specialization>,
) -> (Self, Range<InferVar>) {
let (needed, generic) = cx.specializer.params_needed_by(operand);
let split = InferVar(generic_args.start.0 + needed);
let (generic_args, rest) = (generic_args.start..split, split..generic_args.end);
(
if generic.is_some() {
Self::Instance(Instance {
generic_id: operand.unwrap_id_ref(),
generic_args,
})
} else if needed == 0 {
CopyOperand::try_from(operand).map_or(Self::Unknown, Self::Concrete)
} else {
assert_eq!(needed, 1);
Self::Var(generic_args.start)
},
rest,
)
}
fn display_with_infer_var_values<'a>(
&'a self,
infer_var_value: impl Fn(InferVar) -> Value<InferVar> + Copy + 'a,
) -> impl fmt::Display + '_ {
FmtBy(move |f| {
let var_with_value = |v| {
FmtBy(move |f| {
write!(f, "{v}")?;
match infer_var_value(v) {
Value::Unknown => Ok(()),
Value::Known(o) => write!(f, " = {o}"),
Value::SameAs(v) => write!(f, " = {v}"),
}
})
};
match self {
Self::Unknown => write!(f, "_"),
Self::Var(v) => write!(f, "{}", var_with_value(*v)),
Self::Concrete(o) => write!(f, "{o}"),
Self::Instance(instance) => write!(
f,
"{}",
instance.display(|generic_args| {
InferVar::range_iter(generic_args).map(var_with_value)
})
),
}
})
}
fn display_with_infer_cx<'a>(
&'a self,
cx: &'a InferCx<'_, impl Specialization>,
) -> impl fmt::Display + '_ {
self.display_with_infer_var_values(move |v| {
let get = |v: InferVar| cx.infer_var_values[v.0 as usize];
let mut value = get(v);
while let Value::SameAs(v) = value {
let next = get(v);
if next == Value::Unknown {
break;
}
value = next;
}
value
})
}
}
impl fmt::Display for InferOperand {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.display_with_infer_var_values(|_| Value::Unknown)
.fmt(f)
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
enum InferOperandListTransform {
TypeOfId,
}
#[derive(Clone, PartialEq)]
struct InferOperandList<'a> {
operands: &'a [Operand],
all_generic_args: Range<InferVar>,
transform: Option<InferOperandListTransform>,
}
impl<'a> InferOperandList<'a> {
fn split_first(
&self,
cx: &InferCx<'_, impl Specialization>,
) -> Option<(InferOperand, InferOperandList<'a>)> {
let mut list = self.clone();
loop {
let (first_operand, rest) = list.operands.split_first()?;
list.operands = rest;
let (first, rest_args) = InferOperand::from_operand_and_generic_args(
first_operand,
list.all_generic_args.clone(),
cx,
);
list.all_generic_args = rest_args;
match self.transform {
None => {}
Some(InferOperandListTransform::TypeOfId) => {
if first_operand.id_ref_any().is_none() {
continue;
}
}
}
let first = match self.transform {
None => first,
Some(InferOperandListTransform::TypeOfId) => match first {
InferOperand::Concrete(CopyOperand::IdRef(id)) => cx
.type_of_result
.get(&id)
.cloned()
.unwrap_or(InferOperand::Unknown),
InferOperand::Unknown | InferOperand::Var(_) | InferOperand::Concrete(_) => {
InferOperand::Unknown
}
InferOperand::Instance(instance) => {
let generic = &cx.specializer.generics[&instance.generic_id];
let type_of_result = match generic.def.class.opcode {
Op::Function => Some(generic.def.operands[1].unwrap_id_ref()),
_ => generic.def.result_type,
};
match type_of_result {
Some(type_of_result) => {
InferOperand::from_operand_and_generic_args(
&Operand::IdRef(type_of_result),
instance.generic_args,
cx,
)
.0
}
None => InferOperand::Unknown,
}
}
},
};
return Some((first, list));
}
}
fn iter<'b>(
&self,
cx: &'b InferCx<'_, impl Specialization>,
) -> impl Iterator<Item = InferOperand> + 'b
where
'a: 'b,
{
let mut list = self.clone();
iter::from_fn(move || {
let (next, rest) = list.split_first(cx)?;
list = rest;
Some(next)
})
}
fn display_with_infer_cx<'b>(
&'b self,
cx: &'b InferCx<'a, impl Specialization>,
) -> impl fmt::Display + '_ {
FmtBy(move |f| {
f.debug_list()
.entries(self.iter(cx).map(|operand| {
FmtBy(move |f| write!(f, "{}", operand.display_with_infer_cx(cx)))
}))
.finish()
})
}
}
#[derive(Default)]
struct SmallIntMap<A: smallvec::Array>(SmallVec<A>);
impl<A: smallvec::Array> SmallIntMap<A> {
fn get(&self, i: usize) -> Option<&A::Item> {
self.0.get(i)
}
fn get_mut_or_default(&mut self, i: usize) -> &mut A::Item
where
A::Item: Default,
{
let needed = i + 1;
if self.0.len() < needed {
self.0.resize_with(needed, Default::default);
}
&mut self.0[i]
}
}
impl<A: smallvec::Array> IntoIterator for SmallIntMap<A> {
type Item = (usize, A::Item);
type IntoIter = iter::Enumerate<smallvec::IntoIter<A>>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter().enumerate()
}
}
impl<'a, A: smallvec::Array> IntoIterator for &'a mut SmallIntMap<A> {
type Item = (usize, &'a mut A::Item);
type IntoIter = iter::Enumerate<slice::IterMut<'a, A::Item>>;
fn into_iter(self) -> Self::IntoIter {
self.0.iter_mut().enumerate()
}
}
#[derive(PartialEq)]
struct IndexCompositeMatch<'a> {
indices: &'a [Operand],
leaf: InferOperand,
}
#[must_use]
#[derive(Default)]
struct Match<'a> {
ambiguous: bool,
storage_class_var_found: SmallIntMap<[SmallVec<[InferOperand; 2]>; 1]>,
ty_var_found: SmallIntMap<[SmallVec<[InferOperand; 4]>; 1]>,
index_composite_ty_var_found: SmallIntMap<[SmallVec<[IndexCompositeMatch<'a>; 1]>; 1]>,
ty_list_var_found: SmallIntMap<[SmallVec<[InferOperandList<'a>; 2]>; 1]>,
}
impl<'a> Match<'a> {
fn and(mut self, other: Self) -> Self {
let Match {
ambiguous,
storage_class_var_found,
ty_var_found,
index_composite_ty_var_found,
ty_list_var_found,
} = &mut self;
*ambiguous |= other.ambiguous;
for (i, other_found) in other.storage_class_var_found {
storage_class_var_found
.get_mut_or_default(i)
.extend(other_found);
}
for (i, other_found) in other.ty_var_found {
ty_var_found.get_mut_or_default(i).extend(other_found);
}
for (i, other_found) in other.index_composite_ty_var_found {
index_composite_ty_var_found
.get_mut_or_default(i)
.extend(other_found);
}
for (i, other_found) in other.ty_list_var_found {
ty_list_var_found.get_mut_or_default(i).extend(other_found);
}
self
}
fn or(mut self, other: Self) -> Self {
let Match {
ambiguous,
storage_class_var_found,
ty_var_found,
index_composite_ty_var_found,
ty_list_var_found,
} = &mut self;
*ambiguous |= other.ambiguous;
for (i, self_found) in storage_class_var_found {
let other_found = other
.storage_class_var_found
.get(i)
.map_or(&[][..], |xs| &xs[..]);
self_found.retain(|x| other_found.contains(x));
}
for (i, self_found) in ty_var_found {
let other_found = other.ty_var_found.get(i).map_or(&[][..], |xs| &xs[..]);
self_found.retain(|x| other_found.contains(x));
}
for (i, self_found) in index_composite_ty_var_found {
let other_found = other
.index_composite_ty_var_found
.get(i)
.map_or(&[][..], |xs| &xs[..]);
self_found.retain(|x| other_found.contains(x));
}
for (i, self_found) in ty_list_var_found {
let other_found = other.ty_list_var_found.get(i).map_or(&[][..], |xs| &xs[..]);
self_found.retain(|x| other_found.contains(x));
}
self
}
fn debug_with_infer_cx<'b>(
&'b self,
cx: &'b InferCx<'a, impl Specialization>,
) -> impl fmt::Debug + Captures<'a> + '_ {
fn debug_var_found<'a, A: smallvec::Array<Item = T> + 'a, T: 'a, TD: fmt::Display>(
var_found: &'a SmallIntMap<impl smallvec::Array<Item = SmallVec<A>>>,
display: &'a impl Fn(&'a T) -> TD,
) -> impl Iterator<Item = impl fmt::Debug + 'a> + 'a {
var_found
.0
.iter()
.filter(|found| !found.is_empty())
.map(move |found| {
FmtBy(move |f| {
let mut found = found.iter().map(display);
write!(f, "{}", found.next().unwrap())?;
for x in found {
write!(f, " = {x}")?;
}
Ok(())
})
})
}
FmtBy(move |f| {
let Self {
ambiguous,
storage_class_var_found,
ty_var_found,
index_composite_ty_var_found,
ty_list_var_found,
} = self;
write!(f, "Match{} ", if *ambiguous { " (ambiguous)" } else { "" })?;
let mut list = f.debug_list();
list.entries(debug_var_found(storage_class_var_found, &move |operand| {
operand.display_with_infer_cx(cx)
}));
list.entries(debug_var_found(ty_var_found, &move |operand| {
operand.display_with_infer_cx(cx)
}));
list.entries(
index_composite_ty_var_found
.0
.iter()
.enumerate()
.filter(|(_, found)| !found.is_empty())
.flat_map(|(i, found)| found.iter().map(move |x| (i, x)))
.map(move |(i, IndexCompositeMatch { indices, leaf })| {
FmtBy(move |f| {
match ty_var_found.get(i) {
Some(found) if found.len() == 1 => {
write!(f, "{}", found[0].display_with_infer_cx(cx))?;
}
found => {
let found = found.map_or(&[][..], |xs| &xs[..]);
write!(f, "(")?;
for (j, operand) in found.iter().enumerate() {
if j != 0 {
write!(f, " = ")?;
}
write!(f, "{}", operand.display_with_infer_cx(cx))?;
}
write!(f, ")")?;
}
}
for operand in &indices[..] {
let maybe_idx = match operand {
Operand::IdRef(id) => cx.specializer.int_consts.get(id),
Operand::LiteralInt32(idx) => Some(idx),
_ => None,
};
match maybe_idx {
Some(idx) => write!(f, ".{idx}")?,
None => write!(f, "[{operand}]")?,
}
}
write!(f, " = {}", leaf.display_with_infer_cx(cx))
})
}),
);
list.entries(debug_var_found(ty_list_var_found, &move |list| {
list.display_with_infer_cx(cx)
}));
list.finish()
})
}
}
struct Unapplicable;
impl<'a, S: Specialization> InferCx<'a, S> {
#[allow(clippy::unused_self)] fn match_storage_class_pat(
&self,
pat: &StorageClassPat,
storage_class: InferOperand,
) -> Match<'a> {
match pat {
StorageClassPat::Any => Match::default(),
StorageClassPat::Var(i) => {
let mut m = Match::default();
m.storage_class_var_found
.get_mut_or_default(*i)
.push(storage_class);
m
}
}
}
fn match_ty_pat(&self, pat: &TyPat<'_>, ty: InferOperand) -> Result<Match<'a>, Unapplicable> {
match pat {
TyPat::Any => Ok(Match::default()),
TyPat::Var(i) => {
let mut m = Match::default();
m.ty_var_found.get_mut_or_default(*i).push(ty);
Ok(m)
}
TyPat::Either(a, b) => match self.match_ty_pat(a, ty.clone()) {
Ok(m) if !m.ambiguous => Ok(m),
a_result => match (a_result, self.match_ty_pat(b, ty)) {
(Ok(ma), Ok(mb)) => Ok(ma.or(mb)),
(Ok(m), _) | (_, Ok(m)) => Ok(m),
(Err(Unapplicable), Err(Unapplicable)) => Err(Unapplicable),
},
},
TyPat::IndexComposite(composite_pat) => match composite_pat {
TyPat::Var(i) => {
let mut m = Match::default();
m.index_composite_ty_var_found.get_mut_or_default(*i).push(
IndexCompositeMatch {
indices: &[],
leaf: ty,
},
);
Ok(m)
}
_ => unreachable!(
"`IndexComposite({:?})` isn't supported, only type variable
patterns are (for the composite type), e.g. `IndexComposite(T)`",
composite_pat
),
},
_ => {
let instance = match ty {
InferOperand::Unknown | InferOperand::Concrete(_) => {
return Ok(Match {
ambiguous: true,
..Match::default()
});
}
InferOperand::Var(_) => return Err(Unapplicable),
InferOperand::Instance(instance) => instance,
};
let generic = &self.specializer.generics[&instance.generic_id];
let ty_operands = InferOperandList {
operands: &generic.def.operands,
all_generic_args: instance.generic_args,
transform: None,
};
let simple = |op, inner_pat| {
if generic.def.class.opcode == op {
self.match_ty_pat(inner_pat, ty_operands.split_first(self).unwrap().0)
} else {
Err(Unapplicable)
}
};
match pat {
TyPat::Any | TyPat::Var(_) | TyPat::Either(..) | TyPat::IndexComposite(_) => {
unreachable!()
}
TyPat::Void => unreachable!(),
TyPat::Pointer(storage_class_pat, pointee_pat) => {
let mut ty_operands = ty_operands.iter(self);
let (storage_class, pointee_ty) =
(ty_operands.next().unwrap(), ty_operands.next().unwrap());
Ok(self
.match_storage_class_pat(storage_class_pat, storage_class)
.and(self.match_ty_pat(pointee_pat, pointee_ty)?))
}
TyPat::Array(pat) => simple(Op::TypeArray, pat),
TyPat::Vector(pat) => simple(Op::TypeVector, pat),
TyPat::Vector4(pat) => match ty_operands.operands {
[_, Operand::LiteralInt32(4)] => simple(Op::TypeVector, pat),
_ => Err(Unapplicable),
},
TyPat::Matrix(pat) => simple(Op::TypeMatrix, pat),
TyPat::Image(pat) => simple(Op::TypeImage, pat),
TyPat::Pipe(_pat) => {
if generic.def.class.opcode == Op::TypePipe {
Ok(Match::default())
} else {
Err(Unapplicable)
}
}
TyPat::SampledImage(pat) => simple(Op::TypeSampledImage, pat),
TyPat::Struct(fields_pat) => {
if generic.def.class.opcode == Op::TypeStruct {
self.match_ty_list_pat(fields_pat, ty_operands)
} else {
Err(Unapplicable)
}
}
TyPat::Function(ret_pat, params_pat) => {
let (ret_ty, params_ty_list) = ty_operands.split_first(self).unwrap();
Ok(self
.match_ty_pat(ret_pat, ret_ty)?
.and(self.match_ty_list_pat(params_pat, params_ty_list)?))
}
}
}
}
}
fn match_ty_list_pat(
&self,
mut list_pat: &TyListPat<'_>,
mut ty_list: InferOperandList<'a>,
) -> Result<Match<'a>, Unapplicable> {
let mut m = Match::default();
while let TyListPat::Cons { first: pat, suffix } = list_pat {
list_pat = suffix;
let (ty, rest) = ty_list.split_first(self).ok_or(Unapplicable)?;
ty_list = rest;
m = m.and(self.match_ty_pat(pat, ty)?);
}
match list_pat {
TyListPat::Cons { .. } => unreachable!(),
TyListPat::Any => {}
TyListPat::Var(i) => {
m.ty_list_var_found.get_mut_or_default(*i).push(ty_list);
}
TyListPat::Repeat(repeat_list_pat) => {
let mut tys = ty_list.iter(self).peekable();
loop {
let mut list_pat = repeat_list_pat;
while let TyListPat::Cons { first: pat, suffix } = list_pat {
m = m.and(self.match_ty_pat(pat, tys.next().ok_or(Unapplicable)?)?);
list_pat = suffix;
}
assert!(matches!(list_pat, TyListPat::Nil));
if tys.peek().is_none() {
break;
}
}
}
TyListPat::Nil => {
if ty_list.split_first(self).is_some() {
return Err(Unapplicable);
}
}
}
Ok(m)
}
fn match_inst_sig(
&self,
sig: &InstSig<'_>,
inst: &'a Instruction,
inputs_generic_args: Range<InferVar>,
result_type: Option<InferOperand>,
) -> Result<Match<'a>, Unapplicable> {
let mut m = Match::default();
if let Some(pat) = sig.storage_class {
let all_operands = InferOperandList {
operands: &inst.operands,
all_generic_args: inputs_generic_args.clone(),
transform: None,
};
let storage_class = all_operands
.iter(self)
.zip(&inst.operands)
.filter(|(_, original)| matches!(original, Operand::StorageClass(_)))
.map(|(operand, _)| operand)
.next()
.ok_or(Unapplicable)?;
m = m.and(self.match_storage_class_pat(pat, storage_class));
}
let input_ty_list = InferOperandList {
operands: &inst.operands,
all_generic_args: inputs_generic_args,
transform: Some(InferOperandListTransform::TypeOfId),
};
m = m.and(self.match_ty_list_pat(sig.input_types, input_ty_list.clone())?);
match (sig.output_type, result_type) {
(Some(pat), Some(result_type)) => {
m = m.and(self.match_ty_pat(pat, result_type)?);
}
(None, None) => {}
_ => return Err(Unapplicable),
}
if !m.index_composite_ty_var_found.0.is_empty() {
let composite_indices = {
let mut ty_list = input_ty_list;
let mut list_pat = sig.input_types;
while let TyListPat::Cons { first: _, suffix } = list_pat {
list_pat = suffix;
ty_list = ty_list.split_first(self).ok_or(Unapplicable)?.1;
}
assert_eq!(
list_pat,
&TyListPat::Any,
"`IndexComposite` must have input types end in `..`"
);
ty_list.operands
};
for (_, found) in &mut m.index_composite_ty_var_found {
for index_composite_match in found {
let empty = mem::replace(&mut index_composite_match.indices, composite_indices);
assert_eq!(empty, &[]);
}
}
}
Ok(m)
}
fn match_inst_sigs(
&self,
sigs: &[InstSig<'_>],
inst: &'a Instruction,
inputs_generic_args: Range<InferVar>,
result_type: Option<InferOperand>,
) -> Result<Match<'a>, Unapplicable> {
let mut result = Err(Unapplicable);
for sig in sigs {
result = match (
result,
self.match_inst_sig(sig, inst, inputs_generic_args.clone(), result_type.clone()),
) {
(Err(Unapplicable), Ok(m)) if !m.ambiguous => return Ok(m),
(Ok(a), Ok(b)) => Ok(a.or(b)),
(Ok(m), _) | (_, Ok(m)) => Ok(m),
(Err(Unapplicable), Err(Unapplicable)) => Err(Unapplicable),
};
}
result
}
}
enum InferError {
Conflict(InferOperand, InferOperand),
}
impl InferError {
fn report(self, inst: &Instruction) {
match self {
Self::Conflict(a, b) => {
eprintln!("inference conflict: {a:?} vs {b:?}");
}
}
eprint!(" in ");
if let Some(result_id) = inst.result_id {
eprint!("%{result_id} = ");
}
eprint!("Op{:?}", inst.class.opcode);
for operand in inst
.result_type
.map(Operand::IdRef)
.iter()
.chain(inst.operands.iter())
{
eprint!(" {operand}");
}
eprintln!();
std::process::exit(1);
}
}
impl<'a, S: Specialization> InferCx<'a, S> {
fn resolve_infer_var(&mut self, v: InferVar) -> InferVar {
match self.infer_var_values[v.0 as usize] {
Value::Unknown | Value::Known(_) => v,
Value::SameAs(next) => {
let resolved = self.resolve_infer_var(next);
if resolved != next {
self.infer_var_values[v.0 as usize] = Value::SameAs(resolved);
}
resolved
}
}
}
fn equate_infer_vars(&mut self, a: InferVar, b: InferVar) -> Result<InferVar, InferError> {
let (a, b) = (self.resolve_infer_var(a), self.resolve_infer_var(b));
if a == b {
return Ok(a);
}
let (older, newer) = (a.min(b), a.max(b));
let newer_value = mem::replace(
&mut self.infer_var_values[newer.0 as usize],
Value::SameAs(older),
);
match (self.infer_var_values[older.0 as usize], newer_value) {
(Value::SameAs(_), _) | (_, Value::SameAs(_)) => unreachable!(),
(Value::Known(x), Value::Known(y)) => {
if x != y {
return Err(InferError::Conflict(
InferOperand::Concrete(x),
InferOperand::Concrete(y),
));
}
}
(Value::Unknown, Value::Known(_)) => {
self.infer_var_values[older.0 as usize] = newer_value;
}
(_, Value::Unknown) => {}
}
Ok(older)
}
fn equate_infer_var_ranges(
&mut self,
a: Range<InferVar>,
b: Range<InferVar>,
) -> Result<Range<InferVar>, InferError> {
if a == b {
return Ok(a);
}
assert_eq!(a.end.0 - a.start.0, b.end.0 - b.start.0);
for (a, b) in InferVar::range_iter(&a).zip(InferVar::range_iter(&b)) {
self.equate_infer_vars(a, b)?;
}
Ok(if a.start < b.start { a } else { b })
}
fn equate_infer_operands(
&mut self,
a: InferOperand,
b: InferOperand,
) -> Result<InferOperand, InferError> {
if a == b {
return Ok(a);
}
#[allow(clippy::match_same_arms)]
Ok(match (a.clone(), b.clone()) {
(
InferOperand::Instance(Instance {
generic_id: a_id,
generic_args: a_args,
}),
InferOperand::Instance(Instance {
generic_id: b_id,
generic_args: b_args,
}),
) => {
if a_id != b_id {
return Err(InferError::Conflict(a, b));
}
InferOperand::Instance(Instance {
generic_id: a_id,
generic_args: self.equate_infer_var_ranges(a_args, b_args)?,
})
}
(InferOperand::Instance(_), _) | (_, InferOperand::Instance(_)) => {
return Err(InferError::Conflict(a, b));
}
(InferOperand::Var(a), InferOperand::Var(b)) => {
InferOperand::Var(self.equate_infer_vars(a, b)?)
}
(InferOperand::Var(v), InferOperand::Concrete(new))
| (InferOperand::Concrete(new), InferOperand::Var(v)) => {
let v = self.resolve_infer_var(v);
match &mut self.infer_var_values[v.0 as usize] {
Value::SameAs(_) => unreachable!(),
&mut Value::Known(old) => {
if new != old {
return Err(InferError::Conflict(
InferOperand::Concrete(old),
InferOperand::Concrete(new),
));
}
}
value @ Value::Unknown => *value = Value::Known(new),
}
InferOperand::Var(v)
}
(InferOperand::Concrete(_), InferOperand::Concrete(_)) => {
return Err(InferError::Conflict(a, b));
}
(InferOperand::Unknown, x) | (x, InferOperand::Unknown) => x,
})
}
fn index_composite(&self, composite_ty: InferOperand, indices: &[Operand]) -> InferOperand {
let mut ty = composite_ty;
for idx in indices {
let instance = match ty {
InferOperand::Unknown | InferOperand::Concrete(_) | InferOperand::Var(_) => {
return InferOperand::Unknown;
}
InferOperand::Instance(instance) => instance,
};
let generic = &self.specializer.generics[&instance.generic_id];
let ty_opcode = generic.def.class.opcode;
let ty_operands = InferOperandList {
operands: &generic.def.operands,
all_generic_args: instance.generic_args,
transform: None,
};
let ty_operands_idx = match ty_opcode {
Op::TypeArray | Op::TypeRuntimeArray | Op::TypeVector | Op::TypeMatrix => 0,
Op::TypeStruct => match idx {
Operand::IdRef(id) => {
*self.specializer.int_consts.get(id).unwrap_or_else(|| {
unreachable!("non-constant `OpTypeStruct` field index {}", id);
})
}
&Operand::LiteralInt32(i) => i,
_ => {
unreachable!("invalid `OpTypeStruct` field index operand {:?}", idx);
}
},
_ => unreachable!("indexing non-composite type `Op{:?}`", ty_opcode),
};
ty = ty_operands
.iter(self)
.nth(ty_operands_idx as usize)
.unwrap_or_else(|| {
unreachable!(
"out of bounds index {} for `Op{:?}`",
ty_operands_idx, ty_opcode
);
});
}
ty
}
fn equate_match_findings(&mut self, m: Match<'_>) -> Result<(), InferError> {
let Match {
ambiguous: _,
storage_class_var_found,
ty_var_found,
index_composite_ty_var_found,
ty_list_var_found,
} = m;
for (_, found) in storage_class_var_found {
let mut found = found.into_iter();
if let Some(first) = found.next() {
found.try_fold(first, |a, b| self.equate_infer_operands(a, b))?;
}
}
for (i, found) in ty_var_found {
let mut found = found.into_iter();
if let Some(first) = found.next() {
let equated_ty = found.try_fold(first, |a, b| self.equate_infer_operands(a, b))?;
let index_composite_found = index_composite_ty_var_found
.get(i)
.map_or(&[][..], |xs| &xs[..]);
for IndexCompositeMatch { indices, leaf } in index_composite_found {
let indexing_result_ty = self.index_composite(equated_ty.clone(), indices);
self.equate_infer_operands(indexing_result_ty, leaf.clone())?;
}
}
}
for (_, mut found) in ty_list_var_found {
if let Some((first_list, other_lists)) = found.split_first_mut() {
while let Some((first, rest)) = first_list.split_first(self) {
*first_list = rest;
other_lists.iter_mut().try_fold(first, |a, b_list| {
let (b, rest) = b_list
.split_first(self)
.expect("list length mismatch (invalid SPIR-V?)");
*b_list = rest;
self.equate_infer_operands(a, b)
})?;
}
for other_list in other_lists {
assert!(
other_list.split_first(self).is_none(),
"list length mismatch (invalid SPIR-V?)"
);
}
}
}
Ok(())
}
fn record_instantiated_operand(&mut self, loc: OperandLocation, operand: InferOperand) {
match operand {
InferOperand::Var(v) => {
self.inferred_operands.push((loc, v));
}
InferOperand::Instance(instance) => {
self.instantiated_operands.push((loc, instance));
}
InferOperand::Unknown | InferOperand::Concrete(_) => {}
}
}
fn instantiate_instruction(&mut self, inst: &'a Instruction, inst_loc: InstructionLocation) {
let mut all_generic_args = {
let next_infer_var = InferVar(self.infer_var_values.len().try_into().unwrap());
next_infer_var..next_infer_var
};
let (instantiate_result_type, record_fn_ret_ty, type_of_result) = match inst.class.opcode {
Op::Function => (
None,
inst.result_type,
Some(inst.operands[1].unwrap_id_ref()),
),
_ => (inst.result_type, None, inst.result_type),
};
for (operand_idx, operand) in instantiate_result_type
.map(Operand::IdRef)
.iter()
.map(|o| (OperandIdx::ResultType, o))
.chain(
inst.operands
.iter()
.enumerate()
.map(|(i, o)| (OperandIdx::Input(i), o)),
)
{
let (operand, rest) = InferOperand::from_operand_and_generic_args(
operand,
all_generic_args.end..InferVar(u32::MAX),
self,
);
let generic_args = all_generic_args.end..rest.start;
all_generic_args.end = generic_args.end;
let generic = match &operand {
InferOperand::Instance(instance) => {
Some(&self.specializer.generics[&instance.generic_id])
}
_ => None,
};
match generic {
Some(Generic {
param_values: Some(values),
..
}) => self.infer_var_values.extend(
values
.iter()
.map(|v| v.map_var(|Param(p)| InferVar(generic_args.start.0 + p))),
),
_ => {
self.infer_var_values
.extend(InferVar::range_iter(&generic_args).map(|_| Value::Unknown));
}
}
self.record_instantiated_operand(
OperandLocation {
inst_loc,
operand_idx,
},
operand,
);
}
if let Some(ret_ty) = record_fn_ret_ty {
let (ret_ty, _) = InferOperand::from_operand_and_generic_args(
&Operand::IdRef(ret_ty),
all_generic_args.clone(),
self,
);
self.record_instantiated_operand(
OperandLocation {
inst_loc,
operand_idx: OperandIdx::ResultType,
},
ret_ty,
);
}
let (type_of_result, inputs_generic_args) = match type_of_result {
Some(type_of_result) => {
let (type_of_result, rest) = InferOperand::from_operand_and_generic_args(
&Operand::IdRef(type_of_result),
all_generic_args.clone(),
self,
);
(
Some(type_of_result),
match inst.class.opcode {
Op::Function => all_generic_args,
_ => rest,
},
)
}
None => (None, all_generic_args),
};
let debug_dump_if_enabled = |cx: &Self, prefix| {
if cx.specializer.debug {
let result_type = match inst.class.opcode {
Op::Function => Some(
InferOperand::from_operand_and_generic_args(
&Operand::IdRef(inst.result_type.unwrap()),
inputs_generic_args.clone(),
cx,
)
.0,
),
_ => type_of_result.clone(),
};
let inputs = InferOperandList {
operands: &inst.operands,
all_generic_args: inputs_generic_args.clone(),
transform: None,
};
if inst_loc != InstructionLocation::Module {
eprint!(" ");
}
eprint!("{prefix}");
if let Some(result_id) = inst.result_id {
eprint!("%{result_id} = ");
}
eprint!("Op{:?}", inst.class.opcode);
for operand in result_type.into_iter().chain(inputs.iter(cx)) {
eprint!(" {}", operand.display_with_infer_cx(cx));
}
eprintln!();
}
};
if let Some(sigs) = spirv_type_constraints::instruction_signatures(inst.class.opcode) {
assert_ne!(inst.class.opcode, Op::Function);
debug_dump_if_enabled(self, " -> ");
let m = match self.match_inst_sigs(
sigs,
inst,
inputs_generic_args.clone(),
type_of_result.clone(),
) {
Ok(m) => m,
Err(Unapplicable) => unreachable!(
"spirv_type_constraints(Op{:?}) = `{:?}` doesn't match `{:?}`",
inst.class.opcode, sigs, inst
),
};
if self.specializer.debug {
if inst_loc != InstructionLocation::Module {
eprint!(" ");
}
eprintln!(" found {:?}", m.debug_with_infer_cx(self));
}
if let Err(e) = self.equate_match_findings(m) {
e.report(inst);
}
debug_dump_if_enabled(self, " <- ");
} else {
debug_dump_if_enabled(self, "");
}
if let Some(type_of_result) = type_of_result {
match type_of_result {
InferOperand::Var(_) | InferOperand::Instance(_) => {
self.type_of_result
.insert(inst.result_id.unwrap(), type_of_result);
}
InferOperand::Unknown | InferOperand::Concrete(_) => {}
}
}
}
fn instantiate_function(&mut self, func: &'a Function) {
let func_id = func.def_id().unwrap();
if self.specializer.debug {
eprintln!();
eprint!("specializer::instantiate_function(%{func_id}");
if let Some(name) = self.specializer.debug_names.get(&func_id) {
eprint!(" {name}");
}
eprintln!("):");
}
assert!(self.infer_var_values.is_empty());
self.instantiate_instruction(func.def.as_ref().unwrap(), InstructionLocation::Module);
if self.specializer.debug {
eprintln!("infer body {{");
}
let ret_ty = match self.type_of_result.get(&func_id).cloned() {
Some(InferOperand::Instance(instance)) => {
let generic = &self.specializer.generics[&instance.generic_id];
assert_eq!(generic.def.class.opcode, Op::TypeFunction);
let (ret_ty, mut params_ty_list) = InferOperandList {
operands: &generic.def.operands,
all_generic_args: instance.generic_args,
transform: None,
}
.split_first(self)
.unwrap();
let mut params = func.parameters.iter().enumerate();
while let Some((param_ty, rest)) = params_ty_list.split_first(self) {
params_ty_list = rest;
let (i, param) = params.next().unwrap();
assert_eq!(param.class.opcode, Op::FunctionParameter);
if self.specializer.debug {
eprintln!(
" %{} = Op{:?} {}",
param.result_id.unwrap(),
param.class.opcode,
param_ty.display_with_infer_cx(self)
);
}
self.record_instantiated_operand(
OperandLocation {
inst_loc: InstructionLocation::FnParam(i),
operand_idx: OperandIdx::ResultType,
},
param_ty.clone(),
);
match param_ty {
InferOperand::Var(_) | InferOperand::Instance(_) => {
self.type_of_result
.insert(param.result_id.unwrap(), param_ty);
}
InferOperand::Unknown | InferOperand::Concrete(_) => {}
}
}
assert_eq!(params.next(), None);
Some(ret_ty)
}
_ => None,
};
for (block_idx, block) in func.blocks.iter().enumerate() {
for (inst_idx, inst) in block.instructions.iter().enumerate() {
match inst.class.opcode {
Op::ReturnValue => {
let ret_val_id = inst.operands[0].unwrap_id_ref();
if let (Some(expected), Some(found)) = (
ret_ty.clone(),
self.type_of_result.get(&ret_val_id).cloned(),
) {
if let Err(e) = self.equate_infer_operands(expected, found) {
e.report(inst);
}
}
}
Op::Return => {}
_ => self.instantiate_instruction(
inst,
InstructionLocation::FnBody {
block_idx,
inst_idx,
},
),
}
}
}
if self.specializer.debug {
eprint!("}}");
if let Some(func_ty) = self.type_of_result.get(&func_id) {
eprint!(" -> %{}: {}", func_id, func_ty.display_with_infer_cx(self));
}
eprintln!();
}
}
fn resolve_infer_var_to_concrete_or_param(
&mut self,
v: InferVar,
generic_params: RangeTo<Param>,
) -> ConcreteOrParam {
let v = self.resolve_infer_var(v);
let InferVar(i) = v;
match self.infer_var_values[i as usize] {
Value::SameAs(_) => unreachable!(),
Value::Unknown => {
if i < generic_params.end.0 {
ConcreteOrParam::Param(Param(i))
} else {
ConcreteOrParam::Concrete(
CopyOperand::try_from(&self.specializer.specialization.concrete_fallback())
.unwrap(),
)
}
}
Value::Known(x) => ConcreteOrParam::Concrete(x),
}
}
fn into_replacements(mut self, generic_params: RangeTo<Param>) -> Replacements {
let mut with_instance: IndexMap<_, Vec<_>> = IndexMap::new();
for (loc, instance) in mem::take(&mut self.instantiated_operands) {
with_instance
.entry(Instance {
generic_id: instance.generic_id,
generic_args: InferVar::range_iter(&instance.generic_args)
.map(|v| self.resolve_infer_var_to_concrete_or_param(v, generic_params))
.collect(),
})
.or_default()
.push(loc);
}
let with_concrete_or_param = mem::take(&mut self.inferred_operands)
.into_iter()
.map(|(loc, v)| {
(
loc,
self.resolve_infer_var_to_concrete_or_param(v, generic_params),
)
})
.collect();
Replacements {
with_instance,
with_concrete_or_param,
}
}
}
struct Expander<'a, S: Specialization> {
specializer: &'a Specializer<S>,
builder: Builder,
instances: BTreeMap<Instance<SmallVec<[CopyOperand; 4]>>, Word>,
propagate_instances_queue: VecDeque<Instance<SmallVec<[CopyOperand; 4]>>>,
}
impl<'a, S: Specialization> Expander<'a, S> {
fn new(specializer: &'a Specializer<S>, module: Module) -> Self {
Expander {
specializer,
builder: Builder::new_from_module(module),
instances: BTreeMap::new(),
propagate_instances_queue: VecDeque::new(),
}
}
fn all_instances_of(
&self,
generic_id: Word,
) -> std::collections::btree_map::Range<'_, Instance<SmallVec<[CopyOperand; 4]>>, Word> {
let first_instance_of = |generic_id| Instance {
generic_id,
generic_args: SmallVec::new(),
};
self.instances
.range(first_instance_of(generic_id)..first_instance_of(generic_id + 1))
}
fn alloc_instance_id(&mut self, instance: Instance<SmallVec<[CopyOperand; 4]>>) -> Word {
use std::collections::btree_map::Entry;
match self.instances.entry(instance) {
Entry::Occupied(entry) => *entry.get(),
Entry::Vacant(entry) => {
let instance = entry.key().clone();
self.propagate_instances_queue.push_back(instance);
*entry.insert(self.builder.id())
}
}
}
fn propagate_instances(&mut self) {
while let Some(instance) = self.propagate_instances_queue.pop_back() {
for _ in self.specializer.generics[&instance.generic_id]
.replacements
.to_concrete(&instance.generic_args, |i| self.alloc_instance_id(i))
{}
}
}
fn expand_module(mut self) -> Module {
self.propagate_instances();
let module = self.builder.module_mut();
let mut entry_points = mem::take(&mut module.entry_points);
let debug_names = mem::take(&mut module.debug_names);
let annotations = mem::take(&mut module.annotations);
let types_global_values = mem::take(&mut module.types_global_values);
let functions = mem::take(&mut module.functions);
for inst in &mut entry_points {
let func_id = inst.operands[1].unwrap_id_ref();
assert!(
!self.specializer.generics.contains_key(&func_id),
"entry-point %{func_id} shouldn't be \"generic\""
);
for interface_operand in &mut inst.operands[3..] {
let interface_id = interface_operand.unwrap_id_ref();
let mut instances = self.all_instances_of(interface_id);
match (instances.next(), instances.next()) {
(None, _) => unreachable!(
"entry-point %{} has overly-\"generic\" \
interface variable %{}, with no instances",
func_id, interface_id
),
(Some(_), Some(_)) => unreachable!(
"entry-point %{} has overly-\"generic\" \
interface variable %{}, with too many instances: {:?}",
func_id,
interface_id,
FmtBy(|f| f
.debug_list()
.entries(self.all_instances_of(interface_id).map(
|(instance, _)| FmtBy(move |f| write!(
f,
"{}",
instance.display(|generic_args| generic_args.iter().copied())
))
))
.finish())
),
(Some((_, &instance_id)), None) => {
*interface_operand = Operand::IdRef(instance_id);
}
}
}
}
let expand_debug_or_annotation = |insts: Vec<Instruction>| {
let mut expanded_insts = Vec::with_capacity(insts.len().next_power_of_two());
for inst in insts {
if let [Operand::IdRef(target), ..] = inst.operands[..] {
if self.specializer.generics.contains_key(&target) {
expanded_insts.extend(self.all_instances_of(target).map(
|(_, &instance_id)| {
let mut expanded_inst = inst.clone();
expanded_inst.operands[0] = Operand::IdRef(instance_id);
expanded_inst
},
));
continue;
}
}
expanded_insts.push(inst);
}
expanded_insts
};
let expanded_debug_names = expand_debug_or_annotation(debug_names);
let mut expanded_annotations = expand_debug_or_annotation(annotations);
let mut expanded_types_global_values =
Vec::with_capacity(types_global_values.len().next_power_of_two());
for inst in types_global_values {
if let Some(result_id) = inst.result_id {
if let Some(generic) = self.specializer.generics.get(&result_id) {
expanded_types_global_values.extend(self.all_instances_of(result_id).map(
|(instance, &instance_id)| {
let mut expanded_inst = inst.clone();
expanded_inst.result_id = Some(instance_id);
for (loc, operand) in generic
.replacements
.to_concrete(&instance.generic_args, |i| self.instances[&i])
{
expanded_inst.index_set(loc, operand.into());
}
expanded_inst
},
));
continue;
}
}
expanded_types_global_values.push(inst);
}
let mut expanded_functions = Vec::with_capacity(functions.len().next_power_of_two());
for func in functions {
let func_id = func.def_id().unwrap();
if let Some(generic) = self.specializer.generics.get(&func_id) {
let old_expanded_functions_len = expanded_functions.len();
expanded_functions.extend(self.all_instances_of(func_id).map(
|(instance, &instance_id)| {
let mut expanded_func = func.clone();
expanded_func.def.as_mut().unwrap().result_id = Some(instance_id);
for (loc, operand) in generic
.replacements
.to_concrete(&instance.generic_args, |i| self.instances[&i])
{
expanded_func.index_set(loc, operand.into());
}
expanded_func
},
));
let newly_expanded_functions =
&mut expanded_functions[old_expanded_functions_len..];
if newly_expanded_functions.len() > 1 {
let mut rewrite_rules = FxHashMap::default();
for func in newly_expanded_functions {
rewrite_rules.clear();
rewrite_rules.extend(func.parameters.iter_mut().map(|param| {
let old_id = param.result_id.unwrap();
let new_id = self.builder.id();
param.result_id = Some(new_id);
(old_id, new_id)
}));
rewrite_rules.extend(
func.blocks
.iter()
.flat_map(|b| b.label.iter().chain(b.instructions.iter()))
.filter_map(|inst| inst.result_id)
.map(|old_id| (old_id, self.builder.id())),
);
super::apply_rewrite_rules(&rewrite_rules, &mut func.blocks);
for annotation_idx in 0..expanded_annotations.len() {
let inst = &expanded_annotations[annotation_idx];
if let [Operand::IdRef(target), ..] = inst.operands[..] {
if let Some(&rewritten_target) = rewrite_rules.get(&target) {
let mut expanded_inst = inst.clone();
expanded_inst.operands[0] = Operand::IdRef(rewritten_target);
expanded_annotations.push(expanded_inst);
}
}
}
}
}
continue;
}
expanded_functions.push(func);
}
assert!(self.propagate_instances_queue.is_empty());
let module = self.builder.module_mut();
module.entry_points = entry_points;
module.debug_names = expanded_debug_names;
module.annotations = expanded_annotations;
module.types_global_values = expanded_types_global_values;
module.functions = expanded_functions;
self.builder.module()
}
fn dump_instances(&self, w: &mut impl io::Write) -> io::Result<()> {
writeln!(w, "; All specializer \"generic\"s and their instances:")?;
writeln!(w)?;
for (&generic_id, generic) in &self.specializer.generics {
if let Some(name) = self.specializer.debug_names.get(&generic_id) {
writeln!(w, "; {name}")?;
}
write!(
w,
"{} = Op{:?}",
Instance {
generic_id,
generic_args: Param(0)..Param(generic.param_count)
}
.display(Param::range_iter),
generic.def.class.opcode
)?;
let mut next_param = Param(0);
for operand in generic
.def
.result_type
.map(Operand::IdRef)
.iter()
.chain(generic.def.operands.iter())
{
write!(w, " ")?;
let (needed, used_generic) = self.specializer.params_needed_by(operand);
let params = next_param..Param(next_param.0 + needed);
if generic.def.class.opcode != Op::Function {
next_param = params.end;
}
if used_generic.is_some() {
write!(
w,
"{}",
Instance {
generic_id: operand.unwrap_id_ref(),
generic_args: params
}
.display(Param::range_iter)
)?;
} else if needed == 1 {
write!(w, "{}", params.start)?;
} else {
write!(w, "{operand}")?;
}
}
writeln!(w)?;
if let Some(param_values) = &generic.param_values {
write!(w, " where")?;
for (i, v) in param_values.iter().enumerate() {
let p = Param(i as u32);
match v {
Value::Unknown => {}
Value::Known(o) => write!(w, " {p} = {o},")?,
Value::SameAs(q) => write!(w, " {p} = {q},")?,
}
}
writeln!(w)?;
}
for (instance, instance_id) in self.all_instances_of(generic_id) {
assert_eq!(instance.generic_id, generic_id);
writeln!(
w,
" %{} = {}",
instance_id,
instance.display(|generic_args| generic_args.iter().copied())
)?;
}
writeln!(w)?;
}
Ok(())
}
}