Skip to main content

miden_assembly_syntax/ast/
type.rs

1use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
2
3use miden_debug_types::{SourceManager, SourceSpan, Span, Spanned};
4pub use midenc_hir_type as types;
5use midenc_hir_type::{AddressSpace, Type, TypeRepr};
6
7use super::{
8    ConstantExpr, DocString, GlobalItemIndex, Ident, ItemIndex, Path, SymbolResolution,
9    SymbolResolutionError, Visibility,
10};
11
12/// Abstracts over resolving an item to a concrete [Type], using one of:
13///
14/// * A [GlobalItemIndex]
15/// * An [ItemIndex]
16/// * A [Path]
17/// * A [TypeExpr]
18///
19/// Since type resolution happens in two different contexts during assembly, this abstraction allows
20/// us to share more of the resolution logic in both places.
21pub trait TypeResolver<E> {
22    fn source_manager(&self) -> Arc<dyn SourceManager>;
23    /// Should be called by consumers of this resolver to convert a [SymbolResolutionError] to the
24    /// error type used by the [TypeResolver] implementation.
25    fn resolve_local_failed(&self, err: SymbolResolutionError) -> E;
26    /// Get the [Type] corresponding to the item given by `gid`
27    fn get_type(&self, context: SourceSpan, gid: GlobalItemIndex) -> Result<Type, E>;
28    /// Get the [Type] corresponding to the item in the current module given by `id`
29    fn get_local_type(&self, context: SourceSpan, id: ItemIndex) -> Result<Option<Type>, E>;
30    /// Attempt to resolve a symbol path, given by a `TypeExpr::Ref`, to an item
31    fn resolve_type_ref(&self, ty: Span<&Path>) -> Result<SymbolResolution, E>;
32    /// Resolve a [TypeExpr] to a concrete [Type]
33    fn resolve(&self, ty: &TypeExpr) -> Result<Option<Type>, E> {
34        ty.resolve_type(self)
35    }
36}
37
38// TYPE DECLARATION
39// ================================================================================================
40
41/// An abstraction over the different types of type declarations allowed in Miden Assembly
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum TypeDecl {
44    /// A named type, i.e. a type alias
45    Alias(TypeAlias),
46    /// A C-like enumeration type with associated constants
47    Enum(EnumType),
48}
49
50impl TypeDecl {
51    /// Adds documentation to this type alias
52    pub fn with_docs(self, docs: Option<Span<String>>) -> Self {
53        match self {
54            Self::Alias(ty) => Self::Alias(ty.with_docs(docs)),
55            Self::Enum(ty) => Self::Enum(ty.with_docs(docs)),
56        }
57    }
58
59    /// Get the name assigned to this type declaration
60    pub fn name(&self) -> &Ident {
61        match self {
62            Self::Alias(ty) => &ty.name,
63            Self::Enum(ty) => &ty.name,
64        }
65    }
66
67    /// Get the visibility of this type declaration
68    pub const fn visibility(&self) -> Visibility {
69        match self {
70            Self::Alias(ty) => ty.visibility,
71            Self::Enum(ty) => ty.visibility,
72        }
73    }
74
75    /// Get the documentation of this enum type
76    pub fn docs(&self) -> Option<Span<&str>> {
77        match self {
78            Self::Alias(ty) => ty.docs(),
79            Self::Enum(ty) => ty.docs(),
80        }
81    }
82
83    /// Get the type expression associated with this declaration
84    pub fn ty(&self) -> TypeExpr {
85        match self {
86            Self::Alias(ty) => ty.ty.clone(),
87            Self::Enum(ty) => TypeExpr::Primitive(Span::new(ty.span, ty.ty.clone())),
88        }
89    }
90}
91
92impl Spanned for TypeDecl {
93    fn span(&self) -> SourceSpan {
94        match self {
95            Self::Alias(spanned) => spanned.span,
96            Self::Enum(spanned) => spanned.span,
97        }
98    }
99}
100
101impl From<TypeAlias> for TypeDecl {
102    fn from(value: TypeAlias) -> Self {
103        Self::Alias(value)
104    }
105}
106
107impl From<EnumType> for TypeDecl {
108    fn from(value: EnumType) -> Self {
109        Self::Enum(value)
110    }
111}
112
113impl crate::prettier::PrettyPrint for TypeDecl {
114    fn render(&self) -> crate::prettier::Document {
115        match self {
116            Self::Alias(ty) => ty.render(),
117            Self::Enum(ty) => ty.render(),
118        }
119    }
120}
121
122// FUNCTION TYPE
123// ================================================================================================
124
125/// A procedure type signature
126#[derive(Debug, Clone)]
127pub struct FunctionType {
128    pub span: SourceSpan,
129    pub cc: types::CallConv,
130    pub args: Vec<TypeExpr>,
131    pub results: Vec<TypeExpr>,
132}
133
134impl Eq for FunctionType {}
135
136impl PartialEq for FunctionType {
137    fn eq(&self, other: &Self) -> bool {
138        self.cc == other.cc && self.args == other.args && self.results == other.results
139    }
140}
141
142impl core::hash::Hash for FunctionType {
143    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
144        self.cc.hash(state);
145        self.args.hash(state);
146        self.results.hash(state);
147    }
148}
149
150impl Spanned for FunctionType {
151    fn span(&self) -> SourceSpan {
152        self.span
153    }
154}
155
156impl FunctionType {
157    pub fn new(cc: types::CallConv, args: Vec<TypeExpr>, results: Vec<TypeExpr>) -> Self {
158        Self {
159            span: SourceSpan::UNKNOWN,
160            cc,
161            args,
162            results,
163        }
164    }
165
166    /// Override the default source span
167    #[inline]
168    pub fn with_span(mut self, span: SourceSpan) -> Self {
169        self.span = span;
170        self
171    }
172}
173
174impl crate::prettier::PrettyPrint for FunctionType {
175    fn render(&self) -> crate::prettier::Document {
176        use crate::prettier::*;
177
178        let singleline_args = self
179            .args
180            .iter()
181            .map(|arg| arg.render())
182            .reduce(|acc, arg| acc + const_text(", ") + arg)
183            .unwrap_or(Document::Empty);
184        let multiline_args = indent(
185            4,
186            nl() + self
187                .args
188                .iter()
189                .map(|arg| arg.render())
190                .reduce(|acc, arg| acc + const_text(",") + nl() + arg)
191                .unwrap_or(Document::Empty),
192        ) + nl();
193        let args = singleline_args | multiline_args;
194        let args = const_text("(") + args + const_text(")");
195
196        match self.results.len() {
197            0 => args,
198            1 => args + const_text(" -> ") + self.results[0].render(),
199            _ => {
200                let results = self
201                    .results
202                    .iter()
203                    .map(|r| r.render())
204                    .reduce(|acc, r| acc + const_text(", ") + r)
205                    .unwrap_or(Document::Empty);
206                args + const_text(" -> ") + const_text("(") + results + const_text(")")
207            },
208        }
209    }
210}
211
212// TYPE EXPRESSION
213// ================================================================================================
214
215/// A syntax-level type expression (i.e. primitive type, reference to nominal type, etc.)
216#[derive(Debug, Clone, Eq, PartialEq, Hash)]
217pub enum TypeExpr {
218    /// A primitive integral type, e.g. `i1`, `u16`
219    Primitive(Span<Type>),
220    /// A pointer type expression, e.g. `*u8`
221    Ptr(PointerType),
222    /// An array type expression, e.g. `[u8; 32]`
223    Array(ArrayType),
224    /// A struct type expression, e.g. `struct { a: u32 }`
225    Struct(StructType),
226    /// A reference to a type aliased by name, e.g. `Foo`
227    Ref(Span<Arc<Path>>),
228}
229
230impl TypeExpr {
231    /// Get any references to other types present in this expression
232    pub fn references(&self) -> Vec<Span<Arc<Path>>> {
233        use alloc::collections::BTreeSet;
234
235        let mut worklist = smallvec::SmallVec::<[_; 4]>::from_slice(&[self]);
236        let mut references = BTreeSet::new();
237
238        while let Some(ty) = worklist.pop() {
239            match ty {
240                Self::Primitive(_) => continue,
241                Self::Ptr(ty) => {
242                    worklist.push(&ty.pointee);
243                },
244                Self::Array(ty) => {
245                    worklist.push(&ty.elem);
246                },
247                Self::Struct(ty) => {
248                    for field in ty.fields.iter() {
249                        worklist.push(&field.ty);
250                    }
251                },
252                Self::Ref(ty) => {
253                    references.insert(ty.clone());
254                },
255            }
256        }
257
258        references.into_iter().collect()
259    }
260
261    /// Resolve this type expression to a concrete type, using `resolver`
262    pub fn resolve_type<E, R>(&self, resolver: &R) -> Result<Option<Type>, E>
263    where
264        R: ?Sized + TypeResolver<E>,
265    {
266        match self {
267            TypeExpr::Ref(path) => {
268                let mut current_path = path.clone();
269                loop {
270                    match resolver.resolve_type_ref(current_path.as_deref())? {
271                        SymbolResolution::Local(item) => {
272                            return resolver.get_local_type(current_path.span(), item.into_inner());
273                        },
274                        SymbolResolution::External(path) => {
275                            // We don't have a definition for this type yet
276                            if path == current_path {
277                                break Ok(None);
278                            }
279                            current_path = path;
280                        },
281                        SymbolResolution::Exact { gid, .. } => {
282                            return resolver.get_type(current_path.span(), gid).map(Some);
283                        },
284                        SymbolResolution::Module { path: module_path, .. } => {
285                            break Err(resolver.resolve_local_failed(
286                                SymbolResolutionError::invalid_symbol_type(
287                                    path.span(),
288                                    "type",
289                                    module_path.span(),
290                                    &resolver.source_manager(),
291                                ),
292                            ));
293                        },
294                        SymbolResolution::MastRoot(item) => {
295                            break Err(resolver.resolve_local_failed(
296                                SymbolResolutionError::invalid_symbol_type(
297                                    path.span(),
298                                    "type",
299                                    item.span(),
300                                    &resolver.source_manager(),
301                                ),
302                            ));
303                        },
304                    }
305                }
306            },
307            TypeExpr::Primitive(t) => Ok(Some(t.inner().clone())),
308            TypeExpr::Array(t) => Ok(t
309                .elem
310                .resolve_type(resolver)?
311                .map(|elem| types::Type::Array(Arc::new(types::ArrayType::new(elem, t.arity))))),
312            TypeExpr::Ptr(ty) => Ok(ty
313                .pointee
314                .resolve_type(resolver)?
315                .map(|pointee| types::Type::Ptr(Arc::new(types::PointerType::new(pointee))))),
316            TypeExpr::Struct(t) => {
317                let mut fields = Vec::with_capacity(t.fields.len());
318                for field in t.fields.iter() {
319                    let field_ty = field.ty.resolve_type(resolver)?;
320                    if let Some(field_ty) = field_ty {
321                        fields.push(field_ty);
322                    } else {
323                        return Ok(None);
324                    }
325                }
326                Ok(Some(Type::Struct(Arc::new(types::StructType::new(fields)))))
327            },
328        }
329    }
330}
331
332impl From<Type> for TypeExpr {
333    fn from(ty: Type) -> Self {
334        match ty {
335            Type::Array(t) => Self::Array(ArrayType::new(t.element_type().clone().into(), t.len())),
336            Type::Struct(t) => {
337                Self::Struct(StructType::new(t.fields().iter().enumerate().map(|(i, ft)| {
338                    let name = Ident::new(format!("field{i}")).unwrap();
339                    StructField {
340                        span: SourceSpan::UNKNOWN,
341                        name,
342                        ty: ft.ty.clone().into(),
343                    }
344                })))
345            },
346            Type::Ptr(t) => Self::Ptr((*t).clone().into()),
347            Type::Function(_) => {
348                Self::Ptr(PointerType::new(TypeExpr::Primitive(Span::unknown(Type::Felt))))
349            },
350            Type::List(t) => Self::Ptr(
351                PointerType::new((*t).clone().into()).with_address_space(AddressSpace::Byte),
352            ),
353            Type::I128 | Type::U128 => Self::Array(ArrayType::new(Type::U32.into(), 4)),
354            Type::I64 | Type::U64 => Self::Array(ArrayType::new(Type::U32.into(), 2)),
355            Type::Unknown | Type::Never | Type::F64 => panic!("unrepresentable type value: {ty}"),
356            ty => Self::Primitive(Span::unknown(ty)),
357        }
358    }
359}
360
361impl Spanned for TypeExpr {
362    fn span(&self) -> SourceSpan {
363        match self {
364            Self::Primitive(spanned) => spanned.span(),
365            Self::Ptr(spanned) => spanned.span(),
366            Self::Array(spanned) => spanned.span(),
367            Self::Struct(spanned) => spanned.span(),
368            Self::Ref(spanned) => spanned.span(),
369        }
370    }
371}
372
373impl crate::prettier::PrettyPrint for TypeExpr {
374    fn render(&self) -> crate::prettier::Document {
375        use crate::prettier::*;
376
377        match self {
378            Self::Primitive(ty) => display(ty),
379            Self::Ptr(ty) => ty.render(),
380            Self::Array(ty) => ty.render(),
381            Self::Struct(ty) => ty.render(),
382            Self::Ref(ty) => display(ty),
383        }
384    }
385}
386
387// POINTER TYPE
388// ================================================================================================
389
390#[derive(Debug, Clone)]
391pub struct PointerType {
392    pub span: SourceSpan,
393    pub pointee: Box<TypeExpr>,
394    addrspace: Option<AddressSpace>,
395}
396
397impl From<types::PointerType> for PointerType {
398    fn from(ty: types::PointerType) -> Self {
399        let types::PointerType { addrspace, pointee } = ty;
400        let pointee = Box::new(TypeExpr::from(pointee));
401        Self {
402            span: SourceSpan::UNKNOWN,
403            pointee,
404            addrspace: Some(addrspace),
405        }
406    }
407}
408
409impl Eq for PointerType {}
410
411impl PartialEq for PointerType {
412    fn eq(&self, other: &Self) -> bool {
413        self.address_space() == other.address_space() && self.pointee == other.pointee
414    }
415}
416
417impl core::hash::Hash for PointerType {
418    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
419        self.pointee.hash(state);
420        self.address_space().hash(state);
421    }
422}
423
424impl Spanned for PointerType {
425    fn span(&self) -> SourceSpan {
426        self.span
427    }
428}
429
430impl PointerType {
431    pub fn new(pointee: TypeExpr) -> Self {
432        Self {
433            span: SourceSpan::UNKNOWN,
434            pointee: Box::new(pointee),
435            addrspace: None,
436        }
437    }
438
439    /// Override the default source span
440    #[inline]
441    pub fn with_span(mut self, span: SourceSpan) -> Self {
442        self.span = span;
443        self
444    }
445
446    /// Override the default address space
447    #[inline]
448    pub fn with_address_space(mut self, addrspace: AddressSpace) -> Self {
449        self.addrspace = Some(addrspace);
450        self
451    }
452
453    /// Get the address space of this pointer type
454    #[inline]
455    pub fn address_space(&self) -> AddressSpace {
456        self.addrspace.unwrap_or(AddressSpace::Element)
457    }
458}
459
460impl crate::prettier::PrettyPrint for PointerType {
461    fn render(&self) -> crate::prettier::Document {
462        use crate::prettier::*;
463
464        let doc = const_text("ptr<") + self.pointee.render();
465        if let Some(addrspace) = self.addrspace.as_ref() {
466            doc + const_text(", ") + text(format!("addrspace({})", addrspace)) + const_text(">")
467        } else {
468            doc + const_text(">")
469        }
470    }
471}
472
473// ARRAY TYPE
474// ================================================================================================
475
476#[derive(Debug, Clone)]
477pub struct ArrayType {
478    pub span: SourceSpan,
479    pub elem: Box<TypeExpr>,
480    pub arity: usize,
481}
482
483impl Eq for ArrayType {}
484
485impl PartialEq for ArrayType {
486    fn eq(&self, other: &Self) -> bool {
487        self.arity == other.arity && self.elem == other.elem
488    }
489}
490
491impl core::hash::Hash for ArrayType {
492    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
493        self.elem.hash(state);
494        self.arity.hash(state);
495    }
496}
497
498impl Spanned for ArrayType {
499    fn span(&self) -> SourceSpan {
500        self.span
501    }
502}
503
504impl ArrayType {
505    pub fn new(elem: TypeExpr, arity: usize) -> Self {
506        Self {
507            span: SourceSpan::UNKNOWN,
508            elem: Box::new(elem),
509            arity,
510        }
511    }
512
513    /// Override the default source span
514    #[inline]
515    pub fn with_span(mut self, span: SourceSpan) -> Self {
516        self.span = span;
517        self
518    }
519}
520
521impl crate::prettier::PrettyPrint for ArrayType {
522    fn render(&self) -> crate::prettier::Document {
523        use crate::prettier::*;
524
525        const_text("[")
526            + self.elem.render()
527            + const_text("; ")
528            + display(self.arity)
529            + const_text("]")
530    }
531}
532
533// STRUCT TYPE
534// ================================================================================================
535
536#[derive(Debug, Clone)]
537pub struct StructType {
538    pub span: SourceSpan,
539    pub repr: Span<TypeRepr>,
540    pub fields: Vec<StructField>,
541}
542
543impl Eq for StructType {}
544
545impl PartialEq for StructType {
546    fn eq(&self, other: &Self) -> bool {
547        self.repr == other.repr && self.fields == other.fields
548    }
549}
550
551impl core::hash::Hash for StructType {
552    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
553        self.repr.hash(state);
554        self.fields.hash(state);
555    }
556}
557
558impl Spanned for StructType {
559    fn span(&self) -> SourceSpan {
560        self.span
561    }
562}
563
564impl StructType {
565    pub fn new(fields: impl IntoIterator<Item = StructField>) -> Self {
566        Self {
567            span: SourceSpan::UNKNOWN,
568            repr: Span::unknown(TypeRepr::Default),
569            fields: fields.into_iter().collect(),
570        }
571    }
572
573    /// Override the default struct representation
574    #[inline]
575    pub fn with_repr(mut self, repr: Span<TypeRepr>) -> Self {
576        self.repr = repr;
577        self
578    }
579
580    /// Override the default source span
581    #[inline]
582    pub fn with_span(mut self, span: SourceSpan) -> Self {
583        self.span = span;
584        self
585    }
586}
587
588impl crate::prettier::PrettyPrint for StructType {
589    fn render(&self) -> crate::prettier::Document {
590        use crate::prettier::*;
591
592        let repr = match &*self.repr {
593            TypeRepr::Default => Document::Empty,
594            TypeRepr::BigEndian => const_text("@bigendian "),
595            repr @ (TypeRepr::Align(_) | TypeRepr::Packed(_) | TypeRepr::Transparent) => {
596                text(format!("@{repr} "))
597            },
598        };
599
600        let singleline_body = self
601            .fields
602            .iter()
603            .map(|field| field.render())
604            .reduce(|acc, field| acc + const_text(", ") + field)
605            .unwrap_or(Document::Empty);
606        let multiline_body = indent(
607            4,
608            nl() + self
609                .fields
610                .iter()
611                .map(|field| field.render())
612                .reduce(|acc, field| acc + const_text(",") + nl() + field)
613                .unwrap_or(Document::Empty),
614        ) + nl();
615        let body = singleline_body | multiline_body;
616
617        repr + const_text("struct") + const_text(" { ") + body + const_text(" }")
618    }
619}
620
621// STRUCT FIELD
622// ================================================================================================
623
624#[derive(Debug, Clone)]
625pub struct StructField {
626    pub span: SourceSpan,
627    pub name: Ident,
628    pub ty: TypeExpr,
629}
630
631impl Eq for StructField {}
632
633impl PartialEq for StructField {
634    fn eq(&self, other: &Self) -> bool {
635        self.name == other.name && self.ty == other.ty
636    }
637}
638
639impl core::hash::Hash for StructField {
640    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
641        self.name.hash(state);
642        self.ty.hash(state);
643    }
644}
645
646impl Spanned for StructField {
647    fn span(&self) -> SourceSpan {
648        self.span
649    }
650}
651
652impl crate::prettier::PrettyPrint for StructField {
653    fn render(&self) -> crate::prettier::Document {
654        use crate::prettier::*;
655
656        display(&self.name) + const_text(": ") + self.ty.render()
657    }
658}
659
660// TYPE ALIAS
661// ================================================================================================
662
663/// A [TypeAlias] represents a named [Type].
664///
665/// Type aliases correspond to type declarations in Miden Assembly source files. They are called
666/// aliases, rather than declarations, as the type system for Miden Assembly is structural, rather
667/// than nominal, and so two aliases with the same underlying type are considered equivalent.
668#[derive(Debug, Clone)]
669pub struct TypeAlias {
670    span: SourceSpan,
671    /// The documentation string attached to this definition.
672    docs: Option<DocString>,
673    /// The visibility of this type alias
674    pub visibility: Visibility,
675    /// The name of this type alias
676    pub name: Ident,
677    /// The concrete underlying type
678    pub ty: TypeExpr,
679}
680
681impl TypeAlias {
682    /// Create a new type alias from a name and type
683    pub fn new(visibility: Visibility, name: Ident, ty: TypeExpr) -> Self {
684        Self {
685            span: name.span(),
686            docs: None,
687            visibility,
688            name,
689            ty,
690        }
691    }
692
693    /// Adds documentation to this type alias
694    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
695        self.docs = docs.map(DocString::new);
696        self
697    }
698
699    /// Override the default source span
700    #[inline]
701    pub fn with_span(mut self, span: SourceSpan) -> Self {
702        self.span = span;
703        self
704    }
705
706    /// Set the source span
707    #[inline]
708    pub fn set_span(&mut self, span: SourceSpan) {
709        self.span = span;
710    }
711
712    /// Returns the documentation associated with this item.
713    pub fn docs(&self) -> Option<Span<&str>> {
714        self.docs.as_ref().map(|docstring| docstring.as_spanned_str())
715    }
716
717    /// Get the name of this type alias
718    pub fn name(&self) -> &Ident {
719        &self.name
720    }
721
722    /// Get the visibility of this type alias
723    #[inline]
724    pub const fn visibility(&self) -> Visibility {
725        self.visibility
726    }
727}
728
729impl Eq for TypeAlias {}
730
731impl PartialEq for TypeAlias {
732    fn eq(&self, other: &Self) -> bool {
733        self.visibility == other.visibility
734            && self.name == other.name
735            && self.docs == other.docs
736            && self.ty == other.ty
737    }
738}
739
740impl core::hash::Hash for TypeAlias {
741    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
742        let Self { span: _, docs, visibility, name, ty } = self;
743        docs.hash(state);
744        visibility.hash(state);
745        name.hash(state);
746        ty.hash(state);
747    }
748}
749
750impl Spanned for TypeAlias {
751    fn span(&self) -> SourceSpan {
752        self.span
753    }
754}
755
756impl crate::prettier::PrettyPrint for TypeAlias {
757    fn render(&self) -> crate::prettier::Document {
758        use crate::prettier::*;
759
760        let mut doc = self
761            .docs
762            .as_ref()
763            .map(|docstring| docstring.render())
764            .unwrap_or(Document::Empty);
765
766        if self.visibility.is_public() {
767            doc += display(self.visibility) + const_text(" ");
768        }
769
770        doc + const_text("type")
771            + const_text(" ")
772            + display(&self.name)
773            + const_text(" = ")
774            + self.ty.render()
775    }
776}
777
778// ENUM TYPE
779// ================================================================================================
780
781/// A combined type alias and constant declaration corresponding to a C-like enumeration.
782///
783/// C-style enumerations are effectively a type alias for an integer type with a limited set of
784/// valid values with associated names (referred to as _variants_ of the enum type).
785///
786/// In Miden Assembly, these provide a means for a procedure to declare that it expects an argument
787/// of the underlying integral type, but that values other than those of the declared variants are
788/// illegal/invalid. Currently, these are unchecked, and are only used to convey semantic
789/// information. In the future, we may perform static analysis to try and identify invalid instances
790/// of the enumeration when derived from a constant.
791#[derive(Debug, Clone)]
792pub struct EnumType {
793    span: SourceSpan,
794    /// The documentation string attached to this definition.
795    docs: Option<DocString>,
796    /// The visibility of this enum type
797    visibility: Visibility,
798    /// The enum name
799    name: Ident,
800    /// The type of the discriminant value used for this enum's variants
801    ///
802    /// NOTE: The type must be an integral value, and this is enforced by [`Self::new`].
803    ty: Type,
804    /// The enum variants
805    variants: Vec<Variant>,
806}
807
808impl EnumType {
809    /// Construct a new enum type with the given name and variants
810    ///
811    /// The caller is assumed to have already validated that `ty` is an integral type, and this
812    /// function will assert that this is the case.
813    pub fn new(
814        visibility: Visibility,
815        name: Ident,
816        ty: Type,
817        variants: impl IntoIterator<Item = Variant>,
818    ) -> Self {
819        assert!(ty.is_integer(), "only integer types are allowed in enum type definitions");
820        Self {
821            span: name.span(),
822            docs: None,
823            visibility,
824            name,
825            ty,
826            variants: Vec::from_iter(variants),
827        }
828    }
829
830    /// Adds documentation to this enum declaration.
831    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
832        self.docs = docs.map(DocString::new);
833        self
834    }
835
836    /// Override the default source span
837    pub fn with_span(mut self, span: SourceSpan) -> Self {
838        self.span = span;
839        self
840    }
841
842    /// Set the source span
843    pub fn set_span(&mut self, span: SourceSpan) {
844        self.span = span;
845    }
846
847    /// Get the name of this enum type
848    pub fn name(&self) -> &Ident {
849        &self.name
850    }
851
852    /// Get the visibility of this enum type
853    pub const fn visibility(&self) -> Visibility {
854        self.visibility
855    }
856
857    /// Returns the documentation associated with this item.
858    pub fn docs(&self) -> Option<Span<&str>> {
859        self.docs.as_ref().map(|docstring| docstring.as_spanned_str())
860    }
861
862    /// Get the concrete type of this enum's variants
863    pub fn ty(&self) -> &Type {
864        &self.ty
865    }
866
867    /// Get the variants of this enum type
868    pub fn variants(&self) -> &[Variant] {
869        &self.variants
870    }
871
872    /// Get the variants of this enum type, mutably
873    pub fn variants_mut(&mut self) -> &mut Vec<Variant> {
874        &mut self.variants
875    }
876
877    /// Split this definition into its type alias and variant parts
878    pub fn into_parts(self) -> (TypeAlias, Vec<Variant>) {
879        let Self {
880            span,
881            docs,
882            visibility,
883            name,
884            ty,
885            variants,
886        } = self;
887        let alias = TypeAlias {
888            span,
889            docs,
890            visibility,
891            name,
892            ty: TypeExpr::Primitive(Span::new(span, ty)),
893        };
894        (alias, variants)
895    }
896}
897
898impl Spanned for EnumType {
899    fn span(&self) -> SourceSpan {
900        self.span
901    }
902}
903
904impl Eq for EnumType {}
905
906impl PartialEq for EnumType {
907    fn eq(&self, other: &Self) -> bool {
908        self.visibility == other.visibility
909            && self.name == other.name
910            && self.docs == other.docs
911            && self.ty == other.ty
912            && self.variants == other.variants
913    }
914}
915
916impl core::hash::Hash for EnumType {
917    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
918        let Self {
919            span: _,
920            docs,
921            visibility,
922            name,
923            ty,
924            variants,
925        } = self;
926        docs.hash(state);
927        visibility.hash(state);
928        name.hash(state);
929        ty.hash(state);
930        variants.hash(state);
931    }
932}
933
934impl crate::prettier::PrettyPrint for EnumType {
935    fn render(&self) -> crate::prettier::Document {
936        use crate::prettier::*;
937
938        let mut doc = self
939            .docs
940            .as_ref()
941            .map(|docstring| docstring.render())
942            .unwrap_or(Document::Empty);
943
944        let variants = self
945            .variants
946            .iter()
947            .map(|v| v.render())
948            .reduce(|acc, v| acc + const_text(",") + nl() + v)
949            .unwrap_or(Document::Empty);
950
951        if self.visibility.is_public() {
952            doc += display(self.visibility) + const_text(" ");
953        }
954
955        doc + const_text("enum")
956            + const_text(" ")
957            + display(&self.name)
958            + const_text(" : ")
959            + self.ty.render()
960            + const_text(" {")
961            + nl()
962            + variants
963            + const_text("}")
964    }
965}
966
967// ENUM VARIANT
968// ================================================================================================
969
970/// A variant of an [EnumType].
971///
972/// See the [EnumType] docs for more information.
973#[derive(Debug, Clone)]
974pub struct Variant {
975    pub span: SourceSpan,
976    /// The documentation string attached to the constant derived from this variant.
977    pub docs: Option<DocString>,
978    /// The name of this enum variant
979    pub name: Ident,
980    /// The discriminant value associated with this variant
981    pub discriminant: ConstantExpr,
982}
983
984impl Variant {
985    /// Construct a new variant of an [EnumType], with the given name and discriminant value.
986    pub fn new(name: Ident, discriminant: ConstantExpr) -> Self {
987        Self {
988            span: name.span(),
989            docs: None,
990            name,
991            discriminant,
992        }
993    }
994
995    /// Override the span for this variant
996    pub fn with_span(mut self, span: SourceSpan) -> Self {
997        self.span = span;
998        self
999    }
1000
1001    /// Adds documentation to this variant
1002    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
1003        self.docs = docs.map(DocString::new);
1004        self
1005    }
1006
1007    /// Used to validate that this variant's discriminant value is an instance of `ty`,
1008    /// which must be a type valid for use as the underlying representation for an enum, i.e. an
1009    /// integer type up to 64 bits in size.
1010    ///
1011    /// It is expected that the discriminant expression has been folded to an integer value by the
1012    /// time this is called. If the discriminant has not been fully folded, then an error will be
1013    /// returned.
1014    pub fn assert_instance_of(&self, ty: &Type) -> Result<(), crate::SemanticAnalysisError> {
1015        use crate::{FIELD_MODULUS, SemanticAnalysisError};
1016
1017        let value = match &self.discriminant {
1018            ConstantExpr::Int(value) => value.as_int(),
1019            _ => {
1020                return Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1021                    span: self.discriminant.span(),
1022                    repr: ty.clone(),
1023                });
1024            },
1025        };
1026
1027        match ty {
1028            Type::I1 if value > 1 => Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1029                span: self.discriminant.span(),
1030                repr: ty.clone(),
1031            }),
1032            Type::I1 => Ok(()),
1033            Type::I8 | Type::U8 if value > u8::MAX as u64 => {
1034                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1035                    span: self.discriminant.span(),
1036                    repr: ty.clone(),
1037                })
1038            },
1039            Type::I8 | Type::U8 => Ok(()),
1040            Type::I16 | Type::U16 if value > u16::MAX as u64 => {
1041                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1042                    span: self.discriminant.span(),
1043                    repr: ty.clone(),
1044                })
1045            },
1046            Type::I16 | Type::U16 => Ok(()),
1047            Type::I32 | Type::U32 if value > u32::MAX as u64 => {
1048                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1049                    span: self.discriminant.span(),
1050                    repr: ty.clone(),
1051                })
1052            },
1053            Type::I32 | Type::U32 => Ok(()),
1054            Type::I64 | Type::U64 if value >= FIELD_MODULUS => {
1055                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1056                    span: self.discriminant.span(),
1057                    repr: ty.clone(),
1058                })
1059            },
1060            _ => Err(SemanticAnalysisError::InvalidEnumRepr { span: self.span }),
1061        }
1062    }
1063}
1064
1065impl Spanned for Variant {
1066    fn span(&self) -> SourceSpan {
1067        self.span
1068    }
1069}
1070
1071impl Eq for Variant {}
1072
1073impl PartialEq for Variant {
1074    fn eq(&self, other: &Self) -> bool {
1075        self.name == other.name
1076            && self.discriminant == other.discriminant
1077            && self.docs == other.docs
1078    }
1079}
1080
1081impl core::hash::Hash for Variant {
1082    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1083        let Self { span: _, docs, name, discriminant } = self;
1084        docs.hash(state);
1085        name.hash(state);
1086        discriminant.hash(state);
1087    }
1088}
1089
1090impl crate::prettier::PrettyPrint for Variant {
1091    fn render(&self) -> crate::prettier::Document {
1092        use crate::prettier::*;
1093
1094        let doc = self
1095            .docs
1096            .as_ref()
1097            .map(|docstring| docstring.render())
1098            .unwrap_or(Document::Empty);
1099
1100        doc + display(&self.name) + const_text(" = ") + self.discriminant.render()
1101    }
1102}