miden_assembly_syntax/ast/
type.rs

1use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
2
3use miden_debug_types::{SourceManager, SourceSpan, Span, Spanned};
4pub use midenc_hir_type as types;
5use midenc_hir_type::{AddressSpace, Type, TypeRepr};
6
7use super::{
8    ConstantExpr, DocString, GlobalItemIndex, Ident, ItemIndex, Path, SymbolResolution,
9    SymbolResolutionError, Visibility,
10};
11
12/// Abstracts over resolving an item to a concrete [Type], using one of:
13///
14/// * A [GlobalItemIndex]
15/// * An [ItemIndex]
16/// * A [Path]
17/// * A [TypeExpr]
18///
19/// Since type resolution happens in two different contexts during assembly, this abstraction allows
20/// us to share more of the resolution logic in both places.
21pub trait TypeResolver<E> {
22    fn source_manager(&self) -> Arc<dyn SourceManager>;
23    /// Should be called by consumers of this resolver to convert a [SymbolResolutionError] to the
24    /// error type used by the [TypeResolver] implementation.
25    fn resolve_local_failed(&self, err: SymbolResolutionError) -> E;
26    /// Get the [Type] corresponding to the item given by `gid`
27    fn get_type(&self, context: SourceSpan, gid: GlobalItemIndex) -> Result<Type, E>;
28    /// Get the [Type] corresponding to the item in the current module given by `id`
29    fn get_local_type(&self, context: SourceSpan, id: ItemIndex) -> Result<Option<Type>, E>;
30    /// Attempt to resolve a symbol path, given by a `TypeExpr::Ref`, to an item
31    fn resolve_type_ref(&self, ty: Span<&Path>) -> Result<SymbolResolution, E>;
32    /// Resolve a [TypeExpr] to a concrete [Type]
33    fn resolve(&self, ty: &TypeExpr) -> Result<Option<Type>, E> {
34        ty.resolve_type(self)
35    }
36}
37
38// TYPE DECLARATION
39// ================================================================================================
40
41/// An abstraction over the different types of type declarations allowed in Miden Assembly
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum TypeDecl {
44    /// A named type, i.e. a type alias
45    Alias(TypeAlias),
46    /// A C-like enumeration type with associated constants
47    Enum(EnumType),
48}
49
50impl TypeDecl {
51    /// Adds documentation to this type alias
52    pub fn with_docs(self, docs: Option<Span<String>>) -> Self {
53        match self {
54            Self::Alias(ty) => Self::Alias(ty.with_docs(docs)),
55            Self::Enum(ty) => Self::Enum(ty.with_docs(docs)),
56        }
57    }
58
59    /// Get the name assigned to this type declaration
60    pub fn name(&self) -> &Ident {
61        match self {
62            Self::Alias(ty) => &ty.name,
63            Self::Enum(ty) => &ty.name,
64        }
65    }
66
67    /// Get the visibility of this type declaration
68    pub const fn visibility(&self) -> Visibility {
69        match self {
70            Self::Alias(ty) => ty.visibility,
71            Self::Enum(ty) => ty.visibility,
72        }
73    }
74
75    /// Get the documentation of this enum type
76    pub fn docs(&self) -> Option<Span<&str>> {
77        match self {
78            Self::Alias(ty) => ty.docs(),
79            Self::Enum(ty) => ty.docs(),
80        }
81    }
82
83    /// Get the type expression associated with this declaration
84    pub fn ty(&self) -> TypeExpr {
85        match self {
86            Self::Alias(ty) => ty.ty.clone(),
87            Self::Enum(ty) => TypeExpr::Primitive(Span::new(ty.span, ty.ty.clone())),
88        }
89    }
90}
91
92impl Spanned for TypeDecl {
93    fn span(&self) -> SourceSpan {
94        match self {
95            Self::Alias(spanned) => spanned.span,
96            Self::Enum(spanned) => spanned.span,
97        }
98    }
99}
100
101impl From<TypeAlias> for TypeDecl {
102    fn from(value: TypeAlias) -> Self {
103        Self::Alias(value)
104    }
105}
106
107impl From<EnumType> for TypeDecl {
108    fn from(value: EnumType) -> Self {
109        Self::Enum(value)
110    }
111}
112
113impl crate::prettier::PrettyPrint for TypeDecl {
114    fn render(&self) -> crate::prettier::Document {
115        match self {
116            Self::Alias(ty) => ty.render(),
117            Self::Enum(ty) => ty.render(),
118        }
119    }
120}
121
122// FUNCTION TYPE
123// ================================================================================================
124
125/// A procedure type signature
126#[derive(Debug, Clone)]
127pub struct FunctionType {
128    pub span: SourceSpan,
129    pub cc: types::CallConv,
130    pub args: Vec<TypeExpr>,
131    pub results: Vec<TypeExpr>,
132}
133
134impl Eq for FunctionType {}
135
136impl PartialEq for FunctionType {
137    fn eq(&self, other: &Self) -> bool {
138        self.cc == other.cc && self.args == other.args && self.results == other.results
139    }
140}
141
142impl core::hash::Hash for FunctionType {
143    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
144        self.cc.hash(state);
145        self.args.hash(state);
146        self.results.hash(state);
147    }
148}
149
150impl Spanned for FunctionType {
151    fn span(&self) -> SourceSpan {
152        self.span
153    }
154}
155
156impl FunctionType {
157    pub fn new(cc: types::CallConv, args: Vec<TypeExpr>, results: Vec<TypeExpr>) -> Self {
158        Self {
159            span: SourceSpan::UNKNOWN,
160            cc,
161            args,
162            results,
163        }
164    }
165
166    /// Override the default source span
167    #[inline]
168    pub fn with_span(mut self, span: SourceSpan) -> Self {
169        self.span = span;
170        self
171    }
172}
173
174impl crate::prettier::PrettyPrint for FunctionType {
175    fn render(&self) -> crate::prettier::Document {
176        use crate::prettier::*;
177
178        let singleline_args = self
179            .args
180            .iter()
181            .map(|arg| arg.render())
182            .reduce(|acc, arg| acc + const_text(", ") + arg)
183            .unwrap_or(Document::Empty);
184        let multiline_args = indent(
185            4,
186            nl() + self
187                .args
188                .iter()
189                .map(|arg| arg.render())
190                .reduce(|acc, arg| acc + const_text(",") + nl() + arg)
191                .unwrap_or(Document::Empty),
192        ) + nl();
193        let args = singleline_args | multiline_args;
194        let args = const_text("(") + args + const_text(")");
195
196        match self.results.len() {
197            0 => args,
198            1 => args + const_text(" -> ") + self.results[0].render(),
199            _ => {
200                let results = self
201                    .results
202                    .iter()
203                    .map(|r| r.render())
204                    .reduce(|acc, r| acc + const_text(", ") + r)
205                    .unwrap_or(Document::Empty);
206                args + const_text(" -> ") + const_text("(") + results + const_text(")")
207            },
208        }
209    }
210}
211
212// TYPE EXPRESSION
213// ================================================================================================
214
215/// A syntax-level type expression (i.e. primitive type, reference to nominal type, etc.)
216#[derive(Debug, Clone, Eq, PartialEq, Hash)]
217pub enum TypeExpr {
218    /// A primitive integral type, e.g. `i1`, `u16`
219    Primitive(Span<Type>),
220    /// A pointer type expression, e.g. `*u8`
221    Ptr(PointerType),
222    /// An array type expression, e.g. `[u8; 32]`
223    Array(ArrayType),
224    /// A struct type expression, e.g. `struct { a: u32 }`
225    Struct(StructType),
226    /// A reference to a type aliased by name, e.g. `Foo`
227    Ref(Span<Arc<Path>>),
228}
229
230impl TypeExpr {
231    /// Get any references to other types present in this expression
232    pub fn references(&self) -> Vec<Span<Arc<Path>>> {
233        use alloc::collections::BTreeSet;
234
235        let mut worklist = smallvec::SmallVec::<[_; 4]>::from_slice(&[self]);
236        let mut references = BTreeSet::new();
237
238        while let Some(ty) = worklist.pop() {
239            match ty {
240                Self::Primitive(_) => continue,
241                Self::Ptr(ty) => {
242                    worklist.push(&ty.pointee);
243                },
244                Self::Array(ty) => {
245                    worklist.push(&ty.elem);
246                },
247                Self::Struct(ty) => {
248                    for field in ty.fields.iter() {
249                        worklist.push(&field.ty);
250                    }
251                },
252                Self::Ref(ty) => {
253                    references.insert(ty.clone());
254                },
255            }
256        }
257
258        references.into_iter().collect()
259    }
260
261    /// Resolve this type expression to a concrete type, using `resolver`
262    pub fn resolve_type<E, R>(&self, resolver: &R) -> Result<Option<Type>, E>
263    where
264        R: ?Sized + TypeResolver<E>,
265    {
266        match self {
267            TypeExpr::Ref(path) => {
268                let mut current_path = path.clone();
269                loop {
270                    match resolver.resolve_type_ref(current_path.as_deref())? {
271                        SymbolResolution::Local(item) => {
272                            return resolver.get_local_type(current_path.span(), item.into_inner());
273                        },
274                        SymbolResolution::External(path) => {
275                            // We don't have a definition for this type yet
276                            if path == current_path {
277                                break Ok(None);
278                            }
279                            current_path = path;
280                        },
281                        SymbolResolution::Exact { gid, .. } => {
282                            return resolver.get_type(current_path.span(), gid).map(Some);
283                        },
284                        SymbolResolution::Module { path: module_path, .. } => {
285                            break Err(resolver.resolve_local_failed(
286                                SymbolResolutionError::invalid_symbol_type(
287                                    path.span(),
288                                    "type",
289                                    module_path.span(),
290                                    &resolver.source_manager(),
291                                ),
292                            ));
293                        },
294                        SymbolResolution::MastRoot(item) => {
295                            break Err(resolver.resolve_local_failed(
296                                SymbolResolutionError::invalid_symbol_type(
297                                    path.span(),
298                                    "type",
299                                    item.span(),
300                                    &resolver.source_manager(),
301                                ),
302                            ));
303                        },
304                    }
305                }
306            },
307            TypeExpr::Primitive(t) => Ok(Some(t.inner().clone())),
308            TypeExpr::Array(t) => Ok(t
309                .elem
310                .resolve_type(resolver)?
311                .map(|elem| types::Type::Array(Arc::new(types::ArrayType::new(elem, t.arity))))),
312            TypeExpr::Ptr(ty) => Ok(ty
313                .pointee
314                .resolve_type(resolver)?
315                .map(|pointee| types::Type::Ptr(Arc::new(types::PointerType::new(pointee))))),
316            TypeExpr::Struct(t) => {
317                let mut fields = Vec::with_capacity(t.fields.len());
318                for field in t.fields.iter() {
319                    let field_ty = field.ty.resolve_type(resolver)?;
320                    if let Some(field_ty) = field_ty {
321                        fields.push(field_ty);
322                    } else {
323                        return Ok(None);
324                    }
325                }
326                Ok(Some(Type::Struct(Arc::new(types::StructType::new(fields)))))
327            },
328        }
329    }
330}
331
332impl From<Type> for TypeExpr {
333    fn from(ty: Type) -> Self {
334        match ty {
335            Type::Array(t) => Self::Array(ArrayType::new(t.element_type().clone().into(), t.len())),
336            Type::Struct(t) => {
337                Self::Struct(StructType::new(t.fields().iter().enumerate().map(|(i, ft)| {
338                    let name = Ident::new(format!("field{i}")).unwrap();
339                    StructField {
340                        span: SourceSpan::UNKNOWN,
341                        name,
342                        ty: ft.ty.clone().into(),
343                    }
344                })))
345            },
346            Type::Ptr(t) => Self::Ptr((*t).clone().into()),
347            Type::Function(_) => {
348                Self::Ptr(PointerType::new(TypeExpr::Primitive(Span::unknown(Type::Felt))))
349            },
350            Type::List(t) => Self::Ptr(
351                PointerType::new((*t).clone().into()).with_address_space(AddressSpace::Byte),
352            ),
353            Type::I128 | Type::U128 => Self::Array(ArrayType::new(Type::U32.into(), 4)),
354            Type::I64 | Type::U64 => Self::Array(ArrayType::new(Type::U32.into(), 2)),
355            Type::Unknown | Type::Never | Type::F64 => panic!("unrepresentable type value: {ty}"),
356            ty => Self::Primitive(Span::unknown(ty)),
357        }
358    }
359}
360
361impl Spanned for TypeExpr {
362    fn span(&self) -> SourceSpan {
363        match self {
364            Self::Primitive(spanned) => spanned.span(),
365            Self::Ptr(spanned) => spanned.span(),
366            Self::Array(spanned) => spanned.span(),
367            Self::Struct(spanned) => spanned.span(),
368            Self::Ref(spanned) => spanned.span(),
369        }
370    }
371}
372
373impl crate::prettier::PrettyPrint for TypeExpr {
374    fn render(&self) -> crate::prettier::Document {
375        use crate::prettier::*;
376
377        match self {
378            Self::Primitive(ty) => display(ty),
379            Self::Ptr(ty) => ty.render(),
380            Self::Array(ty) => ty.render(),
381            Self::Struct(ty) => ty.render(),
382            Self::Ref(ty) => display(ty),
383        }
384    }
385}
386
387// POINTER TYPE
388// ================================================================================================
389
390#[derive(Debug, Clone)]
391pub struct PointerType {
392    pub span: SourceSpan,
393    pub pointee: Box<TypeExpr>,
394    addrspace: Option<AddressSpace>,
395}
396
397impl From<types::PointerType> for PointerType {
398    fn from(ty: types::PointerType) -> Self {
399        let types::PointerType { addrspace, pointee } = ty;
400        let pointee = Box::new(TypeExpr::from(pointee));
401        Self {
402            span: SourceSpan::UNKNOWN,
403            pointee,
404            addrspace: Some(addrspace),
405        }
406    }
407}
408
409impl Eq for PointerType {}
410
411impl PartialEq for PointerType {
412    fn eq(&self, other: &Self) -> bool {
413        self.address_space() == other.address_space() && self.pointee == other.pointee
414    }
415}
416
417impl core::hash::Hash for PointerType {
418    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
419        self.pointee.hash(state);
420        self.addrspace.hash(state);
421    }
422}
423
424impl Spanned for PointerType {
425    fn span(&self) -> SourceSpan {
426        self.span
427    }
428}
429
430impl PointerType {
431    pub fn new(pointee: TypeExpr) -> Self {
432        Self {
433            span: SourceSpan::UNKNOWN,
434            pointee: Box::new(pointee),
435            addrspace: None,
436        }
437    }
438
439    /// Override the default source span
440    #[inline]
441    pub fn with_span(mut self, span: SourceSpan) -> Self {
442        self.span = span;
443        self
444    }
445
446    /// Override the default address space
447    #[inline]
448    pub fn with_address_space(mut self, addrspace: AddressSpace) -> Self {
449        self.addrspace = Some(addrspace);
450        self
451    }
452
453    /// Get the address space of this pointer type
454    #[inline]
455    pub fn address_space(&self) -> AddressSpace {
456        self.addrspace.unwrap_or(AddressSpace::Element)
457    }
458}
459
460impl crate::prettier::PrettyPrint for PointerType {
461    fn render(&self) -> crate::prettier::Document {
462        use crate::prettier::*;
463
464        let doc = const_text("ptr<") + self.pointee.render();
465        if let Some(addrspace) = self.addrspace.as_ref() {
466            doc + const_text(", ") + text(format!("addrspace({})", addrspace)) + const_text(">")
467        } else {
468            doc + const_text(">")
469        }
470    }
471}
472
473// ARRAY TYPE
474// ================================================================================================
475
476#[derive(Debug, Clone)]
477pub struct ArrayType {
478    pub span: SourceSpan,
479    pub elem: Box<TypeExpr>,
480    pub arity: usize,
481}
482
483impl Eq for ArrayType {}
484
485impl PartialEq for ArrayType {
486    fn eq(&self, other: &Self) -> bool {
487        self.arity == other.arity && self.elem == other.elem
488    }
489}
490
491impl core::hash::Hash for ArrayType {
492    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
493        self.elem.hash(state);
494        self.arity.hash(state);
495    }
496}
497
498impl Spanned for ArrayType {
499    fn span(&self) -> SourceSpan {
500        self.span
501    }
502}
503
504impl ArrayType {
505    pub fn new(elem: TypeExpr, arity: usize) -> Self {
506        Self {
507            span: SourceSpan::UNKNOWN,
508            elem: Box::new(elem),
509            arity,
510        }
511    }
512
513    /// Override the default source span
514    #[inline]
515    pub fn with_span(mut self, span: SourceSpan) -> Self {
516        self.span = span;
517        self
518    }
519}
520
521impl crate::prettier::PrettyPrint for ArrayType {
522    fn render(&self) -> crate::prettier::Document {
523        use crate::prettier::*;
524
525        const_text("[")
526            + self.elem.render()
527            + const_text("; ")
528            + display(self.arity)
529            + const_text("]")
530    }
531}
532
533// STRUCT TYPE
534// ================================================================================================
535
536#[derive(Debug, Clone)]
537pub struct StructType {
538    pub span: SourceSpan,
539    pub repr: Span<TypeRepr>,
540    pub fields: Vec<StructField>,
541}
542
543impl Eq for StructType {}
544
545impl PartialEq for StructType {
546    fn eq(&self, other: &Self) -> bool {
547        self.repr == other.repr && self.fields == other.fields
548    }
549}
550
551impl core::hash::Hash for StructType {
552    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
553        self.repr.hash(state);
554        self.fields.hash(state);
555    }
556}
557
558impl Spanned for StructType {
559    fn span(&self) -> SourceSpan {
560        self.span
561    }
562}
563
564impl StructType {
565    pub fn new(fields: impl IntoIterator<Item = StructField>) -> Self {
566        Self {
567            span: SourceSpan::UNKNOWN,
568            repr: Span::unknown(TypeRepr::Default),
569            fields: fields.into_iter().collect(),
570        }
571    }
572
573    /// Override the default struct representation
574    #[inline]
575    pub fn with_repr(mut self, repr: Span<TypeRepr>) -> Self {
576        self.repr = repr;
577        self
578    }
579
580    /// Override the default source span
581    #[inline]
582    pub fn with_span(mut self, span: SourceSpan) -> Self {
583        self.span = span;
584        self
585    }
586}
587
588impl crate::prettier::PrettyPrint for StructType {
589    fn render(&self) -> crate::prettier::Document {
590        use crate::prettier::*;
591
592        let repr = match &*self.repr {
593            TypeRepr::Default => Document::Empty,
594            TypeRepr::BigEndian => const_text("@bigendian "),
595            repr @ (TypeRepr::Align(_) | TypeRepr::Packed(_) | TypeRepr::Transparent) => {
596                text(format!("@{repr} "))
597            },
598        };
599
600        let singleline_body = self
601            .fields
602            .iter()
603            .map(|field| field.render())
604            .reduce(|acc, field| acc + const_text(", ") + field)
605            .unwrap_or(Document::Empty);
606        let multiline_body = indent(
607            4,
608            nl() + self
609                .fields
610                .iter()
611                .map(|field| field.render())
612                .reduce(|acc, field| acc + const_text(",") + nl() + field)
613                .unwrap_or(Document::Empty),
614        ) + nl();
615        let body = singleline_body | multiline_body;
616
617        repr + const_text("struct") + const_text(" { ") + body + const_text(" }")
618    }
619}
620
621// STRUCT FIELD
622// ================================================================================================
623
624#[derive(Debug, Clone)]
625pub struct StructField {
626    pub span: SourceSpan,
627    pub name: Ident,
628    pub ty: TypeExpr,
629}
630
631impl Eq for StructField {}
632
633impl PartialEq for StructField {
634    fn eq(&self, other: &Self) -> bool {
635        self.name == other.name && self.ty == other.ty
636    }
637}
638
639impl core::hash::Hash for StructField {
640    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
641        self.name.hash(state);
642        self.ty.hash(state);
643    }
644}
645
646impl Spanned for StructField {
647    fn span(&self) -> SourceSpan {
648        self.span
649    }
650}
651
652impl crate::prettier::PrettyPrint for StructField {
653    fn render(&self) -> crate::prettier::Document {
654        use crate::prettier::*;
655
656        display(&self.name) + const_text(": ") + self.ty.render()
657    }
658}
659
660// TYPE ALIAS
661// ================================================================================================
662
663/// A [TypeAlias] represents a named [Type].
664///
665/// Type aliases correspond to type declarations in Miden Assembly source files. They are called
666/// aliases, rather than declarations, as the type system for Miden Assembly is structural, rather
667/// than nominal, and so two aliases with the same underlying type are considered equivalent.
668#[derive(Debug, Clone)]
669pub struct TypeAlias {
670    span: SourceSpan,
671    /// The documentation string attached to this definition.
672    docs: Option<DocString>,
673    /// The visibility of this type alias
674    pub visibility: Visibility,
675    /// The name of this type alias
676    pub name: Ident,
677    /// The concrete underlying type
678    pub ty: TypeExpr,
679}
680
681impl TypeAlias {
682    /// Create a new type alias from a name and type
683    pub fn new(visibility: Visibility, name: Ident, ty: TypeExpr) -> Self {
684        Self {
685            span: name.span(),
686            docs: None,
687            visibility,
688            name,
689            ty,
690        }
691    }
692
693    /// Adds documentation to this type alias
694    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
695        self.docs = docs.map(DocString::new);
696        self
697    }
698
699    /// Override the default source span
700    #[inline]
701    pub fn with_span(mut self, span: SourceSpan) -> Self {
702        self.span = span;
703        self
704    }
705
706    /// Set the source span
707    #[inline]
708    pub fn set_span(&mut self, span: SourceSpan) {
709        self.span = span;
710    }
711
712    /// Returns the documentation associated with this item.
713    pub fn docs(&self) -> Option<Span<&str>> {
714        self.docs.as_ref().map(|docstring| docstring.as_spanned_str())
715    }
716
717    /// Get the name of this type alias
718    pub fn name(&self) -> &Ident {
719        &self.name
720    }
721
722    /// Get the visibility of this type alias
723    #[inline]
724    pub const fn visibility(&self) -> Visibility {
725        self.visibility
726    }
727}
728
729impl Eq for TypeAlias {}
730
731impl PartialEq for TypeAlias {
732    fn eq(&self, other: &Self) -> bool {
733        self.name == other.name && self.docs == other.docs && self.ty == other.ty
734    }
735}
736
737impl core::hash::Hash for TypeAlias {
738    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
739        let Self { span: _, docs, visibility, name, ty } = self;
740        docs.hash(state);
741        visibility.hash(state);
742        name.hash(state);
743        ty.hash(state);
744    }
745}
746
747impl Spanned for TypeAlias {
748    fn span(&self) -> SourceSpan {
749        self.span
750    }
751}
752
753impl crate::prettier::PrettyPrint for TypeAlias {
754    fn render(&self) -> crate::prettier::Document {
755        use crate::prettier::*;
756
757        let mut doc = self
758            .docs
759            .as_ref()
760            .map(|docstring| docstring.render())
761            .unwrap_or(Document::Empty);
762
763        if self.visibility.is_public() {
764            doc += display(self.visibility) + const_text(" ");
765        }
766
767        doc + const_text("type")
768            + const_text(" ")
769            + display(&self.name)
770            + const_text(" = ")
771            + self.ty.render()
772    }
773}
774
775// ENUM TYPE
776// ================================================================================================
777
778/// A combined type alias and constant declaration corresponding to a C-like enumeration.
779///
780/// C-style enumerations are effectively a type alias for an integer type with a limited set of
781/// valid values with associated names (referred to as _variants_ of the enum type).
782///
783/// In Miden Assembly, these provide a means for a procedure to declare that it expects an argument
784/// of the underlying integral type, but that values other than those of the declared variants are
785/// illegal/invalid. Currently, these are unchecked, and are only used to convey semantic
786/// information. In the future, we may perform static analysis to try and identify invalid instances
787/// of the enumeration when derived from a constant.
788#[derive(Debug, Clone)]
789pub struct EnumType {
790    span: SourceSpan,
791    /// The documentation string attached to this definition.
792    docs: Option<DocString>,
793    /// The visibility of this enum type
794    visibility: Visibility,
795    /// The enum name
796    name: Ident,
797    /// The type of the discriminant value used for this enum's variants
798    ///
799    /// NOTE: The type must be an integral value, and this is enforced by [`Self::new`].
800    ty: Type,
801    /// The enum variants
802    variants: Vec<Variant>,
803}
804
805impl EnumType {
806    /// Construct a new enum type with the given name and variants
807    ///
808    /// The caller is assumed to have already validated that `ty` is an integral type, and this
809    /// function will assert that this is the case.
810    pub fn new(
811        visibility: Visibility,
812        name: Ident,
813        ty: Type,
814        variants: impl IntoIterator<Item = Variant>,
815    ) -> Self {
816        assert!(ty.is_integer(), "only integer types are allowed in enum type definitions");
817        Self {
818            span: name.span(),
819            docs: None,
820            visibility,
821            name,
822            ty,
823            variants: Vec::from_iter(variants),
824        }
825    }
826
827    /// Adds documentation to this enum declaration.
828    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
829        self.docs = docs.map(DocString::new);
830        self
831    }
832
833    /// Override the default source span
834    pub fn with_span(mut self, span: SourceSpan) -> Self {
835        self.span = span;
836        self
837    }
838
839    /// Set the source span
840    pub fn set_span(&mut self, span: SourceSpan) {
841        self.span = span;
842    }
843
844    /// Get the name of this enum type
845    pub fn name(&self) -> &Ident {
846        &self.name
847    }
848
849    /// Get the visibility of this enum type
850    pub const fn visibility(&self) -> Visibility {
851        self.visibility
852    }
853
854    /// Returns the documentation associated with this item.
855    pub fn docs(&self) -> Option<Span<&str>> {
856        self.docs.as_ref().map(|docstring| docstring.as_spanned_str())
857    }
858
859    /// Get the concrete type of this enum's variants
860    pub fn ty(&self) -> &Type {
861        &self.ty
862    }
863
864    /// Get the variants of this enum type
865    pub fn variants(&self) -> &[Variant] {
866        &self.variants
867    }
868
869    /// Get the variants of this enum type, mutably
870    pub fn variants_mut(&mut self) -> &mut Vec<Variant> {
871        &mut self.variants
872    }
873
874    /// Split this definition into its type alias and variant parts
875    pub fn into_parts(self) -> (TypeAlias, Vec<Variant>) {
876        let Self {
877            span,
878            docs,
879            visibility,
880            name,
881            ty,
882            variants,
883        } = self;
884        let alias = TypeAlias {
885            span,
886            docs,
887            visibility,
888            name,
889            ty: TypeExpr::Primitive(Span::new(span, ty)),
890        };
891        (alias, variants)
892    }
893}
894
895impl Spanned for EnumType {
896    fn span(&self) -> SourceSpan {
897        self.span
898    }
899}
900
901impl Eq for EnumType {}
902
903impl PartialEq for EnumType {
904    fn eq(&self, other: &Self) -> bool {
905        self.name == other.name
906            && self.docs == other.docs
907            && self.ty == other.ty
908            && self.variants == other.variants
909    }
910}
911
912impl core::hash::Hash for EnumType {
913    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
914        let Self {
915            span: _,
916            docs,
917            visibility,
918            name,
919            ty,
920            variants,
921        } = self;
922        docs.hash(state);
923        visibility.hash(state);
924        name.hash(state);
925        ty.hash(state);
926        variants.hash(state);
927    }
928}
929
930impl crate::prettier::PrettyPrint for EnumType {
931    fn render(&self) -> crate::prettier::Document {
932        use crate::prettier::*;
933
934        let mut doc = self
935            .docs
936            .as_ref()
937            .map(|docstring| docstring.render())
938            .unwrap_or(Document::Empty);
939
940        let variants = self
941            .variants
942            .iter()
943            .map(|v| v.render())
944            .reduce(|acc, v| acc + const_text(",") + nl() + v)
945            .unwrap_or(Document::Empty);
946
947        if self.visibility.is_public() {
948            doc += display(self.visibility) + const_text(" ");
949        }
950
951        doc + const_text("enum")
952            + const_text(" ")
953            + display(&self.name)
954            + const_text(" : ")
955            + self.ty.render()
956            + const_text(" {")
957            + nl()
958            + variants
959            + const_text("}")
960    }
961}
962
963// ENUM VARIANT
964// ================================================================================================
965
966/// A variant of an [EnumType].
967///
968/// See the [EnumType] docs for more information.
969#[derive(Debug, Clone)]
970pub struct Variant {
971    pub span: SourceSpan,
972    /// The documentation string attached to the constant derived from this variant.
973    pub docs: Option<DocString>,
974    /// The name of this enum variant
975    pub name: Ident,
976    /// The discriminant value associated with this variant
977    pub discriminant: ConstantExpr,
978}
979
980impl Variant {
981    /// Construct a new variant of an [EnumType], with the given name and discriminant value.
982    pub fn new(name: Ident, discriminant: ConstantExpr) -> Self {
983        Self {
984            span: name.span(),
985            docs: None,
986            name,
987            discriminant,
988        }
989    }
990
991    /// Override the span for this variant
992    pub fn with_span(mut self, span: SourceSpan) -> Self {
993        self.span = span;
994        self
995    }
996
997    /// Adds documentation to this variant
998    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
999        self.docs = docs.map(DocString::new);
1000        self
1001    }
1002
1003    /// Used to validate that this variant's discriminant value is an instance of `ty`,
1004    /// which must be a type valid for use as the underlying representation for an enum, i.e. an
1005    /// integer type up to 64 bits in size.
1006    ///
1007    /// It is expected that the discriminant expression has been folded to an integer value by the
1008    /// time this is called. If the discriminant has not been fully folded, then an error will be
1009    /// returned.
1010    pub fn assert_instance_of(&self, ty: &Type) -> Result<(), crate::SemanticAnalysisError> {
1011        use crate::{FIELD_MODULUS, SemanticAnalysisError};
1012
1013        let value = match &self.discriminant {
1014            ConstantExpr::Int(value) => value.as_int(),
1015            _ => {
1016                return Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1017                    span: self.discriminant.span(),
1018                    repr: ty.clone(),
1019                });
1020            },
1021        };
1022
1023        match ty {
1024            Type::I1 if value > 1 => Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1025                span: self.discriminant.span(),
1026                repr: ty.clone(),
1027            }),
1028            Type::I1 => Ok(()),
1029            Type::I8 | Type::U8 if value > u8::MAX as u64 => {
1030                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1031                    span: self.discriminant.span(),
1032                    repr: ty.clone(),
1033                })
1034            },
1035            Type::I8 | Type::U8 => Ok(()),
1036            Type::I16 | Type::U16 if value > u16::MAX as u64 => {
1037                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1038                    span: self.discriminant.span(),
1039                    repr: ty.clone(),
1040                })
1041            },
1042            Type::I16 | Type::U16 => Ok(()),
1043            Type::I32 | Type::U32 if value > u32::MAX as u64 => {
1044                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1045                    span: self.discriminant.span(),
1046                    repr: ty.clone(),
1047                })
1048            },
1049            Type::I32 | Type::U32 => Ok(()),
1050            Type::I64 | Type::U64 if value >= FIELD_MODULUS => {
1051                Err(SemanticAnalysisError::InvalidEnumDiscriminant {
1052                    span: self.discriminant.span(),
1053                    repr: ty.clone(),
1054                })
1055            },
1056            _ => Err(SemanticAnalysisError::InvalidEnumRepr { span: self.span }),
1057        }
1058    }
1059}
1060
1061impl Spanned for Variant {
1062    fn span(&self) -> SourceSpan {
1063        self.span
1064    }
1065}
1066
1067impl Eq for Variant {}
1068
1069impl PartialEq for Variant {
1070    fn eq(&self, other: &Self) -> bool {
1071        self.name == other.name
1072            && self.discriminant == other.discriminant
1073            && self.docs == other.docs
1074    }
1075}
1076
1077impl core::hash::Hash for Variant {
1078    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1079        let Self { span: _, docs, name, discriminant } = self;
1080        docs.hash(state);
1081        name.hash(state);
1082        discriminant.hash(state);
1083    }
1084}
1085
1086impl crate::prettier::PrettyPrint for Variant {
1087    fn render(&self) -> crate::prettier::Document {
1088        use crate::prettier::*;
1089
1090        let doc = self
1091            .docs
1092            .as_ref()
1093            .map(|docstring| docstring.render())
1094            .unwrap_or(Document::Empty);
1095
1096        doc + display(&self.name) + const_text(" = ") + self.discriminant.render()
1097    }
1098}