Skip to main content

miden_assembly_syntax/ast/
type.rs

1use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
2
3use miden_debug_types::{SourceManager, SourceSpan, Span, Spanned};
4pub use midenc_hir_type as types;
5use midenc_hir_type::{AddressSpace, Type, TypeRepr};
6
7use super::{
8    ConstantExpr, DocString, GlobalItemIndex, Ident, ItemIndex, Path, SymbolResolution,
9    SymbolResolutionError, Visibility,
10};
11
12/// Maximum allowed nesting depth of type expressions during resolution.
13///
14/// This limit is intended to prevent stack overflows from maliciously deep type expressions while
15/// remaining far above typical type nesting in real programs.
16const MAX_TYPE_EXPR_NESTING: usize = 256;
17
18/// Abstracts over resolving an item to a concrete [Type], using one of:
19///
20/// * A [GlobalItemIndex]
21/// * An [ItemIndex]
22/// * A [Path]
23/// * A [TypeExpr]
24///
25/// Since type resolution happens in two different contexts during assembly, this abstraction allows
26/// us to share more of the resolution logic in both places.
27pub trait TypeResolver<E> {
28    fn source_manager(&self) -> Arc<dyn SourceManager>;
29    /// Should be called by consumers of this resolver to convert a [SymbolResolutionError] to the
30    /// error type used by the [TypeResolver] implementation.
31    fn resolve_local_failed(&self, err: SymbolResolutionError) -> E;
32    /// Get the [Type] corresponding to the item given by `gid`
33    fn get_type(&self, context: SourceSpan, gid: GlobalItemIndex) -> Result<Type, E>;
34    /// Get the [Type] corresponding to the item in the current module given by `id`
35    fn get_local_type(&self, context: SourceSpan, id: ItemIndex) -> Result<Option<Type>, E>;
36    /// Attempt to resolve a symbol path, given by a `TypeExpr::Ref`, to an item
37    fn resolve_type_ref(&self, ty: Span<&Path>) -> Result<SymbolResolution, E>;
38    /// Resolve a [TypeExpr] to a concrete [Type]
39    fn resolve(&self, ty: &TypeExpr) -> Result<Option<Type>, E> {
40        ty.resolve_type(self)
41    }
42}
43
44// TYPE DECLARATION
45// ================================================================================================
46
47/// An abstraction over the different types of type declarations allowed in Miden Assembly
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum TypeDecl {
50    /// A named type, i.e. a type alias
51    Alias(TypeAlias),
52    /// A C-like enumeration type with associated constants
53    Enum(EnumType),
54}
55
56impl TypeDecl {
57    /// Adds documentation to this type alias
58    pub fn with_docs(self, docs: Option<Span<String>>) -> Self {
59        match self {
60            Self::Alias(ty) => Self::Alias(ty.with_docs(docs)),
61            Self::Enum(ty) => Self::Enum(ty.with_docs(docs)),
62        }
63    }
64
65    /// Get the name assigned to this type declaration
66    pub fn name(&self) -> &Ident {
67        match self {
68            Self::Alias(ty) => &ty.name,
69            Self::Enum(ty) => &ty.name,
70        }
71    }
72
73    /// Get the visibility of this type declaration
74    pub const fn visibility(&self) -> Visibility {
75        match self {
76            Self::Alias(ty) => ty.visibility,
77            Self::Enum(ty) => ty.visibility,
78        }
79    }
80
81    /// Get the documentation of this enum type
82    pub fn docs(&self) -> Option<Span<&str>> {
83        match self {
84            Self::Alias(ty) => ty.docs(),
85            Self::Enum(ty) => ty.docs(),
86        }
87    }
88
89    /// Get the type expression associated with this declaration
90    pub fn ty(&self) -> TypeExpr {
91        match self {
92            Self::Alias(ty) => ty.ty.clone(),
93            Self::Enum(ty) => TypeExpr::Primitive(Span::new(ty.span, ty.ty.clone())),
94        }
95    }
96}
97
98impl Spanned for TypeDecl {
99    fn span(&self) -> SourceSpan {
100        match self {
101            Self::Alias(spanned) => spanned.span,
102            Self::Enum(spanned) => spanned.span,
103        }
104    }
105}
106
107impl From<TypeAlias> for TypeDecl {
108    fn from(value: TypeAlias) -> Self {
109        Self::Alias(value)
110    }
111}
112
113impl From<EnumType> for TypeDecl {
114    fn from(value: EnumType) -> Self {
115        Self::Enum(value)
116    }
117}
118
119impl crate::prettier::PrettyPrint for TypeDecl {
120    fn render(&self) -> crate::prettier::Document {
121        match self {
122            Self::Alias(ty) => ty.render(),
123            Self::Enum(ty) => ty.render(),
124        }
125    }
126}
127
128// FUNCTION TYPE
129// ================================================================================================
130
131/// A procedure type signature
132#[derive(Debug, Clone)]
133pub struct FunctionType {
134    pub span: SourceSpan,
135    pub cc: types::CallConv,
136    pub args: Vec<TypeExpr>,
137    pub results: Vec<TypeExpr>,
138}
139
140impl Eq for FunctionType {}
141
142impl PartialEq for FunctionType {
143    fn eq(&self, other: &Self) -> bool {
144        self.cc == other.cc && self.args == other.args && self.results == other.results
145    }
146}
147
148impl core::hash::Hash for FunctionType {
149    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
150        self.cc.hash(state);
151        self.args.hash(state);
152        self.results.hash(state);
153    }
154}
155
156impl Spanned for FunctionType {
157    fn span(&self) -> SourceSpan {
158        self.span
159    }
160}
161
162impl FunctionType {
163    pub fn new(cc: types::CallConv, args: Vec<TypeExpr>, results: Vec<TypeExpr>) -> Self {
164        Self {
165            span: SourceSpan::UNKNOWN,
166            cc,
167            args,
168            results,
169        }
170    }
171
172    /// Override the default source span
173    #[inline]
174    pub fn with_span(mut self, span: SourceSpan) -> Self {
175        self.span = span;
176        self
177    }
178}
179
180impl crate::prettier::PrettyPrint for FunctionType {
181    fn render(&self) -> crate::prettier::Document {
182        use crate::prettier::*;
183
184        let singleline_args = self
185            .args
186            .iter()
187            .map(|arg| arg.render())
188            .reduce(|acc, arg| acc + const_text(", ") + arg)
189            .unwrap_or(Document::Empty);
190        let multiline_args = indent(
191            4,
192            nl() + self
193                .args
194                .iter()
195                .map(|arg| arg.render())
196                .reduce(|acc, arg| acc + const_text(",") + nl() + arg)
197                .unwrap_or(Document::Empty),
198        ) + nl();
199        let args = singleline_args | multiline_args;
200        let args = const_text("(") + args + const_text(")");
201
202        match self.results.len() {
203            0 => args,
204            1 => args + const_text(" -> ") + self.results[0].render(),
205            _ => {
206                let results = self
207                    .results
208                    .iter()
209                    .map(|r| r.render())
210                    .reduce(|acc, r| acc + const_text(", ") + r)
211                    .unwrap_or(Document::Empty);
212                args + const_text(" -> ") + const_text("(") + results + const_text(")")
213            },
214        }
215    }
216}
217
218// TYPE EXPRESSION
219// ================================================================================================
220
221/// A syntax-level type expression (i.e. primitive type, reference to nominal type, etc.)
222#[derive(Debug, Clone, Eq, PartialEq, Hash)]
223pub enum TypeExpr {
224    /// A primitive integral type, e.g. `i1`, `u16`
225    Primitive(Span<Type>),
226    /// A pointer type expression, e.g. `*u8`
227    Ptr(PointerType),
228    /// An array type expression, e.g. `[u8; 32]`
229    Array(ArrayType),
230    /// A struct type expression, e.g. `struct { a: u32 }`
231    Struct(StructType),
232    /// A reference to a type aliased by name, e.g. `Foo`
233    Ref(Span<Arc<Path>>),
234}
235
236impl TypeExpr {
237    /// Get any references to other types present in this expression
238    pub fn references(&self) -> Vec<Span<Arc<Path>>> {
239        use alloc::collections::BTreeSet;
240
241        let mut worklist = smallvec::SmallVec::<[_; 4]>::from_slice(&[self]);
242        let mut references = BTreeSet::new();
243
244        while let Some(ty) = worklist.pop() {
245            match ty {
246                Self::Primitive(_) => continue,
247                Self::Ptr(ty) => {
248                    worklist.push(&ty.pointee);
249                },
250                Self::Array(ty) => {
251                    worklist.push(&ty.elem);
252                },
253                Self::Struct(ty) => {
254                    for field in ty.fields.iter() {
255                        worklist.push(&field.ty);
256                    }
257                },
258                Self::Ref(ty) => {
259                    references.insert(ty.clone());
260                },
261            }
262        }
263
264        references.into_iter().collect()
265    }
266
267    /// Resolve this type expression to a concrete type, using `resolver`
268    pub fn resolve_type<E, R>(&self, resolver: &R) -> Result<Option<Type>, E>
269    where
270        R: ?Sized + TypeResolver<E>,
271    {
272        self.resolve_type_with_depth(resolver, 0)
273    }
274
275    fn resolve_type_with_depth<E, R>(&self, resolver: &R, depth: usize) -> Result<Option<Type>, E>
276    where
277        R: ?Sized + TypeResolver<E>,
278    {
279        if depth > MAX_TYPE_EXPR_NESTING {
280            let source_manager = resolver.source_manager();
281            return Err(resolver.resolve_local_failed(
282                SymbolResolutionError::type_expression_depth_exceeded(
283                    self.span(),
284                    MAX_TYPE_EXPR_NESTING,
285                    source_manager.as_ref(),
286                ),
287            ));
288        }
289
290        match self {
291            TypeExpr::Ref(path) => {
292                let mut current_path = path.clone();
293                loop {
294                    match resolver.resolve_type_ref(current_path.as_deref())? {
295                        SymbolResolution::Local(item) => {
296                            return resolver.get_local_type(current_path.span(), item.into_inner());
297                        },
298                        SymbolResolution::External(path) => {
299                            // We don't have a definition for this type yet
300                            if path == current_path {
301                                break Ok(None);
302                            }
303                            current_path = path;
304                        },
305                        SymbolResolution::Exact { gid, .. } => {
306                            return resolver.get_type(current_path.span(), gid).map(Some);
307                        },
308                        SymbolResolution::Module { path: module_path, .. } => {
309                            break Err(resolver.resolve_local_failed(
310                                SymbolResolutionError::invalid_symbol_type(
311                                    path.span(),
312                                    "type",
313                                    module_path.span(),
314                                    &resolver.source_manager(),
315                                ),
316                            ));
317                        },
318                        SymbolResolution::MastRoot(item) => {
319                            break Err(resolver.resolve_local_failed(
320                                SymbolResolutionError::invalid_symbol_type(
321                                    path.span(),
322                                    "type",
323                                    item.span(),
324                                    &resolver.source_manager(),
325                                ),
326                            ));
327                        },
328                    }
329                }
330            },
331            TypeExpr::Primitive(t) => Ok(Some(t.inner().clone())),
332            TypeExpr::Array(t) => Ok(t
333                .elem
334                .resolve_type_with_depth(resolver, depth + 1)?
335                .map(|elem| types::Type::Array(Arc::new(types::ArrayType::new(elem, t.arity))))),
336            TypeExpr::Ptr(ty) => Ok(ty
337                .pointee
338                .resolve_type_with_depth(resolver, depth + 1)?
339                .map(|pointee| types::Type::Ptr(Arc::new(types::PointerType::new(pointee))))),
340            TypeExpr::Struct(t) => {
341                let mut fields = Vec::with_capacity(t.fields.len());
342                for field in t.fields.iter() {
343                    let field_ty = field.ty.resolve_type_with_depth(resolver, depth + 1)?;
344                    if let Some(field_ty) = field_ty {
345                        fields.push(field_ty);
346                    } else {
347                        return Ok(None);
348                    }
349                }
350                Ok(Some(Type::Struct(Arc::new(types::StructType::new(fields)))))
351            },
352        }
353    }
354}
355
356impl From<Type> for TypeExpr {
357    fn from(ty: Type) -> Self {
358        match ty {
359            Type::Array(t) => Self::Array(ArrayType::new(t.element_type().clone().into(), t.len())),
360            Type::Struct(t) => {
361                Self::Struct(StructType::new(t.fields().iter().enumerate().map(|(i, ft)| {
362                    let name = Ident::new(format!("field{i}")).unwrap();
363                    StructField {
364                        span: SourceSpan::UNKNOWN,
365                        name,
366                        ty: ft.ty.clone().into(),
367                    }
368                })))
369            },
370            Type::Ptr(t) => Self::Ptr((*t).clone().into()),
371            Type::Function(_) => {
372                Self::Ptr(PointerType::new(TypeExpr::Primitive(Span::unknown(Type::Felt))))
373            },
374            Type::List(t) => Self::Ptr(
375                PointerType::new((*t).clone().into()).with_address_space(AddressSpace::Byte),
376            ),
377            Type::I128 | Type::U128 => Self::Array(ArrayType::new(Type::U32.into(), 4)),
378            Type::I64 | Type::U64 => Self::Array(ArrayType::new(Type::U32.into(), 2)),
379            Type::Unknown | Type::Never | Type::F64 => panic!("unrepresentable type value: {ty}"),
380            ty => Self::Primitive(Span::unknown(ty)),
381        }
382    }
383}
384
385impl Spanned for TypeExpr {
386    fn span(&self) -> SourceSpan {
387        match self {
388            Self::Primitive(spanned) => spanned.span(),
389            Self::Ptr(spanned) => spanned.span(),
390            Self::Array(spanned) => spanned.span(),
391            Self::Struct(spanned) => spanned.span(),
392            Self::Ref(spanned) => spanned.span(),
393        }
394    }
395}
396
397impl crate::prettier::PrettyPrint for TypeExpr {
398    fn render(&self) -> crate::prettier::Document {
399        use crate::prettier::*;
400
401        match self {
402            Self::Primitive(ty) => display(ty),
403            Self::Ptr(ty) => ty.render(),
404            Self::Array(ty) => ty.render(),
405            Self::Struct(ty) => ty.render(),
406            Self::Ref(ty) => display(ty),
407        }
408    }
409}
410
411// POINTER TYPE
412// ================================================================================================
413
414#[derive(Debug, Clone)]
415pub struct PointerType {
416    pub span: SourceSpan,
417    pub pointee: Box<TypeExpr>,
418    addrspace: Option<AddressSpace>,
419}
420
421impl From<types::PointerType> for PointerType {
422    fn from(ty: types::PointerType) -> Self {
423        let types::PointerType { addrspace, pointee } = ty;
424        let pointee = Box::new(TypeExpr::from(pointee));
425        Self {
426            span: SourceSpan::UNKNOWN,
427            pointee,
428            addrspace: Some(addrspace),
429        }
430    }
431}
432
433impl Eq for PointerType {}
434
435impl PartialEq for PointerType {
436    fn eq(&self, other: &Self) -> bool {
437        self.address_space() == other.address_space() && self.pointee == other.pointee
438    }
439}
440
441impl core::hash::Hash for PointerType {
442    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
443        self.pointee.hash(state);
444        self.address_space().hash(state);
445    }
446}
447
448impl Spanned for PointerType {
449    fn span(&self) -> SourceSpan {
450        self.span
451    }
452}
453
454impl PointerType {
455    pub fn new(pointee: TypeExpr) -> Self {
456        Self {
457            span: SourceSpan::UNKNOWN,
458            pointee: Box::new(pointee),
459            addrspace: None,
460        }
461    }
462
463    /// Override the default source span
464    #[inline]
465    pub fn with_span(mut self, span: SourceSpan) -> Self {
466        self.span = span;
467        self
468    }
469
470    /// Override the default address space
471    #[inline]
472    pub fn with_address_space(mut self, addrspace: AddressSpace) -> Self {
473        self.addrspace = Some(addrspace);
474        self
475    }
476
477    /// Get the address space of this pointer type
478    #[inline]
479    pub fn address_space(&self) -> AddressSpace {
480        self.addrspace.unwrap_or(AddressSpace::Element)
481    }
482}
483
484impl crate::prettier::PrettyPrint for PointerType {
485    fn render(&self) -> crate::prettier::Document {
486        use crate::prettier::*;
487
488        let doc = const_text("ptr<") + self.pointee.render();
489        if let Some(addrspace) = self.addrspace.as_ref() {
490            doc + const_text(", ") + text(format!("addrspace({})", addrspace)) + const_text(">")
491        } else {
492            doc + const_text(">")
493        }
494    }
495}
496
497// ARRAY TYPE
498// ================================================================================================
499
500#[derive(Debug, Clone)]
501pub struct ArrayType {
502    pub span: SourceSpan,
503    pub elem: Box<TypeExpr>,
504    pub arity: usize,
505}
506
507impl Eq for ArrayType {}
508
509impl PartialEq for ArrayType {
510    fn eq(&self, other: &Self) -> bool {
511        self.arity == other.arity && self.elem == other.elem
512    }
513}
514
515impl core::hash::Hash for ArrayType {
516    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
517        self.elem.hash(state);
518        self.arity.hash(state);
519    }
520}
521
522impl Spanned for ArrayType {
523    fn span(&self) -> SourceSpan {
524        self.span
525    }
526}
527
528impl ArrayType {
529    pub fn new(elem: TypeExpr, arity: usize) -> Self {
530        Self {
531            span: SourceSpan::UNKNOWN,
532            elem: Box::new(elem),
533            arity,
534        }
535    }
536
537    /// Override the default source span
538    #[inline]
539    pub fn with_span(mut self, span: SourceSpan) -> Self {
540        self.span = span;
541        self
542    }
543}
544
545impl crate::prettier::PrettyPrint for ArrayType {
546    fn render(&self) -> crate::prettier::Document {
547        use crate::prettier::*;
548
549        const_text("[")
550            + self.elem.render()
551            + const_text("; ")
552            + display(self.arity)
553            + const_text("]")
554    }
555}
556
557// STRUCT TYPE
558// ================================================================================================
559
560#[derive(Debug, Clone)]
561pub struct StructType {
562    pub span: SourceSpan,
563    pub repr: Span<TypeRepr>,
564    pub fields: Vec<StructField>,
565}
566
567impl Eq for StructType {}
568
569impl PartialEq for StructType {
570    fn eq(&self, other: &Self) -> bool {
571        self.repr == other.repr && self.fields == other.fields
572    }
573}
574
575impl core::hash::Hash for StructType {
576    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
577        self.repr.hash(state);
578        self.fields.hash(state);
579    }
580}
581
582impl Spanned for StructType {
583    fn span(&self) -> SourceSpan {
584        self.span
585    }
586}
587
588impl StructType {
589    pub fn new(fields: impl IntoIterator<Item = StructField>) -> Self {
590        Self {
591            span: SourceSpan::UNKNOWN,
592            repr: Span::unknown(TypeRepr::Default),
593            fields: fields.into_iter().collect(),
594        }
595    }
596
597    /// Override the default struct representation
598    #[inline]
599    pub fn with_repr(mut self, repr: Span<TypeRepr>) -> Self {
600        self.repr = repr;
601        self
602    }
603
604    /// Override the default source span
605    #[inline]
606    pub fn with_span(mut self, span: SourceSpan) -> Self {
607        self.span = span;
608        self
609    }
610}
611
612impl crate::prettier::PrettyPrint for StructType {
613    fn render(&self) -> crate::prettier::Document {
614        use crate::prettier::*;
615
616        let repr = match &*self.repr {
617            TypeRepr::Default => Document::Empty,
618            TypeRepr::BigEndian => const_text("@bigendian "),
619            repr @ (TypeRepr::Align(_) | TypeRepr::Packed(_) | TypeRepr::Transparent) => {
620                text(format!("@{repr} "))
621            },
622        };
623
624        let singleline_body = self
625            .fields
626            .iter()
627            .map(|field| field.render())
628            .reduce(|acc, field| acc + const_text(", ") + field)
629            .unwrap_or(Document::Empty);
630        let multiline_body = indent(
631            4,
632            nl() + self
633                .fields
634                .iter()
635                .map(|field| field.render())
636                .reduce(|acc, field| acc + const_text(",") + nl() + field)
637                .unwrap_or(Document::Empty),
638        ) + nl();
639        let body = singleline_body | multiline_body;
640
641        repr + const_text("struct") + const_text(" { ") + body + const_text(" }")
642    }
643}
644
645// STRUCT FIELD
646// ================================================================================================
647
648#[derive(Debug, Clone)]
649pub struct StructField {
650    pub span: SourceSpan,
651    pub name: Ident,
652    pub ty: TypeExpr,
653}
654
655impl Eq for StructField {}
656
657impl PartialEq for StructField {
658    fn eq(&self, other: &Self) -> bool {
659        self.name == other.name && self.ty == other.ty
660    }
661}
662
663impl core::hash::Hash for StructField {
664    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
665        self.name.hash(state);
666        self.ty.hash(state);
667    }
668}
669
670impl Spanned for StructField {
671    fn span(&self) -> SourceSpan {
672        self.span
673    }
674}
675
676impl crate::prettier::PrettyPrint for StructField {
677    fn render(&self) -> crate::prettier::Document {
678        use crate::prettier::*;
679
680        display(&self.name) + const_text(": ") + self.ty.render()
681    }
682}
683
684// TYPE ALIAS
685// ================================================================================================
686
687/// A [TypeAlias] represents a named [Type].
688///
689/// Type aliases correspond to type declarations in Miden Assembly source files. They are called
690/// aliases, rather than declarations, as the type system for Miden Assembly is structural, rather
691/// than nominal, and so two aliases with the same underlying type are considered equivalent.
692#[derive(Debug, Clone)]
693pub struct TypeAlias {
694    span: SourceSpan,
695    /// The documentation string attached to this definition.
696    docs: Option<DocString>,
697    /// The visibility of this type alias
698    pub visibility: Visibility,
699    /// The name of this type alias
700    pub name: Ident,
701    /// The concrete underlying type
702    pub ty: TypeExpr,
703}
704
705impl TypeAlias {
706    /// Create a new type alias from a name and type
707    pub fn new(visibility: Visibility, name: Ident, ty: TypeExpr) -> Self {
708        Self {
709            span: name.span(),
710            docs: None,
711            visibility,
712            name,
713            ty,
714        }
715    }
716
717    /// Adds documentation to this type alias
718    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
719        self.docs = docs.map(DocString::new);
720        self
721    }
722
723    /// Override the default source span
724    #[inline]
725    pub fn with_span(mut self, span: SourceSpan) -> Self {
726        self.span = span;
727        self
728    }
729
730    /// Set the source span
731    #[inline]
732    pub fn set_span(&mut self, span: SourceSpan) {
733        self.span = span;
734    }
735
736    /// Returns the documentation associated with this item.
737    pub fn docs(&self) -> Option<Span<&str>> {
738        self.docs.as_ref().map(|docstring| docstring.as_spanned_str())
739    }
740
741    /// Get the name of this type alias
742    pub fn name(&self) -> &Ident {
743        &self.name
744    }
745
746    /// Get the visibility of this type alias
747    #[inline]
748    pub const fn visibility(&self) -> Visibility {
749        self.visibility
750    }
751}
752
753impl Eq for TypeAlias {}
754
755impl PartialEq for TypeAlias {
756    fn eq(&self, other: &Self) -> bool {
757        self.visibility == other.visibility
758            && self.name == other.name
759            && self.docs == other.docs
760            && self.ty == other.ty
761    }
762}
763
764impl core::hash::Hash for TypeAlias {
765    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
766        let Self { span: _, docs, visibility, name, ty } = self;
767        docs.hash(state);
768        visibility.hash(state);
769        name.hash(state);
770        ty.hash(state);
771    }
772}
773
774impl Spanned for TypeAlias {
775    fn span(&self) -> SourceSpan {
776        self.span
777    }
778}
779
780impl crate::prettier::PrettyPrint for TypeAlias {
781    fn render(&self) -> crate::prettier::Document {
782        use crate::prettier::*;
783
784        let mut doc = self
785            .docs
786            .as_ref()
787            .map(|docstring| docstring.render())
788            .unwrap_or(Document::Empty);
789
790        if self.visibility.is_public() {
791            doc += display(self.visibility) + const_text(" ");
792        }
793
794        doc + const_text("type")
795            + const_text(" ")
796            + display(&self.name)
797            + const_text(" = ")
798            + self.ty.render()
799    }
800}
801
802// ENUM TYPE
803// ================================================================================================
804
805/// A combined type alias and constant declaration corresponding to a C-like enumeration.
806///
807/// C-style enumerations are effectively a type alias for an integer type with a limited set of
808/// valid values with associated names (referred to as _variants_ of the enum type).
809///
810/// In Miden Assembly, these provide a means for a procedure to declare that it expects an argument
811/// of the underlying integral type, but that values other than those of the declared variants are
812/// illegal/invalid. Currently, these are unchecked, and are only used to convey semantic
813/// information. In the future, we may perform static analysis to try and identify invalid instances
814/// of the enumeration when derived from a constant.
815#[derive(Debug, Clone)]
816pub struct EnumType {
817    span: SourceSpan,
818    /// The documentation string attached to this definition.
819    docs: Option<DocString>,
820    /// The visibility of this enum type
821    visibility: Visibility,
822    /// The enum name
823    name: Ident,
824    /// The type of the discriminant value used for this enum's variants
825    ///
826    /// NOTE: The type must be an integral value, and this is enforced by [`Self::new`].
827    ty: Type,
828    /// The enum variants
829    variants: Vec<Variant>,
830}
831
832impl EnumType {
833    /// Construct a new enum type with the given name and variants
834    ///
835    /// The caller is assumed to have already validated that `ty` is an integral type, and this
836    /// function will assert that this is the case.
837    pub fn new(
838        visibility: Visibility,
839        name: Ident,
840        ty: Type,
841        variants: impl IntoIterator<Item = Variant>,
842    ) -> Self {
843        assert!(ty.is_integer(), "only integer types are allowed in enum type definitions");
844        Self {
845            span: name.span(),
846            docs: None,
847            visibility,
848            name,
849            ty,
850            variants: Vec::from_iter(variants),
851        }
852    }
853
854    /// Adds documentation to this enum declaration.
855    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
856        self.docs = docs.map(DocString::new);
857        self
858    }
859
860    /// Override the default source span
861    pub fn with_span(mut self, span: SourceSpan) -> Self {
862        self.span = span;
863        self
864    }
865
866    /// Set the source span
867    pub fn set_span(&mut self, span: SourceSpan) {
868        self.span = span;
869    }
870
871    /// Get the name of this enum type
872    pub fn name(&self) -> &Ident {
873        &self.name
874    }
875
876    /// Get the visibility of this enum type
877    pub const fn visibility(&self) -> Visibility {
878        self.visibility
879    }
880
881    /// Returns the documentation associated with this item.
882    pub fn docs(&self) -> Option<Span<&str>> {
883        self.docs.as_ref().map(|docstring| docstring.as_spanned_str())
884    }
885
886    /// Get the concrete type of this enum's variants
887    pub fn ty(&self) -> &Type {
888        &self.ty
889    }
890
891    /// Get the variants of this enum type
892    pub fn variants(&self) -> &[Variant] {
893        &self.variants
894    }
895
896    /// Get the variants of this enum type, mutably
897    pub fn variants_mut(&mut self) -> &mut Vec<Variant> {
898        &mut self.variants
899    }
900
901    /// Split this definition into its type alias and variant parts
902    pub fn into_parts(self) -> (TypeAlias, Vec<Variant>) {
903        let Self {
904            span,
905            docs,
906            visibility,
907            name,
908            ty,
909            variants,
910        } = self;
911        let alias = TypeAlias {
912            span,
913            docs,
914            visibility,
915            name,
916            ty: TypeExpr::Primitive(Span::new(span, ty)),
917        };
918        (alias, variants)
919    }
920}
921
922impl Spanned for EnumType {
923    fn span(&self) -> SourceSpan {
924        self.span
925    }
926}
927
928impl Eq for EnumType {}
929
930impl PartialEq for EnumType {
931    fn eq(&self, other: &Self) -> bool {
932        self.visibility == other.visibility
933            && self.name == other.name
934            && self.docs == other.docs
935            && self.ty == other.ty
936            && self.variants == other.variants
937    }
938}
939
940impl core::hash::Hash for EnumType {
941    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
942        let Self {
943            span: _,
944            docs,
945            visibility,
946            name,
947            ty,
948            variants,
949        } = self;
950        docs.hash(state);
951        visibility.hash(state);
952        name.hash(state);
953        ty.hash(state);
954        variants.hash(state);
955    }
956}
957
958impl crate::prettier::PrettyPrint for EnumType {
959    fn render(&self) -> crate::prettier::Document {
960        use crate::prettier::*;
961
962        let mut doc = self
963            .docs
964            .as_ref()
965            .map(|docstring| docstring.render())
966            .unwrap_or(Document::Empty);
967
968        let variants = self
969            .variants
970            .iter()
971            .map(|v| v.render())
972            .reduce(|acc, v| acc + const_text(",") + nl() + v)
973            .unwrap_or(Document::Empty);
974
975        if self.visibility.is_public() {
976            doc += display(self.visibility) + const_text(" ");
977        }
978
979        doc + const_text("enum")
980            + const_text(" ")
981            + display(&self.name)
982            + const_text(" : ")
983            + self.ty.render()
984            + const_text(" {")
985            + nl()
986            + variants
987            + const_text("}")
988    }
989}
990
991// ENUM VARIANT
992// ================================================================================================
993
994/// A variant of an [EnumType].
995///
996/// See the [EnumType] docs for more information.
997#[derive(Debug, Clone)]
998pub struct Variant {
999    pub span: SourceSpan,
1000    /// The documentation string attached to the constant derived from this variant.
1001    pub docs: Option<DocString>,
1002    /// The name of this enum variant
1003    pub name: Ident,
1004    /// The discriminant value associated with this variant
1005    pub discriminant: ConstantExpr,
1006}
1007
1008impl Variant {
1009    /// Construct a new variant of an [EnumType], with the given name and discriminant value.
1010    pub fn new(name: Ident, discriminant: ConstantExpr) -> Self {
1011        Self {
1012            span: name.span(),
1013            docs: None,
1014            name,
1015            discriminant,
1016        }
1017    }
1018
1019    /// Override the span for this variant
1020    pub fn with_span(mut self, span: SourceSpan) -> Self {
1021        self.span = span;
1022        self
1023    }
1024
1025    /// Adds documentation to this variant
1026    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
1027        self.docs = docs.map(DocString::new);
1028        self
1029    }
1030
1031    /// Used to validate that this variant's discriminant value is an instance of `ty`,
1032    /// which must be a type valid for use as the underlying representation for an enum, i.e. an
1033    /// integer type up to 64 bits in size.
1034    ///
1035    /// It is expected that the discriminant expression has been folded to an integer value by the
1036    /// time this is called. If the discriminant has not been fully folded, then an error will be
1037    /// returned.
1038    pub fn assert_instance_of(&self, ty: &Type) -> Result<(), crate::SemanticAnalysisError> {
1039        use crate::{FIELD_MODULUS, SemanticAnalysisError};
1040
1041        let value = match &self.discriminant {
1042            ConstantExpr::Int(value) => value.as_int(),
1043            _ => {
1044                return Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1045                    span: self.discriminant.span(),
1046                    repr: ty.clone(),
1047                });
1048            },
1049        };
1050
1051        match ty {
1052            Type::Felt if value >= FIELD_MODULUS => {
1053                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1054                    span: self.discriminant.span(),
1055                    repr: ty.clone(),
1056                })
1057            },
1058            // IntValue is represented as an unsigned integer, so negative discriminants
1059            // are rejected during constant evaluation.
1060            Type::Felt => Ok(()),
1061            Type::I1 if value > 1 => Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1062                span: self.discriminant.span(),
1063                repr: ty.clone(),
1064            }),
1065            Type::I1 => Ok(()),
1066            Type::I8 | Type::U8 if value > u8::MAX as u64 => {
1067                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1068                    span: self.discriminant.span(),
1069                    repr: ty.clone(),
1070                })
1071            },
1072            Type::I8 | Type::U8 => Ok(()),
1073            Type::I16 | Type::U16 if value > u16::MAX as u64 => {
1074                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1075                    span: self.discriminant.span(),
1076                    repr: ty.clone(),
1077                })
1078            },
1079            Type::I16 | Type::U16 => Ok(()),
1080            Type::I32 | Type::U32 if value > u32::MAX as u64 => {
1081                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1082                    span: self.discriminant.span(),
1083                    repr: ty.clone(),
1084                })
1085            },
1086            Type::I32 | Type::U32 => Ok(()),
1087            Type::I64 | Type::U64 if value >= FIELD_MODULUS => {
1088                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1089                    span: self.discriminant.span(),
1090                    repr: ty.clone(),
1091                })
1092            },
1093            _ => Err(SemanticAnalysisError::InvalidEnumRepr { span: self.span }),
1094        }
1095    }
1096}
1097
1098impl Spanned for Variant {
1099    fn span(&self) -> SourceSpan {
1100        self.span
1101    }
1102}
1103
1104impl Eq for Variant {}
1105
1106impl PartialEq for Variant {
1107    fn eq(&self, other: &Self) -> bool {
1108        self.name == other.name
1109            && self.discriminant == other.discriminant
1110            && self.docs == other.docs
1111    }
1112}
1113
1114impl core::hash::Hash for Variant {
1115    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1116        let Self { span: _, docs, name, discriminant } = self;
1117        docs.hash(state);
1118        name.hash(state);
1119        discriminant.hash(state);
1120    }
1121}
1122
1123impl crate::prettier::PrettyPrint for Variant {
1124    fn render(&self) -> crate::prettier::Document {
1125        use crate::prettier::*;
1126
1127        let doc = self
1128            .docs
1129            .as_ref()
1130            .map(|docstring| docstring.render())
1131            .unwrap_or(Document::Empty);
1132
1133        doc + display(&self.name) + const_text(" = ") + self.discriminant.render()
1134    }
1135}
1136
1137#[cfg(test)]
1138mod tests {
1139    use alloc::sync::Arc;
1140    use core::str::FromStr;
1141
1142    use miden_debug_types::DefaultSourceManager;
1143
1144    use super::*;
1145
1146    struct DummyResolver {
1147        source_manager: Arc<dyn SourceManager>,
1148    }
1149
1150    impl DummyResolver {
1151        fn new() -> Self {
1152            Self {
1153                source_manager: Arc::new(DefaultSourceManager::default()),
1154            }
1155        }
1156    }
1157
1158    impl TypeResolver<SymbolResolutionError> for DummyResolver {
1159        fn source_manager(&self) -> Arc<dyn SourceManager> {
1160            self.source_manager.clone()
1161        }
1162
1163        fn resolve_local_failed(&self, err: SymbolResolutionError) -> SymbolResolutionError {
1164            err
1165        }
1166
1167        fn get_type(
1168            &self,
1169            context: SourceSpan,
1170            _gid: GlobalItemIndex,
1171        ) -> Result<Type, SymbolResolutionError> {
1172            Err(SymbolResolutionError::undefined(context, self.source_manager.as_ref()))
1173        }
1174
1175        fn get_local_type(
1176            &self,
1177            _context: SourceSpan,
1178            _id: ItemIndex,
1179        ) -> Result<Option<Type>, SymbolResolutionError> {
1180            Ok(None)
1181        }
1182
1183        fn resolve_type_ref(
1184            &self,
1185            ty: Span<&Path>,
1186        ) -> Result<SymbolResolution, SymbolResolutionError> {
1187            Err(SymbolResolutionError::undefined(ty.span(), self.source_manager.as_ref()))
1188        }
1189    }
1190
1191    fn nested_type_expr(depth: usize) -> TypeExpr {
1192        let mut expr = TypeExpr::Primitive(Span::unknown(Type::Felt));
1193        for i in 0..depth {
1194            expr = match i % 3 {
1195                0 => TypeExpr::Ptr(PointerType::new(expr)),
1196                1 => TypeExpr::Array(ArrayType::new(expr, 1)),
1197                _ => {
1198                    let field = StructField {
1199                        span: SourceSpan::UNKNOWN,
1200                        name: Ident::from_str("field").expect("valid ident"),
1201                        ty: expr,
1202                    };
1203                    TypeExpr::Struct(StructType::new([field]))
1204                },
1205            };
1206        }
1207        expr
1208    }
1209
1210    #[test]
1211    fn type_expr_depth_boundary() {
1212        let resolver = DummyResolver::new();
1213
1214        let ok_expr = nested_type_expr(MAX_TYPE_EXPR_NESTING);
1215        assert!(ok_expr.resolve_type(&resolver).is_ok());
1216
1217        let err_expr = nested_type_expr(MAX_TYPE_EXPR_NESTING + 1);
1218        let err = err_expr.resolve_type(&resolver).expect_err("expected depth-exceeded error");
1219        assert!(
1220            matches!(err, SymbolResolutionError::TypeExpressionDepthExceeded { max_depth, .. }
1221                if max_depth == MAX_TYPE_EXPR_NESTING)
1222        );
1223    }
1224}