Skip to main content

cairo_lang_semantic/expr/
pattern.rs

1use cairo_lang_debug::DebugWithDb;
2use cairo_lang_defs::ids::FunctionWithBodyId;
3use cairo_lang_diagnostics::DiagnosticAdded;
4use cairo_lang_filesystem::ids::SmolStrId;
5use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
6use cairo_lang_syntax::node::ast;
7use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
8use salsa::Database;
9
10use super::fmt::ExprFormatter;
11use crate::corelib::core_box_ty;
12use crate::items::function_with_body::FunctionWithBodySemantic;
13use crate::types::wrap_in_snapshots;
14use crate::{
15    ConcreteStructId, ExprNumericLiteral, ExprStringLiteral, LocalVariable, PatternArena,
16    PatternId, semantic,
17};
18
19/// Semantic representation of a Pattern.
20///
21/// A pattern is a way to "destructure" values. A pattern may introduce new variables that are bound
22/// to inner values of a specific value. For example, a tuple pattern destructures a tuple
23/// and may result in new variables for an elements of that tuple.
24/// This is used both in let statements and match statements.
25// TODO(spapini): Replace this doc with a reference to the language documentation about patterns,
26// once it is available.
27#[derive(Clone, Debug, PartialEq, Eq, DebugWithDb, SemanticObject)]
28#[debug_db(ExprFormatter<'db>)]
29pub enum Pattern<'db> {
30    Literal(PatternLiteral<'db>),
31    StringLiteral(PatternStringLiteral<'db>),
32    Variable(PatternVariable<'db>),
33    Struct(PatternStruct<'db>),
34    Tuple(PatternTuple<'db>),
35    FixedSizeArray(PatternFixedSizeArray<'db>),
36    EnumVariant(PatternEnumVariant<'db>),
37    Otherwise(PatternOtherwise<'db>),
38    Missing(PatternMissing<'db>),
39}
40
41impl<'db> Pattern<'db> {
42    pub fn ty(&self) -> semantic::TypeId<'db> {
43        match self {
44            Pattern::Literal(literal) => literal.literal.ty,
45            Pattern::StringLiteral(string_literal) => string_literal.string_literal.ty,
46            Pattern::Variable(variable) => variable.var.ty,
47            Pattern::Struct(pattern_struct) => pattern_struct.ty,
48            Pattern::Tuple(pattern_tuple) => pattern_tuple.ty,
49            Pattern::FixedSizeArray(pattern_fixed_size_array) => pattern_fixed_size_array.ty,
50            Pattern::EnumVariant(pattern_enum_variant) => pattern_enum_variant.ty,
51            Pattern::Otherwise(pattern_otherwise) => pattern_otherwise.ty,
52            Pattern::Missing(pattern_missing) => pattern_missing.ty,
53        }
54    }
55
56    pub fn variables(
57        &self,
58        queryable: &dyn PatternVariablesQueryable<'db>,
59    ) -> Vec<PatternVariable<'db>> {
60        match self {
61            Pattern::Variable(variable) => vec![variable.clone()],
62            Pattern::Struct(pattern_struct) => pattern_struct
63                .field_patterns
64                .iter()
65                .flat_map(|(pattern, _member)| queryable.query(*pattern))
66                .collect(),
67            Pattern::Tuple(pattern_tuple) => pattern_tuple
68                .field_patterns
69                .iter()
70                .flat_map(|pattern| queryable.query(*pattern))
71                .collect(),
72            Pattern::FixedSizeArray(pattern_fixed_size_array) => pattern_fixed_size_array
73                .elements_patterns
74                .iter()
75                .flat_map(|pattern| queryable.query(*pattern))
76                .collect(),
77            Pattern::EnumVariant(pattern_enum_variant) => {
78                match &pattern_enum_variant.inner_pattern {
79                    Some(pattern) => queryable.query(*pattern),
80                    None => vec![],
81                }
82            }
83            Pattern::Literal(_)
84            | Pattern::StringLiteral(_)
85            | Pattern::Otherwise(_)
86            | Pattern::Missing(_) => vec![],
87        }
88    }
89
90    pub fn stable_ptr(&self) -> ast::PatternPtr<'db> {
91        match self {
92            Pattern::Literal(pattern) => pattern.stable_ptr,
93            Pattern::StringLiteral(pattern) => pattern.stable_ptr,
94            Pattern::Variable(pattern) => pattern.stable_ptr,
95            Pattern::Struct(pattern) => pattern.stable_ptr.into(),
96            Pattern::Tuple(pattern) => pattern.stable_ptr.into(),
97            Pattern::FixedSizeArray(pattern) => pattern.stable_ptr.into(),
98            Pattern::EnumVariant(pattern) => pattern.stable_ptr,
99            Pattern::Otherwise(pattern) => pattern.stable_ptr.into(),
100            Pattern::Missing(pattern) => pattern.stable_ptr,
101        }
102    }
103}
104
105impl<'db> From<&Pattern<'db>> for SyntaxStablePtrId<'db> {
106    fn from(pattern: &Pattern<'db>) -> Self {
107        pattern.stable_ptr().into()
108    }
109}
110
111/// Information about how a type is wrapped for pattern matching.
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub struct PatternWrappingInfo {
114    /// Number of outer snapshot wrappers (e.g., `@Box<Enum>` has 1 outer snapshot).
115    pub n_outer_snapshots: usize,
116    /// If the type is wrapped in a Box, contains the number of inner snapshots
117    /// (e.g., `Box<@Enum>` has `Some(1)`, `Box<Enum>` has `Some(0)`, non-boxed has `None`).
118    pub n_boxed_inner_snapshots: Option<usize>,
119}
120impl PatternWrappingInfo {
121    /// Wraps a type according to the wrapping information.
122    /// First wraps with inner snapshots (if boxed), then wraps in a Box (if boxed),
123    /// and finally wraps with outer snapshots.
124    pub fn wrap<'db>(&self, db: &'db dyn Database, ty: crate::TypeId<'db>) -> crate::TypeId<'db> {
125        wrap_in_snapshots(
126            db,
127            if let Some(n_inner_snapshots) = self.n_boxed_inner_snapshots {
128                core_box_ty(db, wrap_in_snapshots(db, ty, n_inner_snapshots))
129            } else {
130                ty
131            },
132            self.n_outer_snapshots,
133        )
134    }
135}
136
137/// Polymorphic container of [`Pattern`] objects used for querying pattern variables.
138pub trait PatternVariablesQueryable<'a> {
139    /// Lookup the pattern in this container and then get [`Pattern::variables`] from it.
140    fn query(&self, id: PatternId) -> Vec<PatternVariable<'a>>;
141}
142
143impl<'a> PatternVariablesQueryable<'a> for PatternArena<'a> {
144    fn query(&self, id: PatternId) -> Vec<PatternVariable<'a>> {
145        self[id].variables(self)
146    }
147}
148
149/// Query a function for variables of patterns defined within it.
150///
151/// This is a wrapper over [`Database`] that takes [`FunctionWithBodyId`]
152/// and relays queries to [`FunctionWithBodySemantic::pattern_semantic`].
153pub struct QueryPatternVariablesFromDb<'a>(
154    pub &'a (dyn Database + 'static),
155    pub FunctionWithBodyId<'a>,
156);
157
158impl<'a> PatternVariablesQueryable<'a> for QueryPatternVariablesFromDb<'a> {
159    fn query(&self, id: PatternId) -> Vec<PatternVariable<'a>> {
160        let pattern: Pattern<'a> = self.0.pattern_semantic(self.1, id);
161        pattern.variables(self)
162    }
163}
164
165#[derive(Clone, Debug, Hash, PartialEq, Eq, DebugWithDb, SemanticObject)]
166#[debug_db(ExprFormatter<'db>)]
167pub struct PatternLiteral<'db> {
168    pub literal: ExprNumericLiteral<'db>,
169    #[hide_field_debug_with_db]
170    #[dont_rewrite]
171    pub stable_ptr: ast::PatternPtr<'db>,
172}
173
174#[derive(Clone, Debug, Hash, PartialEq, Eq, DebugWithDb, SemanticObject)]
175#[debug_db(ExprFormatter<'db>)]
176pub struct PatternStringLiteral<'db> {
177    pub string_literal: ExprStringLiteral<'db>,
178    #[hide_field_debug_with_db]
179    #[dont_rewrite]
180    pub stable_ptr: ast::PatternPtr<'db>,
181}
182
183/// A pattern that binds the matched value to a variable.
184#[derive(Clone, Debug, Hash, PartialEq, Eq, SemanticObject)]
185pub struct PatternVariable<'db> {
186    #[dont_rewrite]
187    pub name: SmolStrId<'db>,
188    pub var: LocalVariable<'db>,
189    #[dont_rewrite]
190    pub stable_ptr: ast::PatternPtr<'db>,
191}
192impl<'db> DebugWithDb<'db> for PatternVariable<'db> {
193    type Db = ExprFormatter<'db>;
194
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &'db ExprFormatter<'_>) -> std::fmt::Result {
196        write!(f, "{}", self.name.long(db.db))
197    }
198}
199
200/// A pattern that destructures a struct to its fields.
201#[derive(Clone, Debug, PartialEq, Eq, DebugWithDb, SemanticObject)]
202#[debug_db(ExprFormatter<'db>)]
203pub struct PatternStruct<'db> {
204    pub concrete_struct_id: ConcreteStructId<'db>,
205    // TODO(spapini): This should be ConcreteMember, when available.
206    pub field_patterns: Vec<(PatternId, semantic::Member<'db>)>,
207    pub ty: semantic::TypeId<'db>,
208    #[dont_rewrite]
209    pub wrapping_info: PatternWrappingInfo,
210    #[hide_field_debug_with_db]
211    #[dont_rewrite]
212    pub stable_ptr: ast::PatternStructPtr<'db>,
213}
214
215/// A pattern that destructures a tuple to its fields.
216#[derive(Clone, Debug, Hash, PartialEq, Eq, DebugWithDb, SemanticObject)]
217#[debug_db(ExprFormatter<'db>)]
218pub struct PatternTuple<'db> {
219    pub field_patterns: Vec<PatternId>,
220    pub ty: semantic::TypeId<'db>,
221    #[hide_field_debug_with_db]
222    #[dont_rewrite]
223    pub stable_ptr: ast::PatternTuplePtr<'db>,
224}
225
226/// A pattern that destructures a fixed size array into its elements.
227#[derive(Clone, Debug, Hash, PartialEq, Eq, DebugWithDb, SemanticObject)]
228#[debug_db(ExprFormatter<'db>)]
229pub struct PatternFixedSizeArray<'db> {
230    pub elements_patterns: Vec<PatternId>,
231    pub ty: semantic::TypeId<'db>,
232    #[hide_field_debug_with_db]
233    #[dont_rewrite]
234    pub stable_ptr: ast::PatternFixedSizeArrayPtr<'db>,
235}
236
237/// A pattern that destructures a specific variant of an enum to its inner value.
238#[derive(Clone, Debug, Hash, PartialEq, Eq, DebugWithDb, SemanticObject)]
239#[debug_db(ExprFormatter<'db>)]
240pub struct PatternEnumVariant<'db> {
241    pub variant: semantic::ConcreteVariant<'db>,
242    pub inner_pattern: Option<PatternId>,
243    pub ty: semantic::TypeId<'db>,
244    #[hide_field_debug_with_db]
245    #[dont_rewrite]
246    pub stable_ptr: ast::PatternPtr<'db>,
247}
248
249#[derive(Clone, Debug, Hash, PartialEq, Eq, DebugWithDb, SemanticObject)]
250#[debug_db(ExprFormatter<'db>)]
251pub struct PatternOtherwise<'db> {
252    pub ty: semantic::TypeId<'db>,
253    #[hide_field_debug_with_db]
254    #[dont_rewrite]
255    pub stable_ptr: ast::TerminalUnderscorePtr<'db>,
256}
257
258#[derive(Clone, Debug, Hash, PartialEq, Eq, DebugWithDb, SemanticObject)]
259#[debug_db(ExprFormatter<'db>)]
260pub struct PatternMissing<'db> {
261    pub ty: semantic::TypeId<'db>,
262    #[hide_field_debug_with_db]
263    #[dont_rewrite]
264    pub stable_ptr: ast::PatternPtr<'db>,
265    #[hide_field_debug_with_db]
266    #[dont_rewrite]
267    pub diag_added: DiagnosticAdded,
268}