Skip to main content

pliron/
basic_block.rs

1//! A [BasicBlock] is a list of [Operation]s.
2
3use combine::{
4    parser::{Parser, char::spaces},
5    sep_by, token,
6};
7use thiserror::Error;
8
9use crate::{
10    attribute::AttributeDict,
11    builtin::op_interfaces::{IsTerminatorInterface, NoTerminatorInterface},
12    common_traits::{Named, Verify},
13    context::{Arena, Context, Ptr, private::ArenaObj},
14    debug_info::{get_block_arg_name, set_block_arg_name},
15    identifier::Identifier,
16    indented_block,
17    irfmt::{
18        parsers::{delimited_list_parser, location, spaced, type_parser},
19        printers::{iter_with_sep, list_with_sep},
20    },
21    linked_list::{ContainsLinkedList, LinkedList, private},
22    location::{Located, Location},
23    op::op_impls,
24    operation::{DefUseVerifyErr, Operation, OperationParserConfig},
25    parsable::{self, IntoParseResult, Parsable, ParseResult},
26    printable::{self, ListSeparator, Printable, indented_nl},
27    region::Region,
28    result::Result,
29    r#type::{TypeObj, Typed},
30    utils::vec_exns::VecExtns,
31    value::{DefNode, Value},
32    verify_err, verify_error,
33};
34
35/// Argument to a [BasicBlock]
36pub(crate) struct BlockArgument {
37    /// The def containing the list of this argument's uses.
38    pub(crate) def: DefNode<Value>,
39    /// A [Ptr] to the [BasicBlock] of which this is an argument.
40    pub(crate) def_block: Ptr<BasicBlock>,
41    /// Index of this argument in the block's list of arguments.
42    pub(crate) arg_idx: usize,
43    /// The [Type](crate::type::Type) of this argument.
44    pub(crate) ty: Ptr<TypeObj>,
45}
46
47impl BlockArgument {
48    /// Get the [Type](crate::type::Type) of this argument.
49    pub fn get_type(&self, _ctx: &Context) -> Ptr<TypeObj> {
50        self.ty
51    }
52
53    /// Set the [Type](crate::type::Type) of this argument.
54    pub fn set_type(&mut self, _ctx: &Context, ty: Ptr<TypeObj>) {
55        self.ty = ty;
56    }
57}
58
59impl Typed for BlockArgument {
60    fn get_type(&self, ctx: &Context) -> Ptr<TypeObj> {
61        self.get_type(ctx)
62    }
63}
64
65impl Named for BlockArgument {
66    fn given_name(&self, ctx: &Context) -> Option<Identifier> {
67        get_block_arg_name(ctx, self.def_block, self.arg_idx)
68    }
69    fn id(&self, ctx: &Context) -> Identifier {
70        format!("{}_arg{}", self.def_block.deref(ctx).id(ctx), self.arg_idx)
71            .try_into()
72            .unwrap()
73    }
74}
75
76impl From<&BlockArgument> for Value {
77    fn from(value: &BlockArgument) -> Self {
78        Value::BlockArgument {
79            block: value.def_block,
80            arg_idx: value.arg_idx,
81        }
82    }
83}
84
85impl Printable for BlockArgument {
86    fn fmt(
87        &self,
88        ctx: &Context,
89        _state: &printable::State,
90        f: &mut core::fmt::Formatter<'_>,
91    ) -> core::fmt::Result {
92        write!(f, "{}: {}", self.unique_name(ctx), self.ty.disp(ctx))
93    }
94}
95
96impl Verify for BlockArgument {
97    fn verify(&self, ctx: &Context) -> Result<()> {
98        Into::<Value>::into(self).verify(ctx)
99    }
100}
101
102/// [Operation]s contained in this [BasicBlock]
103#[derive(Default)]
104pub struct OpsInBlock {
105    first: Option<Ptr<Operation>>,
106    last: Option<Ptr<Operation>>,
107}
108
109/// Links a [BasicBlock] with other blocks and the container [Region].
110#[derive(Default)]
111struct RegionLinks {
112    /// Parent region of this block.
113    parent_region: Option<Ptr<Region>>,
114    /// The next block in the region's list of block.
115    next_block: Option<Ptr<BasicBlock>>,
116    /// The previous block in the region's list of blocks.
117    prev_block: Option<Ptr<BasicBlock>>,
118}
119
120/// A basic block contains a list of [Operation]s. It may have [arguments](Value::BlockArgument).
121pub struct BasicBlock {
122    pub(crate) self_ptr: Ptr<BasicBlock>,
123    pub(crate) label: Option<Identifier>,
124    pub(crate) ops_list: OpsInBlock,
125    pub(crate) args: Vec<BlockArgument>,
126    pub(crate) preds: DefNode<Ptr<BasicBlock>>,
127    /// Links to the parent [Region] and
128    /// previous and next [BasicBlock]s in the block.
129    region_links: RegionLinks,
130    /// A dictionary of attributes.
131    pub attributes: AttributeDict,
132    loc: Location,
133}
134
135impl Named for BasicBlock {
136    fn given_name(&self, _ctx: &Context) -> Option<Identifier> {
137        self.label.clone()
138    }
139    fn id(&self, _ctx: &Context) -> Identifier {
140        self.self_ptr.make_name("block")
141    }
142}
143
144impl BasicBlock {
145    /// Create a new Basic Block.
146    pub fn new(
147        ctx: &mut Context,
148        label: Option<Identifier>,
149        arg_types: Vec<Ptr<TypeObj>>,
150    ) -> Ptr<BasicBlock> {
151        let f = |self_ptr: Ptr<BasicBlock>| BasicBlock {
152            self_ptr,
153            label,
154            args: vec![],
155            ops_list: OpsInBlock::default(),
156            preds: DefNode::new(),
157            region_links: RegionLinks::default(),
158            attributes: AttributeDict::default(),
159            loc: Location::Unknown,
160        };
161        let newblock = Self::alloc(ctx, f);
162        // Let's update the args of the new block. Easier to do it here than during creation.
163        let args = arg_types
164            .into_iter()
165            .enumerate()
166            .map(|(arg_idx, ty)| BlockArgument {
167                def: DefNode::new(),
168                def_block: newblock,
169                arg_idx,
170                ty,
171            })
172            .collect();
173        newblock.deref_mut(ctx).args = args;
174        // We're done.
175        newblock
176    }
177
178    /// Get parent region.
179    pub fn get_parent_region(&self) -> Option<Ptr<Region>> {
180        self.region_links.parent_region
181    }
182
183    /// Get parent operation.
184    pub fn get_parent_op(&self, ctx: &Context) -> Option<Ptr<Operation>> {
185        self.get_parent_region()
186            .map(|region| region.deref(ctx).get_parent_op())
187    }
188
189    /// Get idx'th argument as a Value.
190    pub fn get_argument(&self, arg_idx: usize) -> Value {
191        self.args
192            .get(arg_idx)
193            .map(|arg| arg.into())
194            .unwrap_or_else(|| panic!("Block argument index {arg_idx} out of bounds"))
195    }
196
197    /// Get an iterator over the arguments
198    pub fn arguments(&self) -> impl Iterator<Item = Value> + '_ {
199        self.args.iter().map(Into::into)
200    }
201
202    /// Add a new argument with specified type. Returns idx at which it was added.
203    pub fn add_argument(&mut self, ty: Ptr<TypeObj>) -> usize {
204        self.args.push_back_with(|arg_idx| BlockArgument {
205            def: DefNode::new(),
206            def_block: self.self_ptr,
207            arg_idx,
208            ty,
209        })
210    }
211
212    /// Get a reference to the idx'th argument.
213    pub(crate) fn get_argument_ref(&self, arg_idx: usize) -> &BlockArgument {
214        self.args
215            .get(arg_idx)
216            .unwrap_or_else(|| panic!("Block argument index {arg_idx} out of bounds"))
217    }
218
219    /// Get a mutable reference to the idx'th argument.
220    pub(crate) fn get_argument_mut(&mut self, arg_idx: usize) -> &mut BlockArgument {
221        self.args
222            .get_mut(arg_idx)
223            .unwrap_or_else(|| panic!("Block argument index {arg_idx} out of bounds"))
224    }
225
226    /// Get the number of arguments.
227    pub fn get_num_arguments(&self) -> usize {
228        self.args.len()
229    }
230
231    /// Get all successors of this block.
232    pub fn succs(&self, ctx: &Context) -> Vec<Ptr<BasicBlock>> {
233        self.get_terminator(ctx)
234            .map(|term| term.deref(ctx).successors().collect())
235            .unwrap_or_default()
236    }
237
238    /// Is `succ` a successor of this block?
239    pub fn has_succ(&self, ctx: &Context, succ: Ptr<BasicBlock>) -> bool {
240        self.succs(ctx).contains(&succ)
241    }
242
243    /// Get the block terminator, if one exists.
244    pub fn get_terminator(&self, ctx: &Context) -> Option<Ptr<Operation>> {
245        let last_opr = self.get_tail()?;
246        let last_op = Operation::get_op_dyn(last_opr, ctx);
247        op_impls::<dyn IsTerminatorInterface>(last_op.as_ref()).then_some(last_opr)
248    }
249
250    /// Drop all uses that this block holds.
251    pub fn drop_all_uses(ptr: Ptr<Self>, ctx: &Context) {
252        let ops: Vec<_> = ptr.deref(ctx).iter(ctx).collect();
253        for op in ops {
254            Operation::drop_all_uses(op, ctx);
255        }
256    }
257
258    /// Unlink and deallocate this block and everything that it contains.
259    /// There must not be any uses outside the block.
260    pub fn erase(ptr: Ptr<Self>, ctx: &mut Context) {
261        Self::drop_all_uses(ptr, ctx);
262        assert!(
263            !ptr.has_pred(ctx),
264            "BasicBlock with predecessor(s) being erased"
265        );
266
267        if ptr.deref(ctx).iter(ctx).any(|op| op.deref(ctx).has_use()) {
268            panic!("Attemping to erase block which has a use outside the block")
269        }
270        if ptr.is_linked(ctx) {
271            ptr.unlink(ctx);
272        }
273        ArenaObj::dealloc(ptr, ctx);
274    }
275}
276
277impl Located for BasicBlock {
278    fn loc(&self) -> Location {
279        self.loc.clone()
280    }
281
282    fn set_loc(&mut self, loc: Location) {
283        self.loc = loc;
284    }
285}
286
287impl private::ContainsLinkedList<Operation> for BasicBlock {
288    fn set_head(&mut self, head: Option<Ptr<Operation>>) {
289        self.ops_list.first = head;
290    }
291
292    fn set_tail(&mut self, tail: Option<Ptr<Operation>>) {
293        self.ops_list.last = tail;
294    }
295}
296
297impl ContainsLinkedList<Operation> for BasicBlock {
298    fn get_head(&self) -> Option<Ptr<Operation>> {
299        self.ops_list.first
300    }
301
302    fn get_tail(&self) -> Option<Ptr<Operation>> {
303        self.ops_list.last
304    }
305}
306
307impl PartialEq for BasicBlock {
308    fn eq(&self, other: &Self) -> bool {
309        self.self_ptr == other.self_ptr
310    }
311}
312
313impl private::LinkedList for BasicBlock {
314    type ContainerType = Region;
315    fn set_next(&mut self, next: Option<Ptr<Self>>) {
316        self.region_links.next_block = next;
317    }
318    fn set_prev(&mut self, prev: Option<Ptr<Self>>) {
319        self.region_links.prev_block = prev;
320    }
321    fn set_container(&mut self, container: Option<Ptr<Self::ContainerType>>) {
322        self.region_links.parent_region = container;
323    }
324}
325
326impl LinkedList for BasicBlock {
327    fn get_next(&self) -> Option<Ptr<Self>> {
328        self.region_links.next_block
329    }
330    fn get_prev(&self) -> Option<Ptr<Self>> {
331        self.region_links.prev_block
332    }
333    fn get_container(&self) -> Option<Ptr<Self::ContainerType>> {
334        self.region_links.parent_region
335    }
336}
337
338impl ArenaObj for BasicBlock {
339    fn get_arena(ctx: &Context) -> &Arena<Self> {
340        &ctx.basic_blocks
341    }
342    fn get_arena_mut(ctx: &mut Context) -> &mut Arena<Self> {
343        &mut ctx.basic_blocks
344    }
345    fn dealloc_sub_objects(ptr: Ptr<Self>, ctx: &mut Context) {
346        let ops: Vec<_> = ptr.deref_mut(ctx).iter(ctx).collect();
347        for op in ops {
348            ArenaObj::dealloc(op, ctx);
349        }
350    }
351    fn get_self_ptr(&self, _ctx: &Context) -> Ptr<Self> {
352        self.self_ptr
353    }
354}
355
356impl Verify for BasicBlock {
357    fn verify(&self, ctx: &Context) -> Result<()> {
358        // Ensure that the block has a terminator
359        // (unless the enclosing [Op] is marked [NoTerminatorInterface].
360        let label: String = self.unique_name(ctx).into();
361        let parent_op = self.get_parent_op(ctx).ok_or_else(|| {
362            verify_error!(self.loc(), BasicBlockVerifyErr::NoParent(label.clone()))
363        })?;
364        let parent_op = Operation::get_op_dyn(parent_op, ctx);
365        if !op_impls::<dyn NoTerminatorInterface>(parent_op.as_ref())
366            && self.get_terminator(ctx).is_none()
367        {
368            let loc = self.loc();
369            verify_err!(loc, BasicBlockVerifyErr::MissingTerminator(label))?;
370        }
371        // Check that every predecessor points back to this block.
372        for pred in self.self_ptr.preds(ctx) {
373            if !pred.deref(ctx).has_succ(ctx, self.self_ptr) {
374                let loc = self.loc();
375                verify_err!(loc, DefUseVerifyErr)?;
376            }
377        }
378        self.args.iter().try_for_each(|arg| arg.verify(ctx))?;
379        self.iter(ctx).try_for_each(|op| op.deref(ctx).verify(ctx))
380    }
381}
382
383/// Error indicating that a basic block is missing a terminator.
384#[derive(Debug, Error)]
385
386pub enum BasicBlockVerifyErr {
387    #[error("Basic block \"{0}\" is missing a terminator")]
388    MissingTerminator(String),
389    #[error("Basic block \"{0}\" has a terminator that is not the last operation in the block")]
390    TerminatorNotLast(String),
391    #[error("Basic block \"{0}\" has no parent operation")]
392    NoParent(String),
393}
394
395impl Printable for BasicBlock {
396    fn fmt(
397        &self,
398        ctx: &Context,
399        state: &printable::State,
400        f: &mut core::fmt::Formatter<'_>,
401    ) -> core::fmt::Result {
402        write!(
403            f,
404            "^{}({}):",
405            self.unique_name(ctx),
406            list_with_sep(&self.args, ListSeparator::CharSpace(',')).print(ctx, state),
407        )?;
408
409        indented_block!(state, {
410            write!(
411                f,
412                "{}{}",
413                indented_nl(state),
414                iter_with_sep(self.iter(ctx), ListSeparator::CharNewline(';')).print(ctx, state),
415            )?;
416        });
417
418        Ok(())
419    }
420}
421
422impl Parsable for BasicBlock {
423    type Arg = ();
424    type Parsed = Ptr<BasicBlock>;
425
426    ///  A basic block is
427    ///  label(arg_1: type_1, ..., arg_n: type_n):
428    ///    op_1;
429    ///    ... ;
430    ///    op_n
431    fn parse<'a>(
432        state_stream: &mut parsable::StateStream<'a>,
433        _arg: Self::Arg,
434    ) -> ParseResult<'a, Self::Parsed> {
435        let loc = state_stream.loc();
436
437        let arg = (
438            (location(), Identifier::parser(())).skip(spaced(token(':'))),
439            type_parser().skip(spaces()),
440        );
441        let args = spaced(delimited_list_parser('(', ')', ',', arg)).skip(token(':'));
442        let ops = spaces().with(sep_by::<Vec<_>, _, _, _>(
443            Operation::parser(OperationParserConfig {
444                look_for_outlined_attrs: false,
445            })
446            .skip(spaces()),
447            token(';').skip(spaces()),
448        ));
449
450        let label = spaced(token('^').with(Identifier::parser(())));
451        let (label, args, ops) = (label, args, ops)
452            .parse_stream(state_stream)
453            .into_result()?
454            .0;
455
456        // We've parsed the components. Now construct the result.
457        let (arg_names, arg_types): (Vec<_>, Vec<_>) = args.into_iter().unzip();
458        let block = BasicBlock::new(state_stream.state.ctx, Some(label.clone()), arg_types);
459        for (arg_idx, (loc, name)) in arg_names.into_iter().enumerate() {
460            let def: Value = (&block.deref(state_stream.state.ctx).args[arg_idx]).into();
461            state_stream.state.name_tracker.ssa_def(
462                state_stream.state.ctx,
463                &(name.clone(), loc),
464                def,
465            )?;
466            set_block_arg_name(state_stream.state.ctx, block, arg_idx, name);
467        }
468        for op in ops {
469            op.insert_at_back(block, state_stream.state.ctx);
470        }
471        state_stream
472            .state
473            .name_tracker
474            .block_def(state_stream.state.ctx, &(label, loc), block)?;
475        Ok(block).into_parse_result()
476    }
477}