Skip to main content

miden_assembly_syntax/ast/
type.rs

1use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
2
3use miden_debug_types::{SourceManager, SourceSpan, Span, Spanned};
4pub use midenc_hir_type as types;
5use midenc_hir_type::{AddressSpace, Type, TypeRepr};
6
7use super::{
8    ConstantExpr, DocString, GlobalItemIndex, Ident, ItemIndex, Path, SymbolResolution,
9    SymbolResolutionError, Visibility,
10};
11
12/// Maximum allowed nesting depth of type expressions during resolution.
13///
14/// This limit is intended to prevent stack overflows from maliciously deep type expressions while
15/// remaining far above typical type nesting in real programs.
16const MAX_TYPE_EXPR_NESTING: usize = 256;
17
18/// Abstracts over resolving an item to a concrete [Type], using one of:
19///
20/// * A [GlobalItemIndex]
21/// * An [ItemIndex]
22/// * A [Path]
23/// * A [TypeExpr]
24///
25/// Since type resolution happens in two different contexts during assembly, this abstraction allows
26/// us to share more of the resolution logic in both places.
27///
28/// NOTE: Most methods of this trait take a mutable reference to the resolver, so that the resolver
29/// can mutate its own state as necessary during resolution (e.g. to manage a cache, or other side
30/// table-like data structures).
31pub trait TypeResolver<E> {
32    fn source_manager(&self) -> Arc<dyn SourceManager>;
33    /// Should be called by consumers of this resolver to convert a [SymbolResolutionError] to the
34    /// error type used by the [TypeResolver] implementation.
35    fn resolve_local_failed(&self, err: SymbolResolutionError) -> E;
36    /// Get the [Type] corresponding to the item given by `gid`
37    fn get_type(&mut self, context: SourceSpan, gid: GlobalItemIndex) -> Result<Type, E>;
38    /// Get the [Type] corresponding to the item in the current module given by `id`
39    fn get_local_type(&mut self, context: SourceSpan, id: ItemIndex) -> Result<Option<Type>, E>;
40    /// Attempt to resolve a symbol path, given by a `TypeExpr::Ref`, to an item
41    fn resolve_type_ref(&mut self, ty: Span<&Path>) -> Result<SymbolResolution, E>;
42    /// Resolve a [TypeExpr] to a concrete [Type]
43    fn resolve(&mut self, ty: &TypeExpr) -> Result<Option<Type>, E> {
44        ty.resolve_type(self)
45    }
46}
47
48// TYPE DECLARATION
49// ================================================================================================
50
51/// An abstraction over the different types of type declarations allowed in Miden Assembly
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub enum TypeDecl {
54    /// A named type, i.e. a type alias
55    Alias(TypeAlias),
56    /// A C-like enumeration type with associated constants
57    Enum(EnumType),
58}
59
60impl TypeDecl {
61    /// Adds documentation to this type alias
62    pub fn with_docs(self, docs: Option<Span<String>>) -> Self {
63        match self {
64            Self::Alias(ty) => Self::Alias(ty.with_docs(docs)),
65            Self::Enum(ty) => Self::Enum(ty.with_docs(docs)),
66        }
67    }
68
69    /// Get the name assigned to this type declaration
70    pub fn name(&self) -> &Ident {
71        match self {
72            Self::Alias(ty) => &ty.name,
73            Self::Enum(ty) => &ty.name,
74        }
75    }
76
77    /// Get the visibility of this type declaration
78    pub const fn visibility(&self) -> Visibility {
79        match self {
80            Self::Alias(ty) => ty.visibility,
81            Self::Enum(ty) => ty.visibility,
82        }
83    }
84
85    /// Get the documentation of this enum type
86    pub fn docs(&self) -> Option<Span<&str>> {
87        match self {
88            Self::Alias(ty) => ty.docs(),
89            Self::Enum(ty) => ty.docs(),
90        }
91    }
92
93    /// Get the type expression associated with this declaration
94    pub fn ty(&self) -> TypeExpr {
95        match self {
96            Self::Alias(ty) => ty.ty.clone(),
97            Self::Enum(ty) => TypeExpr::Primitive(Span::new(ty.span, ty.ty.clone())),
98        }
99    }
100}
101
102impl Spanned for TypeDecl {
103    fn span(&self) -> SourceSpan {
104        match self {
105            Self::Alias(spanned) => spanned.span,
106            Self::Enum(spanned) => spanned.span,
107        }
108    }
109}
110
111impl From<TypeAlias> for TypeDecl {
112    fn from(value: TypeAlias) -> Self {
113        Self::Alias(value)
114    }
115}
116
117impl From<EnumType> for TypeDecl {
118    fn from(value: EnumType) -> Self {
119        Self::Enum(value)
120    }
121}
122
123impl crate::prettier::PrettyPrint for TypeDecl {
124    fn render(&self) -> crate::prettier::Document {
125        match self {
126            Self::Alias(ty) => ty.render(),
127            Self::Enum(ty) => ty.render(),
128        }
129    }
130}
131
132// FUNCTION TYPE
133// ================================================================================================
134
135/// A procedure type signature
136#[derive(Debug, Clone)]
137pub struct FunctionType {
138    pub span: SourceSpan,
139    pub cc: types::CallConv,
140    pub args: Vec<TypeExpr>,
141    pub results: Vec<TypeExpr>,
142}
143
144impl Eq for FunctionType {}
145
146impl PartialEq for FunctionType {
147    fn eq(&self, other: &Self) -> bool {
148        self.cc == other.cc && self.args == other.args && self.results == other.results
149    }
150}
151
152impl core::hash::Hash for FunctionType {
153    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
154        self.cc.hash(state);
155        self.args.hash(state);
156        self.results.hash(state);
157    }
158}
159
160impl Spanned for FunctionType {
161    fn span(&self) -> SourceSpan {
162        self.span
163    }
164}
165
166impl FunctionType {
167    pub fn new(cc: types::CallConv, args: Vec<TypeExpr>, results: Vec<TypeExpr>) -> Self {
168        Self {
169            span: SourceSpan::UNKNOWN,
170            cc,
171            args,
172            results,
173        }
174    }
175
176    /// Override the default source span
177    #[inline]
178    pub fn with_span(mut self, span: SourceSpan) -> Self {
179        self.span = span;
180        self
181    }
182}
183
184impl crate::prettier::PrettyPrint for FunctionType {
185    fn render(&self) -> crate::prettier::Document {
186        use crate::prettier::*;
187
188        let singleline_args = self
189            .args
190            .iter()
191            .map(PrettyPrint::render)
192            .reduce(|acc, arg| acc + const_text(", ") + arg)
193            .unwrap_or(Document::Empty);
194        let multiline_args = indent(
195            4,
196            nl() + self
197                .args
198                .iter()
199                .map(PrettyPrint::render)
200                .reduce(|acc, arg| acc + const_text(",") + nl() + arg)
201                .unwrap_or(Document::Empty),
202        ) + nl();
203        let args = singleline_args | multiline_args;
204        let args = const_text("(") + args + const_text(")");
205
206        match self.results.len() {
207            0 => args,
208            1 => args + const_text(" -> ") + self.results[0].render(),
209            _ => {
210                let results = self
211                    .results
212                    .iter()
213                    .map(PrettyPrint::render)
214                    .reduce(|acc, r| acc + const_text(", ") + r)
215                    .unwrap_or(Document::Empty);
216                args + const_text(" -> ") + const_text("(") + results + const_text(")")
217            },
218        }
219    }
220}
221
222// TYPE EXPRESSION
223// ================================================================================================
224
225/// A syntax-level type expression (i.e. primitive type, reference to nominal type, etc.)
226#[derive(Debug, Clone, Eq, PartialEq, Hash)]
227pub enum TypeExpr {
228    /// A primitive integral type, e.g. `i1`, `u16`
229    Primitive(Span<Type>),
230    /// A pointer type expression, e.g. `*u8`
231    Ptr(PointerType),
232    /// An array type expression, e.g. `[u8; 32]`
233    Array(ArrayType),
234    /// A struct type expression, e.g. `struct { a: u32 }`
235    Struct(StructType),
236    /// A reference to a type aliased by name, e.g. `Foo`
237    Ref(Span<Arc<Path>>),
238}
239
240impl TypeExpr {
241    /// Set the name associated with this type expression, if applicable.
242    ///
243    /// Currently this just sets the name of struct types, but if we add other types with names in
244    /// the future, we can support them here.
245    pub fn set_name(&mut self, name: Ident) {
246        match self {
247            Self::Struct(struct_ty) => {
248                struct_ty.name = Some(name);
249            },
250            Self::Primitive(_) | Self::Ptr(_) | Self::Array(_) | Self::Ref(_) => (),
251        }
252    }
253
254    /// Get any references to other types present in this expression
255    pub fn references(&self) -> Vec<Span<Arc<Path>>> {
256        use alloc::collections::BTreeSet;
257
258        let mut worklist = smallvec::SmallVec::<[_; 4]>::from_slice(&[self]);
259        let mut references = BTreeSet::new();
260
261        while let Some(ty) = worklist.pop() {
262            match ty {
263                Self::Primitive(_) => {},
264                Self::Ptr(ty) => {
265                    worklist.push(&ty.pointee);
266                },
267                Self::Array(ty) => {
268                    worklist.push(&ty.elem);
269                },
270                Self::Struct(ty) => {
271                    for field in ty.fields.iter() {
272                        worklist.push(&field.ty);
273                    }
274                },
275                Self::Ref(ty) => {
276                    references.insert(ty.clone());
277                },
278            }
279        }
280
281        references.into_iter().collect()
282    }
283
284    /// Resolve this type expression to a concrete type, using `resolver`
285    pub fn resolve_type<E, R>(&self, resolver: &mut R) -> Result<Option<Type>, E>
286    where
287        R: ?Sized + TypeResolver<E>,
288    {
289        self.resolve_type_with_depth(resolver, 0)
290    }
291
292    fn resolve_type_with_depth<E, R>(
293        &self,
294        resolver: &mut R,
295        depth: usize,
296    ) -> Result<Option<Type>, E>
297    where
298        R: ?Sized + TypeResolver<E>,
299    {
300        if depth > MAX_TYPE_EXPR_NESTING {
301            let source_manager = resolver.source_manager();
302            return Err(resolver.resolve_local_failed(
303                SymbolResolutionError::type_expression_depth_exceeded(
304                    self.span(),
305                    MAX_TYPE_EXPR_NESTING,
306                    source_manager.as_ref(),
307                ),
308            ));
309        }
310
311        match self {
312            TypeExpr::Ref(path) => {
313                let mut current_path = path.clone();
314                loop {
315                    match resolver.resolve_type_ref(current_path.as_deref())? {
316                        SymbolResolution::Local(item) => {
317                            return resolver.get_local_type(current_path.span(), item.into_inner());
318                        },
319                        SymbolResolution::External(path) => {
320                            // We don't have a definition for this type yet
321                            if path == current_path {
322                                break Ok(None);
323                            }
324                            current_path = path;
325                        },
326                        SymbolResolution::Exact { gid, .. } => {
327                            return resolver.get_type(current_path.span(), gid).map(Some);
328                        },
329                        SymbolResolution::Module { path: module_path, .. } => {
330                            break Err(resolver.resolve_local_failed(
331                                SymbolResolutionError::invalid_symbol_type(
332                                    path.span(),
333                                    "type",
334                                    module_path.span(),
335                                    &resolver.source_manager(),
336                                ),
337                            ));
338                        },
339                        SymbolResolution::MastRoot(item) => {
340                            break Err(resolver.resolve_local_failed(
341                                SymbolResolutionError::invalid_symbol_type(
342                                    path.span(),
343                                    "type",
344                                    item.span(),
345                                    &resolver.source_manager(),
346                                ),
347                            ));
348                        },
349                    }
350                }
351            },
352            TypeExpr::Primitive(t) => Ok(Some(t.inner().clone())),
353            TypeExpr::Array(t) => Ok(t
354                .elem
355                .resolve_type_with_depth(resolver, depth + 1)?
356                .map(|elem| Type::Array(Arc::new(types::ArrayType::new(elem, t.arity))))),
357            TypeExpr::Ptr(ty) => Ok(ty
358                .pointee
359                .resolve_type_with_depth(resolver, depth + 1)?
360                .map(|pointee| Type::Ptr(Arc::new(types::PointerType::new(pointee))))),
361            TypeExpr::Struct(t) => {
362                let mut fields = Vec::with_capacity(t.fields.len());
363                for field in t.fields.iter() {
364                    let field_ty = field.ty.resolve_type_with_depth(resolver, depth + 1)?;
365                    if let Some(field_ty) = field_ty {
366                        fields.push(field_ty);
367                    } else {
368                        return Ok(None);
369                    }
370                }
371                Ok(Some(Type::Struct(Arc::new(types::StructType::from_parts(
372                    t.name.clone().map(Ident::into_inner),
373                    t.repr.into_inner(),
374                    fields,
375                )))))
376            },
377        }
378    }
379}
380
381impl From<Type> for TypeExpr {
382    fn from(ty: Type) -> Self {
383        match ty {
384            Type::Array(t) => Self::Array(ArrayType::new(t.element_type().clone().into(), t.len())),
385            Type::Struct(t) => Self::Struct(StructType::new(
386                None,
387                t.fields().iter().enumerate().map(|(i, ft)| {
388                    let name = Ident::new(format!("field{i}")).unwrap();
389                    StructField {
390                        span: SourceSpan::UNKNOWN,
391                        name,
392                        ty: ft.ty.clone().into(),
393                    }
394                }),
395            )),
396            Type::Ptr(t) => Self::Ptr((*t).clone().into()),
397            Type::Function(_) => {
398                Self::Ptr(PointerType::new(TypeExpr::Primitive(Span::unknown(Type::Felt))))
399            },
400            Type::List(t) => Self::Ptr(
401                PointerType::new((*t).clone().into()).with_address_space(AddressSpace::Byte),
402            ),
403            Type::I128 | Type::U128 => Self::Array(ArrayType::new(Type::U32.into(), 4)),
404            Type::I64 | Type::U64 => Self::Array(ArrayType::new(Type::U32.into(), 2)),
405            Type::Unknown | Type::Never | Type::F64 => panic!("unrepresentable type value: {ty}"),
406            ty => Self::Primitive(Span::unknown(ty)),
407        }
408    }
409}
410
411impl Spanned for TypeExpr {
412    fn span(&self) -> SourceSpan {
413        match self {
414            Self::Primitive(spanned) => spanned.span(),
415            Self::Ptr(spanned) => spanned.span(),
416            Self::Array(spanned) => spanned.span(),
417            Self::Struct(spanned) => spanned.span(),
418            Self::Ref(spanned) => spanned.span(),
419        }
420    }
421}
422
423impl crate::prettier::PrettyPrint for TypeExpr {
424    fn render(&self) -> crate::prettier::Document {
425        use crate::prettier::*;
426
427        match self {
428            Self::Primitive(ty) => display(ty),
429            Self::Ptr(ty) => ty.render(),
430            Self::Array(ty) => ty.render(),
431            Self::Struct(ty) => ty.render(),
432            Self::Ref(ty) => display(ty),
433        }
434    }
435}
436
437// POINTER TYPE
438// ================================================================================================
439
440#[derive(Debug, Clone)]
441pub struct PointerType {
442    pub span: SourceSpan,
443    pub pointee: Box<TypeExpr>,
444    addrspace: Option<AddressSpace>,
445}
446
447impl From<types::PointerType> for PointerType {
448    fn from(ty: types::PointerType) -> Self {
449        let types::PointerType { addrspace, pointee } = ty;
450        let pointee = Box::new(TypeExpr::from(pointee));
451        Self {
452            span: SourceSpan::UNKNOWN,
453            pointee,
454            addrspace: Some(addrspace),
455        }
456    }
457}
458
459impl Eq for PointerType {}
460
461impl PartialEq for PointerType {
462    fn eq(&self, other: &Self) -> bool {
463        self.address_space() == other.address_space() && self.pointee == other.pointee
464    }
465}
466
467impl core::hash::Hash for PointerType {
468    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
469        self.pointee.hash(state);
470        self.address_space().hash(state);
471    }
472}
473
474impl Spanned for PointerType {
475    fn span(&self) -> SourceSpan {
476        self.span
477    }
478}
479
480impl PointerType {
481    pub fn new(pointee: TypeExpr) -> Self {
482        Self {
483            span: SourceSpan::UNKNOWN,
484            pointee: Box::new(pointee),
485            addrspace: None,
486        }
487    }
488
489    /// Override the default source span
490    #[inline]
491    pub fn with_span(mut self, span: SourceSpan) -> Self {
492        self.span = span;
493        self
494    }
495
496    /// Override the default address space
497    #[inline]
498    pub fn with_address_space(mut self, addrspace: AddressSpace) -> Self {
499        self.addrspace = Some(addrspace);
500        self
501    }
502
503    /// Get the address space of this pointer type
504    #[inline]
505    pub fn address_space(&self) -> AddressSpace {
506        self.addrspace.unwrap_or(AddressSpace::Element)
507    }
508}
509
510impl crate::prettier::PrettyPrint for PointerType {
511    fn render(&self) -> crate::prettier::Document {
512        use crate::prettier::*;
513
514        let doc = const_text("ptr<") + self.pointee.render();
515        if let Some(addrspace) = self.addrspace.as_ref() {
516            doc + const_text(", ") + text(format!("addrspace({addrspace})")) + const_text(">")
517        } else {
518            doc + const_text(">")
519        }
520    }
521}
522
523// ARRAY TYPE
524// ================================================================================================
525
526#[derive(Debug, Clone)]
527pub struct ArrayType {
528    pub span: SourceSpan,
529    pub elem: Box<TypeExpr>,
530    pub arity: usize,
531}
532
533impl Eq for ArrayType {}
534
535impl PartialEq for ArrayType {
536    fn eq(&self, other: &Self) -> bool {
537        self.arity == other.arity && self.elem == other.elem
538    }
539}
540
541impl core::hash::Hash for ArrayType {
542    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
543        self.elem.hash(state);
544        self.arity.hash(state);
545    }
546}
547
548impl Spanned for ArrayType {
549    fn span(&self) -> SourceSpan {
550        self.span
551    }
552}
553
554impl ArrayType {
555    pub fn new(elem: TypeExpr, arity: usize) -> Self {
556        Self {
557            span: SourceSpan::UNKNOWN,
558            elem: Box::new(elem),
559            arity,
560        }
561    }
562
563    /// Override the default source span
564    #[inline]
565    pub fn with_span(mut self, span: SourceSpan) -> Self {
566        self.span = span;
567        self
568    }
569}
570
571impl crate::prettier::PrettyPrint for ArrayType {
572    fn render(&self) -> crate::prettier::Document {
573        use crate::prettier::*;
574
575        const_text("[")
576            + self.elem.render()
577            + const_text("; ")
578            + display(self.arity)
579            + const_text("]")
580    }
581}
582
583// STRUCT TYPE
584// ================================================================================================
585
586#[derive(Debug, Clone)]
587pub struct StructType {
588    pub span: SourceSpan,
589    pub name: Option<Ident>,
590    pub repr: Span<TypeRepr>,
591    pub fields: Vec<StructField>,
592}
593
594impl Eq for StructType {}
595
596impl PartialEq for StructType {
597    fn eq(&self, other: &Self) -> bool {
598        self.name == other.name && self.repr == other.repr && self.fields == other.fields
599    }
600}
601
602impl core::hash::Hash for StructType {
603    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
604        self.name.hash(state);
605        self.repr.hash(state);
606        self.fields.hash(state);
607    }
608}
609
610impl Spanned for StructType {
611    fn span(&self) -> SourceSpan {
612        self.span
613    }
614}
615
616impl StructType {
617    pub fn new(name: Option<Ident>, fields: impl IntoIterator<Item = StructField>) -> Self {
618        Self {
619            span: SourceSpan::UNKNOWN,
620            name,
621            repr: Span::unknown(TypeRepr::Default),
622            fields: fields.into_iter().collect(),
623        }
624    }
625
626    /// Override the default struct representation
627    #[inline]
628    pub fn with_repr(mut self, repr: Span<TypeRepr>) -> Self {
629        self.repr = repr;
630        self
631    }
632
633    /// Override the default source span
634    #[inline]
635    pub fn with_span(mut self, span: SourceSpan) -> Self {
636        self.span = span;
637        self
638    }
639}
640
641impl crate::prettier::PrettyPrint for StructType {
642    fn render(&self) -> crate::prettier::Document {
643        use crate::prettier::*;
644
645        let repr = match &*self.repr {
646            TypeRepr::Default => Document::Empty,
647            TypeRepr::BigEndian => const_text("@bigendian "),
648            repr @ (TypeRepr::Align(_) | TypeRepr::Packed(_) | TypeRepr::Transparent) => {
649                text(format!("@{repr} "))
650            },
651        };
652
653        let singleline_body = self
654            .fields
655            .iter()
656            .map(PrettyPrint::render)
657            .reduce(|acc, field| acc + const_text(", ") + field)
658            .unwrap_or(Document::Empty);
659        let multiline_body = indent(
660            4,
661            nl() + self
662                .fields
663                .iter()
664                .map(PrettyPrint::render)
665                .reduce(|acc, field| acc + const_text(",") + nl() + field)
666                .unwrap_or(Document::Empty),
667        ) + nl();
668        let body = singleline_body | multiline_body;
669
670        repr + const_text("struct") + const_text(" { ") + body + const_text(" }")
671    }
672}
673
674// STRUCT FIELD
675// ================================================================================================
676
677#[derive(Debug, Clone)]
678pub struct StructField {
679    pub span: SourceSpan,
680    pub name: Ident,
681    pub ty: TypeExpr,
682}
683
684impl Eq for StructField {}
685
686impl PartialEq for StructField {
687    fn eq(&self, other: &Self) -> bool {
688        self.name == other.name && self.ty == other.ty
689    }
690}
691
692impl core::hash::Hash for StructField {
693    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
694        self.name.hash(state);
695        self.ty.hash(state);
696    }
697}
698
699impl Spanned for StructField {
700    fn span(&self) -> SourceSpan {
701        self.span
702    }
703}
704
705impl crate::prettier::PrettyPrint for StructField {
706    fn render(&self) -> crate::prettier::Document {
707        use crate::prettier::*;
708
709        display(&self.name) + const_text(": ") + self.ty.render()
710    }
711}
712
713// TYPE ALIAS
714// ================================================================================================
715
716/// A [TypeAlias] represents a named [Type].
717///
718/// Type aliases correspond to type declarations in Miden Assembly source files. They are called
719/// aliases, rather than declarations, as the type system for Miden Assembly is structural, rather
720/// than nominal, and so two aliases with the same underlying type are considered equivalent.
721#[derive(Debug, Clone)]
722pub struct TypeAlias {
723    span: SourceSpan,
724    /// The documentation string attached to this definition.
725    docs: Option<DocString>,
726    /// The visibility of this type alias
727    pub visibility: Visibility,
728    /// The name of this type alias
729    pub name: Ident,
730    /// The concrete underlying type
731    pub ty: TypeExpr,
732}
733
734impl TypeAlias {
735    /// Create a new type alias from a name and type
736    pub fn new(visibility: Visibility, name: Ident, ty: TypeExpr) -> Self {
737        Self {
738            span: name.span(),
739            docs: None,
740            visibility,
741            name,
742            ty,
743        }
744    }
745
746    /// Adds documentation to this type alias
747    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
748        self.docs = docs.map(DocString::new);
749        self
750    }
751
752    /// Override the default source span
753    #[inline]
754    pub fn with_span(mut self, span: SourceSpan) -> Self {
755        self.span = span;
756        self
757    }
758
759    /// Set the source span
760    #[inline]
761    pub fn set_span(&mut self, span: SourceSpan) {
762        self.span = span;
763    }
764
765    /// Returns the documentation associated with this item.
766    pub fn docs(&self) -> Option<Span<&str>> {
767        self.docs.as_ref().map(|docstring| docstring.as_spanned_str())
768    }
769
770    /// Get the name of this type alias
771    pub fn name(&self) -> &Ident {
772        &self.name
773    }
774
775    /// Get the visibility of this type alias
776    #[inline]
777    pub const fn visibility(&self) -> Visibility {
778        self.visibility
779    }
780}
781
782impl Eq for TypeAlias {}
783
784impl PartialEq for TypeAlias {
785    fn eq(&self, other: &Self) -> bool {
786        self.visibility == other.visibility
787            && self.name == other.name
788            && self.docs == other.docs
789            && self.ty == other.ty
790    }
791}
792
793impl core::hash::Hash for TypeAlias {
794    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
795        let Self { span: _, docs, visibility, name, ty } = self;
796        docs.hash(state);
797        visibility.hash(state);
798        name.hash(state);
799        ty.hash(state);
800    }
801}
802
803impl Spanned for TypeAlias {
804    fn span(&self) -> SourceSpan {
805        self.span
806    }
807}
808
809impl crate::prettier::PrettyPrint for TypeAlias {
810    fn render(&self) -> crate::prettier::Document {
811        use crate::prettier::*;
812
813        let mut doc = self.docs.as_ref().map(PrettyPrint::render).unwrap_or(Document::Empty);
814
815        if self.visibility.is_public() {
816            doc += display(self.visibility) + const_text(" ");
817        }
818
819        doc + const_text("type")
820            + const_text(" ")
821            + display(&self.name)
822            + const_text(" = ")
823            + self.ty.render()
824    }
825}
826
827// ENUM TYPE
828// ================================================================================================
829
830/// A combined type alias and constant declaration corresponding to a C-like enumeration.
831///
832/// C-style enumerations are effectively a type alias for an integer type with a limited set of
833/// valid values with associated names (referred to as _variants_ of the enum type).
834///
835/// In Miden Assembly, these provide a means for a procedure to declare that it expects an argument
836/// of the underlying integral type, but that values other than those of the declared variants are
837/// illegal/invalid. Currently, these are unchecked, and are only used to convey semantic
838/// information. In the future, we may perform static analysis to try and identify invalid instances
839/// of the enumeration when derived from a constant.
840#[derive(Debug, Clone)]
841pub struct EnumType {
842    span: SourceSpan,
843    /// The documentation string attached to this definition.
844    docs: Option<DocString>,
845    /// The visibility of this enum type
846    visibility: Visibility,
847    /// The enum name
848    name: Ident,
849    /// The type of the discriminant value used for this enum's variants
850    ///
851    /// NOTE: The type must be an integral value, and this is enforced by [`Self::new`].
852    ty: Type,
853    /// The enum variants
854    variants: Vec<Variant>,
855}
856
857impl EnumType {
858    /// Construct a new enum type with the given name and variants
859    ///
860    /// The caller is assumed to have already validated that `ty` is an integral type, and this
861    /// function will assert that this is the case.
862    pub fn new(
863        visibility: Visibility,
864        name: Ident,
865        ty: Type,
866        variants: impl IntoIterator<Item = Variant>,
867    ) -> Self {
868        assert!(ty.is_integer(), "only integer types are allowed in enum type definitions");
869        Self {
870            span: name.span(),
871            docs: None,
872            visibility,
873            name,
874            ty,
875            variants: Vec::from_iter(variants),
876        }
877    }
878
879    /// Adds documentation to this enum declaration.
880    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
881        self.docs = docs.map(DocString::new);
882        self
883    }
884
885    /// Override the default source span
886    pub fn with_span(mut self, span: SourceSpan) -> Self {
887        self.span = span;
888        self
889    }
890
891    /// Returns true if this is a C-style enum where the discriminant is the value
892    pub fn is_c_like(&self) -> bool {
893        !self.variants.is_empty() && self.variants.iter().all(|v| v.value_ty.is_none())
894    }
895
896    /// Set the source span
897    pub fn set_span(&mut self, span: SourceSpan) {
898        self.span = span;
899    }
900
901    /// Get the name of this enum type
902    pub fn name(&self) -> &Ident {
903        &self.name
904    }
905
906    /// Get the visibility of this enum type
907    pub const fn visibility(&self) -> Visibility {
908        self.visibility
909    }
910
911    /// Returns the documentation associated with this item.
912    pub fn docs(&self) -> Option<Span<&str>> {
913        self.docs.as_ref().map(|docstring| docstring.as_spanned_str())
914    }
915
916    /// Get the concrete type of this enum's variants
917    pub fn ty(&self) -> &Type {
918        &self.ty
919    }
920
921    /// Get the variants of this enum type
922    pub fn variants(&self) -> &[Variant] {
923        &self.variants
924    }
925
926    /// Get the variants of this enum type, mutably
927    pub fn variants_mut(&mut self) -> &mut Vec<Variant> {
928        &mut self.variants
929    }
930
931    /// Split this definition into its type alias and variant parts
932    pub fn into_parts(self) -> (TypeAlias, Vec<Variant>) {
933        let Self {
934            span,
935            docs,
936            visibility,
937            name,
938            ty,
939            variants,
940        } = self;
941        let alias = TypeAlias {
942            span,
943            docs,
944            visibility,
945            name,
946            ty: TypeExpr::Primitive(Span::new(span, ty)),
947        };
948        (alias, variants)
949    }
950}
951
952impl Spanned for EnumType {
953    fn span(&self) -> SourceSpan {
954        self.span
955    }
956}
957
958impl Eq for EnumType {}
959
960impl PartialEq for EnumType {
961    fn eq(&self, other: &Self) -> bool {
962        self.visibility == other.visibility
963            && self.name == other.name
964            && self.docs == other.docs
965            && self.ty == other.ty
966            && self.variants == other.variants
967    }
968}
969
970impl core::hash::Hash for EnumType {
971    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
972        let Self {
973            span: _,
974            docs,
975            visibility,
976            name,
977            ty,
978            variants,
979        } = self;
980        docs.hash(state);
981        visibility.hash(state);
982        name.hash(state);
983        ty.hash(state);
984        variants.hash(state);
985    }
986}
987
988impl crate::prettier::PrettyPrint for EnumType {
989    fn render(&self) -> crate::prettier::Document {
990        use crate::prettier::*;
991
992        let mut doc = self.docs.as_ref().map(PrettyPrint::render).unwrap_or(Document::Empty);
993
994        let variants = self
995            .variants
996            .iter()
997            .map(PrettyPrint::render)
998            .reduce(|acc, v| acc + const_text(",") + nl() + v)
999            .unwrap_or(Document::Empty);
1000
1001        if self.visibility.is_public() {
1002            doc += display(self.visibility) + const_text(" ");
1003        }
1004
1005        doc + const_text("enum")
1006            + const_text(" ")
1007            + display(&self.name)
1008            + const_text(" : ")
1009            + self.ty.render()
1010            + const_text(" {")
1011            + nl()
1012            + variants
1013            + const_text("}")
1014    }
1015}
1016
1017// ENUM VARIANT
1018// ================================================================================================
1019
1020/// A variant of an [EnumType].
1021///
1022/// See the [EnumType] docs for more information.
1023#[derive(Debug, Clone)]
1024pub struct Variant {
1025    pub span: SourceSpan,
1026    /// The documentation string attached to the constant derived from this variant.
1027    pub docs: Option<DocString>,
1028    /// The name of this enum variant
1029    pub name: Ident,
1030    /// The payload value type of this variant
1031    ///
1032    /// NOTE: This is not supported in Miden Assembly text format yet, but can be set when lowering
1033    /// directly to the AST.
1034    pub value_ty: Option<TypeExpr>,
1035    /// The discriminant value associated with this variant
1036    pub discriminant: ConstantExpr,
1037}
1038
1039impl Variant {
1040    /// Construct a new variant of an [EnumType], with the given name and discriminant value.
1041    pub fn new(name: Ident, discriminant: ConstantExpr, payload: Option<TypeExpr>) -> Self {
1042        Self {
1043            span: name.span(),
1044            docs: None,
1045            name,
1046            value_ty: payload,
1047            discriminant,
1048        }
1049    }
1050
1051    /// Override the span for this variant
1052    pub fn with_span(mut self, span: SourceSpan) -> Self {
1053        self.span = span;
1054        self
1055    }
1056
1057    /// Adds documentation to this variant
1058    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
1059        self.docs = docs.map(DocString::new);
1060        self
1061    }
1062
1063    /// Used to validate that this variant's discriminant value is an instance of `ty`,
1064    /// which must be a type valid for use as the underlying representation for an enum, i.e. an
1065    /// integer type up to 64 bits in size.
1066    ///
1067    /// It is expected that the discriminant expression has been folded to an integer value by the
1068    /// time this is called. If the discriminant has not been fully folded, then an error will be
1069    /// returned.
1070    pub fn assert_instance_of(&self, ty: &Type) -> Result<(), crate::SemanticAnalysisError> {
1071        use crate::{FIELD_MODULUS, SemanticAnalysisError};
1072
1073        let value = match &self.discriminant {
1074            ConstantExpr::Int(value) => value.as_int(),
1075            _ => {
1076                return Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1077                    span: self.discriminant.span(),
1078                    repr: ty.clone(),
1079                });
1080            },
1081        };
1082
1083        match ty {
1084            Type::Felt if value >= FIELD_MODULUS => {
1085                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1086                    span: self.discriminant.span(),
1087                    repr: ty.clone(),
1088                })
1089            },
1090            // IntValue is represented as an unsigned integer, so negative discriminants
1091            // are rejected during constant evaluation.
1092            Type::Felt => Ok(()),
1093            Type::I1 if value > 1 => Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1094                span: self.discriminant.span(),
1095                repr: ty.clone(),
1096            }),
1097            Type::I1 => Ok(()),
1098            Type::I8 | Type::U8 if value > u8::MAX as u64 => {
1099                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1100                    span: self.discriminant.span(),
1101                    repr: ty.clone(),
1102                })
1103            },
1104            Type::I8 | Type::U8 => Ok(()),
1105            Type::I16 | Type::U16 if value > u16::MAX as u64 => {
1106                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1107                    span: self.discriminant.span(),
1108                    repr: ty.clone(),
1109                })
1110            },
1111            Type::I16 | Type::U16 => Ok(()),
1112            Type::I32 | Type::U32 if value > u32::MAX as u64 => {
1113                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1114                    span: self.discriminant.span(),
1115                    repr: ty.clone(),
1116                })
1117            },
1118            Type::I32 | Type::U32 => Ok(()),
1119            Type::I64 | Type::U64 if value >= FIELD_MODULUS => {
1120                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1121                    span: self.discriminant.span(),
1122                    repr: ty.clone(),
1123                })
1124            },
1125            _ => Err(SemanticAnalysisError::InvalidEnumRepr { span: self.span }),
1126        }
1127    }
1128}
1129
1130impl Spanned for Variant {
1131    fn span(&self) -> SourceSpan {
1132        self.span
1133    }
1134}
1135
1136impl Eq for Variant {}
1137
1138impl PartialEq for Variant {
1139    fn eq(&self, other: &Self) -> bool {
1140        self.name == other.name
1141            && self.value_ty == other.value_ty
1142            && self.discriminant == other.discriminant
1143            && self.docs == other.docs
1144    }
1145}
1146
1147impl core::hash::Hash for Variant {
1148    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1149        let Self {
1150            span: _,
1151            docs,
1152            name,
1153            value_ty,
1154            discriminant,
1155        } = self;
1156        docs.hash(state);
1157        name.hash(state);
1158        value_ty.hash(state);
1159        discriminant.hash(state);
1160    }
1161}
1162
1163impl crate::prettier::PrettyPrint for Variant {
1164    fn render(&self) -> crate::prettier::Document {
1165        use crate::prettier::*;
1166
1167        let doc = self.docs.as_ref().map(PrettyPrint::render).unwrap_or(Document::Empty);
1168
1169        let name = display(&self.name);
1170        let name_and_payload = if let Some(value_ty) = self.value_ty.as_ref() {
1171            name + const_text("(") + value_ty.render() + const_text(")")
1172        } else {
1173            name
1174        };
1175        doc + name_and_payload + const_text(" = ") + self.discriminant.render()
1176    }
1177}
1178
1179#[cfg(test)]
1180mod tests {
1181    use alloc::sync::Arc;
1182    use core::str::FromStr;
1183
1184    use miden_debug_types::DefaultSourceManager;
1185
1186    use super::*;
1187
1188    struct DummyResolver {
1189        source_manager: Arc<dyn SourceManager>,
1190    }
1191
1192    impl DummyResolver {
1193        fn new() -> Self {
1194            Self {
1195                source_manager: Arc::new(DefaultSourceManager::default()),
1196            }
1197        }
1198    }
1199
1200    impl TypeResolver<SymbolResolutionError> for DummyResolver {
1201        fn source_manager(&self) -> Arc<dyn SourceManager> {
1202            self.source_manager.clone()
1203        }
1204
1205        fn resolve_local_failed(&self, err: SymbolResolutionError) -> SymbolResolutionError {
1206            err
1207        }
1208
1209        fn get_type(
1210            &mut self,
1211            context: SourceSpan,
1212            _gid: GlobalItemIndex,
1213        ) -> Result<Type, SymbolResolutionError> {
1214            Err(SymbolResolutionError::undefined(context, self.source_manager.as_ref()))
1215        }
1216
1217        fn get_local_type(
1218            &mut self,
1219            _context: SourceSpan,
1220            _id: ItemIndex,
1221        ) -> Result<Option<Type>, SymbolResolutionError> {
1222            Ok(None)
1223        }
1224
1225        fn resolve_type_ref(
1226            &mut self,
1227            ty: Span<&Path>,
1228        ) -> Result<SymbolResolution, SymbolResolutionError> {
1229            Err(SymbolResolutionError::undefined(ty.span(), self.source_manager.as_ref()))
1230        }
1231    }
1232
1233    fn nested_type_expr(depth: usize) -> TypeExpr {
1234        let mut expr = TypeExpr::Primitive(Span::unknown(Type::Felt));
1235        for i in 0..depth {
1236            expr = match i % 3 {
1237                0 => TypeExpr::Ptr(PointerType::new(expr)),
1238                1 => TypeExpr::Array(ArrayType::new(expr, 1)),
1239                _ => {
1240                    let field = StructField {
1241                        span: SourceSpan::UNKNOWN,
1242                        name: Ident::from_str("field").expect("valid ident"),
1243                        ty: expr,
1244                    };
1245                    TypeExpr::Struct(StructType::new(None, [field]))
1246                },
1247            };
1248        }
1249        expr
1250    }
1251
1252    #[test]
1253    fn type_expr_depth_boundary() {
1254        let mut resolver = DummyResolver::new();
1255
1256        let ok_expr = nested_type_expr(MAX_TYPE_EXPR_NESTING);
1257        assert!(ok_expr.resolve_type(&mut resolver).is_ok());
1258
1259        let err_expr = nested_type_expr(MAX_TYPE_EXPR_NESTING + 1);
1260        let err = err_expr.resolve_type(&mut resolver).expect_err("expected depth-exceeded error");
1261        assert!(
1262            matches!(err, SymbolResolutionError::TypeExpressionDepthExceeded { max_depth, .. }
1263                if max_depth == MAX_TYPE_EXPR_NESTING)
1264        );
1265    }
1266}