midenc_hir/ir/
callable.rs

1use alloc::{format, vec::Vec};
2use core::fmt;
3
4use super::SymbolPathAttr;
5use crate::{
6    formatter, CallConv, EntityRef, Op, OpOperandRange, OpOperandRangeMut, RegionRef, Symbol,
7    SymbolPath, SymbolRef, Type, UnsafeIntrusiveEntityRef, Value, ValueRef, Visibility,
8};
9
10/// A call-like operation is one that transfers control from one function to another.
11///
12/// These operations may be traditional static calls, e.g. `call @foo`, or indirect calls, e.g.
13/// `call_indirect v1`. An operation that uses this interface cannot _also_ implement the
14/// `CallableOpInterface`.
15pub trait CallOpInterface: Op {
16    /// Get the callee of this operation.
17    ///
18    /// A callee is either a symbol, or a reference to an SSA value.
19    fn callable_for_callee(&self) -> Callable;
20    /// Sets the callee for this operation.
21    fn set_callee(&mut self, callable: Callable);
22    /// Get the operands of this operation that are used as arguments for the callee
23    fn arguments(&self) -> OpOperandRange<'_>;
24    /// Get a mutable reference to the operands of this operation that are used as arguments for the
25    /// callee
26    fn arguments_mut(&mut self) -> OpOperandRangeMut<'_>;
27    /// Resolve the callable operation for the current callee to a `CallableOpInterface`, or `None`
28    /// if a valid callable was not resolved, using the provided symbol table.
29    ///
30    /// This method is used to perform callee resolution using a cached symbol table, rather than
31    /// traversing the operation hierarchy looking for symbol tables to try resolving with.
32    fn resolve_in_symbol_table(&self, symbols: &dyn crate::SymbolTable) -> Option<SymbolRef>;
33    /// Resolve the callable operation for the current callee to a `CallableOpInterface`, or `None`
34    /// if a valid callable was not resolved.
35    fn resolve(&self) -> Option<SymbolRef>;
36}
37
38/// A callable operation is one who represents a potential function, and may be a target for a call-
39/// like operation (i.e. implementations of `CallOpInterface`). These operations may be traditional
40/// function ops (i.e. `Function`), as well as function reference-producing operations, such as an
41/// op that creates closures, or captures a function by reference.
42///
43/// These operations may only contain a single region.
44pub trait CallableOpInterface: Op {
45    /// Returns the region on the current operation that is callable.
46    ///
47    /// This may return `None` in the case of an external callable object, e.g. an externally-
48    /// defined function reference.
49    fn get_callable_region(&self) -> Option<RegionRef>;
50    /// Returns the signature of the callable
51    fn signature(&self) -> &Signature;
52}
53
54#[doc(hidden)]
55pub trait AsCallableSymbolRef {
56    fn as_callable_symbol_ref(&self) -> SymbolRef;
57}
58impl<T: Symbol + CallableOpInterface> AsCallableSymbolRef for T {
59    #[inline(always)]
60    fn as_callable_symbol_ref(&self) -> SymbolRef {
61        unsafe { SymbolRef::from_raw(self as &dyn Symbol) }
62    }
63}
64impl<T: Symbol + CallableOpInterface> AsCallableSymbolRef for UnsafeIntrusiveEntityRef<T> {
65    #[inline(always)]
66    fn as_callable_symbol_ref(&self) -> SymbolRef {
67        let t_ptr = Self::as_ptr(self);
68        unsafe { SymbolRef::from_raw(t_ptr as *const dyn Symbol) }
69    }
70}
71
72/// A [Callable] represents a symbol or a value which can be used as a valid _callee_ for a
73/// [CallOpInterface] implementation.
74///
75/// Symbols are not SSA values, but there are situations where we want to treat them as one, such
76/// as indirect calls. Abstracting over whether the callable is a symbol or an SSA value allows us
77/// to focus on the call semantics, rather than the difference between the type types of value.
78#[derive(Debug, Clone)]
79pub enum Callable {
80    Symbol(SymbolPath),
81    Value(ValueRef),
82}
83impl fmt::Display for Callable {
84    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
85        match self {
86            Self::Symbol(path) => fmt::Display::fmt(path, f),
87            Self::Value(value) => fmt::Display::fmt(value, f),
88        }
89    }
90}
91impl From<&SymbolPathAttr> for Callable {
92    fn from(value: &SymbolPathAttr) -> Self {
93        Self::Symbol(value.path.clone())
94    }
95}
96impl From<&SymbolPath> for Callable {
97    fn from(value: &SymbolPath) -> Self {
98        Self::Symbol(value.clone())
99    }
100}
101impl From<SymbolPath> for Callable {
102    fn from(value: SymbolPath) -> Self {
103        Self::Symbol(value)
104    }
105}
106impl From<ValueRef> for Callable {
107    fn from(value: ValueRef) -> Self {
108        Self::Value(value)
109    }
110}
111impl Callable {
112    #[inline(always)]
113    pub fn new(callable: impl Into<Self>) -> Self {
114        callable.into()
115    }
116
117    pub fn is_symbol(&self) -> bool {
118        matches!(self, Self::Symbol(_))
119    }
120
121    pub fn is_value(&self) -> bool {
122        matches!(self, Self::Value(_))
123    }
124
125    pub fn as_symbol_path(&self) -> Option<&SymbolPath> {
126        match self {
127            Self::Symbol(ref name) => Some(name),
128            _ => None,
129        }
130    }
131
132    pub fn as_value(&self) -> Option<EntityRef<'_, dyn Value>> {
133        match self {
134            Self::Value(ref value_ref) => Some(value_ref.borrow()),
135            _ => None,
136        }
137    }
138
139    pub fn unwrap_symbol_path(self) -> SymbolPath {
140        match self {
141            Self::Symbol(name) => name,
142            Self::Value(value_ref) => panic!("expected symbol, got {}", value_ref.borrow().id()),
143        }
144    }
145
146    pub fn unwrap_value_ref(self) -> ValueRef {
147        match self {
148            Self::Value(value) => value,
149            Self::Symbol(ref name) => panic!("expected value, got {name}"),
150        }
151    }
152}
153
154/// Represents whether an argument or return value has a special purpose in
155/// the calling convention of a function.
156#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Hash)]
157#[repr(u8)]
158pub enum ArgumentPurpose {
159    /// No special purpose, the argument is passed/returned by value
160    #[default]
161    Default,
162    /// Used for platforms where the calling convention expects return values of
163    /// a certain size to be written to a pointer passed in by the caller.
164    StructReturn,
165}
166impl fmt::Display for ArgumentPurpose {
167    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
168        match self {
169            Self::Default => f.write_str("default"),
170            Self::StructReturn => f.write_str("sret"),
171        }
172    }
173}
174
175/// Represents how to extend a small integer value to native machine integer width.
176///
177/// For Miden, native integrals are unsigned 64-bit field elements, but it is typically
178/// going to be the case that we are targeting the subset of Miden Assembly where integrals
179/// are unsigned 32-bit integers with a standard twos-complement binary representation.
180///
181/// It is for the latter scenario that argument extension is really relevant.
182#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Hash)]
183#[repr(u8)]
184pub enum ArgumentExtension {
185    /// Do not perform any extension, high bits have undefined contents
186    #[default]
187    None,
188    /// Zero-extend the value
189    Zext,
190    /// Sign-extend the value
191    Sext,
192}
193impl fmt::Display for ArgumentExtension {
194    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
195        match self {
196            Self::None => f.write_str("none"),
197            Self::Zext => f.write_str("zext"),
198            Self::Sext => f.write_str("sext"),
199        }
200    }
201}
202
203/// Describes a function parameter or result.
204#[derive(Debug, Clone, PartialEq, Eq, Hash)]
205pub struct AbiParam {
206    /// The type associated with this value
207    pub ty: Type,
208    /// The special purpose, if any, of this parameter or result
209    pub purpose: ArgumentPurpose,
210    /// The desired approach to extending the size of this value to
211    /// a larger bit width, if applicable.
212    pub extension: ArgumentExtension,
213}
214impl AbiParam {
215    pub fn new(ty: Type) -> Self {
216        Self {
217            ty,
218            purpose: ArgumentPurpose::default(),
219            extension: ArgumentExtension::default(),
220        }
221    }
222
223    pub fn sret(ty: Type) -> Self {
224        assert!(ty.is_pointer(), "sret parameters must be pointers");
225        Self {
226            ty,
227            purpose: ArgumentPurpose::StructReturn,
228            extension: ArgumentExtension::default(),
229        }
230    }
231}
232impl formatter::PrettyPrint for AbiParam {
233    fn render(&self) -> formatter::Document {
234        use crate::formatter::*;
235
236        let mut doc = const_text("(") + const_text("param") + const_text(" ");
237        if !matches!(self.purpose, ArgumentPurpose::Default) {
238            doc += const_text("(") + display(self.purpose) + const_text(")") + const_text(" ");
239        }
240        if !matches!(self.extension, ArgumentExtension::None) {
241            doc += const_text("(") + display(self.extension) + const_text(")") + const_text(" ");
242        }
243        doc + text(format!("{}", &self.ty)) + const_text(")")
244    }
245}
246
247impl fmt::Display for AbiParam {
248    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249        let mut builder = f.debug_map();
250        builder.entry(&"ty", &format_args!("{}", &self.ty));
251        if !matches!(self.purpose, ArgumentPurpose::Default) {
252            builder.entry(&"purpose", &format_args!("{}", &self.purpose));
253        }
254        if !matches!(self.extension, ArgumentExtension::None) {
255            builder.entry(&"extension", &format_args!("{}", &self.extension));
256        }
257        builder.finish()
258    }
259}
260
261/// A [Signature] represents the type, ABI, and linkage of a function.
262///
263/// A function signature provides us with all of the necessary detail to correctly
264/// validate and emit code for a function, whether from the perspective of a caller,
265/// or the callee.
266#[derive(Debug, Clone, PartialEq, Eq, Hash)]
267pub struct Signature {
268    /// The arguments expected by this function
269    pub params: Vec<AbiParam>,
270    /// The results returned by this function
271    pub results: Vec<AbiParam>,
272    /// The calling convention that applies to this function
273    pub cc: CallConv,
274    /// The linkage/visibility that should be used for this function
275    pub visibility: Visibility,
276}
277
278crate::define_attr_type!(Signature);
279
280impl fmt::Display for Signature {
281    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282        f.debug_map()
283            .key(&"params")
284            .value_with(|f| {
285                let mut builder = f.debug_list();
286                for param in self.params.iter() {
287                    builder.entry(&format_args!("{param}"));
288                }
289                builder.finish()
290            })
291            .key(&"results")
292            .value_with(|f| {
293                let mut builder = f.debug_list();
294                for param in self.params.iter() {
295                    builder.entry(&format_args!("{param}"));
296                }
297                builder.finish()
298            })
299            .entry(&"cc", &format_args!("{}", &self.cc))
300            .entry(&"visibility", &format_args!("{}", &self.visibility))
301            .finish()
302    }
303}
304
305impl Signature {
306    /// Create a new signature with the given parameter and result types,
307    /// for a public function using the `SystemV` calling convention
308    pub fn new<P: IntoIterator<Item = AbiParam>, R: IntoIterator<Item = AbiParam>>(
309        params: P,
310        results: R,
311    ) -> Self {
312        Self {
313            params: params.into_iter().collect(),
314            results: results.into_iter().collect(),
315            cc: CallConv::SystemV,
316            visibility: Visibility::Public,
317        }
318    }
319
320    /// Returns true if this function is externally visible
321    pub fn is_public(&self) -> bool {
322        matches!(self.visibility, Visibility::Public)
323    }
324
325    /// Returns true if this function is only visible within it's containing module
326    pub fn is_private(&self) -> bool {
327        matches!(self.visibility, Visibility::Public)
328    }
329
330    /// Returns true if this function is a kernel function
331    pub fn is_kernel(&self) -> bool {
332        matches!(self.cc, CallConv::Kernel)
333    }
334
335    /// Returns the number of arguments expected by this function
336    pub fn arity(&self) -> usize {
337        self.params().len()
338    }
339
340    /// Returns a slice containing the parameters for this function
341    pub fn params(&self) -> &[AbiParam] {
342        self.params.as_slice()
343    }
344
345    /// Returns the parameter at `index`, if present
346    #[inline]
347    pub fn param(&self, index: usize) -> Option<&AbiParam> {
348        self.params.get(index)
349    }
350
351    /// Returns a slice containing the results of this function
352    pub fn results(&self) -> &[AbiParam] {
353        match self.results.as_slice() {
354            [AbiParam {
355                ty: Type::Never, ..
356            }] => &[],
357            results => results,
358        }
359    }
360}
361impl formatter::PrettyPrint for Signature {
362    fn render(&self) -> formatter::Document {
363        use crate::formatter::*;
364
365        let cc = if matches!(self.cc, CallConv::SystemV) {
366            None
367        } else {
368            Some(
369                const_text("(")
370                    + const_text("cc")
371                    + const_text(" ")
372                    + display(self.cc)
373                    + const_text(")"),
374            )
375        };
376
377        let params = self.params.iter().fold(cc.unwrap_or(Document::Empty), |acc, param| {
378            if acc.is_empty() {
379                param.render()
380            } else {
381                acc + const_text(" ") + param.render()
382            }
383        });
384
385        if self.results.is_empty() {
386            params
387        } else {
388            let open = const_text("(") + const_text("result");
389            let results = self
390                .results
391                .iter()
392                .fold(open, |acc, e| acc + const_text(" ") + text(format!("{}", &e.ty)))
393                + const_text(")");
394            if matches!(params, Document::Empty) {
395                results
396            } else {
397                params + const_text(" ") + results
398            }
399        }
400    }
401}