sway_ir/
function.rs

1//! A typical function data type.
2//!
3//! [`Function`] is named, takes zero or more arguments and has an optional return value.  It
4//! contains a collection of [`Block`]s.
5//!
6//! It also maintains a collection of local values which can be typically regarded as variables
7//! existing in the function scope.
8
9use std::collections::{BTreeMap, HashMap};
10use std::fmt::Write;
11
12use rustc_hash::{FxHashMap, FxHashSet};
13
14use crate::{
15    block::{Block, BlockIterator, Label},
16    context::Context,
17    error::IrError,
18    irtype::Type,
19    metadata::MetadataIndex,
20    module::Module,
21    value::{Value, ValueDatum},
22    variable::{LocalVar, LocalVarContent},
23    BlockArgument, BranchToWithArgs,
24};
25use crate::{Constant, InstOp};
26
27/// A wrapper around an [ECS](https://github.com/orlp/slotmap) handle into the
28/// [`Context`].
29#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
30pub struct Function(pub slotmap::DefaultKey);
31
32#[doc(hidden)]
33pub struct FunctionContent {
34    pub name: String,
35    /// Display string representing the function in the ABI errors
36    /// related context (in "errorCodes" and "panickingCalls" sections).
37    // TODO: Explore how and if we should lazy evaluate `abi_errors_display`,
38    //       only for functions that are actually used in ABI errors context.
39    //       Having it precomputed for every function is a simple design.
40    //       Lazy evaluation might be much more complex to implement and
41    //       a premature optimization, considering that even for large
42    //       project we compile <1500 functions.
43    pub abi_errors_display: String,
44    pub arguments: Vec<(String, Value)>,
45    pub return_type: Type,
46    pub blocks: Vec<Block>,
47    pub module: Module,
48    pub is_public: bool,
49    pub is_entry: bool,
50    /// True if the function was an entry, before getting wrapped
51    /// by the `__entry` function. E.g, a script `main` function.
52    pub is_original_entry: bool,
53    pub is_fallback: bool,
54    pub selector: Option<[u8; 4]>,
55    pub metadata: Option<MetadataIndex>,
56
57    pub local_storage: BTreeMap<String, LocalVar>, // BTree rather than Hash for deterministic ordering.
58
59    next_label_idx: u64,
60}
61
62impl Function {
63    /// Return a new [`Function`] handle.
64    ///
65    /// Creates a [`Function`] in the `context` within `module` and returns a handle.
66    ///
67    /// `name`, `args`, `return_type` and `is_public` are the usual suspects.  `selector` is a
68    /// special value used for Sway contract calls; much like `name` is unique and not particularly
69    /// used elsewhere in the IR.
70    #[allow(clippy::too_many_arguments)]
71    pub fn new(
72        context: &mut Context,
73        module: Module,
74        name: String,
75        abi_errors_display: String,
76        args: Vec<(String, Type, Option<MetadataIndex>)>,
77        return_type: Type,
78        selector: Option<[u8; 4]>,
79        is_public: bool,
80        is_entry: bool,
81        is_original_entry: bool,
82        is_fallback: bool,
83        metadata: Option<MetadataIndex>,
84    ) -> Function {
85        let content = FunctionContent {
86            name,
87            abi_errors_display,
88            // Arguments to a function are the arguments to its entry block.
89            // We set it up after creating the entry block below.
90            arguments: Vec::new(),
91            return_type,
92            blocks: Vec::new(),
93            module,
94            is_public,
95            is_entry,
96            is_original_entry,
97            is_fallback,
98            selector,
99            metadata,
100            local_storage: BTreeMap::new(),
101            next_label_idx: 0,
102        };
103        let func = Function(context.functions.insert(content));
104
105        context.modules[module.0].functions.push(func);
106
107        let entry_block = Block::new(context, func, Some("entry".to_owned()));
108        context
109            .functions
110            .get_mut(func.0)
111            .unwrap()
112            .blocks
113            .push(entry_block);
114
115        // Setup the arguments.
116        let arguments: Vec<_> = args
117            .into_iter()
118            .enumerate()
119            .map(|(idx, (name, ty, arg_metadata))| {
120                (
121                    name,
122                    Value::new_argument(
123                        context,
124                        BlockArgument {
125                            block: entry_block,
126                            idx,
127                            ty,
128                            is_immutable: false,
129                        },
130                    )
131                    .add_metadatum(context, arg_metadata),
132                )
133            })
134            .collect();
135        context
136            .functions
137            .get_mut(func.0)
138            .unwrap()
139            .arguments
140            .clone_from(&arguments);
141        let (_, arg_vals): (Vec<_>, Vec<_>) = arguments.iter().cloned().unzip();
142        context.blocks.get_mut(entry_block.0).unwrap().args = arg_vals;
143
144        func
145    }
146
147    /// Create and append a new [`Block`] to this function.
148    pub fn create_block(&self, context: &mut Context, label: Option<Label>) -> Block {
149        let block = Block::new(context, *self, label);
150        let func = context.functions.get_mut(self.0).unwrap();
151        func.blocks.push(block);
152        block
153    }
154
155    /// Create and insert a new [`Block`] into this function.
156    ///
157    /// The new block is inserted before `other`.
158    pub fn create_block_before(
159        &self,
160        context: &mut Context,
161        other: &Block,
162        label: Option<Label>,
163    ) -> Result<Block, IrError> {
164        let block_idx = context.functions[self.0]
165            .blocks
166            .iter()
167            .position(|block| block == other)
168            .ok_or_else(|| {
169                let label = &context.blocks[other.0].label;
170                IrError::MissingBlock(label.clone())
171            })?;
172
173        let new_block = Block::new(context, *self, label);
174        context.functions[self.0]
175            .blocks
176            .insert(block_idx, new_block);
177        Ok(new_block)
178    }
179
180    /// Create and insert a new [`Block`] into this function.
181    ///
182    /// The new block is inserted after `other`.
183    pub fn create_block_after(
184        &self,
185        context: &mut Context,
186        other: &Block,
187        label: Option<Label>,
188    ) -> Result<Block, IrError> {
189        // We need to create the new block first (even though we may not use it on Err below) since
190        // we can't borrow context mutably twice.
191        let new_block = Block::new(context, *self, label);
192        let func = context.functions.get_mut(self.0).unwrap();
193        func.blocks
194            .iter()
195            .position(|block| block == other)
196            .map(|idx| {
197                func.blocks.insert(idx + 1, new_block);
198                new_block
199            })
200            .ok_or_else(|| {
201                let label = &context.blocks[other.0].label;
202                IrError::MissingBlock(label.clone())
203            })
204    }
205
206    /// Remove a [`Block`] from this function.
207    ///
208    /// > Care must be taken to ensure the block has no predecessors otherwise the function will be
209    /// > made invalid.
210    pub fn remove_block(&self, context: &mut Context, block: &Block) -> Result<(), IrError> {
211        let label = block.get_label(context);
212        let func = context.functions.get_mut(self.0).unwrap();
213        let block_idx = func
214            .blocks
215            .iter()
216            .position(|b| b == block)
217            .ok_or(IrError::RemoveMissingBlock(label))?;
218        func.blocks.remove(block_idx);
219        Ok(())
220    }
221
222    /// Get a new unique block label.
223    ///
224    /// If `hint` is `None` then the label will be in the form `"blockN"` where N is an
225    /// incrementing decimal.
226    ///
227    /// Otherwise if the hint is already unique to this function it will be returned.  If not
228    /// already unique it will have N appended to it until it is unique.
229    pub fn get_unique_label(&self, context: &mut Context, hint: Option<String>) -> String {
230        match hint {
231            Some(hint) => {
232                if context.functions[self.0]
233                    .blocks
234                    .iter()
235                    .any(|block| context.blocks[block.0].label == hint)
236                {
237                    let idx = self.get_next_label_idx(context);
238                    self.get_unique_label(context, Some(format!("{hint}{idx}")))
239                } else {
240                    hint
241                }
242            }
243            None => {
244                let idx = self.get_next_label_idx(context);
245                self.get_unique_label(context, Some(format!("block{idx}")))
246            }
247        }
248    }
249
250    fn get_next_label_idx(&self, context: &mut Context) -> u64 {
251        let func = context.functions.get_mut(self.0).unwrap();
252        let idx = func.next_label_idx;
253        func.next_label_idx += 1;
254        idx
255    }
256
257    /// Return the number of blocks in this function.
258    pub fn num_blocks(&self, context: &Context) -> usize {
259        context.functions[self.0].blocks.len()
260    }
261
262    /// Return the number of instructions in this function.
263    ///
264    /// The [crate::InstOp::AsmBlock] is counted as a single instruction,
265    /// regardless of the number of [crate::asm::AsmInstruction]s in the ASM block.
266    /// E.g., even if the ASM block is empty and contains no instructions, it
267    /// will still be counted as a single instruction.
268    ///
269    /// If you want to count every ASM instruction as an instruction, use
270    /// `num_instructions_incl_asm_instructions` instead.
271    pub fn num_instructions(&self, context: &Context) -> usize {
272        self.block_iter(context)
273            .map(|block| block.num_instructions(context))
274            .sum()
275    }
276
277    /// Return the number of instructions in this function, including
278    /// the [crate::asm::AsmInstruction]s found in [crate::InstOp::AsmBlock]s.
279    ///
280    /// Every [crate::asm::AsmInstruction] encountered in any of the ASM blocks
281    /// will be counted as an instruction. The [crate::InstOp::AsmBlock] itself
282    /// is not counted but rather replaced with the number of ASM instructions
283    /// found in the block. In other words, empty ASM blocks do not count as
284    /// instructions.
285    ///
286    /// If you want to count [crate::InstOp::AsmBlock]s as single instructions, use
287    /// `num_instructions` instead.
288    pub fn num_instructions_incl_asm_instructions(&self, context: &Context) -> usize {
289        self.instruction_iter(context).fold(0, |num, (_, value)| {
290            match &value
291                .get_instruction(context)
292                .expect("We are iterating through the instructions.")
293                .op
294            {
295                InstOp::AsmBlock(asm, _) => num + asm.body.len(),
296                _ => num + 1,
297            }
298        })
299    }
300
301    /// Return the function name.
302    pub fn get_name<'a>(&self, context: &'a Context) -> &'a str {
303        &context.functions[self.0].name
304    }
305
306    /// Return the display string representing the function in the ABI errors
307    /// related context, in the "errorCodes" and "panickingCalls" sections.
308    pub fn get_abi_errors_display(&self, context: &Context) -> String {
309        context.functions[self.0].abi_errors_display.clone()
310    }
311
312    /// Return the module that this function belongs to.
313    pub fn get_module(&self, context: &Context) -> Module {
314        context.functions[self.0].module
315    }
316
317    /// Return the function entry (i.e., the first) block.
318    pub fn get_entry_block(&self, context: &Context) -> Block {
319        context.functions[self.0].blocks[0]
320    }
321
322    /// Return the attached metadata.
323    pub fn get_metadata(&self, context: &Context) -> Option<MetadataIndex> {
324        context.functions[self.0].metadata
325    }
326
327    /// Whether this function has a valid selector.
328    pub fn has_selector(&self, context: &Context) -> bool {
329        context.functions[self.0].selector.is_some()
330    }
331
332    /// Return the function selector, if it has one.
333    pub fn get_selector(&self, context: &Context) -> Option<[u8; 4]> {
334        context.functions[self.0].selector
335    }
336
337    /// Whether or not the function is a program entry point, i.e. `main`, `#[test]` fns or abi
338    /// methods.
339    pub fn is_entry(&self, context: &Context) -> bool {
340        context.functions[self.0].is_entry
341    }
342
343    /// Whether or not the function was a program entry point, i.e. `main`, `#[test]` fns or abi
344    /// methods, before it got wrapped within the `__entry` function.
345    pub fn is_original_entry(&self, context: &Context) -> bool {
346        context.functions[self.0].is_original_entry
347    }
348
349    /// Whether or not this function is a contract fallback function
350    pub fn is_fallback(&self, context: &Context) -> bool {
351        context.functions[self.0].is_fallback
352    }
353
354    // Get the function return type.
355    pub fn get_return_type(&self, context: &Context) -> Type {
356        context.functions[self.0].return_type
357    }
358
359    // Set a new function return type.
360    pub fn set_return_type(&self, context: &mut Context, new_ret_type: Type) {
361        context.functions.get_mut(self.0).unwrap().return_type = new_ret_type
362    }
363
364    /// Get the number of args.
365    pub fn num_args(&self, context: &Context) -> usize {
366        context.functions[self.0].arguments.len()
367    }
368
369    /// Get an arg value by name, if found.
370    pub fn get_arg(&self, context: &Context, name: &str) -> Option<Value> {
371        context.functions[self.0]
372            .arguments
373            .iter()
374            .find_map(|(arg_name, val)| (arg_name == name).then_some(val))
375            .copied()
376    }
377
378    /// Append an extra argument to the function signature.
379    ///
380    /// NOTE: `arg` must be a `BlockArgument` value with the correct index otherwise `add_arg` will
381    /// panic.
382    pub fn add_arg<S: Into<String>>(&self, context: &mut Context, name: S, arg: Value) {
383        match context.values[arg.0].value {
384            ValueDatum::Argument(BlockArgument { idx, .. })
385                if idx == context.functions[self.0].arguments.len() =>
386            {
387                context.functions[self.0].arguments.push((name.into(), arg));
388            }
389            _ => panic!("Inconsistent function argument being added"),
390        }
391    }
392
393    /// Find the name of an arg by value.
394    pub fn lookup_arg_name<'a>(&self, context: &'a Context, value: &Value) -> Option<&'a String> {
395        context.functions[self.0]
396            .arguments
397            .iter()
398            .find_map(|(name, arg_val)| (arg_val == value).then_some(name))
399    }
400
401    /// Return an iterator for each of the function arguments.
402    pub fn args_iter<'a>(&self, context: &'a Context) -> impl Iterator<Item = &'a (String, Value)> {
403        context.functions[self.0].arguments.iter()
404    }
405
406    /// Is argument `i` marked immutable?
407    pub fn is_arg_immutable(&self, context: &Context, i: usize) -> bool {
408        if let Some((_, val)) = context.functions[self.0].arguments.get(i) {
409            if let ValueDatum::Argument(arg) = &context.values[val.0].value {
410                return arg.is_immutable;
411            }
412        }
413        false
414    }
415
416    /// Get a pointer to a local value by name, if found.
417    pub fn get_local_var(&self, context: &Context, name: &str) -> Option<LocalVar> {
418        context.functions[self.0].local_storage.get(name).copied()
419    }
420
421    /// Find the name of a local value by pointer.
422    pub fn lookup_local_name<'a>(
423        &self,
424        context: &'a Context,
425        var: &LocalVar,
426    ) -> Option<&'a String> {
427        context.functions[self.0]
428            .local_storage
429            .iter()
430            .find_map(|(name, local_var)| if local_var == var { Some(name) } else { None })
431    }
432
433    /// Add a value to the function local storage.
434    ///
435    /// The name must be unique to this function else an error is returned.
436    pub fn new_local_var(
437        &self,
438        context: &mut Context,
439        name: String,
440        local_type: Type,
441        initializer: Option<Constant>,
442        mutable: bool,
443    ) -> Result<LocalVar, IrError> {
444        let var = LocalVar::new(context, local_type, initializer, mutable);
445        let func = context.functions.get_mut(self.0).unwrap();
446        func.local_storage
447            .insert(name.clone(), var)
448            .map(|_| Err(IrError::FunctionLocalClobbered(func.name.clone(), name)))
449            .unwrap_or(Ok(var))
450    }
451
452    /// Add a value to the function local storage, by forcing the name to be unique if needed.
453    ///
454    /// Will use the provided name as a hint and rename to guarantee insertion.
455    pub fn new_unique_local_var(
456        &self,
457        context: &mut Context,
458        name: String,
459        local_type: Type,
460        initializer: Option<Constant>,
461        mutable: bool,
462    ) -> LocalVar {
463        let func = &context.functions[self.0];
464        let new_name = if func.local_storage.contains_key(&name) {
465            // Assuming that we'll eventually find a unique name by appending numbers to the old
466            // one...
467            (0..)
468                .find_map(|n| {
469                    let candidate = format!("{name}{n}");
470                    if func.local_storage.contains_key(&candidate) {
471                        None
472                    } else {
473                        Some(candidate)
474                    }
475                })
476                .unwrap()
477        } else {
478            name
479        };
480        self.new_local_var(context, new_name, local_type, initializer, mutable)
481            .unwrap()
482    }
483
484    /// Return an iterator to all of the values in this function's local storage.
485    pub fn locals_iter<'a>(
486        &self,
487        context: &'a Context,
488    ) -> impl Iterator<Item = (&'a String, &'a LocalVar)> {
489        context.functions[self.0].local_storage.iter()
490    }
491
492    /// Remove given list of locals
493    pub fn remove_locals(&self, context: &mut Context, removals: &Vec<String>) {
494        for remove in removals {
495            if let Some(local) = context.functions[self.0].local_storage.remove(remove) {
496                context.local_vars.remove(local.0);
497            }
498        }
499    }
500
501    /// Merge values from another [`Function`] into this one.
502    ///
503    /// The names of the merged values are guaranteed to be unique via the use of
504    /// [`Function::new_unique_local_var`].
505    ///
506    /// Returns a map from the original pointers to the newly merged pointers.
507    pub fn merge_locals_from(
508        &self,
509        context: &mut Context,
510        other: Function,
511    ) -> HashMap<LocalVar, LocalVar> {
512        let mut var_map = HashMap::new();
513        let old_vars: Vec<(String, LocalVar, LocalVarContent)> = context.functions[other.0]
514            .local_storage
515            .iter()
516            .map(|(name, var)| (name.clone(), *var, context.local_vars[var.0].clone()))
517            .collect();
518        for (name, old_var, old_var_content) in old_vars {
519            let old_ty = old_var_content
520                .ptr_ty
521                .get_pointee_type(context)
522                .expect("LocalVar types are always pointers.");
523            let new_var = self.new_unique_local_var(
524                context,
525                name.clone(),
526                old_ty,
527                old_var_content.initializer,
528                old_var_content.mutable,
529            );
530            var_map.insert(old_var, new_var);
531        }
532        var_map
533    }
534
535    /// Return an iterator to each block in this function.
536    pub fn block_iter(&self, context: &Context) -> BlockIterator {
537        BlockIterator::new(context, self)
538    }
539
540    /// Return an iterator to each instruction in each block in this function.
541    ///
542    /// This is a convenience method for when all instructions in a function need to be inspected.
543    /// The instruction value is returned from the iterator along with the block it belongs to.
544    pub fn instruction_iter<'a>(
545        &self,
546        context: &'a Context,
547    ) -> impl Iterator<Item = (Block, Value)> + 'a {
548        context.functions[self.0]
549            .blocks
550            .iter()
551            .flat_map(move |block| {
552                block
553                    .instruction_iter(context)
554                    .map(move |ins_val| (*block, ins_val))
555            })
556    }
557
558    /// Replace a value with another within this function.
559    ///
560    /// This is a convenience method which iterates over this function's blocks and calls
561    /// [`Block::replace_values`] in turn.
562    ///
563    /// `starting_block` is an optimisation for when the first possible reference to `old_val` is
564    /// known.
565    pub fn replace_values(
566        &self,
567        context: &mut Context,
568        replace_map: &FxHashMap<Value, Value>,
569        starting_block: Option<Block>,
570    ) {
571        let mut block_iter = self.block_iter(context).peekable();
572
573        if let Some(ref starting_block) = starting_block {
574            // Skip blocks until we hit the starting block.
575            while block_iter
576                .next_if(|block| block != starting_block)
577                .is_some()
578            {}
579        }
580
581        for block in block_iter {
582            block.replace_values(context, replace_map);
583        }
584    }
585
586    pub fn replace_value(
587        &self,
588        context: &mut Context,
589        old_val: Value,
590        new_val: Value,
591        starting_block: Option<Block>,
592    ) {
593        let mut map = FxHashMap::<Value, Value>::default();
594        map.insert(old_val, new_val);
595        self.replace_values(context, &map, starting_block);
596    }
597
598    /// A graphviz dot graph of the control-flow-graph.
599    pub fn dot_cfg(&self, context: &Context) -> String {
600        let mut worklist = Vec::<Block>::new();
601        let mut visited = FxHashSet::<Block>::default();
602        let entry = self.get_entry_block(context);
603        let mut res = format!("digraph {} {{\n", self.get_name(context));
604
605        worklist.push(entry);
606        while let Some(n) = worklist.pop() {
607            visited.insert(n);
608            for BranchToWithArgs { block: n_succ, .. } in n.successors(context) {
609                let _ = writeln!(
610                    res,
611                    "\t{} -> {}\n",
612                    n.get_label(context),
613                    n_succ.get_label(context)
614                );
615                if !visited.contains(&n_succ) {
616                    worklist.push(n_succ);
617                }
618            }
619        }
620
621        res += "}\n";
622        res
623    }
624}
625
626/// An iterator over each [`Function`] in a [`Module`].
627pub struct FunctionIterator {
628    functions: Vec<slotmap::DefaultKey>,
629    next: usize,
630}
631
632impl FunctionIterator {
633    /// Return a new iterator for the functions in `module`.
634    pub fn new(context: &Context, module: &Module) -> FunctionIterator {
635        // Copy all the current modules indices, so they may be modified in the context during
636        // iteration.
637        FunctionIterator {
638            functions: context.modules[module.0]
639                .functions
640                .iter()
641                .map(|func| func.0)
642                .collect(),
643            next: 0,
644        }
645    }
646}
647
648impl Iterator for FunctionIterator {
649    type Item = Function;
650
651    fn next(&mut self) -> Option<Function> {
652        if self.next < self.functions.len() {
653            let idx = self.next;
654            self.next += 1;
655            Some(Function(self.functions[idx]))
656        } else {
657            None
658        }
659    }
660}