use super::*;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AffectedDecoration {
pub original_res_id: u32,
pub new_res_ids: Vec<u32>,
pub correction_type: CorrectionType,
}
pub struct DecorateIn<'a> {
pub spv: &'a [u32],
pub instruction_inserts: &'a mut Vec<InstructionInsert>,
pub first_op_deocrate_idx: Option<usize>,
pub op_decorate_idxs: &'a [usize],
pub affected_decorations: &'a [AffectedDecoration],
pub corrections: &'a mut Option<CorrectionMap>,
}
pub struct DecorateOut {
pub descriptor_sets_to_correct: HashSet<u32>,
}
pub fn decorate(d_in: DecorateIn) -> DecorateOut {
let DecorateIn {
spv,
instruction_inserts,
first_op_deocrate_idx,
op_decorate_idxs,
affected_decorations: affected_variables,
corrections,
} = d_in;
let mut new_variable_id_to_decorations = HashMap::new();
let mut descriptor_sets_to_correct = HashSet::new();
let mut all_descriptor_sets = corrections.is_none().then_some(HashMap::new());
op_decorate_idxs.iter().for_each(|&d_idx| {
let target_id = spv[d_idx + 1];
let decoration_id = spv[d_idx + 2];
let decoration_value = spv[d_idx + 3];
if decoration_id == SPV_DECORATION_BINDING
&& let Some(all_descriptor_sets) = all_descriptor_sets.as_mut()
{
all_descriptor_sets
.entry(target_id)
.or_insert((None, None))
.0 = Some(decoration_value);
}
if decoration_id == SPV_DECORATION_DESCRIPTOR_SET
&& let Some(all_descriptor_sets) = all_descriptor_sets.as_mut()
{
all_descriptor_sets
.entry(target_id)
.or_insert((None, None))
.1 = Some(decoration_value);
}
affected_variables.iter().for_each(
|AffectedDecoration {
original_res_id,
new_res_ids,
correction_type,
}| {
if *original_res_id == target_id {
for new_res_id in new_res_ids {
if decoration_id == SPV_DECORATION_BINDING {
new_variable_id_to_decorations
.entry((new_res_id, correction_type))
.or_insert((None, None))
.0 = Some((d_idx, decoration_value));
} else if decoration_id == SPV_DECORATION_DESCRIPTOR_SET {
new_variable_id_to_decorations
.entry((new_res_id, correction_type))
.or_insert((None, None))
.1 = Some((d_idx, decoration_value));
descriptor_sets_to_correct.insert(decoration_value);
}
}
}
},
);
});
let mut new_variable_id_to_decorations = new_variable_id_to_decorations
.into_iter()
.collect::<Vec<_>>();
new_variable_id_to_decorations.sort_by_key(|(_, (maybe_binding, _))| {
let (_, binding) = maybe_binding.unwrap();
binding
});
let new_variable_id_to_decorations = new_variable_id_to_decorations
.into_iter()
.map(|(new_res_id, (maybe_binding, maybe_descriptor_set))| {
let (binding_idx, binding) = maybe_binding.unwrap();
let (descriptor_set_idx, descriptor_set) = maybe_descriptor_set.unwrap();
(
new_res_id,
((binding_idx, binding), (descriptor_set_idx, descriptor_set)),
)
})
.collect::<HashMap<_, _>>();
if let Some(all_descriptor_sets) = all_descriptor_sets {
let mut new_corrections = CorrectionMap::default();
let mut all_descriptor_sets = all_descriptor_sets.into_iter().collect::<Vec<_>>();
all_descriptor_sets.sort_by_key(|(_, (maybe_binding, _))| maybe_binding.unwrap());
for (_, (binding, set)) in all_descriptor_sets {
let set = set.unwrap();
let binding = binding.unwrap();
new_corrections
.sets
.entry(set)
.or_insert(CorrectionSet::default())
.bindings
.insert(
binding,
CorrectionBinding {
corrections: vec![],
},
);
}
*corrections = Some(new_corrections);
}
let old_corrections = corrections.clone();
new_variable_id_to_decorations.iter().for_each(
|(
(new_res_id, correction_type),
((_binding_idx, binding), (_descriptor_set_idx, descriptor_set)),
)| {
instruction_inserts.push(InstructionInsert {
previous_spv_idx: first_op_deocrate_idx.unwrap(),
instruction: vec![
encode_word(4, SPV_INSTRUCTION_OP_DECORATE),
**new_res_id,
SPV_DECORATION_DESCRIPTOR_SET,
*descriptor_set,
encode_word(4, SPV_INSTRUCTION_OP_DECORATE),
**new_res_id,
SPV_DECORATION_BINDING,
*binding,
],
});
if let Some(bindings) = corrections.as_mut().unwrap().sets.get_mut(descriptor_set) {
let mut input_bindings = old_corrections
.as_ref()
.unwrap()
.sets
.get(descriptor_set)
.unwrap()
.bindings
.iter()
.collect::<Vec<_>>();
input_bindings.sort_by_key(|(k, _)| **k);
let input_bindings = input_bindings
.iter()
.map(|(binding, correction)| (binding, correction.corrections.len()))
.collect::<Vec<_>>();
let mut my_binding = *binding as isize;
let mut last_binding = 0;
for &(binding, binding_count) in input_bindings.iter() {
my_binding -=
binding_count as isize + **binding as isize - last_binding as isize;
last_binding = **binding;
if my_binding <= 0 {
bindings
.bindings
.get_mut(binding)
.unwrap()
.corrections
.insert(my_binding.unsigned_abs(), **correction_type);
break;
}
}
}
},
);
DecorateOut {
descriptor_sets_to_correct,
}
}