spirv_linker/
lib.rs

1#[cfg(test)]
2mod test;
3
4use rspirv::binary::Consumer;
5use rspirv::binary::Disassemble;
6use rspirv::spirv;
7use std::collections::{HashMap, HashSet};
8use thiserror::Error;
9use topological_sort::TopologicalSort;
10
11#[derive(Error, Debug, PartialEq)]
12pub enum LinkerError {
13    #[error("Unresolved symbol {:?}", .0)]
14    UnresolvedSymbol(String),
15    #[error("Multiple exports found for {:?}", .0)]
16    MultipleExports(String),
17    #[error("Types mismatch for {:?}, imported with type {:?}, exported with type {:?}", .name, .import_type, .export_type)]
18    TypeMismatch {
19        name: String,
20        import_type: String,
21        export_type: String,
22    },
23    #[error("unknown data store error")]
24    Unknown,
25}
26
27pub type Result<T> = std::result::Result<T, LinkerError>;
28
29pub fn load(bytes: &[u8]) -> rspirv::dr::Module {
30    let mut loader = rspirv::dr::Loader::new();
31    rspirv::binary::parse_bytes(&bytes, &mut loader).unwrap();
32    let module = loader.module();
33    module
34}
35
36fn shift_ids(module: &mut rspirv::dr::Module, add: u32) {
37    module.all_inst_iter_mut().for_each(|inst| {
38        if let Some(ref mut result_id) = &mut inst.result_id {
39            *result_id += add;
40        }
41
42        if let Some(ref mut result_type) = &mut inst.result_type {
43            *result_type += add;
44        }
45
46        inst.operands.iter_mut().for_each(|op| match op {
47            rspirv::dr::Operand::IdMemorySemantics(w)
48            | rspirv::dr::Operand::IdScope(w)
49            | rspirv::dr::Operand::IdRef(w) => *w += add,
50            _ => {}
51        })
52    });
53}
54
55fn replace_all_uses_with(module: &mut rspirv::dr::Module, before: u32, after: u32) {
56    module.all_inst_iter_mut().for_each(|inst| {
57        if let Some(ref mut result_type) = &mut inst.result_type {
58            if *result_type == before {
59                *result_type = after;
60            }
61        }
62
63        inst.operands.iter_mut().for_each(|op| match op {
64            rspirv::dr::Operand::IdMemorySemantics(w)
65            | rspirv::dr::Operand::IdScope(w)
66            | rspirv::dr::Operand::IdRef(w) => {
67                if *w == before {
68                    *w = after
69                }
70            }
71            _ => {}
72        })
73    });
74}
75
76fn remove_duplicate_capablities(module: &mut rspirv::dr::Module) {
77    let mut set = HashSet::new();
78    let mut caps = vec![];
79
80    for c in &module.capabilities {
81        let keep = match c.operands[0] {
82            rspirv::dr::Operand::Capability(cap) => set.insert(cap),
83            _ => true,
84        };
85
86        if keep {
87            caps.push(c.clone());
88        }
89    }
90
91    module.capabilities = caps;
92}
93
94fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) {
95    let mut set = HashSet::new();
96    let mut caps = vec![];
97
98    for c in &module.ext_inst_imports {
99        let keep = match &c.operands[0] {
100            rspirv::dr::Operand::LiteralString(ext_inst_import) => set.insert(ext_inst_import),
101            _ => true,
102        };
103
104        if keep {
105            caps.push(c.clone());
106        }
107    }
108
109    module.ext_inst_imports = caps;
110}
111
112fn kill_with_id(insts: &mut Vec<rspirv::dr::Instruction>, id: u32) {
113    kill_with(insts, |inst| {
114        if inst.operands.is_empty() {
115            return false;
116        }
117
118        match inst.operands[0] {
119            rspirv::dr::Operand::IdMemorySemantics(w)
120            | rspirv::dr::Operand::IdScope(w)
121            | rspirv::dr::Operand::IdRef(w)
122                if w == id =>
123            {
124                true
125            }
126            _ => false,
127        }
128    })
129}
130
131fn kill_with<F>(insts: &mut Vec<rspirv::dr::Instruction>, f: F)
132where
133    F: Fn(&rspirv::dr::Instruction) -> bool,
134{
135    if insts.is_empty() {
136        return;
137    }
138
139    let mut idx = insts.len() - 1;
140    // odd backwards loop so we can swap_remove
141    loop {
142        if f(&insts[idx]) {
143            insts.swap_remove(idx);
144        }
145
146        if idx == 0 || insts.is_empty() {
147            break;
148        }
149
150        idx -= 1;
151    }
152}
153
154fn kill_annotations_and_debug(module: &mut rspirv::dr::Module, id: u32) {
155    kill_with_id(&mut module.annotations, id);
156
157    // need to remove OpGroupDecorate members that mention this id
158    module.annotations.iter_mut().for_each(|inst| {
159        if inst.class.opcode == spirv::Op::GroupDecorate {
160            inst.operands.retain(|op| match op {
161                rspirv::dr::Operand::IdRef(w) if *w != id => return true,
162                _ => return false,
163            });
164        }
165    });
166
167    kill_with_id(&mut module.debug_string_source, id);
168    kill_with_id(&mut module.debug_names, id);
169    kill_with_id(&mut module.debug_module_processed, id);
170}
171
172fn remove_duplicate_types(module: rspirv::dr::Module) -> rspirv::dr::Module {
173    use rspirv::binary::Assemble;
174
175    // jb-todo: spirv-tools's linker has special case handling for SpvOpTypeForwardPointer,
176    // not sure if we need that; see https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp for reference
177
178    let mut instructions = module
179        .all_inst_iter()
180        .cloned()
181        .collect::<Vec<_>>()
182        .into_boxed_slice(); // force boxed slice so we don't accidentally grow or shrink it later
183
184    let mut def_use_analyzer = DefUseAnalyzer::new(&mut instructions);
185
186    let mut kill_annotations = vec![];
187    let mut continue_from_idx = 0;
188
189    // need to do this process iteratively because types can reference each other
190    loop {
191        let mut dedup = std::collections::HashMap::new();
192        let mut duplicate = None;
193
194        for (iterator_idx, module_inst) in module
195            .types_global_values
196            .iter()
197            .enumerate()
198            .skip(continue_from_idx)
199        {
200            let (inst_idx, inst) = def_use_analyzer.def(module_inst.result_id.unwrap());
201
202            if inst.class.opcode == spirv::Op::Nop {
203                continue;
204            }
205
206            // partially assemble only the opcode and operands to be used as a key
207            // maybe this should also include the result_type
208            let data = {
209                let mut data = vec![];
210
211                data.push(inst.class.opcode as u32);
212                for op in &inst.operands {
213                    op.assemble_into(&mut data);
214                }
215
216                data
217            };
218
219            // dedup contains a tuple of three indices;
220            // the first two point into our `def_use_analyzer.instructions` map
221            // the last one points into the `module.types_global_values` iterator so we can resume iteration
222            dedup
223                .entry(data)
224                .and_modify(|(identical_idx, backtrack_idx)| {
225                    duplicate = Some((inst_idx, *identical_idx, *backtrack_idx));
226                })
227                .or_insert((inst_idx, iterator_idx)); // store the index that we encountered an instruction
228                                                      // for the first time so we can backtrack later
229
230            if let Some((_, _, backtrack_idx)) = duplicate {
231                continue_from_idx = backtrack_idx;
232                break;
233            }
234        }
235
236        if let Some((before_idx, after_idx, _)) = duplicate {
237            let before_id = def_use_analyzer.instructions[before_idx].result_id.unwrap();
238            let after_id = def_use_analyzer.instructions[after_idx].result_id.unwrap();
239
240            // remove annotations later
241            kill_annotations.push(before_id);
242
243            def_use_analyzer.for_each_use(before_id, |inst| {
244                if inst.result_type == Some(before_id) {
245                    inst.result_type = Some(after_id);
246                }
247
248                for op in inst.operands.iter_mut() {
249                    match op {
250                        rspirv::dr::Operand::IdMemorySemantics(w)
251                        | rspirv::dr::Operand::IdScope(w)
252                        | rspirv::dr::Operand::IdRef(w) => {
253                            if *w == before_id {
254                                *w = after_id
255                            }
256                        }
257                        _ => {}
258                    }
259                }
260            });
261
262            // this loop / system works on the assumption that all indices remain valid,
263            // so instead of removing the instruction we just nop it out - `consume_instruction` will then
264            // skip all OpNops and they won't appear in the newly constructed module
265            def_use_analyzer.instructions[before_idx] =
266                rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]);
267        } else {
268            break;
269        }
270    }
271
272    let mut loader = rspirv::dr::Loader::new();
273
274    for inst in def_use_analyzer.instructions.iter() {
275        loader.consume_instruction(inst.clone());
276    }
277
278    let mut module = loader.module();
279
280    for remove in kill_annotations {
281        kill_annotations_and_debug(&mut module, remove);
282    }
283
284    module
285}
286
287#[derive(Clone, Debug)]
288struct LinkSymbol {
289    name: String,
290    id: u32,
291    type_id: u32,
292    parameters: Vec<rspirv::dr::Instruction>,
293}
294
295#[derive(Debug)]
296struct ImportExportPair {
297    import: LinkSymbol,
298    export: LinkSymbol,
299}
300
301#[derive(Debug)]
302struct LinkInfo {
303    imports: Vec<LinkSymbol>,
304    exports: HashMap<String, Vec<LinkSymbol>>,
305    potential_pairs: Vec<ImportExportPair>,
306}
307
308fn inst_fully_eq(a: &rspirv::dr::Instruction, b: &rspirv::dr::Instruction) -> bool {
309    // both function instructions need to be 100% identical so check all members
310    // jb-todo: derive(PartialEq) on Instruction?
311    a.result_id == b.result_id
312        && a.class == b.class
313        && a.result_type == b.result_type
314        && a.operands == b.operands
315}
316
317fn find_import_export_pairs(module: &rspirv::dr::Module, defs: &DefAnalyzer) -> Result<LinkInfo> {
318    let mut imports = vec![];
319    let mut exports: HashMap<String, Vec<LinkSymbol>> = HashMap::new();
320
321    for annotation in &module.annotations {
322        if annotation.class.opcode == spirv::Op::Decorate
323            && annotation.operands[1]
324                == rspirv::dr::Operand::Decoration(spirv::Decoration::LinkageAttributes)
325        {
326            let id = match annotation.operands[0] {
327                rspirv::dr::Operand::IdRef(i) => i,
328                _ => panic!("Expected IdRef"),
329            };
330
331            let name = match &annotation.operands[2] {
332                rspirv::dr::Operand::LiteralString(s) => s,
333                _ => panic!("Expected LiteralString"),
334            };
335
336            let ty = &annotation.operands[3];
337
338            let def_inst = defs
339                .def(id)
340                .expect(&format!("Need a matching op for ID {}", id));
341
342            let (type_id, parameters) = match def_inst.class.opcode {
343                spirv::Op::Variable => (def_inst.result_type.unwrap(), vec![]),
344                spirv::Op::Function => {
345                    let type_id = if let rspirv::dr::Operand::IdRef(id) = &def_inst.operands[1] {
346                        *id
347                    } else {
348                        panic!("Expected IdRef");
349                    };
350
351                    let def_fn = module
352                        .functions
353                        .iter()
354                        .find(|f| inst_fully_eq(f.def.as_ref().unwrap(), def_inst))
355                        .unwrap();
356
357                    (type_id, def_fn.parameters.clone())
358                }
359                _ => panic!("Unexpected op"),
360            };
361
362            let symbol = LinkSymbol {
363                name: name.to_string(),
364                id,
365                type_id,
366                parameters,
367            };
368
369            if ty == &rspirv::dr::Operand::LinkageType(spirv::LinkageType::Import) {
370                imports.push(symbol);
371            } else {
372                exports
373                    .entry(symbol.name.clone())
374                    .and_modify(|v| v.push(symbol.clone()))
375                    .or_insert_with(|| vec![symbol.clone()]);
376            }
377        }
378    }
379
380    LinkInfo {
381        imports,
382        exports,
383        potential_pairs: vec![],
384    }
385    .find_potential_pairs()
386}
387
388fn cleanup_type(mut ty: rspirv::dr::Instruction) -> String {
389    ty.result_id = None;
390    ty.disassemble()
391}
392
393impl LinkInfo {
394    fn find_potential_pairs(mut self) -> Result<Self> {
395        for import in &self.imports {
396            let potential_matching_exports = self.exports.get(&import.name);
397            if let Some(potential_matching_exports) = potential_matching_exports {
398                if potential_matching_exports.len() > 1 {
399                    return Err(LinkerError::MultipleExports(import.name.clone()));
400                }
401
402                self.potential_pairs.push(ImportExportPair {
403                    import: import.clone(),
404                    export: potential_matching_exports.first().unwrap().clone(),
405                });
406            } else {
407                return Err(LinkerError::UnresolvedSymbol(import.name.clone()));
408            }
409        }
410
411        Ok(self)
412    }
413
414    /// returns the list of matching import / export pairs after validation the list of potential pairs
415    fn ensure_matching_import_export_pairs(
416        &self,
417        defs: &DefAnalyzer,
418    ) -> Result<&Vec<ImportExportPair>> {
419        for pair in &self.potential_pairs {
420            let import_result_type = defs.def(pair.import.type_id).unwrap();
421            let export_result_type = defs.def(pair.export.type_id).unwrap();
422
423            let imp = trans_aggregate_type(defs, import_result_type);
424            let exp = trans_aggregate_type(defs, export_result_type);
425
426            if imp != exp {
427                return Err(LinkerError::TypeMismatch {
428                    name: pair.import.name.clone(),
429                    import_type: cleanup_type(import_result_type.clone()),
430                    export_type: cleanup_type(export_result_type.clone()),
431                });
432            }
433
434            for (import_param, export_param) in pair
435                .import
436                .parameters
437                .iter()
438                .zip(pair.export.parameters.iter())
439            {
440                if !import_param.is_type_identical(export_param) {
441                    panic!("Type error in signatures")
442                }
443
444                // jb-todo: validate that OpDecoration is identical too
445            }
446        }
447
448        Ok(&self.potential_pairs)
449    }
450}
451
452struct DefAnalyzer {
453    def_ids: HashMap<u32, rspirv::dr::Instruction>,
454}
455
456impl DefAnalyzer {
457    fn new(module: &rspirv::dr::Module) -> Self {
458        let mut def_ids = HashMap::new();
459
460        module.all_inst_iter().for_each(|inst| {
461            if let Some(def_id) = inst.result_id {
462                def_ids
463                    .entry(def_id)
464                    .and_modify(|stored_inst| {
465                        *stored_inst = inst.clone();
466                    })
467                    .or_insert(inst.clone());
468            }
469        });
470
471        Self { def_ids }
472    }
473
474    fn def(&self, id: u32) -> Option<&rspirv::dr::Instruction> {
475        self.def_ids.get(&id)
476    }
477}
478
479struct DefUseAnalyzer<'a> {
480    def_ids: HashMap<u32, usize>,
481    use_ids: HashMap<u32, Vec<usize>>,
482    use_result_type_ids: HashMap<u32, Vec<usize>>,
483    instructions: &'a mut [rspirv::dr::Instruction]
484}
485
486impl<'a> DefUseAnalyzer<'a> {
487    fn new(instructions: &'a mut [rspirv::dr::Instruction]) -> Self{
488        let mut def_ids = HashMap::new();
489        let mut use_ids: HashMap<u32, Vec<usize>> = HashMap::new();
490        let mut use_result_type_ids: HashMap<u32, Vec<usize>> = HashMap::new();
491
492        instructions
493            .iter()
494            .enumerate()
495            .for_each(|(inst_idx, inst)| {
496                if let Some(def_id) = inst.result_id {
497                    def_ids
498                        .entry(def_id)
499                        .and_modify(|stored_inst| {
500                            *stored_inst = inst_idx;
501                        })
502                        .or_insert(inst_idx);
503                }
504
505                if let Some(result_type) = inst.result_type {
506                    use_result_type_ids
507                        .entry(result_type)
508                        .and_modify(|v| v.push(inst_idx))
509                        .or_insert(vec![inst_idx]);
510                }
511
512                for op in inst.operands.iter() {
513                    match op {
514                        rspirv::dr::Operand::IdMemorySemantics(w)
515                        | rspirv::dr::Operand::IdScope(w)
516                        | rspirv::dr::Operand::IdRef(w) => {
517                            use_ids
518                                .entry(*w)
519                                .and_modify(|v| v.push(inst_idx))
520                                .or_insert(vec![inst_idx]);
521                        }
522                        _ => {}
523                    }
524                }
525            });
526
527        Self {
528            def_ids,
529            use_ids,
530            use_result_type_ids,
531            instructions
532        }
533    }
534
535    fn def_idx(&self, id: u32) -> usize {
536        self.def_ids[&id]
537    }
538
539    fn def(&self, id: u32) -> (usize, &rspirv::dr::Instruction) {
540        let idx = self.def_idx(id);
541        (idx, &self.instructions[idx])
542    }
543
544    fn for_each_use<F>(&mut self, id: u32, mut f: F) 
545    where F: FnMut(&mut rspirv::dr::Instruction) {
546        // find by `result_type`
547        if let Some(use_result_type_id) = self.use_result_type_ids.get(&id) {
548            for inst_idx in use_result_type_id {
549                f(&mut self.instructions[*inst_idx])
550            }
551        }
552
553        // find by operand
554        if let Some(use_id) = self.use_ids.get(&id) {
555            for inst_idx in use_id {
556                f(&mut self.instructions[*inst_idx]);
557            }
558        }
559    }
560}
561
562fn import_kill_annotations_and_debug(module: &mut rspirv::dr::Module, info: &LinkInfo) {
563    for import in &info.imports {
564        kill_annotations_and_debug(module, import.id);
565        for param in &import.parameters {
566            kill_annotations_and_debug(module, param.result_id.unwrap())
567        }
568    }
569}
570
571pub struct Options {
572    /// `true` if we're creating a library
573    pub lib: bool,
574
575    /// `true` if partial linking is allowed
576    pub partial: bool,
577}
578
579impl Default for Options {
580    fn default() -> Self {
581        Self {
582            lib: false,
583            partial: false,
584        }
585    }
586}
587
588fn kill_linkage_instructions(
589    pairs: &Vec<ImportExportPair>,
590    module: &mut rspirv::dr::Module,
591    opts: &Options,
592) {
593    // drop imported functions
594    for pair in pairs.iter() {
595        module
596            .functions
597            .retain(|f| pair.import.id != f.def.as_ref().unwrap().result_id.unwrap());
598    }
599
600    // drop imported variables
601    for pair in pairs.iter() {
602        module
603            .types_global_values
604            .retain(|v| pair.import.id != v.result_id.unwrap());
605    }
606
607    // drop linkage attributes (both import and export)
608    kill_with(&mut module.annotations, |inst| {
609        let eq = pairs
610            .iter()
611            .find(|p| {
612                if inst.operands.is_empty() {
613                    return false;
614                }
615
616                if let rspirv::dr::Operand::IdRef(id) = inst.operands[0] {
617                    id == p.import.id || id == p.export.id
618                } else {
619                    false
620                }
621            })
622            .is_some();
623
624        eq && inst.class.opcode == spirv::Op::Decorate
625            && inst.operands[1]
626                == rspirv::dr::Operand::Decoration(spirv::Decoration::LinkageAttributes)
627    });
628
629    if !opts.lib {
630        kill_with(&mut module.annotations, |inst| {
631            inst.class.opcode == spirv::Op::Decorate
632                && inst.operands[1]
633                    == rspirv::dr::Operand::Decoration(spirv::Decoration::LinkageAttributes)
634                && inst.operands[3] == rspirv::dr::Operand::LinkageType(spirv::LinkageType::Export)
635        });
636    }
637
638    // drop OpCapability Linkage
639    kill_with(&mut module.capabilities, |inst| {
640        inst.class.opcode == spirv::Op::Capability
641            && inst.operands[0] == rspirv::dr::Operand::Capability(spirv::Capability::Linkage)
642    })
643}
644
645fn compact_ids(module: &mut rspirv::dr::Module) -> u32 {
646    let mut remap = HashMap::new();
647
648    let mut insert = |current_id: u32| -> u32 {
649        if remap.contains_key(&current_id) {
650            remap[&current_id]
651        } else {
652            let new_id = remap.len() as u32 + 1;
653            remap.insert(current_id, new_id);
654            new_id
655        }
656    };
657
658    module.all_inst_iter_mut().for_each(|inst| {
659        if let Some(ref mut result_id) = &mut inst.result_id {
660            *result_id = insert(*result_id);
661        }
662
663        if let Some(ref mut result_type) = &mut inst.result_type {
664            *result_type = insert(*result_type);
665        }
666
667        inst.operands.iter_mut().for_each(|op| match op {
668            rspirv::dr::Operand::IdMemorySemantics(w)
669            | rspirv::dr::Operand::IdScope(w)
670            | rspirv::dr::Operand::IdRef(w) => {
671                *w = insert(*w);
672            }
673            _ => {}
674        })
675    });
676
677    remap.len() as u32 + 1
678}
679
680fn sort_globals(module: &mut rspirv::dr::Module) {
681    let mut ts = TopologicalSort::<u32>::new();
682
683    for t in module.types_global_values.iter() {
684        if let Some(result_id) = t.result_id {
685            if let Some(result_type) = t.result_type {
686                ts.add_dependency(result_type, result_id);
687            }
688
689            for op in &t.operands {
690                match op {
691                    rspirv::dr::Operand::IdMemorySemantics(w)
692                    | rspirv::dr::Operand::IdScope(w)
693                    | rspirv::dr::Operand::IdRef(w) => {
694                        ts.add_dependency(*w, result_id); // the op defining the IdRef should come before our op / result_id
695                    }
696                    _ => {}
697                }
698            }
699        }
700    }
701
702    let defs = DefAnalyzer::new(&module);
703
704    let mut new_types_global_values = vec![];
705
706    loop {
707        if ts.is_empty() {
708            break;
709        }
710
711        let mut v = ts.pop_all();
712        v.sort();
713
714        for result_id in v {
715            new_types_global_values.push(defs.def(result_id).unwrap().clone());
716        }
717    }
718
719    assert!(module.types_global_values.len() == new_types_global_values.len());
720
721    module.types_global_values = new_types_global_values;
722}
723
724#[derive(PartialEq, Debug)]
725enum ScalarType {
726    Void,
727    Bool,
728    Int { width: u32, signed: bool },
729    Float { width: u32 },
730    Opaque { name: String },
731    Event,
732    DeviceEvent,
733    ReserveId,
734    Queue,
735    Pipe,
736    ForwardPointer { storage_class: spirv::StorageClass },
737    PipeStorage,
738    NamedBarrier,
739    Sampler,
740}
741
742fn trans_scalar_type(inst: &rspirv::dr::Instruction) -> Option<ScalarType> {
743    Some(match inst.class.opcode {
744        spirv::Op::TypeVoid => ScalarType::Void,
745        spirv::Op::TypeBool => ScalarType::Bool,
746        spirv::Op::TypeEvent => ScalarType::Event,
747        spirv::Op::TypeDeviceEvent => ScalarType::DeviceEvent,
748        spirv::Op::TypeReserveId => ScalarType::ReserveId,
749        spirv::Op::TypeQueue => ScalarType::Queue,
750        spirv::Op::TypePipe => ScalarType::Pipe,
751        spirv::Op::TypePipeStorage => ScalarType::PipeStorage,
752        spirv::Op::TypeNamedBarrier => ScalarType::NamedBarrier,
753        spirv::Op::TypeSampler => ScalarType::Sampler,
754        spirv::Op::TypeForwardPointer => ScalarType::ForwardPointer {
755            storage_class: match inst.operands[0] {
756                rspirv::dr::Operand::StorageClass(s) => s,
757                _ => panic!("Unexpected operand while parsing type"),
758            },
759        },
760        spirv::Op::TypeInt => ScalarType::Int {
761            width: match inst.operands[0] {
762                rspirv::dr::Operand::LiteralInt32(w) => w,
763                _ => panic!("Unexpected operand while parsing type"),
764            },
765            signed: match inst.operands[1] {
766                rspirv::dr::Operand::LiteralInt32(s) => {
767                    if s == 0 {
768                        false
769                    } else {
770                        true
771                    }
772                }
773                _ => panic!("Unexpected operand while parsing type"),
774            },
775        },
776        spirv::Op::TypeFloat => ScalarType::Float {
777            width: match inst.operands[0] {
778                rspirv::dr::Operand::LiteralInt32(w) => w,
779                _ => panic!("Unexpected operand while parsing type"),
780            },
781        },
782        spirv::Op::TypeOpaque => ScalarType::Opaque {
783            name: match &inst.operands[0] {
784                rspirv::dr::Operand::LiteralString(s) => s.clone(),
785                _ => panic!("Unexpected operand while parsing type"),
786            },
787        },
788        _ => return None,
789    })
790}
791
792#[derive(PartialEq, Debug)]
793enum AggregateType {
794    Scalar(ScalarType),
795    Array {
796        ty: Box<AggregateType>,
797        len: u64,
798    },
799    Pointer {
800        ty: Box<AggregateType>,
801        storage_class: spirv::StorageClass,
802    },
803    Image {
804        ty: Box<AggregateType>,
805        dim: spirv::Dim,
806        depth: u32,
807        arrayed: u32,
808        multi_sampled: u32,
809        sampled: u32,
810        format: spirv::ImageFormat,
811        access: Option<spirv::AccessQualifier>,
812    },
813    SampledImage {
814        ty: Box<AggregateType>,
815    },
816    Aggregate(Vec<AggregateType>),
817}
818
819fn op_def(def: &DefAnalyzer, operand: &rspirv::dr::Operand) -> rspirv::dr::Instruction {
820    def.def(match operand {
821        rspirv::dr::Operand::IdMemorySemantics(w)
822        | rspirv::dr::Operand::IdScope(w)
823        | rspirv::dr::Operand::IdRef(w) => *w,
824        _ => panic!("Expected ID"),
825    })
826    .unwrap()
827    .clone()
828}
829
830fn extract_literal_int_as_u64(op: &rspirv::dr::Operand) -> u64 {
831    match op {
832        rspirv::dr::Operand::LiteralInt32(v) => (*v).into(),
833        rspirv::dr::Operand::LiteralInt64(v) => *v,
834        _ => panic!("Unexpected literal int"),
835    }
836}
837
838fn extract_literal_u32(op: &rspirv::dr::Operand) -> u32 {
839    match op {
840        rspirv::dr::Operand::LiteralInt32(v) => *v,
841        _ => panic!("Unexpected literal u32"),
842    }
843}
844
845fn trans_aggregate_type(
846    def: &DefAnalyzer,
847    inst: &rspirv::dr::Instruction,
848) -> Option<AggregateType> {
849    Some(match inst.class.opcode {
850        spirv::Op::TypeArray => {
851            let len_def = op_def(def, &inst.operands[1]);
852            assert!(len_def.class.opcode == spirv::Op::Constant); // don't support spec constants yet
853
854            let len_value = extract_literal_int_as_u64(&len_def.operands[1]);
855
856            AggregateType::Array {
857                ty: Box::new(
858                    trans_aggregate_type(def, &op_def(def, &inst.operands[0]))
859                        .expect("Expect base type for OpTypeArray"),
860                ),
861                len: len_value,
862            }
863        }
864        spirv::Op::TypePointer => AggregateType::Pointer {
865            storage_class: match inst.operands[0] {
866                rspirv::dr::Operand::StorageClass(s) => s,
867                _ => panic!("Unexpected operand while parsing type"),
868            },
869            ty: Box::new(
870                trans_aggregate_type(def, &op_def(def, &inst.operands[1]))
871                    .expect("Expect base type for OpTypePointer"),
872            ),
873        },
874        spirv::Op::TypeRuntimeArray
875        | spirv::Op::TypeVector
876        | spirv::Op::TypeMatrix
877        | spirv::Op::TypeSampledImage => AggregateType::Aggregate(
878            trans_aggregate_type(def, &op_def(def, &inst.operands[0]))
879                .map_or_else(|| vec![], |v| vec![v]),
880        ),
881        spirv::Op::TypeStruct | spirv::Op::TypeFunction => {
882            let mut types = vec![];
883            for operand in inst.operands.iter() {
884                let op_def = op_def(def, operand);
885
886                match trans_aggregate_type(def, &op_def) {
887                    Some(ty) => types.push(ty),
888                    None => panic!("Expected type"),
889                }
890            }
891
892            AggregateType::Aggregate(types)
893        }
894        spirv::Op::TypeImage => AggregateType::Image {
895            ty: Box::new(
896                trans_aggregate_type(def, &op_def(def, &inst.operands[0]))
897                    .expect("Expect base type for OpTypeImage"),
898            ),
899            dim: match inst.operands[1] {
900                rspirv::dr::Operand::Dim(d) => d,
901                _ => panic!("Invalid dim"),
902            },
903            depth: extract_literal_u32(&inst.operands[2]),
904            arrayed: extract_literal_u32(&inst.operands[3]),
905            multi_sampled: extract_literal_u32(&inst.operands[4]),
906            sampled: extract_literal_u32(&inst.operands[5]),
907            format: match inst.operands[6] {
908                rspirv::dr::Operand::ImageFormat(f) => f,
909                _ => panic!("Invalid image format"),
910            },
911            access: inst
912                .operands
913                .get(7)
914                .map(|op| match op {
915                    rspirv::dr::Operand::AccessQualifier(a) => Some(a.clone()),
916                    _ => None,
917                })
918                .flatten(),
919        },
920        _ => {
921            if let Some(ty) = trans_scalar_type(inst) {
922                AggregateType::Scalar(ty)
923            } else {
924                return None;
925            }
926        }
927    })
928}
929
930pub fn link(inputs: &mut [&mut rspirv::dr::Module], opts: &Options) -> Result<rspirv::dr::Module> {
931    // shift all the ids
932    let mut bound = inputs[0].header.as_ref().unwrap().bound - 1;
933
934    for mut module in inputs.iter_mut().skip(1) {
935        shift_ids(&mut module, bound);
936        bound += module.header.as_ref().unwrap().bound - 1;
937    }
938
939    // merge the binaries
940    let mut loader = rspirv::dr::Loader::new();
941
942    for module in inputs.iter() {
943        module.all_inst_iter().for_each(|inst| {
944            loader.consume_instruction(inst.clone());
945        });
946    }
947
948    let mut output = loader.module();
949
950    // find import / export pairs
951    let defs = DefAnalyzer::new(&output);
952    let info = find_import_export_pairs(&output, &defs)?;
953
954    // ensure import / export pairs have matching types and defintions
955    let matching_pairs = info.ensure_matching_import_export_pairs(&defs)?;
956
957    // remove duplicates (https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp)
958    remove_duplicate_capablities(&mut output);
959    remove_duplicate_ext_inst_imports(&mut output);
960    let mut output = remove_duplicate_types(output);
961    // jb-todo: strip identical OpDecoration / OpDecorationGroups
962
963    // remove names and decorations of import variables / functions https://github.com/KhronosGroup/SPIRV-Tools/blob/8a0ebd40f86d1f18ad42ea96c6ac53915076c3c7/source/opt/ir_context.cpp#L404
964    import_kill_annotations_and_debug(&mut output, &info);
965
966    // rematch import variables and functions to export variables / functions https://github.com/KhronosGroup/SPIRV-Tools/blob/8a0ebd40f86d1f18ad42ea96c6ac53915076c3c7/source/opt/ir_context.cpp#L255
967    for pair in matching_pairs {
968        replace_all_uses_with(&mut output, pair.import.id, pair.export.id);
969    }
970
971    // remove linkage specific instructions
972    kill_linkage_instructions(&matching_pairs, &mut output, &opts);
973
974    sort_globals(&mut output);
975
976    // compact the ids https://github.com/KhronosGroup/SPIRV-Tools/blob/e02f178a716b0c3c803ce31b9df4088596537872/source/opt/compact_ids_pass.cpp#L43
977    let bound = compact_ids(&mut output);
978    output.header = Some(rspirv::dr::ModuleHeader::new(bound));
979
980    output.debug_module_processed.push(rspirv::dr::Instruction::new(
981        spirv::Op::ModuleProcessed,
982        None,
983        None,
984        vec![rspirv::dr::Operand::LiteralString(
985            "Linked by rspirv-linker".to_string(),
986        )],
987    ));
988
989    // output the module
990    Ok(output)
991}