Skip to main content

miden_assembly_syntax/ast/
type.rs

1use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
2
3use miden_debug_types::{SourceManager, SourceSpan, Span, Spanned};
4pub use midenc_hir_type as types;
5use midenc_hir_type::{AddressSpace, Type, TypeRepr};
6
7use super::{
8    ConstantExpr, DocString, GlobalItemIndex, Ident, ItemIndex, Path, SymbolResolution,
9    SymbolResolutionError, Visibility,
10};
11
12/// Maximum allowed nesting depth of type expressions during resolution.
13///
14/// This limit is intended to prevent stack overflows from maliciously deep type expressions while
15/// remaining far above typical type nesting in real programs.
16const MAX_TYPE_EXPR_NESTING: usize = 256;
17
18/// Abstracts over resolving an item to a concrete [Type], using one of:
19///
20/// * A [GlobalItemIndex]
21/// * An [ItemIndex]
22/// * A [Path]
23/// * A [TypeExpr]
24///
25/// Since type resolution happens in two different contexts during assembly, this abstraction allows
26/// us to share more of the resolution logic in both places.
27///
28/// NOTE: Most methods of this trait take a mutable reference to the resolver, so that the resolver
29/// can mutate its own state as necessary during resolution (e.g. to manage a cache, or other side
30/// table-like data structures).
31pub trait TypeResolver<E> {
32    fn source_manager(&self) -> Arc<dyn SourceManager>;
33    /// Should be called by consumers of this resolver to convert a [SymbolResolutionError] to the
34    /// error type used by the [TypeResolver] implementation.
35    fn resolve_local_failed(&self, err: SymbolResolutionError) -> E;
36    /// Get the [Type] corresponding to the item given by `gid`
37    fn get_type(&mut self, context: SourceSpan, gid: GlobalItemIndex) -> Result<Type, E>;
38    /// Get the [Type] corresponding to the item in the current module given by `id`
39    fn get_local_type(&mut self, context: SourceSpan, id: ItemIndex) -> Result<Option<Type>, E>;
40    /// Attempt to resolve a symbol path, given by a `TypeExpr::Ref`, to an item
41    fn resolve_type_ref(&mut self, ty: Span<&Path>) -> Result<SymbolResolution, E>;
42    /// Resolve a [TypeExpr] to a concrete [Type]
43    fn resolve(&mut self, ty: &TypeExpr) -> Result<Option<Type>, E> {
44        ty.resolve_type(self)
45    }
46}
47
48// TYPE DECLARATION
49// ================================================================================================
50
51/// An abstraction over the different types of type declarations allowed in Miden Assembly
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub enum TypeDecl {
54    /// A named type, i.e. a type alias
55    Alias(TypeAlias),
56    /// A C-like enumeration type with associated constants
57    Enum(EnumType),
58}
59
60impl TypeDecl {
61    /// Adds documentation to this type alias
62    pub fn with_docs(self, docs: Option<Span<String>>) -> Self {
63        match self {
64            Self::Alias(ty) => Self::Alias(ty.with_docs(docs)),
65            Self::Enum(ty) => Self::Enum(ty.with_docs(docs)),
66        }
67    }
68
69    /// Get the name assigned to this type declaration
70    pub fn name(&self) -> &Ident {
71        match self {
72            Self::Alias(ty) => &ty.name,
73            Self::Enum(ty) => &ty.name,
74        }
75    }
76
77    /// Get the visibility of this type declaration
78    pub const fn visibility(&self) -> Visibility {
79        match self {
80            Self::Alias(ty) => ty.visibility,
81            Self::Enum(ty) => ty.visibility,
82        }
83    }
84
85    /// Get the documentation of this enum type
86    pub fn docs(&self) -> Option<Span<&str>> {
87        match self {
88            Self::Alias(ty) => ty.docs(),
89            Self::Enum(ty) => ty.docs(),
90        }
91    }
92
93    /// Get the type expression associated with this declaration
94    pub fn ty(&self) -> TypeExpr {
95        match self {
96            Self::Alias(ty) => ty.ty.clone(),
97            Self::Enum(ty) => TypeExpr::Primitive(Span::new(ty.span, ty.ty.clone())),
98        }
99    }
100}
101
102impl Spanned for TypeDecl {
103    fn span(&self) -> SourceSpan {
104        match self {
105            Self::Alias(spanned) => spanned.span,
106            Self::Enum(spanned) => spanned.span,
107        }
108    }
109}
110
111impl From<TypeAlias> for TypeDecl {
112    fn from(value: TypeAlias) -> Self {
113        Self::Alias(value)
114    }
115}
116
117impl From<EnumType> for TypeDecl {
118    fn from(value: EnumType) -> Self {
119        Self::Enum(value)
120    }
121}
122
123impl crate::prettier::PrettyPrint for TypeDecl {
124    fn render(&self) -> crate::prettier::Document {
125        match self {
126            Self::Alias(ty) => ty.render(),
127            Self::Enum(ty) => ty.render(),
128        }
129    }
130}
131
132// FUNCTION TYPE
133// ================================================================================================
134
135/// A procedure type signature
136#[derive(Debug, Clone)]
137pub struct FunctionType {
138    pub span: SourceSpan,
139    pub cc: types::CallConv,
140    pub args: Vec<TypeExpr>,
141    pub results: Vec<TypeExpr>,
142}
143
144impl Eq for FunctionType {}
145
146impl PartialEq for FunctionType {
147    fn eq(&self, other: &Self) -> bool {
148        self.cc == other.cc && self.args == other.args && self.results == other.results
149    }
150}
151
152impl core::hash::Hash for FunctionType {
153    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
154        self.cc.hash(state);
155        self.args.hash(state);
156        self.results.hash(state);
157    }
158}
159
160impl Spanned for FunctionType {
161    fn span(&self) -> SourceSpan {
162        self.span
163    }
164}
165
166impl FunctionType {
167    pub fn new(cc: types::CallConv, args: Vec<TypeExpr>, results: Vec<TypeExpr>) -> Self {
168        Self {
169            span: SourceSpan::UNKNOWN,
170            cc,
171            args,
172            results,
173        }
174    }
175
176    /// Override the default source span
177    #[inline]
178    pub fn with_span(mut self, span: SourceSpan) -> Self {
179        self.span = span;
180        self
181    }
182}
183
184impl crate::prettier::PrettyPrint for FunctionType {
185    fn render(&self) -> crate::prettier::Document {
186        use crate::prettier::*;
187
188        let singleline_args = self
189            .args
190            .iter()
191            .map(PrettyPrint::render)
192            .reduce(|acc, arg| acc + const_text(", ") + arg)
193            .unwrap_or(Document::Empty);
194        let multiline_args = indent(
195            4,
196            nl() + self
197                .args
198                .iter()
199                .map(PrettyPrint::render)
200                .reduce(|acc, arg| acc + const_text(",") + nl() + arg)
201                .unwrap_or(Document::Empty),
202        ) + nl();
203        let args = singleline_args | multiline_args;
204        let args = const_text("(") + args + const_text(")");
205
206        match self.results.len() {
207            0 => args,
208            1 => args + const_text(" -> ") + self.results[0].render(),
209            _ => {
210                let results = self
211                    .results
212                    .iter()
213                    .map(PrettyPrint::render)
214                    .reduce(|acc, r| acc + const_text(", ") + r)
215                    .unwrap_or(Document::Empty);
216                args + const_text(" -> ") + const_text("(") + results + const_text(")")
217            },
218        }
219    }
220}
221
222// TYPE EXPRESSION
223// ================================================================================================
224
225/// A syntax-level type expression (i.e. primitive type, reference to nominal type, etc.)
226#[derive(Debug, Clone, Eq, PartialEq, Hash)]
227pub enum TypeExpr {
228    /// A primitive integral type, e.g. `i1`, `u16`
229    Primitive(Span<Type>),
230    /// A pointer type expression, e.g. `*u8`
231    Ptr(PointerType),
232    /// An array type expression, e.g. `[u8; 32]`
233    Array(ArrayType),
234    /// A struct type expression, e.g. `struct { a: u32 }`
235    Struct(StructType),
236    /// A reference to a type aliased by name, e.g. `Foo`
237    Ref(Span<Arc<Path>>),
238}
239
240impl TypeExpr {
241    /// Set the name associated with this type expression, if applicable.
242    ///
243    /// Currently this just sets the name of struct types, but if we add other types with names in
244    /// the future, we can support them here.
245    pub fn set_name(&mut self, name: Ident) {
246        match self {
247            Self::Struct(struct_ty) => {
248                struct_ty.name = Some(name);
249            },
250            Self::Primitive(_) | Self::Ptr(_) | Self::Array(_) | Self::Ref(_) => (),
251        }
252    }
253
254    /// Get any references to other types present in this expression
255    pub fn references(&self) -> Vec<Span<Arc<Path>>> {
256        use alloc::collections::BTreeSet;
257
258        let mut worklist = smallvec::SmallVec::<[_; 4]>::from_slice(&[self]);
259        let mut references = BTreeSet::new();
260
261        while let Some(ty) = worklist.pop() {
262            match ty {
263                Self::Primitive(_) => {},
264                Self::Ptr(ty) => {
265                    worklist.push(&ty.pointee);
266                },
267                Self::Array(ty) => {
268                    worklist.push(&ty.elem);
269                },
270                Self::Struct(ty) => {
271                    for field in ty.fields.iter() {
272                        worklist.push(&field.ty);
273                    }
274                },
275                Self::Ref(ty) => {
276                    references.insert(ty.clone());
277                },
278            }
279        }
280
281        references.into_iter().collect()
282    }
283
284    /// Resolve this type expression to a concrete type, using `resolver`
285    pub fn resolve_type<E, R>(&self, resolver: &mut R) -> Result<Option<Type>, E>
286    where
287        R: ?Sized + TypeResolver<E>,
288    {
289        self.resolve_type_with_depth(resolver, 0)
290    }
291
292    fn resolve_type_with_depth<E, R>(
293        &self,
294        resolver: &mut R,
295        depth: usize,
296    ) -> Result<Option<Type>, E>
297    where
298        R: ?Sized + TypeResolver<E>,
299    {
300        if depth > MAX_TYPE_EXPR_NESTING {
301            let source_manager = resolver.source_manager();
302            return Err(resolver.resolve_local_failed(
303                SymbolResolutionError::type_expression_depth_exceeded(
304                    self.span(),
305                    MAX_TYPE_EXPR_NESTING,
306                    source_manager.as_ref(),
307                ),
308            ));
309        }
310
311        match self {
312            TypeExpr::Ref(path) => {
313                let mut current_path = path.clone();
314                loop {
315                    match resolver.resolve_type_ref(current_path.as_deref())? {
316                        SymbolResolution::Local(item) => {
317                            return resolver.get_local_type(current_path.span(), item.into_inner());
318                        },
319                        SymbolResolution::External(path) => {
320                            // We don't have a definition for this type yet
321                            if path == current_path {
322                                break Ok(None);
323                            }
324                            current_path = path;
325                        },
326                        SymbolResolution::Exact { gid, .. } => {
327                            return resolver.get_type(current_path.span(), gid).map(Some);
328                        },
329                        SymbolResolution::Module { path: module_path, .. } => {
330                            break Err(resolver.resolve_local_failed(
331                                SymbolResolutionError::invalid_symbol_type(
332                                    path.span(),
333                                    "type",
334                                    module_path.span(),
335                                    &resolver.source_manager(),
336                                ),
337                            ));
338                        },
339                        SymbolResolution::MastRoot(item) => {
340                            break Err(resolver.resolve_local_failed(
341                                SymbolResolutionError::invalid_symbol_type(
342                                    path.span(),
343                                    "type",
344                                    item.span(),
345                                    &resolver.source_manager(),
346                                ),
347                            ));
348                        },
349                    }
350                }
351            },
352            TypeExpr::Primitive(t) => Ok(Some(t.inner().clone())),
353            TypeExpr::Array(t) => Ok(t
354                .elem
355                .resolve_type_with_depth(resolver, depth + 1)?
356                .map(|elem| Type::Array(Arc::new(types::ArrayType::new(elem, t.arity))))),
357            TypeExpr::Ptr(ty) => Ok(ty
358                .pointee
359                .resolve_type_with_depth(resolver, depth + 1)?
360                .map(|pointee| Type::Ptr(Arc::new(types::PointerType::new(pointee))))),
361            TypeExpr::Struct(t) => {
362                let mut fields = Vec::with_capacity(t.fields.len());
363                for field in t.fields.iter() {
364                    let field_ty = field.ty.resolve_type_with_depth(resolver, depth + 1)?;
365                    if let Some(field_ty) = field_ty {
366                        fields.push((field.name.clone().into_inner(), field_ty));
367                    } else {
368                        return Ok(None);
369                    }
370                }
371                Ok(Some(Type::Struct(Arc::new(types::StructType::from_parts(
372                    t.name.clone().map(Ident::into_inner),
373                    t.repr.into_inner(),
374                    fields,
375                )))))
376            },
377        }
378    }
379}
380
381impl From<Type> for TypeExpr {
382    fn from(ty: Type) -> Self {
383        match ty {
384            Type::Array(t) => Self::Array(ArrayType::new(t.element_type().clone().into(), t.len())),
385            Type::Struct(t) => {
386                let name = t.name().and_then(|name| Ident::new(name.as_ref()).ok());
387                let fields = t.fields().iter().enumerate().map(|(i, ft)| {
388                    let name = ft
389                        .name
390                        .as_deref()
391                        .map(Ident::new)
392                        .and_then(Result::ok)
393                        .unwrap_or_else(|| Ident::new(format!("field{i}")).unwrap());
394                    StructField {
395                        span: SourceSpan::UNKNOWN,
396                        name,
397                        ty: ft.ty.clone().into(),
398                    }
399                });
400                Self::Struct(
401                    StructType::new(name, fields)
402                        .with_repr(Span::unknown(t.repr()))
403                        .with_span(SourceSpan::UNKNOWN),
404                )
405            },
406            Type::Ptr(t) => Self::Ptr((*t).clone().into()),
407            Type::Function(_) => {
408                Self::Ptr(PointerType::new(TypeExpr::Primitive(Span::unknown(Type::Felt))))
409            },
410            Type::List(t) => Self::Ptr(
411                PointerType::new((*t).clone().into()).with_address_space(AddressSpace::Byte),
412            ),
413            Type::Unknown | Type::Never | Type::F64 => panic!("unrepresentable type value: {ty}"),
414            ty => Self::Primitive(Span::unknown(ty)),
415        }
416    }
417}
418
419impl Spanned for TypeExpr {
420    fn span(&self) -> SourceSpan {
421        match self {
422            Self::Primitive(spanned) => spanned.span(),
423            Self::Ptr(spanned) => spanned.span(),
424            Self::Array(spanned) => spanned.span(),
425            Self::Struct(spanned) => spanned.span(),
426            Self::Ref(spanned) => spanned.span(),
427        }
428    }
429}
430
431impl crate::prettier::PrettyPrint for TypeExpr {
432    fn render(&self) -> crate::prettier::Document {
433        use crate::prettier::*;
434
435        match self {
436            Self::Primitive(ty) => display(ty),
437            Self::Ptr(ty) => ty.render(),
438            Self::Array(ty) => ty.render(),
439            Self::Struct(ty) => ty.render(),
440            Self::Ref(ty) => display(ty),
441        }
442    }
443}
444
445// POINTER TYPE
446// ================================================================================================
447
448#[derive(Debug, Clone)]
449pub struct PointerType {
450    pub span: SourceSpan,
451    pub pointee: Box<TypeExpr>,
452    addrspace: Option<AddressSpace>,
453}
454
455impl From<types::PointerType> for PointerType {
456    fn from(ty: types::PointerType) -> Self {
457        let types::PointerType { addrspace, pointee } = ty;
458        let pointee = Box::new(TypeExpr::from(pointee));
459        Self {
460            span: SourceSpan::UNKNOWN,
461            pointee,
462            addrspace: Some(addrspace),
463        }
464    }
465}
466
467impl Eq for PointerType {}
468
469impl PartialEq for PointerType {
470    fn eq(&self, other: &Self) -> bool {
471        self.address_space() == other.address_space() && self.pointee == other.pointee
472    }
473}
474
475impl core::hash::Hash for PointerType {
476    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
477        self.pointee.hash(state);
478        self.address_space().hash(state);
479    }
480}
481
482impl Spanned for PointerType {
483    fn span(&self) -> SourceSpan {
484        self.span
485    }
486}
487
488impl PointerType {
489    pub fn new(pointee: TypeExpr) -> Self {
490        Self {
491            span: SourceSpan::UNKNOWN,
492            pointee: Box::new(pointee),
493            addrspace: None,
494        }
495    }
496
497    /// Override the default source span
498    #[inline]
499    pub fn with_span(mut self, span: SourceSpan) -> Self {
500        self.span = span;
501        self
502    }
503
504    /// Override the default address space
505    #[inline]
506    pub fn with_address_space(mut self, addrspace: AddressSpace) -> Self {
507        self.addrspace = Some(addrspace);
508        self
509    }
510
511    /// Get the address space of this pointer type
512    #[inline]
513    pub fn address_space(&self) -> AddressSpace {
514        self.addrspace.unwrap_or(AddressSpace::Element)
515    }
516}
517
518impl crate::prettier::PrettyPrint for PointerType {
519    fn render(&self) -> crate::prettier::Document {
520        use crate::prettier::*;
521
522        let doc = const_text("ptr<") + self.pointee.render();
523        if let Some(addrspace) = self.addrspace.as_ref() {
524            doc + const_text(", ") + text(format!("addrspace({addrspace})")) + const_text(">")
525        } else {
526            doc + const_text(">")
527        }
528    }
529}
530
531// ARRAY TYPE
532// ================================================================================================
533
534#[derive(Debug, Clone)]
535pub struct ArrayType {
536    pub span: SourceSpan,
537    pub elem: Box<TypeExpr>,
538    pub arity: usize,
539}
540
541impl Eq for ArrayType {}
542
543impl PartialEq for ArrayType {
544    fn eq(&self, other: &Self) -> bool {
545        self.arity == other.arity && self.elem == other.elem
546    }
547}
548
549impl core::hash::Hash for ArrayType {
550    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
551        self.elem.hash(state);
552        self.arity.hash(state);
553    }
554}
555
556impl Spanned for ArrayType {
557    fn span(&self) -> SourceSpan {
558        self.span
559    }
560}
561
562impl ArrayType {
563    pub fn new(elem: TypeExpr, arity: usize) -> Self {
564        Self {
565            span: SourceSpan::UNKNOWN,
566            elem: Box::new(elem),
567            arity,
568        }
569    }
570
571    /// Override the default source span
572    #[inline]
573    pub fn with_span(mut self, span: SourceSpan) -> Self {
574        self.span = span;
575        self
576    }
577}
578
579impl crate::prettier::PrettyPrint for ArrayType {
580    fn render(&self) -> crate::prettier::Document {
581        use crate::prettier::*;
582
583        const_text("[")
584            + self.elem.render()
585            + const_text("; ")
586            + display(self.arity)
587            + const_text("]")
588    }
589}
590
591// STRUCT TYPE
592// ================================================================================================
593
594#[derive(Debug, Clone)]
595pub struct StructType {
596    pub span: SourceSpan,
597    pub name: Option<Ident>,
598    pub repr: Span<TypeRepr>,
599    pub fields: Vec<StructField>,
600}
601
602impl Eq for StructType {}
603
604impl PartialEq for StructType {
605    fn eq(&self, other: &Self) -> bool {
606        self.name == other.name && self.repr == other.repr && self.fields == other.fields
607    }
608}
609
610impl core::hash::Hash for StructType {
611    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
612        self.name.hash(state);
613        self.repr.hash(state);
614        self.fields.hash(state);
615    }
616}
617
618impl Spanned for StructType {
619    fn span(&self) -> SourceSpan {
620        self.span
621    }
622}
623
624impl StructType {
625    pub fn new(name: Option<Ident>, fields: impl IntoIterator<Item = StructField>) -> Self {
626        Self {
627            span: SourceSpan::UNKNOWN,
628            name,
629            repr: Span::unknown(TypeRepr::Default),
630            fields: fields.into_iter().collect(),
631        }
632    }
633
634    /// Override the default struct representation
635    #[inline]
636    pub fn with_repr(mut self, repr: Span<TypeRepr>) -> Self {
637        self.repr = repr;
638        self
639    }
640
641    /// Override the default source span
642    #[inline]
643    pub fn with_span(mut self, span: SourceSpan) -> Self {
644        self.span = span;
645        self
646    }
647}
648
649impl crate::prettier::PrettyPrint for StructType {
650    fn render(&self) -> crate::prettier::Document {
651        use crate::prettier::*;
652
653        let repr = match &*self.repr {
654            TypeRepr::Default => Document::Empty,
655            TypeRepr::BigEndian => const_text(" @bigendian"),
656            repr @ (TypeRepr::Align(_) | TypeRepr::Packed(_) | TypeRepr::Transparent) => {
657                text(format!(" @{repr}"))
658            },
659        };
660
661        let singleline_body = self
662            .fields
663            .iter()
664            .map(PrettyPrint::render)
665            .reduce(|acc, field| acc + const_text(", ") + field)
666            .unwrap_or(Document::Empty);
667        let multiline_body = indent(
668            4,
669            nl() + self
670                .fields
671                .iter()
672                .map(PrettyPrint::render)
673                .reduce(|acc, field| acc + const_text(",") + nl() + field)
674                .unwrap_or(Document::Empty),
675        ) + nl();
676        let body = singleline_body | multiline_body;
677
678        const_text("struct") + repr + const_text(" { ") + body + const_text(" }")
679    }
680}
681
682// STRUCT FIELD
683// ================================================================================================
684
685#[derive(Debug, Clone)]
686pub struct StructField {
687    pub span: SourceSpan,
688    pub name: Ident,
689    pub ty: TypeExpr,
690}
691
692impl Eq for StructField {}
693
694impl PartialEq for StructField {
695    fn eq(&self, other: &Self) -> bool {
696        self.name == other.name && self.ty == other.ty
697    }
698}
699
700impl core::hash::Hash for StructField {
701    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
702        self.name.hash(state);
703        self.ty.hash(state);
704    }
705}
706
707impl Spanned for StructField {
708    fn span(&self) -> SourceSpan {
709        self.span
710    }
711}
712
713impl crate::prettier::PrettyPrint for StructField {
714    fn render(&self) -> crate::prettier::Document {
715        use crate::prettier::*;
716
717        display(&self.name) + const_text(": ") + self.ty.render()
718    }
719}
720
721// TYPE ALIAS
722// ================================================================================================
723
724/// A [TypeAlias] represents a named [Type].
725///
726/// Type aliases correspond to type declarations in Miden Assembly source files. They are called
727/// aliases, rather than declarations, as the type system for Miden Assembly is structural, rather
728/// than nominal, and so two aliases with the same underlying type are considered equivalent.
729#[derive(Debug, Clone)]
730pub struct TypeAlias {
731    span: SourceSpan,
732    /// The documentation string attached to this definition.
733    docs: Option<DocString>,
734    /// The visibility of this type alias
735    pub visibility: Visibility,
736    /// The name of this type alias
737    pub name: Ident,
738    /// The concrete underlying type
739    pub ty: TypeExpr,
740}
741
742impl TypeAlias {
743    /// Create a new type alias from a name and type
744    pub fn new(visibility: Visibility, name: Ident, ty: TypeExpr) -> Self {
745        Self {
746            span: name.span(),
747            docs: None,
748            visibility,
749            name,
750            ty,
751        }
752    }
753
754    /// Adds documentation to this type alias
755    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
756        self.docs = docs.map(DocString::new);
757        self
758    }
759
760    /// Override the default source span
761    #[inline]
762    pub fn with_span(mut self, span: SourceSpan) -> Self {
763        self.span = span;
764        self
765    }
766
767    /// Set the source span
768    #[inline]
769    pub fn set_span(&mut self, span: SourceSpan) {
770        self.span = span;
771    }
772
773    /// Returns the documentation associated with this item.
774    pub fn docs(&self) -> Option<Span<&str>> {
775        self.docs.as_ref().map(|docstring| docstring.as_spanned_str())
776    }
777
778    /// Get the name of this type alias
779    pub fn name(&self) -> &Ident {
780        &self.name
781    }
782
783    /// Get the visibility of this type alias
784    #[inline]
785    pub const fn visibility(&self) -> Visibility {
786        self.visibility
787    }
788}
789
790impl Eq for TypeAlias {}
791
792impl PartialEq for TypeAlias {
793    fn eq(&self, other: &Self) -> bool {
794        self.visibility == other.visibility
795            && self.name == other.name
796            && self.docs == other.docs
797            && self.ty == other.ty
798    }
799}
800
801impl core::hash::Hash for TypeAlias {
802    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
803        let Self { span: _, docs, visibility, name, ty } = self;
804        docs.hash(state);
805        visibility.hash(state);
806        name.hash(state);
807        ty.hash(state);
808    }
809}
810
811impl Spanned for TypeAlias {
812    fn span(&self) -> SourceSpan {
813        self.span
814    }
815}
816
817impl crate::prettier::PrettyPrint for TypeAlias {
818    fn render(&self) -> crate::prettier::Document {
819        use crate::prettier::*;
820
821        let mut doc = self.docs.as_ref().map(PrettyPrint::render).unwrap_or(Document::Empty);
822
823        if self.visibility.is_public() {
824            doc += display(self.visibility) + const_text(" ");
825        }
826
827        doc + const_text("type")
828            + const_text(" ")
829            + display(&self.name)
830            + const_text(" = ")
831            + self.ty.render()
832    }
833}
834
835// ENUM TYPE
836// ================================================================================================
837
838/// A combined type alias and constant declaration corresponding to a C-like enumeration.
839///
840/// C-style enumerations are effectively a type alias for an integer type with a limited set of
841/// valid values with associated names (referred to as _variants_ of the enum type).
842///
843/// In Miden Assembly, these provide a means for a procedure to declare that it expects an argument
844/// of the underlying integral type, but that values other than those of the declared variants are
845/// illegal/invalid. Currently, these are unchecked, and are only used to convey semantic
846/// information. In the future, we may perform static analysis to try and identify invalid instances
847/// of the enumeration when derived from a constant.
848#[derive(Debug, Clone)]
849pub struct EnumType {
850    span: SourceSpan,
851    /// The documentation string attached to this definition.
852    docs: Option<DocString>,
853    /// The visibility of this enum type
854    visibility: Visibility,
855    /// The enum name
856    name: Ident,
857    /// The type of the discriminant value used for this enum's variants
858    ///
859    /// NOTE: The type must be an integral value, and this is enforced by [`Self::new`].
860    ty: Type,
861    /// The enum variants
862    variants: Vec<Variant>,
863}
864
865impl EnumType {
866    /// Construct a new enum type with the given name and variants
867    ///
868    /// The caller is assumed to have already validated that `ty` is an integral type, and this
869    /// function will assert that this is the case.
870    pub fn new(
871        visibility: Visibility,
872        name: Ident,
873        ty: Type,
874        variants: impl IntoIterator<Item = Variant>,
875    ) -> Self {
876        assert!(ty.is_integer(), "only integer types are allowed in enum type definitions");
877        Self {
878            span: name.span(),
879            docs: None,
880            visibility,
881            name,
882            ty,
883            variants: Vec::from_iter(variants),
884        }
885    }
886
887    /// Adds documentation to this enum declaration.
888    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
889        self.docs = docs.map(DocString::new);
890        self
891    }
892
893    /// Override the default source span
894    pub fn with_span(mut self, span: SourceSpan) -> Self {
895        self.span = span;
896        self
897    }
898
899    /// Returns true if this is a C-style enum where the discriminant is the value
900    pub fn is_c_like(&self) -> bool {
901        !self.variants.is_empty() && self.variants.iter().all(|v| v.value_ty.is_none())
902    }
903
904    /// Set the source span
905    pub fn set_span(&mut self, span: SourceSpan) {
906        self.span = span;
907    }
908
909    /// Get the name of this enum type
910    pub fn name(&self) -> &Ident {
911        &self.name
912    }
913
914    /// Get the visibility of this enum type
915    pub const fn visibility(&self) -> Visibility {
916        self.visibility
917    }
918
919    /// Returns the documentation associated with this item.
920    pub fn docs(&self) -> Option<Span<&str>> {
921        self.docs.as_ref().map(|docstring| docstring.as_spanned_str())
922    }
923
924    /// Get the concrete type of this enum's variants
925    pub fn ty(&self) -> &Type {
926        &self.ty
927    }
928
929    /// Get the variants of this enum type
930    pub fn variants(&self) -> &[Variant] {
931        &self.variants
932    }
933
934    /// Get the variants of this enum type, mutably
935    pub fn variants_mut(&mut self) -> &mut Vec<Variant> {
936        &mut self.variants
937    }
938
939    /// Split this definition into its type alias and variant parts
940    pub fn into_parts(self) -> (TypeAlias, Vec<Variant>) {
941        let Self {
942            span,
943            docs,
944            visibility,
945            name,
946            ty,
947            variants,
948        } = self;
949        let alias = TypeAlias {
950            span,
951            docs,
952            visibility,
953            name,
954            ty: TypeExpr::Primitive(Span::new(span, ty)),
955        };
956        (alias, variants)
957    }
958}
959
960impl Spanned for EnumType {
961    fn span(&self) -> SourceSpan {
962        self.span
963    }
964}
965
966impl Eq for EnumType {}
967
968impl PartialEq for EnumType {
969    fn eq(&self, other: &Self) -> bool {
970        self.visibility == other.visibility
971            && self.name == other.name
972            && self.docs == other.docs
973            && self.ty == other.ty
974            && self.variants == other.variants
975    }
976}
977
978impl core::hash::Hash for EnumType {
979    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
980        let Self {
981            span: _,
982            docs,
983            visibility,
984            name,
985            ty,
986            variants,
987        } = self;
988        docs.hash(state);
989        visibility.hash(state);
990        name.hash(state);
991        ty.hash(state);
992        variants.hash(state);
993    }
994}
995
996impl crate::prettier::PrettyPrint for EnumType {
997    fn render(&self) -> crate::prettier::Document {
998        use crate::prettier::*;
999
1000        let mut doc = self.docs.as_ref().map(PrettyPrint::render).unwrap_or(Document::Empty);
1001
1002        let variants = self
1003            .variants
1004            .iter()
1005            .map(PrettyPrint::render)
1006            .reduce(|acc, v| acc + const_text(",") + nl() + v)
1007            .unwrap_or(Document::Empty);
1008
1009        if self.visibility.is_public() {
1010            doc += display(self.visibility) + const_text(" ");
1011        }
1012
1013        doc + const_text("enum")
1014            + const_text(" ")
1015            + display(&self.name)
1016            + const_text(" : ")
1017            + self.ty.render()
1018            + const_text(" {")
1019            + nl()
1020            + variants
1021            + const_text("}")
1022    }
1023}
1024
1025// ENUM VARIANT
1026// ================================================================================================
1027
1028/// A variant of an [EnumType].
1029///
1030/// See the [EnumType] docs for more information.
1031#[derive(Debug, Clone)]
1032pub struct Variant {
1033    pub span: SourceSpan,
1034    /// The documentation string attached to the constant derived from this variant.
1035    pub docs: Option<DocString>,
1036    /// The name of this enum variant
1037    pub name: Ident,
1038    /// The payload value type of this variant
1039    ///
1040    /// NOTE: This is not supported in Miden Assembly text format yet, but can be set when lowering
1041    /// directly to the AST.
1042    pub value_ty: Option<TypeExpr>,
1043    /// The discriminant value associated with this variant
1044    pub discriminant: ConstantExpr,
1045}
1046
1047impl Variant {
1048    /// Construct a new variant of an [EnumType], with the given name and discriminant value.
1049    pub fn new(name: Ident, discriminant: ConstantExpr, payload: Option<TypeExpr>) -> Self {
1050        Self {
1051            span: name.span(),
1052            docs: None,
1053            name,
1054            value_ty: payload,
1055            discriminant,
1056        }
1057    }
1058
1059    /// Override the span for this variant
1060    pub fn with_span(mut self, span: SourceSpan) -> Self {
1061        self.span = span;
1062        self
1063    }
1064
1065    /// Adds documentation to this variant
1066    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
1067        self.docs = docs.map(DocString::new);
1068        self
1069    }
1070
1071    /// Used to validate that this variant's discriminant value is an instance of `ty`,
1072    /// which must be a type valid for use as the underlying representation for an enum, i.e. an
1073    /// integer type up to 64 bits in size.
1074    ///
1075    /// It is expected that the discriminant expression has been folded to an integer value by the
1076    /// time this is called. If the discriminant has not been fully folded, then an error will be
1077    /// returned.
1078    pub fn assert_instance_of(&self, ty: &Type) -> Result<(), crate::SemanticAnalysisError> {
1079        use crate::{FIELD_MODULUS, SemanticAnalysisError};
1080
1081        let value = match &self.discriminant {
1082            ConstantExpr::Int(value) => value.as_int(),
1083            _ => {
1084                return Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1085                    span: self.discriminant.span(),
1086                    repr: ty.clone(),
1087                });
1088            },
1089        };
1090
1091        match ty {
1092            Type::Felt if value >= FIELD_MODULUS => {
1093                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1094                    span: self.discriminant.span(),
1095                    repr: ty.clone(),
1096                })
1097            },
1098            // IntValue is represented as an unsigned integer, so negative discriminants
1099            // are rejected during constant evaluation.
1100            Type::Felt => Ok(()),
1101            Type::I1 if value > 1 => Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1102                span: self.discriminant.span(),
1103                repr: ty.clone(),
1104            }),
1105            Type::I1 => Ok(()),
1106            Type::I8 | Type::U8 if value > u8::MAX as u64 => {
1107                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1108                    span: self.discriminant.span(),
1109                    repr: ty.clone(),
1110                })
1111            },
1112            Type::I8 | Type::U8 => Ok(()),
1113            Type::I16 | Type::U16 if value > u16::MAX as u64 => {
1114                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1115                    span: self.discriminant.span(),
1116                    repr: ty.clone(),
1117                })
1118            },
1119            Type::I16 | Type::U16 => Ok(()),
1120            Type::I32 | Type::U32 if value > u32::MAX as u64 => {
1121                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1122                    span: self.discriminant.span(),
1123                    repr: ty.clone(),
1124                })
1125            },
1126            Type::I32 | Type::U32 => Ok(()),
1127            Type::I64 | Type::U64 if value >= FIELD_MODULUS => {
1128                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1129                    span: self.discriminant.span(),
1130                    repr: ty.clone(),
1131                })
1132            },
1133            _ => Err(SemanticAnalysisError::InvalidEnumRepr { span: self.span }),
1134        }
1135    }
1136}
1137
1138impl Spanned for Variant {
1139    fn span(&self) -> SourceSpan {
1140        self.span
1141    }
1142}
1143
1144impl Eq for Variant {}
1145
1146impl PartialEq for Variant {
1147    fn eq(&self, other: &Self) -> bool {
1148        self.name == other.name
1149            && self.value_ty == other.value_ty
1150            && self.discriminant == other.discriminant
1151            && self.docs == other.docs
1152    }
1153}
1154
1155impl core::hash::Hash for Variant {
1156    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1157        let Self {
1158            span: _,
1159            docs,
1160            name,
1161            value_ty,
1162            discriminant,
1163        } = self;
1164        docs.hash(state);
1165        name.hash(state);
1166        value_ty.hash(state);
1167        discriminant.hash(state);
1168    }
1169}
1170
1171impl crate::prettier::PrettyPrint for Variant {
1172    fn render(&self) -> crate::prettier::Document {
1173        use crate::prettier::*;
1174
1175        let doc = self.docs.as_ref().map(PrettyPrint::render).unwrap_or(Document::Empty);
1176
1177        let name = display(&self.name);
1178        let name_and_payload = if let Some(value_ty) = self.value_ty.as_ref() {
1179            name + const_text("(") + value_ty.render() + const_text(")")
1180        } else {
1181            name
1182        };
1183        doc + name_and_payload + const_text(" = ") + self.discriminant.render()
1184    }
1185}
1186
1187#[cfg(test)]
1188mod tests {
1189    use alloc::{string::ToString, sync::Arc};
1190    use core::str::FromStr;
1191
1192    use miden_debug_types::{DefaultSourceManager, SourceFile, SourceId, SourceLanguage, Uri};
1193
1194    use super::*;
1195    use crate::{ast::Form, prettier::PrettyPrint};
1196
1197    struct DummyResolver {
1198        source_manager: Arc<dyn SourceManager>,
1199    }
1200
1201    impl DummyResolver {
1202        fn new() -> Self {
1203            Self {
1204                source_manager: Arc::new(DefaultSourceManager::default()),
1205            }
1206        }
1207    }
1208
1209    impl TypeResolver<SymbolResolutionError> for DummyResolver {
1210        fn source_manager(&self) -> Arc<dyn SourceManager> {
1211            self.source_manager.clone()
1212        }
1213
1214        fn resolve_local_failed(&self, err: SymbolResolutionError) -> SymbolResolutionError {
1215            err
1216        }
1217
1218        fn get_type(
1219            &mut self,
1220            context: SourceSpan,
1221            _gid: GlobalItemIndex,
1222        ) -> Result<Type, SymbolResolutionError> {
1223            Err(SymbolResolutionError::undefined(context, self.source_manager.as_ref()))
1224        }
1225
1226        fn get_local_type(
1227            &mut self,
1228            _context: SourceSpan,
1229            _id: ItemIndex,
1230        ) -> Result<Option<Type>, SymbolResolutionError> {
1231            Ok(None)
1232        }
1233
1234        fn resolve_type_ref(
1235            &mut self,
1236            ty: Span<&Path>,
1237        ) -> Result<SymbolResolution, SymbolResolutionError> {
1238            Err(SymbolResolutionError::undefined(ty.span(), self.source_manager.as_ref()))
1239        }
1240    }
1241
1242    fn nested_type_expr(depth: usize) -> TypeExpr {
1243        let mut expr = TypeExpr::Primitive(Span::unknown(Type::Felt));
1244        for i in 0..depth {
1245            expr = match i % 3 {
1246                0 => TypeExpr::Ptr(PointerType::new(expr)),
1247                1 => TypeExpr::Array(ArrayType::new(expr, 1)),
1248                _ => {
1249                    let field = StructField {
1250                        span: SourceSpan::UNKNOWN,
1251                        name: Ident::from_str("field").expect("valid ident"),
1252                        ty: expr,
1253                    };
1254                    TypeExpr::Struct(StructType::new(None, [field]))
1255                },
1256            };
1257        }
1258        expr
1259    }
1260
1261    fn test_source_file(source: &str) -> Arc<SourceFile> {
1262        Arc::new(SourceFile::new(
1263            SourceId::default(),
1264            SourceLanguage::Masm,
1265            Uri::new("memory:///type-expr-test.masm"),
1266            source.to_string().into_boxed_str(),
1267        ))
1268    }
1269
1270    fn parse_type_alias_expr(source: &str) -> TypeExpr {
1271        let mut forms =
1272            crate::parser::parse_forms(test_source_file(source)).expect("type alias should parse");
1273        assert_eq!(forms.len(), 1, "expected exactly one parsed form");
1274        match forms.pop().expect("expected parsed form") {
1275            Form::Type(alias) => alias.ty,
1276            form => panic!("expected type alias form, got {form:?}"),
1277        }
1278    }
1279
1280    fn repr_round_trip_struct(repr: TypeRepr) -> TypeExpr {
1281        TypeExpr::Struct(
1282            StructType::new(
1283                None,
1284                [
1285                    StructField {
1286                        span: SourceSpan::UNKNOWN,
1287                        name: Ident::from_str("prefix").expect("valid ident"),
1288                        ty: TypeExpr::Primitive(Span::unknown(Type::Felt)),
1289                    },
1290                    StructField {
1291                        span: SourceSpan::UNKNOWN,
1292                        name: Ident::from_str("suffix").expect("valid ident"),
1293                        ty: TypeExpr::Primitive(Span::unknown(Type::U32)),
1294                    },
1295                ],
1296            )
1297            .with_repr(Span::unknown(repr)),
1298        )
1299    }
1300
1301    #[test]
1302    fn type_expr_depth_boundary() {
1303        let mut resolver = DummyResolver::new();
1304
1305        let ok_expr = nested_type_expr(MAX_TYPE_EXPR_NESTING);
1306        assert!(ok_expr.resolve_type(&mut resolver).is_ok());
1307
1308        let err_expr = nested_type_expr(MAX_TYPE_EXPR_NESTING + 1);
1309        let err = err_expr.resolve_type(&mut resolver).expect_err("expected depth-exceeded error");
1310        assert!(
1311            matches!(err, SymbolResolutionError::TypeExpressionDepthExceeded { max_depth, .. }
1312                if max_depth == MAX_TYPE_EXPR_NESTING)
1313        );
1314    }
1315
1316    #[test]
1317    fn struct_type_expr_render_round_trips_non_default_reprs() {
1318        for repr in [
1319            TypeRepr::BigEndian,
1320            TypeRepr::align(16),
1321            TypeRepr::packed(1),
1322            TypeRepr::packed(2),
1323            TypeRepr::Transparent,
1324        ] {
1325            let rendered = repr_round_trip_struct(repr).to_pretty_string();
1326            assert!(
1327                rendered.starts_with("struct @"),
1328                "non-default struct repr should render after `struct`: {rendered}"
1329            );
1330
1331            let parsed = parse_type_alias_expr(&format!("type RoundTrip = {rendered}\n"));
1332            let TypeExpr::Struct(parsed) = parsed else {
1333                panic!("expected rendered type to parse back as a struct");
1334            };
1335            assert_eq!(*parsed.repr, repr);
1336            assert_eq!(parsed.fields[0].name.as_str(), "prefix");
1337            assert_eq!(parsed.fields[1].name.as_str(), "suffix");
1338        }
1339    }
1340
1341    #[test]
1342    fn type_expr_from_type_preserves_wide_integer_primitives() {
1343        for ty in [Type::I64, Type::U64, Type::I128, Type::U128] {
1344            let expr = TypeExpr::from(ty.clone());
1345            let TypeExpr::Primitive(actual) = expr else {
1346                panic!("expected primitive type expression for {ty}, got {expr:?}");
1347            };
1348            assert_eq!(actual.into_inner(), ty);
1349        }
1350    }
1351
1352    #[test]
1353    fn type_expr_from_type_preserves_struct_metadata() {
1354        let ty = Type::Struct(Arc::new(types::StructType::from_parts(
1355            Some(Arc::from("miden:base/core-types@1.0.0/account-id")),
1356            TypeRepr::BigEndian,
1357            [
1358                (Arc::<str>::from("prefix"), Type::Felt),
1359                (Arc::<str>::from("suffix"), Type::Felt),
1360            ],
1361        )));
1362
1363        let TypeExpr::Struct(actual) = TypeExpr::from(ty) else {
1364            panic!("expected struct type expression");
1365        };
1366        assert_eq!(
1367            actual.name.as_ref().map(Ident::as_str),
1368            Some("miden:base/core-types@1.0.0/account-id"),
1369        );
1370        assert_eq!(*actual.repr, TypeRepr::BigEndian);
1371        assert_eq!(actual.fields[0].name.as_str(), "prefix");
1372        assert_eq!(actual.fields[1].name.as_str(), "suffix");
1373    }
1374
1375    #[test]
1376    fn parsed_struct_type_preserves_field_names_through_resolution() {
1377        let expr = parse_type_alias_expr(
1378            "type AccountId = struct @bigendian { prefix: felt, suffix: felt }\n",
1379        );
1380
1381        let mut resolver = DummyResolver::new();
1382        let resolved = expr
1383            .resolve_type(&mut resolver)
1384            .expect("struct type should resolve")
1385            .expect("struct type should be concrete");
1386        let Type::Struct(resolved_struct) = &resolved else {
1387            panic!("expected resolved struct type, got {resolved:?}");
1388        };
1389        assert_eq!(resolved_struct.repr(), TypeRepr::BigEndian);
1390        assert_eq!(resolved_struct.fields()[0].name.as_deref(), Some("prefix"));
1391        assert_eq!(resolved_struct.fields()[1].name.as_deref(), Some("suffix"));
1392
1393        let TypeExpr::Struct(converted) = TypeExpr::from(resolved) else {
1394            panic!("expected concrete struct to convert back to struct type expression");
1395        };
1396        assert_eq!(*converted.repr, TypeRepr::BigEndian);
1397        assert_eq!(converted.fields[0].name.as_str(), "prefix");
1398        assert_eq!(converted.fields[1].name.as_str(), "suffix");
1399    }
1400}