Skip to main content

miden_assembly_syntax/ast/
type.rs

1use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
2
3use miden_debug_types::{SourceManager, SourceSpan, Span, Spanned};
4pub use midenc_hir_type as types;
5use midenc_hir_type::{AddressSpace, Type, TypeRepr};
6
7use super::{
8    ConstantExpr, DocString, GlobalItemIndex, Ident, ItemIndex, Path, SymbolResolution,
9    SymbolResolutionError, Visibility,
10};
11
12/// Maximum allowed nesting depth of type expressions during resolution.
13///
14/// This limit is intended to prevent stack overflows from maliciously deep type expressions while
15/// remaining far above typical type nesting in real programs.
16const MAX_TYPE_EXPR_NESTING: usize = 256;
17
18/// Abstracts over resolving an item to a concrete [Type], using one of:
19///
20/// * A [GlobalItemIndex]
21/// * An [ItemIndex]
22/// * A [Path]
23/// * A [TypeExpr]
24///
25/// Since type resolution happens in two different contexts during assembly, this abstraction allows
26/// us to share more of the resolution logic in both places.
27///
28/// NOTE: Most methods of this trait take a mutable reference to the resolver, so that the resolver
29/// can mutate its own state as necessary during resolution (e.g. to manage a cache, or other side
30/// table-like data structures).
31pub trait TypeResolver<E> {
32    fn source_manager(&self) -> Arc<dyn SourceManager>;
33    /// Should be called by consumers of this resolver to convert a [SymbolResolutionError] to the
34    /// error type used by the [TypeResolver] implementation.
35    fn resolve_local_failed(&self, err: SymbolResolutionError) -> E;
36    /// Get the [Type] corresponding to the item given by `gid`
37    fn get_type(&mut self, context: SourceSpan, gid: GlobalItemIndex) -> Result<Type, E>;
38    /// Get the [Type] corresponding to the item in the current module given by `id`
39    fn get_local_type(&mut self, context: SourceSpan, id: ItemIndex) -> Result<Option<Type>, E>;
40    /// Attempt to resolve a symbol path, given by a `TypeExpr::Ref`, to an item
41    fn resolve_type_ref(&mut self, ty: Span<&Path>) -> Result<SymbolResolution, E>;
42    /// Resolve a [TypeExpr] to a concrete [Type]
43    fn resolve(&mut self, ty: &TypeExpr) -> Result<Option<Type>, E> {
44        ty.resolve_type(self)
45    }
46}
47
48// TYPE DECLARATION
49// ================================================================================================
50
51/// An abstraction over the different types of type declarations allowed in Miden Assembly
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub enum TypeDecl {
54    /// A named type, i.e. a type alias
55    Alias(TypeAlias),
56    /// A C-like enumeration type with associated constants
57    Enum(EnumType),
58}
59
60impl TypeDecl {
61    /// Adds documentation to this type alias
62    pub fn with_docs(self, docs: Option<Span<String>>) -> Self {
63        match self {
64            Self::Alias(ty) => Self::Alias(ty.with_docs(docs)),
65            Self::Enum(ty) => Self::Enum(ty.with_docs(docs)),
66        }
67    }
68
69    /// Get the name assigned to this type declaration
70    pub fn name(&self) -> &Ident {
71        match self {
72            Self::Alias(ty) => &ty.name,
73            Self::Enum(ty) => &ty.name,
74        }
75    }
76
77    /// Get the visibility of this type declaration
78    pub const fn visibility(&self) -> Visibility {
79        match self {
80            Self::Alias(ty) => ty.visibility,
81            Self::Enum(ty) => ty.visibility,
82        }
83    }
84
85    /// Get the documentation of this enum type
86    pub fn docs(&self) -> Option<Span<&str>> {
87        match self {
88            Self::Alias(ty) => ty.docs(),
89            Self::Enum(ty) => ty.docs(),
90        }
91    }
92
93    /// Get the type expression associated with this declaration
94    pub fn ty(&self) -> TypeExpr {
95        match self {
96            Self::Alias(ty) => ty.ty.clone(),
97            Self::Enum(ty) => TypeExpr::Primitive(Span::new(ty.span, ty.ty.clone())),
98        }
99    }
100}
101
102impl Spanned for TypeDecl {
103    fn span(&self) -> SourceSpan {
104        match self {
105            Self::Alias(spanned) => spanned.span,
106            Self::Enum(spanned) => spanned.span,
107        }
108    }
109}
110
111impl From<TypeAlias> for TypeDecl {
112    fn from(value: TypeAlias) -> Self {
113        Self::Alias(value)
114    }
115}
116
117impl From<EnumType> for TypeDecl {
118    fn from(value: EnumType) -> Self {
119        Self::Enum(value)
120    }
121}
122
123impl crate::prettier::PrettyPrint for TypeDecl {
124    fn render(&self) -> crate::prettier::Document {
125        match self {
126            Self::Alias(ty) => ty.render(),
127            Self::Enum(ty) => ty.render(),
128        }
129    }
130}
131
132// FUNCTION TYPE
133// ================================================================================================
134
135/// A procedure type signature
136#[derive(Debug, Clone)]
137pub struct FunctionType {
138    pub span: SourceSpan,
139    pub cc: types::CallConv,
140    pub args: Vec<TypeExpr>,
141    pub results: Vec<TypeExpr>,
142}
143
144impl Eq for FunctionType {}
145
146impl PartialEq for FunctionType {
147    fn eq(&self, other: &Self) -> bool {
148        self.cc == other.cc && self.args == other.args && self.results == other.results
149    }
150}
151
152impl core::hash::Hash for FunctionType {
153    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
154        self.cc.hash(state);
155        self.args.hash(state);
156        self.results.hash(state);
157    }
158}
159
160impl Spanned for FunctionType {
161    fn span(&self) -> SourceSpan {
162        self.span
163    }
164}
165
166impl FunctionType {
167    pub fn new(cc: types::CallConv, args: Vec<TypeExpr>, results: Vec<TypeExpr>) -> Self {
168        Self {
169            span: SourceSpan::UNKNOWN,
170            cc,
171            args,
172            results,
173        }
174    }
175
176    /// Override the default source span
177    #[inline]
178    pub fn with_span(mut self, span: SourceSpan) -> Self {
179        self.span = span;
180        self
181    }
182}
183
184impl crate::prettier::PrettyPrint for FunctionType {
185    fn render(&self) -> crate::prettier::Document {
186        use crate::prettier::*;
187
188        let singleline_args = self
189            .args
190            .iter()
191            .map(|arg| arg.render())
192            .reduce(|acc, arg| acc + const_text(", ") + arg)
193            .unwrap_or(Document::Empty);
194        let multiline_args = indent(
195            4,
196            nl() + self
197                .args
198                .iter()
199                .map(|arg| arg.render())
200                .reduce(|acc, arg| acc + const_text(",") + nl() + arg)
201                .unwrap_or(Document::Empty),
202        ) + nl();
203        let args = singleline_args | multiline_args;
204        let args = const_text("(") + args + const_text(")");
205
206        match self.results.len() {
207            0 => args,
208            1 => args + const_text(" -> ") + self.results[0].render(),
209            _ => {
210                let results = self
211                    .results
212                    .iter()
213                    .map(|r| r.render())
214                    .reduce(|acc, r| acc + const_text(", ") + r)
215                    .unwrap_or(Document::Empty);
216                args + const_text(" -> ") + const_text("(") + results + const_text(")")
217            },
218        }
219    }
220}
221
222// TYPE EXPRESSION
223// ================================================================================================
224
225/// A syntax-level type expression (i.e. primitive type, reference to nominal type, etc.)
226#[derive(Debug, Clone, Eq, PartialEq, Hash)]
227pub enum TypeExpr {
228    /// A primitive integral type, e.g. `i1`, `u16`
229    Primitive(Span<Type>),
230    /// A pointer type expression, e.g. `*u8`
231    Ptr(PointerType),
232    /// An array type expression, e.g. `[u8; 32]`
233    Array(ArrayType),
234    /// A struct type expression, e.g. `struct { a: u32 }`
235    Struct(StructType),
236    /// A reference to a type aliased by name, e.g. `Foo`
237    Ref(Span<Arc<Path>>),
238}
239
240impl TypeExpr {
241    /// Set the name associated with this type expression, if applicable.
242    ///
243    /// Currently this just sets the name of struct types, but if we add other types with names in
244    /// the future, we can support them here.
245    pub fn set_name(&mut self, name: Ident) {
246        match self {
247            Self::Struct(struct_ty) => {
248                struct_ty.name = Some(name);
249            },
250            Self::Primitive(_) | Self::Ptr(_) | Self::Array(_) | Self::Ref(_) => (),
251        }
252    }
253
254    /// Get any references to other types present in this expression
255    pub fn references(&self) -> Vec<Span<Arc<Path>>> {
256        use alloc::collections::BTreeSet;
257
258        let mut worklist = smallvec::SmallVec::<[_; 4]>::from_slice(&[self]);
259        let mut references = BTreeSet::new();
260
261        while let Some(ty) = worklist.pop() {
262            match ty {
263                Self::Primitive(_) => {},
264                Self::Ptr(ty) => {
265                    worklist.push(&ty.pointee);
266                },
267                Self::Array(ty) => {
268                    worklist.push(&ty.elem);
269                },
270                Self::Struct(ty) => {
271                    for field in ty.fields.iter() {
272                        worklist.push(&field.ty);
273                    }
274                },
275                Self::Ref(ty) => {
276                    references.insert(ty.clone());
277                },
278            }
279        }
280
281        references.into_iter().collect()
282    }
283
284    /// Resolve this type expression to a concrete type, using `resolver`
285    pub fn resolve_type<E, R>(&self, resolver: &mut R) -> Result<Option<Type>, E>
286    where
287        R: ?Sized + TypeResolver<E>,
288    {
289        self.resolve_type_with_depth(resolver, 0)
290    }
291
292    fn resolve_type_with_depth<E, R>(
293        &self,
294        resolver: &mut R,
295        depth: usize,
296    ) -> Result<Option<Type>, E>
297    where
298        R: ?Sized + TypeResolver<E>,
299    {
300        if depth > MAX_TYPE_EXPR_NESTING {
301            let source_manager = resolver.source_manager();
302            return Err(resolver.resolve_local_failed(
303                SymbolResolutionError::type_expression_depth_exceeded(
304                    self.span(),
305                    MAX_TYPE_EXPR_NESTING,
306                    source_manager.as_ref(),
307                ),
308            ));
309        }
310
311        match self {
312            TypeExpr::Ref(path) => {
313                let mut current_path = path.clone();
314                loop {
315                    match resolver.resolve_type_ref(current_path.as_deref())? {
316                        SymbolResolution::Local(item) => {
317                            return resolver.get_local_type(current_path.span(), item.into_inner());
318                        },
319                        SymbolResolution::External(path) => {
320                            // We don't have a definition for this type yet
321                            if path == current_path {
322                                break Ok(None);
323                            }
324                            current_path = path;
325                        },
326                        SymbolResolution::Exact { gid, .. } => {
327                            return resolver.get_type(current_path.span(), gid).map(Some);
328                        },
329                        SymbolResolution::Module { path: module_path, .. } => {
330                            break Err(resolver.resolve_local_failed(
331                                SymbolResolutionError::invalid_symbol_type(
332                                    path.span(),
333                                    "type",
334                                    module_path.span(),
335                                    &resolver.source_manager(),
336                                ),
337                            ));
338                        },
339                        SymbolResolution::MastRoot(item) => {
340                            break Err(resolver.resolve_local_failed(
341                                SymbolResolutionError::invalid_symbol_type(
342                                    path.span(),
343                                    "type",
344                                    item.span(),
345                                    &resolver.source_manager(),
346                                ),
347                            ));
348                        },
349                    }
350                }
351            },
352            TypeExpr::Primitive(t) => Ok(Some(t.inner().clone())),
353            TypeExpr::Array(t) => Ok(t
354                .elem
355                .resolve_type_with_depth(resolver, depth + 1)?
356                .map(|elem| types::Type::Array(Arc::new(types::ArrayType::new(elem, t.arity))))),
357            TypeExpr::Ptr(ty) => Ok(ty
358                .pointee
359                .resolve_type_with_depth(resolver, depth + 1)?
360                .map(|pointee| types::Type::Ptr(Arc::new(types::PointerType::new(pointee))))),
361            TypeExpr::Struct(t) => {
362                let mut fields = Vec::with_capacity(t.fields.len());
363                for field in t.fields.iter() {
364                    let field_ty = field.ty.resolve_type_with_depth(resolver, depth + 1)?;
365                    if let Some(field_ty) = field_ty {
366                        fields.push(field_ty);
367                    } else {
368                        return Ok(None);
369                    }
370                }
371                Ok(Some(Type::Struct(Arc::new(types::StructType::from_parts(
372                    t.name.clone().map(|id| id.into_inner()),
373                    t.repr.into_inner(),
374                    fields,
375                )))))
376            },
377        }
378    }
379}
380
381impl From<Type> for TypeExpr {
382    fn from(ty: Type) -> Self {
383        match ty {
384            Type::Array(t) => Self::Array(ArrayType::new(t.element_type().clone().into(), t.len())),
385            Type::Struct(t) => Self::Struct(StructType::new(
386                None,
387                t.fields().iter().enumerate().map(|(i, ft)| {
388                    let name = Ident::new(format!("field{i}")).unwrap();
389                    StructField {
390                        span: SourceSpan::UNKNOWN,
391                        name,
392                        ty: ft.ty.clone().into(),
393                    }
394                }),
395            )),
396            Type::Ptr(t) => Self::Ptr((*t).clone().into()),
397            Type::Function(_) => {
398                Self::Ptr(PointerType::new(TypeExpr::Primitive(Span::unknown(Type::Felt))))
399            },
400            Type::List(t) => Self::Ptr(
401                PointerType::new((*t).clone().into()).with_address_space(AddressSpace::Byte),
402            ),
403            Type::I128 | Type::U128 => Self::Array(ArrayType::new(Type::U32.into(), 4)),
404            Type::I64 | Type::U64 => Self::Array(ArrayType::new(Type::U32.into(), 2)),
405            Type::Unknown | Type::Never | Type::F64 => panic!("unrepresentable type value: {ty}"),
406            ty => Self::Primitive(Span::unknown(ty)),
407        }
408    }
409}
410
411impl Spanned for TypeExpr {
412    fn span(&self) -> SourceSpan {
413        match self {
414            Self::Primitive(spanned) => spanned.span(),
415            Self::Ptr(spanned) => spanned.span(),
416            Self::Array(spanned) => spanned.span(),
417            Self::Struct(spanned) => spanned.span(),
418            Self::Ref(spanned) => spanned.span(),
419        }
420    }
421}
422
423impl crate::prettier::PrettyPrint for TypeExpr {
424    fn render(&self) -> crate::prettier::Document {
425        use crate::prettier::*;
426
427        match self {
428            Self::Primitive(ty) => display(ty),
429            Self::Ptr(ty) => ty.render(),
430            Self::Array(ty) => ty.render(),
431            Self::Struct(ty) => ty.render(),
432            Self::Ref(ty) => display(ty),
433        }
434    }
435}
436
437// POINTER TYPE
438// ================================================================================================
439
440#[derive(Debug, Clone)]
441pub struct PointerType {
442    pub span: SourceSpan,
443    pub pointee: Box<TypeExpr>,
444    addrspace: Option<AddressSpace>,
445}
446
447impl From<types::PointerType> for PointerType {
448    fn from(ty: types::PointerType) -> Self {
449        let types::PointerType { addrspace, pointee } = ty;
450        let pointee = Box::new(TypeExpr::from(pointee));
451        Self {
452            span: SourceSpan::UNKNOWN,
453            pointee,
454            addrspace: Some(addrspace),
455        }
456    }
457}
458
459impl Eq for PointerType {}
460
461impl PartialEq for PointerType {
462    fn eq(&self, other: &Self) -> bool {
463        self.address_space() == other.address_space() && self.pointee == other.pointee
464    }
465}
466
467impl core::hash::Hash for PointerType {
468    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
469        self.pointee.hash(state);
470        self.address_space().hash(state);
471    }
472}
473
474impl Spanned for PointerType {
475    fn span(&self) -> SourceSpan {
476        self.span
477    }
478}
479
480impl PointerType {
481    pub fn new(pointee: TypeExpr) -> Self {
482        Self {
483            span: SourceSpan::UNKNOWN,
484            pointee: Box::new(pointee),
485            addrspace: None,
486        }
487    }
488
489    /// Override the default source span
490    #[inline]
491    pub fn with_span(mut self, span: SourceSpan) -> Self {
492        self.span = span;
493        self
494    }
495
496    /// Override the default address space
497    #[inline]
498    pub fn with_address_space(mut self, addrspace: AddressSpace) -> Self {
499        self.addrspace = Some(addrspace);
500        self
501    }
502
503    /// Get the address space of this pointer type
504    #[inline]
505    pub fn address_space(&self) -> AddressSpace {
506        self.addrspace.unwrap_or(AddressSpace::Element)
507    }
508}
509
510impl crate::prettier::PrettyPrint for PointerType {
511    fn render(&self) -> crate::prettier::Document {
512        use crate::prettier::*;
513
514        let doc = const_text("ptr<") + self.pointee.render();
515        if let Some(addrspace) = self.addrspace.as_ref() {
516            doc + const_text(", ") + text(format!("addrspace({})", addrspace)) + const_text(">")
517        } else {
518            doc + const_text(">")
519        }
520    }
521}
522
523// ARRAY TYPE
524// ================================================================================================
525
526#[derive(Debug, Clone)]
527pub struct ArrayType {
528    pub span: SourceSpan,
529    pub elem: Box<TypeExpr>,
530    pub arity: usize,
531}
532
533impl Eq for ArrayType {}
534
535impl PartialEq for ArrayType {
536    fn eq(&self, other: &Self) -> bool {
537        self.arity == other.arity && self.elem == other.elem
538    }
539}
540
541impl core::hash::Hash for ArrayType {
542    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
543        self.elem.hash(state);
544        self.arity.hash(state);
545    }
546}
547
548impl Spanned for ArrayType {
549    fn span(&self) -> SourceSpan {
550        self.span
551    }
552}
553
554impl ArrayType {
555    pub fn new(elem: TypeExpr, arity: usize) -> Self {
556        Self {
557            span: SourceSpan::UNKNOWN,
558            elem: Box::new(elem),
559            arity,
560        }
561    }
562
563    /// Override the default source span
564    #[inline]
565    pub fn with_span(mut self, span: SourceSpan) -> Self {
566        self.span = span;
567        self
568    }
569}
570
571impl crate::prettier::PrettyPrint for ArrayType {
572    fn render(&self) -> crate::prettier::Document {
573        use crate::prettier::*;
574
575        const_text("[")
576            + self.elem.render()
577            + const_text("; ")
578            + display(self.arity)
579            + const_text("]")
580    }
581}
582
583// STRUCT TYPE
584// ================================================================================================
585
586#[derive(Debug, Clone)]
587pub struct StructType {
588    pub span: SourceSpan,
589    pub name: Option<Ident>,
590    pub repr: Span<TypeRepr>,
591    pub fields: Vec<StructField>,
592}
593
594impl Eq for StructType {}
595
596impl PartialEq for StructType {
597    fn eq(&self, other: &Self) -> bool {
598        self.name == other.name && self.repr == other.repr && self.fields == other.fields
599    }
600}
601
602impl core::hash::Hash for StructType {
603    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
604        self.name.hash(state);
605        self.repr.hash(state);
606        self.fields.hash(state);
607    }
608}
609
610impl Spanned for StructType {
611    fn span(&self) -> SourceSpan {
612        self.span
613    }
614}
615
616impl StructType {
617    pub fn new(name: Option<Ident>, fields: impl IntoIterator<Item = StructField>) -> Self {
618        Self {
619            span: SourceSpan::UNKNOWN,
620            name,
621            repr: Span::unknown(TypeRepr::Default),
622            fields: fields.into_iter().collect(),
623        }
624    }
625
626    /// Override the default struct representation
627    #[inline]
628    pub fn with_repr(mut self, repr: Span<TypeRepr>) -> Self {
629        self.repr = repr;
630        self
631    }
632
633    /// Override the default source span
634    #[inline]
635    pub fn with_span(mut self, span: SourceSpan) -> Self {
636        self.span = span;
637        self
638    }
639}
640
641impl crate::prettier::PrettyPrint for StructType {
642    fn render(&self) -> crate::prettier::Document {
643        use crate::prettier::*;
644
645        let repr = match &*self.repr {
646            TypeRepr::Default => Document::Empty,
647            TypeRepr::BigEndian => const_text("@bigendian "),
648            repr @ (TypeRepr::Align(_) | TypeRepr::Packed(_) | TypeRepr::Transparent) => {
649                text(format!("@{repr} "))
650            },
651        };
652
653        let singleline_body = self
654            .fields
655            .iter()
656            .map(|field| field.render())
657            .reduce(|acc, field| acc + const_text(", ") + field)
658            .unwrap_or(Document::Empty);
659        let multiline_body = indent(
660            4,
661            nl() + self
662                .fields
663                .iter()
664                .map(|field| field.render())
665                .reduce(|acc, field| acc + const_text(",") + nl() + field)
666                .unwrap_or(Document::Empty),
667        ) + nl();
668        let body = singleline_body | multiline_body;
669
670        repr + const_text("struct") + const_text(" { ") + body + const_text(" }")
671    }
672}
673
674// STRUCT FIELD
675// ================================================================================================
676
677#[derive(Debug, Clone)]
678pub struct StructField {
679    pub span: SourceSpan,
680    pub name: Ident,
681    pub ty: TypeExpr,
682}
683
684impl Eq for StructField {}
685
686impl PartialEq for StructField {
687    fn eq(&self, other: &Self) -> bool {
688        self.name == other.name && self.ty == other.ty
689    }
690}
691
692impl core::hash::Hash for StructField {
693    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
694        self.name.hash(state);
695        self.ty.hash(state);
696    }
697}
698
699impl Spanned for StructField {
700    fn span(&self) -> SourceSpan {
701        self.span
702    }
703}
704
705impl crate::prettier::PrettyPrint for StructField {
706    fn render(&self) -> crate::prettier::Document {
707        use crate::prettier::*;
708
709        display(&self.name) + const_text(": ") + self.ty.render()
710    }
711}
712
713// TYPE ALIAS
714// ================================================================================================
715
716/// A [TypeAlias] represents a named [Type].
717///
718/// Type aliases correspond to type declarations in Miden Assembly source files. They are called
719/// aliases, rather than declarations, as the type system for Miden Assembly is structural, rather
720/// than nominal, and so two aliases with the same underlying type are considered equivalent.
721#[derive(Debug, Clone)]
722pub struct TypeAlias {
723    span: SourceSpan,
724    /// The documentation string attached to this definition.
725    docs: Option<DocString>,
726    /// The visibility of this type alias
727    pub visibility: Visibility,
728    /// The name of this type alias
729    pub name: Ident,
730    /// The concrete underlying type
731    pub ty: TypeExpr,
732}
733
734impl TypeAlias {
735    /// Create a new type alias from a name and type
736    pub fn new(visibility: Visibility, name: Ident, ty: TypeExpr) -> Self {
737        Self {
738            span: name.span(),
739            docs: None,
740            visibility,
741            name,
742            ty,
743        }
744    }
745
746    /// Adds documentation to this type alias
747    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
748        self.docs = docs.map(DocString::new);
749        self
750    }
751
752    /// Override the default source span
753    #[inline]
754    pub fn with_span(mut self, span: SourceSpan) -> Self {
755        self.span = span;
756        self
757    }
758
759    /// Set the source span
760    #[inline]
761    pub fn set_span(&mut self, span: SourceSpan) {
762        self.span = span;
763    }
764
765    /// Returns the documentation associated with this item.
766    pub fn docs(&self) -> Option<Span<&str>> {
767        self.docs.as_ref().map(|docstring| docstring.as_spanned_str())
768    }
769
770    /// Get the name of this type alias
771    pub fn name(&self) -> &Ident {
772        &self.name
773    }
774
775    /// Get the visibility of this type alias
776    #[inline]
777    pub const fn visibility(&self) -> Visibility {
778        self.visibility
779    }
780}
781
782impl Eq for TypeAlias {}
783
784impl PartialEq for TypeAlias {
785    fn eq(&self, other: &Self) -> bool {
786        self.visibility == other.visibility
787            && self.name == other.name
788            && self.docs == other.docs
789            && self.ty == other.ty
790    }
791}
792
793impl core::hash::Hash for TypeAlias {
794    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
795        let Self { span: _, docs, visibility, name, ty } = self;
796        docs.hash(state);
797        visibility.hash(state);
798        name.hash(state);
799        ty.hash(state);
800    }
801}
802
803impl Spanned for TypeAlias {
804    fn span(&self) -> SourceSpan {
805        self.span
806    }
807}
808
809impl crate::prettier::PrettyPrint for TypeAlias {
810    fn render(&self) -> crate::prettier::Document {
811        use crate::prettier::*;
812
813        let mut doc = self
814            .docs
815            .as_ref()
816            .map(|docstring| docstring.render())
817            .unwrap_or(Document::Empty);
818
819        if self.visibility.is_public() {
820            doc += display(self.visibility) + const_text(" ");
821        }
822
823        doc + const_text("type")
824            + const_text(" ")
825            + display(&self.name)
826            + const_text(" = ")
827            + self.ty.render()
828    }
829}
830
831// ENUM TYPE
832// ================================================================================================
833
834/// A combined type alias and constant declaration corresponding to a C-like enumeration.
835///
836/// C-style enumerations are effectively a type alias for an integer type with a limited set of
837/// valid values with associated names (referred to as _variants_ of the enum type).
838///
839/// In Miden Assembly, these provide a means for a procedure to declare that it expects an argument
840/// of the underlying integral type, but that values other than those of the declared variants are
841/// illegal/invalid. Currently, these are unchecked, and are only used to convey semantic
842/// information. In the future, we may perform static analysis to try and identify invalid instances
843/// of the enumeration when derived from a constant.
844#[derive(Debug, Clone)]
845pub struct EnumType {
846    span: SourceSpan,
847    /// The documentation string attached to this definition.
848    docs: Option<DocString>,
849    /// The visibility of this enum type
850    visibility: Visibility,
851    /// The enum name
852    name: Ident,
853    /// The type of the discriminant value used for this enum's variants
854    ///
855    /// NOTE: The type must be an integral value, and this is enforced by [`Self::new`].
856    ty: Type,
857    /// The enum variants
858    variants: Vec<Variant>,
859}
860
861impl EnumType {
862    /// Construct a new enum type with the given name and variants
863    ///
864    /// The caller is assumed to have already validated that `ty` is an integral type, and this
865    /// function will assert that this is the case.
866    pub fn new(
867        visibility: Visibility,
868        name: Ident,
869        ty: Type,
870        variants: impl IntoIterator<Item = Variant>,
871    ) -> Self {
872        assert!(ty.is_integer(), "only integer types are allowed in enum type definitions");
873        Self {
874            span: name.span(),
875            docs: None,
876            visibility,
877            name,
878            ty,
879            variants: Vec::from_iter(variants),
880        }
881    }
882
883    /// Adds documentation to this enum declaration.
884    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
885        self.docs = docs.map(DocString::new);
886        self
887    }
888
889    /// Override the default source span
890    pub fn with_span(mut self, span: SourceSpan) -> Self {
891        self.span = span;
892        self
893    }
894
895    /// Returns true if this is a C-style enum where the discriminant is the value
896    pub fn is_c_like(&self) -> bool {
897        !self.variants.is_empty() && self.variants.iter().all(|v| v.value_ty.is_none())
898    }
899
900    /// Set the source span
901    pub fn set_span(&mut self, span: SourceSpan) {
902        self.span = span;
903    }
904
905    /// Get the name of this enum type
906    pub fn name(&self) -> &Ident {
907        &self.name
908    }
909
910    /// Get the visibility of this enum type
911    pub const fn visibility(&self) -> Visibility {
912        self.visibility
913    }
914
915    /// Returns the documentation associated with this item.
916    pub fn docs(&self) -> Option<Span<&str>> {
917        self.docs.as_ref().map(|docstring| docstring.as_spanned_str())
918    }
919
920    /// Get the concrete type of this enum's variants
921    pub fn ty(&self) -> &Type {
922        &self.ty
923    }
924
925    /// Get the variants of this enum type
926    pub fn variants(&self) -> &[Variant] {
927        &self.variants
928    }
929
930    /// Get the variants of this enum type, mutably
931    pub fn variants_mut(&mut self) -> &mut Vec<Variant> {
932        &mut self.variants
933    }
934
935    /// Split this definition into its type alias and variant parts
936    pub fn into_parts(self) -> (TypeAlias, Vec<Variant>) {
937        let Self {
938            span,
939            docs,
940            visibility,
941            name,
942            ty,
943            variants,
944        } = self;
945        let alias = TypeAlias {
946            span,
947            docs,
948            visibility,
949            name,
950            ty: TypeExpr::Primitive(Span::new(span, ty)),
951        };
952        (alias, variants)
953    }
954}
955
956impl Spanned for EnumType {
957    fn span(&self) -> SourceSpan {
958        self.span
959    }
960}
961
962impl Eq for EnumType {}
963
964impl PartialEq for EnumType {
965    fn eq(&self, other: &Self) -> bool {
966        self.visibility == other.visibility
967            && self.name == other.name
968            && self.docs == other.docs
969            && self.ty == other.ty
970            && self.variants == other.variants
971    }
972}
973
974impl core::hash::Hash for EnumType {
975    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
976        let Self {
977            span: _,
978            docs,
979            visibility,
980            name,
981            ty,
982            variants,
983        } = self;
984        docs.hash(state);
985        visibility.hash(state);
986        name.hash(state);
987        ty.hash(state);
988        variants.hash(state);
989    }
990}
991
992impl crate::prettier::PrettyPrint for EnumType {
993    fn render(&self) -> crate::prettier::Document {
994        use crate::prettier::*;
995
996        let mut doc = self
997            .docs
998            .as_ref()
999            .map(|docstring| docstring.render())
1000            .unwrap_or(Document::Empty);
1001
1002        let variants = self
1003            .variants
1004            .iter()
1005            .map(|v| v.render())
1006            .reduce(|acc, v| acc + const_text(",") + nl() + v)
1007            .unwrap_or(Document::Empty);
1008
1009        if self.visibility.is_public() {
1010            doc += display(self.visibility) + const_text(" ");
1011        }
1012
1013        doc + const_text("enum")
1014            + const_text(" ")
1015            + display(&self.name)
1016            + const_text(" : ")
1017            + self.ty.render()
1018            + const_text(" {")
1019            + nl()
1020            + variants
1021            + const_text("}")
1022    }
1023}
1024
1025// ENUM VARIANT
1026// ================================================================================================
1027
1028/// A variant of an [EnumType].
1029///
1030/// See the [EnumType] docs for more information.
1031#[derive(Debug, Clone)]
1032pub struct Variant {
1033    pub span: SourceSpan,
1034    /// The documentation string attached to the constant derived from this variant.
1035    pub docs: Option<DocString>,
1036    /// The name of this enum variant
1037    pub name: Ident,
1038    /// The payload value type of this variant
1039    ///
1040    /// NOTE: This is not supported in Miden Assembly text format yet, but can be set when lowering
1041    /// directly to the AST.
1042    pub value_ty: Option<TypeExpr>,
1043    /// The discriminant value associated with this variant
1044    pub discriminant: ConstantExpr,
1045}
1046
1047impl Variant {
1048    /// Construct a new variant of an [EnumType], with the given name and discriminant value.
1049    pub fn new(name: Ident, discriminant: ConstantExpr, payload: Option<TypeExpr>) -> Self {
1050        Self {
1051            span: name.span(),
1052            docs: None,
1053            name,
1054            value_ty: payload,
1055            discriminant,
1056        }
1057    }
1058
1059    /// Override the span for this variant
1060    pub fn with_span(mut self, span: SourceSpan) -> Self {
1061        self.span = span;
1062        self
1063    }
1064
1065    /// Adds documentation to this variant
1066    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
1067        self.docs = docs.map(DocString::new);
1068        self
1069    }
1070
1071    /// Used to validate that this variant's discriminant value is an instance of `ty`,
1072    /// which must be a type valid for use as the underlying representation for an enum, i.e. an
1073    /// integer type up to 64 bits in size.
1074    ///
1075    /// It is expected that the discriminant expression has been folded to an integer value by the
1076    /// time this is called. If the discriminant has not been fully folded, then an error will be
1077    /// returned.
1078    pub fn assert_instance_of(&self, ty: &Type) -> Result<(), crate::SemanticAnalysisError> {
1079        use crate::{FIELD_MODULUS, SemanticAnalysisError};
1080
1081        let value = match &self.discriminant {
1082            ConstantExpr::Int(value) => value.as_int(),
1083            _ => {
1084                return Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1085                    span: self.discriminant.span(),
1086                    repr: ty.clone(),
1087                });
1088            },
1089        };
1090
1091        match ty {
1092            Type::Felt if value >= FIELD_MODULUS => {
1093                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1094                    span: self.discriminant.span(),
1095                    repr: ty.clone(),
1096                })
1097            },
1098            // IntValue is represented as an unsigned integer, so negative discriminants
1099            // are rejected during constant evaluation.
1100            Type::Felt => Ok(()),
1101            Type::I1 if value > 1 => Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1102                span: self.discriminant.span(),
1103                repr: ty.clone(),
1104            }),
1105            Type::I1 => Ok(()),
1106            Type::I8 | Type::U8 if value > u8::MAX as u64 => {
1107                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1108                    span: self.discriminant.span(),
1109                    repr: ty.clone(),
1110                })
1111            },
1112            Type::I8 | Type::U8 => Ok(()),
1113            Type::I16 | Type::U16 if value > u16::MAX as u64 => {
1114                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1115                    span: self.discriminant.span(),
1116                    repr: ty.clone(),
1117                })
1118            },
1119            Type::I16 | Type::U16 => Ok(()),
1120            Type::I32 | Type::U32 if value > u32::MAX as u64 => {
1121                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1122                    span: self.discriminant.span(),
1123                    repr: ty.clone(),
1124                })
1125            },
1126            Type::I32 | Type::U32 => Ok(()),
1127            Type::I64 | Type::U64 if value >= FIELD_MODULUS => {
1128                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1129                    span: self.discriminant.span(),
1130                    repr: ty.clone(),
1131                })
1132            },
1133            _ => Err(SemanticAnalysisError::InvalidEnumRepr { span: self.span }),
1134        }
1135    }
1136}
1137
1138impl Spanned for Variant {
1139    fn span(&self) -> SourceSpan {
1140        self.span
1141    }
1142}
1143
1144impl Eq for Variant {}
1145
1146impl PartialEq for Variant {
1147    fn eq(&self, other: &Self) -> bool {
1148        self.name == other.name
1149            && self.value_ty == other.value_ty
1150            && self.discriminant == other.discriminant
1151            && self.docs == other.docs
1152    }
1153}
1154
1155impl core::hash::Hash for Variant {
1156    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1157        let Self {
1158            span: _,
1159            docs,
1160            name,
1161            value_ty,
1162            discriminant,
1163        } = self;
1164        docs.hash(state);
1165        name.hash(state);
1166        value_ty.hash(state);
1167        discriminant.hash(state);
1168    }
1169}
1170
1171impl crate::prettier::PrettyPrint for Variant {
1172    fn render(&self) -> crate::prettier::Document {
1173        use crate::prettier::*;
1174
1175        let doc = self
1176            .docs
1177            .as_ref()
1178            .map(|docstring| docstring.render())
1179            .unwrap_or(Document::Empty);
1180
1181        let name = display(&self.name);
1182        let name_and_payload = if let Some(value_ty) = self.value_ty.as_ref() {
1183            name + const_text("(") + value_ty.render() + const_text(")")
1184        } else {
1185            name
1186        };
1187        doc + name_and_payload + const_text(" = ") + self.discriminant.render()
1188    }
1189}
1190
1191#[cfg(test)]
1192mod tests {
1193    use alloc::sync::Arc;
1194    use core::str::FromStr;
1195
1196    use miden_debug_types::DefaultSourceManager;
1197
1198    use super::*;
1199
1200    struct DummyResolver {
1201        source_manager: Arc<dyn SourceManager>,
1202    }
1203
1204    impl DummyResolver {
1205        fn new() -> Self {
1206            Self {
1207                source_manager: Arc::new(DefaultSourceManager::default()),
1208            }
1209        }
1210    }
1211
1212    impl TypeResolver<SymbolResolutionError> for DummyResolver {
1213        fn source_manager(&self) -> Arc<dyn SourceManager> {
1214            self.source_manager.clone()
1215        }
1216
1217        fn resolve_local_failed(&self, err: SymbolResolutionError) -> SymbolResolutionError {
1218            err
1219        }
1220
1221        fn get_type(
1222            &mut self,
1223            context: SourceSpan,
1224            _gid: GlobalItemIndex,
1225        ) -> Result<Type, SymbolResolutionError> {
1226            Err(SymbolResolutionError::undefined(context, self.source_manager.as_ref()))
1227        }
1228
1229        fn get_local_type(
1230            &mut self,
1231            _context: SourceSpan,
1232            _id: ItemIndex,
1233        ) -> Result<Option<Type>, SymbolResolutionError> {
1234            Ok(None)
1235        }
1236
1237        fn resolve_type_ref(
1238            &mut self,
1239            ty: Span<&Path>,
1240        ) -> Result<SymbolResolution, SymbolResolutionError> {
1241            Err(SymbolResolutionError::undefined(ty.span(), self.source_manager.as_ref()))
1242        }
1243    }
1244
1245    fn nested_type_expr(depth: usize) -> TypeExpr {
1246        let mut expr = TypeExpr::Primitive(Span::unknown(Type::Felt));
1247        for i in 0..depth {
1248            expr = match i % 3 {
1249                0 => TypeExpr::Ptr(PointerType::new(expr)),
1250                1 => TypeExpr::Array(ArrayType::new(expr, 1)),
1251                _ => {
1252                    let field = StructField {
1253                        span: SourceSpan::UNKNOWN,
1254                        name: Ident::from_str("field").expect("valid ident"),
1255                        ty: expr,
1256                    };
1257                    TypeExpr::Struct(StructType::new(None, [field]))
1258                },
1259            };
1260        }
1261        expr
1262    }
1263
1264    #[test]
1265    fn type_expr_depth_boundary() {
1266        let mut resolver = DummyResolver::new();
1267
1268        let ok_expr = nested_type_expr(MAX_TYPE_EXPR_NESTING);
1269        assert!(ok_expr.resolve_type(&mut resolver).is_ok());
1270
1271        let err_expr = nested_type_expr(MAX_TYPE_EXPR_NESTING + 1);
1272        let err = err_expr.resolve_type(&mut resolver).expect_err("expected depth-exceeded error");
1273        assert!(
1274            matches!(err, SymbolResolutionError::TypeExpressionDepthExceeded { max_depth, .. }
1275                if max_depth == MAX_TYPE_EXPR_NESTING)
1276        );
1277    }
1278}