cairo_lang_semantic/items/
constant.rs

1use std::iter::zip;
2use std::sync::Arc;
3
4use cairo_lang_debug::DebugWithDb;
5use cairo_lang_defs::db::DefsGroup;
6use cairo_lang_defs::ids::{
7    ConstantId, ExternFunctionId, GenericParamId, LanguageElementId, LookupItemId, ModuleItemId,
8    NamedLanguageElementId, TopLevelLanguageElementId, TraitConstantId, TraitId, VarId,
9};
10use cairo_lang_diagnostics::{
11    DiagnosticAdded, DiagnosticEntry, DiagnosticNote, Diagnostics, Maybe, MaybeAsRef,
12    skip_diagnostic,
13};
14use cairo_lang_proc_macros::{DebugWithDb, HeapSize, SemanticObject};
15use cairo_lang_syntax::node::ast::ItemConstant;
16use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
17use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode};
18use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
19use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
20use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
21use cairo_lang_utils::{Intern, define_short_id, extract_matches, require, try_extract_matches};
22use itertools::Itertools;
23use num_bigint::BigInt;
24use num_traits::{ToPrimitive, Zero};
25use salsa::Database;
26use starknet_types_core::felt::{CAIRO_PRIME_BIGINT, Felt as Felt252};
27
28use super::functions::{GenericFunctionId, GenericFunctionWithBodyId};
29use super::imp::{ImplId, ImplLongId};
30use crate::corelib::{
31    CoreInfo, CorelibSemantic, core_nonzero_ty, false_variant, true_variant,
32    try_extract_bounded_int_type_ranges, try_extract_nz_wrapped_type, unit_ty, validate_literal,
33};
34use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
35use crate::expr::compute::{ComputationContext, ExprAndId, compute_expr_semantic};
36use crate::expr::inference::conform::InferenceConform;
37use crate::expr::inference::{ConstVar, InferenceId};
38use crate::helper::ModuleHelper;
39use crate::items::enm::SemanticEnumEx;
40use crate::items::extern_function::ExternFunctionSemantic;
41use crate::items::free_function::FreeFunctionSemantic;
42use crate::items::function_with_body::FunctionWithBodySemantic;
43use crate::items::generics::GenericParamSemantic;
44use crate::items::imp::ImplSemantic;
45use crate::items::structure::StructSemantic;
46use crate::items::trt::TraitSemantic;
47use crate::resolve::{Resolver, ResolverData};
48use crate::substitution::{GenericSubstitution, SemanticRewriter};
49use crate::types::resolve_type;
50use crate::{
51    Arenas, ConcreteFunction, ConcreteTypeId, ConcreteVariant, Condition, Expr, ExprBlock,
52    ExprConstant, ExprFunctionCall, ExprFunctionCallArg, ExprId, ExprMemberAccess, ExprStructCtor,
53    FunctionId, GenericParam, LogicalOperator, Pattern, PatternId, SemanticDiagnostic, Statement,
54    TypeId, TypeLongId, semantic_object_for_id,
55};
56
57#[derive(Clone, Debug, PartialEq, Eq, DebugWithDb)]
58#[debug_db(dyn Database)]
59pub struct Constant<'db> {
60    /// The actual id of the const expression value.
61    pub value: ExprId,
62    /// The arena of all the expressions for the const calculation.
63    pub arenas: Arc<Arenas<'db>>,
64}
65
66// TODO: Review this well.
67unsafe impl<'db> salsa::Update for Constant<'db> {
68    unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
69        let old_constant: &mut Constant<'db> = unsafe { &mut *old_pointer };
70
71        if old_constant.value != new_value.value {
72            *old_constant = new_value;
73            return true;
74        }
75
76        false
77    }
78}
79
80impl<'db> Constant<'db> {
81    pub fn ty(&self) -> TypeId<'db> {
82        self.arenas.exprs[self.value].ty()
83    }
84}
85
86/// Information about a constant definition.
87///
88/// Helper struct for the data returned by [ConstantSemantic::constant_semantic_data].
89#[derive(Clone, Debug, PartialEq, Eq, DebugWithDb, salsa::Update)]
90#[debug_db(dyn Database)]
91pub struct ConstantData<'db> {
92    pub diagnostics: Diagnostics<'db, SemanticDiagnostic<'db>>,
93    pub constant: Maybe<Constant<'db>>,
94    pub const_value: ConstValueId<'db>,
95    pub resolver_data: Arc<ResolverData<'db>>,
96}
97
98define_short_id!(ConstValueId, ConstValue<'db>);
99semantic_object_for_id!(ConstValueId, ConstValue<'a>);
100impl<'db> ConstValueId<'db> {
101    /// Creates a new const value from a BigInt.
102    pub fn from_int(db: &'db dyn Database, ty: TypeId<'db>, value: &BigInt) -> Self {
103        let get_basic_const_value = |ty| {
104            let info = db.const_calc_info();
105            if ty != info.u256 {
106                ConstValue::Int(value.clone(), ty).intern(db)
107            } else {
108                let mask128 = BigInt::from(u128::MAX);
109                let low = value & mask128;
110                let high = value >> 128;
111                ConstValue::Struct(
112                    vec![
113                        (ConstValue::Int(low, info.u128).intern(db)),
114                        (ConstValue::Int(high, info.u128).intern(db)),
115                    ],
116                    ty,
117                )
118                .intern(db)
119            }
120        };
121        if let Some(inner) = try_extract_nz_wrapped_type(db, ty) {
122            ConstValue::NonZero(get_basic_const_value(inner)).intern(db)
123        } else {
124            get_basic_const_value(ty)
125        }
126    }
127    /// Returns a formatted string of the const value.
128    pub fn format(&self, db: &dyn Database) -> String {
129        format!("{:?}", self.long(db).debug(db))
130    }
131
132    /// Returns true if the const does not depend on any generics.
133    pub fn is_fully_concrete(&self, db: &dyn Database) -> bool {
134        self.long(db).is_fully_concrete(db)
135    }
136
137    /// Returns true if the const does not contain any inference variables.
138    pub fn is_var_free(&self, db: &dyn Database) -> bool {
139        self.long(db).is_var_free(db)
140    }
141
142    /// Returns the type of the const.
143    pub fn ty(&self, db: &'db dyn Database) -> Maybe<TypeId<'db>> {
144        self.long(db).ty(db)
145    }
146}
147
148/// Moves the value of a felt252, to the range of `[range_min, range_min + PRIME)`.
149pub fn felt252_for_downcast(value: &BigInt, range_min: &BigInt) -> BigInt {
150    Felt252::from(value - range_min).to_bigint() + range_min
151}
152
153/// Canonicalize the value of a felt252, to the range of `(-PRIME, PRIME)`.
154pub fn canonical_felt252(value: &BigInt) -> BigInt {
155    value % &*CAIRO_PRIME_BIGINT
156}
157
158/// A constant value.
159#[derive(Clone, Debug, Hash, PartialEq, Eq, SemanticObject, HeapSize, salsa::Update)]
160pub enum ConstValue<'db> {
161    Int(#[dont_rewrite] BigInt, TypeId<'db>),
162    Struct(Vec<ConstValueId<'db>>, TypeId<'db>),
163    Enum(ConcreteVariant<'db>, ConstValueId<'db>),
164    NonZero(ConstValueId<'db>),
165    Generic(#[dont_rewrite] GenericParamId<'db>),
166    ImplConstant(ImplConstantId<'db>),
167    Var(ConstVar<'db>, TypeId<'db>),
168    /// A missing value, used in cases where the value is not known due to diagnostics.
169    Missing(#[dont_rewrite] DiagnosticAdded),
170}
171impl<'db> ConstValue<'db> {
172    /// Returns true if the const does not depend on any generics.
173    pub fn is_fully_concrete(&self, db: &dyn Database) -> bool {
174        self.ty(db).unwrap().is_fully_concrete(db)
175            && match self {
176                ConstValue::Int(_, _) => true,
177                ConstValue::Struct(members, _) => {
178                    members.iter().all(|member| member.is_fully_concrete(db))
179                }
180                ConstValue::Enum(_, val) | ConstValue::NonZero(val) => val.is_fully_concrete(db),
181                ConstValue::Generic(_)
182                | ConstValue::Var(_, _)
183                | ConstValue::Missing(_)
184                | ConstValue::ImplConstant(_) => false,
185            }
186    }
187
188    /// Returns true if the const does not contain any inference variables.
189    pub fn is_var_free(&self, db: &dyn Database) -> bool {
190        self.ty(db).unwrap().is_var_free(db)
191            && match self {
192                ConstValue::Int(_, _) | ConstValue::Generic(_) | ConstValue::Missing(_) => true,
193                ConstValue::Struct(members, _) => {
194                    members.iter().all(|member| member.is_var_free(db))
195                }
196                ConstValue::Enum(_, val) | ConstValue::NonZero(val) => val.is_var_free(db),
197                ConstValue::Var(_, _) => false,
198                ConstValue::ImplConstant(impl_constant) => impl_constant.impl_id().is_var_free(db),
199            }
200    }
201
202    /// Returns the type of the const.
203    pub fn ty(&self, db: &'db dyn Database) -> Maybe<TypeId<'db>> {
204        Ok(match self {
205            ConstValue::Int(_, ty) => *ty,
206            ConstValue::Struct(_, ty) => *ty,
207            ConstValue::Enum(variant, _) => {
208                TypeLongId::Concrete(ConcreteTypeId::Enum(variant.concrete_enum_id)).intern(db)
209            }
210            ConstValue::NonZero(value) => core_nonzero_ty(db, value.ty(db)?),
211            ConstValue::Generic(param) => {
212                extract_matches!(db.generic_param_semantic(*param)?, GenericParam::Const).ty
213            }
214            ConstValue::Var(_, ty) => *ty,
215            ConstValue::Missing(_) => TypeId::missing(db, skip_diagnostic()),
216            ConstValue::ImplConstant(impl_constant_id) => {
217                db.impl_constant_concrete_implized_type(*impl_constant_id)?
218            }
219        })
220    }
221
222    /// Returns the value of an int const as a BigInt.
223    pub fn to_int(&self) -> Option<&BigInt> {
224        match self {
225            ConstValue::Int(value, _) => Some(value),
226            _ => None,
227        }
228    }
229}
230
231/// An impl item of kind const.
232#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, SemanticObject, HeapSize, salsa::Update)]
233pub struct ImplConstantId<'db> {
234    /// The impl the item const is in.
235    impl_id: ImplId<'db>,
236    /// The trait const this impl const "implements".
237    trait_constant_id: TraitConstantId<'db>,
238}
239
240impl<'db> ImplConstantId<'db> {
241    /// Creates a new impl constant id. For an impl constant of a concrete impl, asserts that the
242    /// trait constant belongs to the same trait that the impl implements (panics if not).
243    pub fn new(
244        impl_id: ImplId<'db>,
245        trait_constant_id: TraitConstantId<'db>,
246        db: &dyn Database,
247    ) -> Self {
248        if let ImplLongId::Concrete(concrete_impl) = impl_id.long(db) {
249            let impl_def_id = concrete_impl.impl_def_id(db);
250            assert_eq!(Ok(trait_constant_id.trait_id(db)), db.impl_def_trait(impl_def_id));
251        }
252
253        ImplConstantId { impl_id, trait_constant_id }
254    }
255    pub fn impl_id(&self) -> ImplId<'db> {
256        self.impl_id
257    }
258    pub fn trait_constant_id(&self) -> TraitConstantId<'db> {
259        self.trait_constant_id
260    }
261
262    pub fn format(&self, db: &dyn Database) -> String {
263        format!("{}::{}", self.impl_id.name(db), self.trait_constant_id.name(db).long(db))
264    }
265}
266impl<'db> DebugWithDb<'db> for ImplConstantId<'db> {
267    type Db = dyn Database;
268
269    fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &'db dyn Database) -> std::fmt::Result {
270        write!(f, "{}", self.format(db))
271    }
272}
273
274/// Returns the semantic data of a constant.
275#[salsa::tracked(returns(ref), cycle_result=constant_semantic_data_cycle)]
276fn constant_semantic_data<'db>(
277    db: &'db dyn Database,
278    const_id: ConstantId<'db>,
279    in_cycle: bool,
280) -> Maybe<ConstantData<'db>> {
281    let lookup_item_id = LookupItemId::ModuleItem(ModuleItemId::Constant(const_id));
282    if in_cycle {
283        constant_semantic_data_cycle_helper(
284            db,
285            &db.module_constant_by_id(const_id)?,
286            lookup_item_id,
287            None,
288            &const_id,
289        )
290    } else {
291        constant_semantic_data_helper(
292            db,
293            &db.module_constant_by_id(const_id)?,
294            lookup_item_id,
295            None,
296            &const_id,
297        )
298    }
299}
300
301/// Cycle handling for [ConstantSemantic::constant_semantic_data].
302fn constant_semantic_data_cycle<'db>(
303    db: &'db dyn Database,
304    const_id: ConstantId<'db>,
305    _in_cycle: bool,
306) -> Maybe<ConstantData<'db>> {
307    let lookup_item_id = LookupItemId::ModuleItem(ModuleItemId::Constant(const_id));
308    constant_semantic_data_cycle_helper(
309        db,
310        &db.module_constant_by_id(const_id)?,
311        lookup_item_id,
312        None,
313        &const_id,
314    )
315}
316
317/// Returns constant semantic data for the given ItemConstant.
318pub fn constant_semantic_data_helper<'db>(
319    db: &'db dyn Database,
320    constant_ast: &ItemConstant<'db>,
321    lookup_item_id: LookupItemId<'db>,
322    parent_resolver_data: Option<Arc<ResolverData<'db>>>,
323    element_id: &impl LanguageElementId<'db>,
324) -> Maybe<ConstantData<'db>> {
325    let mut diagnostics: SemanticDiagnostics<'_> = SemanticDiagnostics::default();
326    // TODO(spapini): when code changes in a file, all the AST items change (as they contain a path
327    // to the green root that changes. Once ASTs are rooted on items, use a selector that picks only
328    // the item instead of all the module data.
329
330    let inference_id = InferenceId::LookupItemDeclaration(lookup_item_id);
331
332    let mut resolver = match parent_resolver_data {
333        Some(parent_resolver_data) => {
334            Resolver::with_data(db, parent_resolver_data.clone_with_inference_id(db, inference_id))
335        }
336        None => Resolver::new(db, element_id.parent_module(db), inference_id),
337    };
338    resolver.set_feature_config(element_id, constant_ast, &mut diagnostics);
339
340    let ty_syntax = constant_ast.type_clause(db).ty(db);
341    let constant_type = resolve_type(db, &mut diagnostics, &mut resolver, &ty_syntax);
342    // The type of a constant must not rely on its value, to create better stability for the
343    // constant signature.
344    if !constant_type.is_var_free(db) {
345        diagnostics.report(
346            ty_syntax.stable_ptr(db).untyped(),
347            SemanticDiagnosticKind::ConstTypeNotVarFree,
348        );
349    }
350
351    let mut ctx = ComputationContext::new_global(db, &mut diagnostics, &mut resolver);
352
353    let value = compute_expr_semantic(&mut ctx, &constant_ast.value(db));
354    let const_value = resolve_const_expr_and_evaluate(
355        db,
356        &mut ctx,
357        &value,
358        constant_ast.stable_ptr(db).untyped(),
359        constant_type,
360        true,
361    );
362    let constant = Ok(Constant { value: value.id, arenas: Arc::new(ctx.arenas) });
363    let const_value = resolver
364        .inference()
365        .rewrite(const_value)
366        .unwrap_or_else(|_| ConstValue::Missing(skip_diagnostic()).intern(db));
367    let resolver_data = Arc::new(resolver.data);
368    Ok(ConstantData { diagnostics: diagnostics.build(), const_value, constant, resolver_data })
369}
370
371/// Helper for cycle handling of constants.
372pub fn constant_semantic_data_cycle_helper<'db>(
373    db: &'db dyn Database,
374    constant_ast: &ItemConstant<'db>,
375    lookup_item_id: LookupItemId<'db>,
376    parent_resolver_data: Option<Arc<ResolverData<'db>>>,
377    element_id: &impl LanguageElementId<'db>,
378) -> Maybe<ConstantData<'db>> {
379    let mut diagnostics: SemanticDiagnostics<'_> = SemanticDiagnostics::default();
380
381    let inference_id = InferenceId::LookupItemDeclaration(lookup_item_id);
382
383    let resolver = match parent_resolver_data {
384        Some(parent_resolver_data) => {
385            Resolver::with_data(db, parent_resolver_data.clone_with_inference_id(db, inference_id))
386        }
387        None => Resolver::new(db, element_id.parent_module(db), inference_id),
388    };
389
390    let resolver_data = Arc::new(resolver.data);
391
392    let diagnostic_added =
393        diagnostics.report(constant_ast.stable_ptr(db), SemanticDiagnosticKind::ConstCycle);
394    Ok(ConstantData {
395        constant: Err(diagnostic_added),
396        const_value: ConstValue::Missing(diagnostic_added).intern(db),
397        diagnostics: diagnostics.build(),
398        resolver_data,
399    })
400}
401
402/// Checks if the given expression only involved constant calculations.
403pub fn validate_const_expr<'db>(ctx: &mut ComputationContext<'db, '_>, expr_id: ExprId) {
404    let info = ctx.db.const_calc_info();
405    let mut eval_ctx = ConstantEvaluateContext {
406        db: ctx.db,
407        info: info.as_ref(),
408        arenas: &ctx.arenas,
409        vars: Default::default(),
410        generic_substitution: Default::default(),
411        depth: 0,
412        diagnostics: ctx.diagnostics,
413    };
414    eval_ctx.validate(expr_id);
415}
416
417/// Resolves the given const expression and evaluates its value.
418pub fn resolve_const_expr_and_evaluate<'db, 'mt>(
419    db: &'db dyn Database,
420    ctx: &'mt mut ComputationContext<'db, '_>,
421    value: &ExprAndId<'db>,
422    const_stable_ptr: SyntaxStablePtrId<'db>,
423    target_type: TypeId<'db>,
424    finalize: bool,
425) -> ConstValueId<'db> {
426    let prev_err_count = ctx.diagnostics.error_count;
427    let mut_ref = &mut ctx.resolver;
428    let mut inference: crate::expr::inference::Inference<'db, '_> = mut_ref.inference();
429    if let Err(err_set) = inference.conform_ty(value.ty(), target_type) {
430        inference.report_on_pending_error(err_set, ctx.diagnostics, const_stable_ptr);
431    }
432
433    if finalize {
434        // Check fully resolved.
435        inference.finalize(ctx.diagnostics, const_stable_ptr);
436    } else if let Err(err_set) = inference.solve() {
437        inference.report_on_pending_error(err_set, ctx.diagnostics, const_stable_ptr);
438    }
439
440    // TODO(orizi): Consider moving this to be called only upon creating const values, other callees
441    // don't necessarily need it.
442    ctx.apply_inference_rewriter_to_exprs();
443
444    match &value.expr {
445        Expr::Constant(ExprConstant { const_value_id, .. }) => *const_value_id,
446        // Check that the expression is a valid constant.
447        _ if ctx.diagnostics.error_count > prev_err_count => {
448            ConstValue::Missing(skip_diagnostic()).intern(db)
449        }
450        _ => {
451            let info = db.const_calc_info();
452            let mut eval_ctx = ConstantEvaluateContext {
453                db,
454                info: info.as_ref(),
455                arenas: &ctx.arenas,
456                vars: Default::default(),
457                generic_substitution: Default::default(),
458                depth: 0,
459                diagnostics: ctx.diagnostics,
460            };
461            eval_ctx.validate(value.id);
462            if eval_ctx.diagnostics.error_count > prev_err_count {
463                ConstValue::Missing(skip_diagnostic()).intern(db)
464            } else {
465                eval_ctx.evaluate(value.id)
466            }
467        }
468    }
469}
470
471/// A context for evaluating constant expressions.
472struct ConstantEvaluateContext<'a, 'r, 'mt> {
473    db: &'a dyn Database,
474    info: &'r ConstCalcInfo<'a>,
475    arenas: &'r Arenas<'a>,
476    vars: OrderedHashMap<VarId<'a>, ConstValueId<'a>>,
477    generic_substitution: GenericSubstitution<'a>,
478    depth: usize,
479    diagnostics: &'mt mut SemanticDiagnostics<'a>,
480}
481impl<'a, 'r, 'mt> ConstantEvaluateContext<'a, 'r, 'mt> {
482    /// Validate the given expression can be used as constant.
483    fn validate(&mut self, expr_id: ExprId) {
484        match &self.arenas.exprs[expr_id] {
485            Expr::Var(_) | Expr::Constant(_) | Expr::Missing(_) => {}
486            Expr::Block(ExprBlock { statements, tail: Some(inner), .. }) => {
487                for statement_id in statements {
488                    match &self.arenas.statements[*statement_id] {
489                        Statement::Let(statement) => {
490                            self.validate(statement.expr);
491                        }
492                        Statement::Expr(expr) => {
493                            self.validate(expr.expr);
494                        }
495                        other => {
496                            self.diagnostics.report(
497                                other.stable_ptr(),
498                                SemanticDiagnosticKind::UnsupportedConstant,
499                            );
500                        }
501                    }
502                }
503                self.validate(*inner);
504            }
505            Expr::FunctionCall(expr) => {
506                for arg in &expr.args {
507                    match arg {
508                        ExprFunctionCallArg::Value(arg) => self.validate(*arg),
509                        ExprFunctionCallArg::Reference(var) => {
510                            self.diagnostics.report(
511                                var.stable_ptr(),
512                                SemanticDiagnosticKind::UnsupportedConstant,
513                            );
514                        }
515                    }
516                    if let ExprFunctionCallArg::Value(arg) = arg {
517                        self.validate(*arg);
518                    }
519                }
520                if !self.is_function_const(expr.function) {
521                    self.diagnostics.report(
522                        expr.stable_ptr.untyped(),
523                        SemanticDiagnosticKind::UnsupportedConstant,
524                    );
525                }
526            }
527            Expr::Literal(expr) => {
528                if let Err(err) = validate_literal(self.db, expr.ty, &expr.value) {
529                    self.diagnostics.report(
530                        expr.stable_ptr.untyped(),
531                        SemanticDiagnosticKind::LiteralError(err),
532                    );
533                }
534            }
535            Expr::Tuple(expr) => {
536                for item in &expr.items {
537                    self.validate(*item);
538                }
539            }
540            Expr::StructCtor(ExprStructCtor { members, base_struct: None, .. }) => {
541                for (expr_id, _) in members {
542                    self.validate(*expr_id);
543                }
544            }
545            Expr::EnumVariantCtor(expr) => self.validate(expr.value_expr),
546            Expr::MemberAccess(expr) => self.validate(expr.expr),
547            Expr::FixedSizeArray(expr) => match &expr.items {
548                crate::FixedSizeArrayItems::Items(items) => {
549                    for item in items {
550                        self.validate(*item);
551                    }
552                }
553                crate::FixedSizeArrayItems::ValueAndSize(value, _) => {
554                    self.validate(*value);
555                }
556            },
557            Expr::Snapshot(expr) => self.validate(expr.inner),
558            Expr::Desnap(expr) => self.validate(expr.inner),
559            Expr::LogicalOperator(expr) => {
560                self.validate(expr.lhs);
561                self.validate(expr.rhs);
562            }
563            Expr::Match(expr) => {
564                self.validate(expr.matched_expr);
565                for arm in &expr.arms {
566                    self.validate(arm.expression);
567                }
568            }
569            Expr::If(expr) => {
570                for condition in &expr.conditions {
571                    self.validate(match condition {
572                        Condition::BoolExpr(id) | Condition::Let(id, _) => *id,
573                    });
574                }
575                self.validate(expr.if_block);
576                if let Some(else_block) = expr.else_block {
577                    self.validate(else_block);
578                }
579            }
580            other => {
581                self.diagnostics.report(
582                    other.stable_ptr().untyped(),
583                    SemanticDiagnosticKind::UnsupportedConstant,
584                );
585            }
586        }
587    }
588
589    /// Returns true if the given function is allowed to be called in constant context.
590    fn is_function_const(&self, function_id: FunctionId<'a>) -> bool {
591        if function_id == self.panic_with_felt252 {
592            return true;
593        }
594        let db = self.db;
595        let concrete_function = function_id.get_concrete(db);
596        let signature = (|| match concrete_function.generic_function {
597            GenericFunctionId::Free(id) => db.free_function_signature(id),
598            GenericFunctionId::Extern(id) => db.extern_function_signature(id),
599            GenericFunctionId::Impl(id) => {
600                if let ImplLongId::Concrete(impl_id) = id.impl_id.long(db)
601                    && let Ok(Some(impl_function_id)) = impl_id.get_impl_function(db, id.function)
602                {
603                    return self.db.impl_function_signature(impl_function_id);
604                }
605                self.db.trait_function_signature(id.function)
606            }
607        })();
608        if signature.map(|s| s.is_const) == Ok(true) {
609            return true;
610        }
611        let Ok(Some(body)) = concrete_function.body(db) else { return false };
612        let GenericFunctionWithBodyId::Impl(imp) = body.generic_function(db) else {
613            return false;
614        };
615        let impl_def = imp.concrete_impl_id.impl_def_id(db);
616        if impl_def.parent_module(db).owning_crate(db) != db.core_crate() {
617            return false;
618        }
619        let Ok(trait_id) = db.impl_def_trait(impl_def) else {
620            return false;
621        };
622        self.const_traits.contains(&trait_id)
623    }
624
625    /// Evaluate the given const expression value.
626    fn evaluate<'ctx>(&'ctx mut self, expr_id: ExprId) -> ConstValueId<'a> {
627        let expr = &self.arenas.exprs[expr_id];
628        let db = self.db;
629        let to_missing = |diag_added| ConstValue::Missing(diag_added).intern(db);
630        match expr {
631            Expr::Var(expr) => self.vars.get(&expr.var).copied().unwrap_or_else(|| {
632                to_missing(
633                    self.diagnostics
634                        .report(expr.stable_ptr, SemanticDiagnosticKind::UnsupportedConstant),
635                )
636            }),
637            Expr::Constant(expr) => self
638                .generic_substitution
639                .substitute(self.db, expr.const_value_id)
640                .unwrap_or_else(to_missing),
641            Expr::Block(ExprBlock { statements, tail: Some(inner), .. }) => {
642                for statement_id in statements {
643                    match &self.arenas.statements[*statement_id] {
644                        Statement::Let(statement) => {
645                            let value = self.evaluate(statement.expr);
646                            self.destructure_pattern(statement.pattern, value);
647                        }
648                        Statement::Expr(expr) => {
649                            self.evaluate(expr.expr);
650                        }
651                        other => {
652                            self.diagnostics.report(
653                                other.stable_ptr(),
654                                SemanticDiagnosticKind::UnsupportedConstant,
655                            );
656                        }
657                    }
658                }
659                self.evaluate(*inner)
660            }
661            Expr::FunctionCall(expr) => self.evaluate_function_call(expr),
662            Expr::Literal(expr) => ConstValueId::from_int(db, expr.ty, &expr.value),
663            Expr::Tuple(expr) => ConstValue::Struct(
664                expr.items.iter().map(|expr_id| self.evaluate(*expr_id)).collect(),
665                expr.ty,
666            )
667            .intern(db),
668            Expr::StructCtor(ExprStructCtor {
669                members,
670                base_struct: None,
671                ty,
672                concrete_struct_id,
673                ..
674            }) => {
675                let member_order = match db.concrete_struct_members(*concrete_struct_id) {
676                    Ok(member_order) => member_order,
677                    Err(diag_add) => return to_missing(diag_add),
678                };
679                ConstValue::Struct(
680                    member_order
681                        .values()
682                        .map(|m| {
683                            members
684                                .iter()
685                                .find(|(_, member_id)| m.id == *member_id)
686                                .map(|(expr_id, _)| self.evaluate(*expr_id))
687                                .expect("Should have been caught by semantic validation")
688                        })
689                        .collect(),
690                    *ty,
691                )
692                .intern(db)
693            }
694            Expr::EnumVariantCtor(expr) => {
695                ConstValue::Enum(expr.variant, self.evaluate(expr.value_expr)).intern(db)
696            }
697            Expr::MemberAccess(expr) => {
698                self.evaluate_member_access(expr).unwrap_or_else(to_missing)
699            }
700            Expr::FixedSizeArray(expr) => ConstValue::Struct(
701                match &expr.items {
702                    crate::FixedSizeArrayItems::Items(items) => {
703                        items.iter().map(|expr_id| self.evaluate(*expr_id)).collect()
704                    }
705                    crate::FixedSizeArrayItems::ValueAndSize(value, count) => {
706                        let value = self.evaluate(*value);
707                        if let Some(count) = count.long(db).to_int() {
708                            vec![value; count.to_usize().unwrap()]
709                        } else {
710                            self.diagnostics.report(
711                                expr.stable_ptr.untyped(),
712                                SemanticDiagnosticKind::UnsupportedConstant,
713                            );
714                            vec![]
715                        }
716                    }
717                },
718                expr.ty,
719            )
720            .intern(db),
721            Expr::Snapshot(expr) => self.evaluate(expr.inner),
722            Expr::Desnap(expr) => self.evaluate(expr.inner),
723            Expr::LogicalOperator(expr) => {
724                let lhs = self.evaluate(expr.lhs);
725                if let ConstValue::Enum(v, _) = lhs.long(db) {
726                    let early_return_variant = match expr.op {
727                        LogicalOperator::AndAnd => false_variant(self.db),
728                        LogicalOperator::OrOr => true_variant(self.db),
729                    };
730                    if *v == early_return_variant { lhs } else { self.evaluate(expr.rhs) }
731                } else {
732                    to_missing(skip_diagnostic())
733                }
734            }
735            Expr::Match(expr) => {
736                let value = self.evaluate(expr.matched_expr);
737                let ConstValue::Enum(variant, value) = value.long(db) else {
738                    return to_missing(skip_diagnostic());
739                };
740                for arm in &expr.arms {
741                    for pattern_id in &arm.patterns {
742                        let pattern = &self.arenas.patterns[*pattern_id];
743                        if matches!(pattern, Pattern::Otherwise(_)) {
744                            return self.evaluate(arm.expression);
745                        }
746                        let Pattern::EnumVariant(pattern) = pattern else {
747                            continue;
748                        };
749                        if pattern.variant.idx != variant.idx {
750                            continue;
751                        }
752                        if let Some(inner_pattern) = pattern.inner_pattern {
753                            self.destructure_pattern(inner_pattern, *value);
754                        }
755                        return self.evaluate(arm.expression);
756                    }
757                }
758                to_missing(
759                    self.diagnostics.report(
760                        expr.stable_ptr.untyped(),
761                        SemanticDiagnosticKind::UnsupportedConstant,
762                    ),
763                )
764            }
765            Expr::If(expr) => {
766                let mut if_condition: bool = true;
767                for condition in &expr.conditions {
768                    match condition {
769                        crate::Condition::BoolExpr(id) => {
770                            let condition = self.evaluate(*id);
771                            let ConstValue::Enum(variant, _) = condition.long(db) else {
772                                return to_missing(skip_diagnostic());
773                            };
774                            if *variant != true_variant(self.db) {
775                                if_condition = false;
776                                break;
777                            }
778                        }
779                        crate::Condition::Let(id, patterns) => {
780                            let value = self.evaluate(*id);
781                            let ConstValue::Enum(variant, value) = value.long(db) else {
782                                return to_missing(skip_diagnostic());
783                            };
784                            let mut found_pattern = false;
785                            for pattern_id in patterns {
786                                let Pattern::EnumVariant(pattern) =
787                                    &self.arenas.patterns[*pattern_id]
788                                else {
789                                    continue;
790                                };
791                                if pattern.variant != *variant {
792                                    // Continue to the next option in the `|` list.
793                                    continue;
794                                }
795                                if let Some(inner_pattern) = pattern.inner_pattern {
796                                    self.destructure_pattern(inner_pattern, *value);
797                                }
798                                found_pattern = true;
799                                break;
800                            }
801                            if !found_pattern {
802                                if_condition = false;
803                                break;
804                            }
805                        }
806                    }
807                }
808
809                if if_condition {
810                    self.evaluate(expr.if_block)
811                } else if let Some(else_block) = expr.else_block {
812                    self.evaluate(else_block)
813                } else {
814                    self.unit_const
815                }
816            }
817            _ => to_missing(skip_diagnostic()),
818        }
819    }
820
821    /// Attempts to evaluate constants from a const function call.
822    fn evaluate_function_call(&mut self, expr: &ExprFunctionCall<'a>) -> ConstValueId<'a> {
823        let db = self.db;
824        let to_missing = |diag_added| ConstValue::Missing(diag_added).intern(db);
825        let args = expr
826            .args
827            .iter()
828            .filter_map(|arg| try_extract_matches!(arg, ExprFunctionCallArg::Value))
829            .map(|arg| self.evaluate(*arg))
830            .collect_vec();
831        if expr.function == self.panic_with_felt252 {
832            return to_missing(self.diagnostics.report(
833                expr.stable_ptr.untyped(),
834                SemanticDiagnosticKind::FailedConstantCalculation,
835            ));
836        }
837        let concrete_function =
838            match self.generic_substitution.substitute(db, expr.function.get_concrete(db)) {
839                Ok(v) => v,
840                Err(err) => return to_missing(err),
841            };
842        if let Some(calc_result) =
843            self.evaluate_const_function_call(&concrete_function, &args, expr)
844        {
845            return calc_result;
846        }
847
848        let imp = extract_matches!(concrete_function.generic_function, GenericFunctionId::Impl);
849        let bool_value = |condition: bool| {
850            if condition { self.true_const } else { self.false_const }
851        };
852
853        if imp.function == self.eq_fn {
854            return bool_value(args[0] == args[1]);
855        } else if imp.function == self.ne_fn {
856            return bool_value(args[0] != args[1]);
857        } else if imp.function == self.not_fn {
858            return bool_value(args[0] == self.false_const);
859        }
860
861        let args = match args
862            .into_iter()
863            .map(|arg| NumericArg::try_new(db, arg))
864            .collect::<Option<Vec<_>>>()
865        {
866            Some(args) => args,
867            // Diagnostic can be skipped as we would either have a semantic error for a bad arg for
868            // the function, or the arg itself couldn't have been calculated.
869            None => return to_missing(skip_diagnostic()),
870        };
871        let value = match imp.function {
872            id if id == self.neg_fn => -&args[0].v,
873            id if id == self.add_fn => &args[0].v + &args[1].v,
874            id if id == self.sub_fn => &args[0].v - &args[1].v,
875            id if id == self.mul_fn => &args[0].v * &args[1].v,
876            id if (id == self.div_fn || id == self.rem_fn) && args[1].v.is_zero() => {
877                return to_missing(
878                    self.diagnostics
879                        .report(expr.stable_ptr.untyped(), SemanticDiagnosticKind::DivisionByZero),
880                );
881            }
882            id if id == self.div_fn => &args[0].v / &args[1].v,
883            id if id == self.rem_fn => &args[0].v % &args[1].v,
884            id if id == self.bitand_fn => &args[0].v & &args[1].v,
885            id if id == self.bitor_fn => &args[0].v | &args[1].v,
886            id if id == self.bitxor_fn => &args[0].v ^ &args[1].v,
887            id if id == self.lt_fn => return bool_value(args[0].v < args[1].v),
888            id if id == self.le_fn => return bool_value(args[0].v <= args[1].v),
889            id if id == self.gt_fn => return bool_value(args[0].v > args[1].v),
890            id if id == self.ge_fn => return bool_value(args[0].v >= args[1].v),
891            id if id == self.div_rem_fn => {
892                // No need for non-zero check as this is type checked to begin with.
893                // Also results are always in the range of the input type, so `unwrap`s are ok.
894                return ConstValue::Struct(
895                    vec![
896                        ConstValueId::from_int(db, args[0].ty, &(&args[0].v / &args[1].v)),
897                        ConstValueId::from_int(db, args[0].ty, &(&args[0].v % &args[1].v)),
898                    ],
899                    expr.ty,
900                )
901                .intern(db);
902            }
903            _ => {
904                unreachable!("Unexpected function call in constant lowering: {:?}", expr)
905            }
906        };
907        if expr.ty == self.felt252 {
908            ConstValue::Int(canonical_felt252(&value), expr.ty).intern(db)
909        } else if let Err(err) = validate_literal(db, expr.ty, &value) {
910            to_missing(
911                self.diagnostics
912                    .report(expr.stable_ptr.untyped(), SemanticDiagnosticKind::LiteralError(err)),
913            )
914        } else {
915            ConstValueId::from_int(db, expr.ty, &value)
916        }
917    }
918
919    /// Attempts to evaluate a constant function call.
920    fn evaluate_const_function_call(
921        &mut self,
922        concrete_function: &ConcreteFunction<'a>,
923        args: &[ConstValueId<'a>],
924        expr: &ExprFunctionCall<'a>,
925    ) -> Option<ConstValueId<'a>> {
926        let db = self.db;
927        if let GenericFunctionId::Extern(extern_fn) = concrete_function.generic_function {
928            let expr_ty = self.generic_substitution.substitute(db, expr.ty).ok()?;
929            if self.upcast_fns.contains(&extern_fn) {
930                let [arg] = args else { return None };
931                return Some(ConstValueId::from_int(db, expr_ty, arg.long(db).to_int()?));
932            } else if self.unwrap_non_zero == extern_fn {
933                let [arg] = args else { return None };
934                return try_extract_matches!(arg.long(db), ConstValue::NonZero).copied();
935            } else if self.u128s_from_felt252 == extern_fn {
936                let [arg] = args else { return None };
937                let TypeLongId::Concrete(ConcreteTypeId::Enum(enm)) = expr_ty.long(db) else {
938                    return None;
939                };
940                let (narrow, wide) =
941                    db.concrete_enum_variants(*enm).ok()?.into_iter().collect_tuple()?;
942                let value = felt252_for_downcast(arg.long(db).to_int()?, &BigInt::ZERO);
943                let mask128 = BigInt::from(u128::MAX);
944                let low = (&value) & mask128;
945                let high: BigInt = (&value) >> 128;
946                return Some(if high.is_zero() {
947                    ConstValue::Enum(narrow, ConstValue::Int(low, narrow.ty).intern(db)).intern(db)
948                } else {
949                    ConstValue::Enum(
950                        wide,
951                        ConstValue::Struct(
952                            vec![
953                                (ConstValue::Int(high, self.u128).intern(db)),
954                                (ConstValue::Int(low, self.u128).intern(db)),
955                            ],
956                            wide.ty,
957                        )
958                        .intern(db),
959                    )
960                    .intern(db)
961                });
962            } else if let Some(reversed) = self.downcast_fns.get(&extern_fn) {
963                let [arg] = args else { return None };
964                let ConstValue::Int(value, input_ty) = arg.long(db) else { return None };
965                let TypeLongId::Concrete(ConcreteTypeId::Enum(enm)) = expr_ty.long(db) else {
966                    return None;
967                };
968                let (variant0, variant1) =
969                    db.concrete_enum_variants(*enm).ok()?.into_iter().collect_tuple()?;
970                let (some, none) =
971                    if *reversed { (variant1, variant0) } else { (variant0, variant1) };
972                let success_ty = some.ty;
973                let out_range = self
974                    .type_value_ranges
975                    .get(&success_ty)
976                    .cloned()
977                    .or_else(|| {
978                        let (min, max) = try_extract_bounded_int_type_ranges(db, success_ty)?;
979                        Some(TypeRange::new(min, max))
980                    })
981                    .unwrap_or_else(|| {
982                        unreachable!(
983                            "`downcast` is only allowed into types that can be literals. Got `{}`.",
984                            success_ty.format(db)
985                        )
986                    });
987                let value = if *input_ty == self.felt252 {
988                    felt252_for_downcast(value, &out_range.min)
989                } else {
990                    value.clone()
991                };
992                return Some(if value >= out_range.min && value <= out_range.max {
993                    ConstValue::Enum(some, ConstValue::Int(value, success_ty).intern(db)).intern(db)
994                } else {
995                    ConstValue::Enum(none, self.unit_const).intern(db)
996                });
997            } else if self.nz_fns.contains(&extern_fn) {
998                let [arg] = args else { return None };
999                let (ty, is_zero) = match arg.long(db) {
1000                    ConstValue::Int(val, ty) => (ty, val.is_zero()),
1001                    ConstValue::Struct(members, ty) => (
1002                        ty,
1003                        // For u256 struct with (low, high), check if both are zero
1004                        members.iter().all(|member| match member.long(db) {
1005                            ConstValue::Int(val, _) => val.is_zero(),
1006                            _ => false,
1007                        }),
1008                    ),
1009                    _ => unreachable!(
1010                        "`is_zero` is only allowed for integers got `{}`",
1011                        arg.ty(db).unwrap().format(db)
1012                    ),
1013                };
1014
1015                return Some(
1016                    if is_zero {
1017                        ConstValue::Enum(
1018                            crate::corelib::jump_nz_zero_variant(db, *ty),
1019                            self.unit_const,
1020                        )
1021                    } else {
1022                        ConstValue::Enum(
1023                            crate::corelib::jump_nz_nonzero_variant(db, *ty),
1024                            ConstValue::NonZero(*arg).intern(db),
1025                        )
1026                    }
1027                    .intern(db),
1028                );
1029            } else {
1030                unreachable!(
1031                    "Unexpected extern function in constant lowering: `{}`",
1032                    extern_fn.full_path(db)
1033                );
1034            }
1035        }
1036        let body_id = concrete_function.body(db).ok()??;
1037        let concrete_body_id = body_id.function_with_body_id(db);
1038        let signature = db.function_with_body_signature(concrete_body_id).ok()?;
1039        require(signature.is_const)?;
1040        let generic_substitution = body_id.substitution(db).ok()?;
1041        let body = db.function_body(concrete_body_id).ok()?;
1042        const MAX_CONST_EVAL_DEPTH: usize = 100;
1043        if self.depth > MAX_CONST_EVAL_DEPTH {
1044            return Some(
1045                ConstValue::Missing(self.diagnostics.report(
1046                    expr.stable_ptr,
1047                    SemanticDiagnosticKind::ConstantCalculationDepthExceeded,
1048                ))
1049                .intern(db),
1050            );
1051        }
1052        let mut diagnostics = SemanticDiagnostics::default();
1053        let mut inner = ConstantEvaluateContext {
1054            db,
1055            info: self.info,
1056            arenas: &body.arenas,
1057            vars: signature
1058                .params
1059                .iter()
1060                .map(|p| VarId::Param(p.id))
1061                .zip(args.iter().cloned())
1062                .collect(),
1063            generic_substitution,
1064            depth: self.depth + 1,
1065            diagnostics: &mut diagnostics,
1066        };
1067        let value = inner.evaluate(body.body_expr);
1068        for diagnostic in diagnostics.build().get_all() {
1069            let location = diagnostic.location(db);
1070            let (inner_diag, mut notes) = match diagnostic.kind {
1071                SemanticDiagnosticKind::ConstantCalculationDepthExceeded => {
1072                    self.diagnostics.report(
1073                        expr.stable_ptr,
1074                        SemanticDiagnosticKind::ConstantCalculationDepthExceeded,
1075                    );
1076                    continue;
1077                }
1078                SemanticDiagnosticKind::InnerFailedConstantCalculation(inner_diag, notes) => {
1079                    (inner_diag, notes)
1080                }
1081                _ => (diagnostic.into(), vec![]),
1082            };
1083            notes.push(DiagnosticNote::with_location(
1084                format!("In `{}`", concrete_function.full_path(db)),
1085                location,
1086            ));
1087            self.diagnostics.report(
1088                expr.stable_ptr,
1089                SemanticDiagnosticKind::InnerFailedConstantCalculation(inner_diag, notes),
1090            );
1091        }
1092        Some(value)
1093    }
1094
1095    /// Extract const member access from a const value.
1096    fn evaluate_member_access(&mut self, expr: &ExprMemberAccess<'a>) -> Maybe<ConstValueId<'a>> {
1097        let full_struct = self.evaluate(expr.expr);
1098        let ConstValue::Struct(values, _) = full_struct.long(self.db) else {
1099            // A semantic diagnostic should have been reported.
1100            return Err(skip_diagnostic());
1101        };
1102        let members = self.db.concrete_struct_members(expr.concrete_struct_id)?;
1103        let Some(member_idx) = members.iter().position(|(_, member)| member.id == expr.member)
1104        else {
1105            // A semantic diagnostic should have been reported.
1106            return Err(skip_diagnostic());
1107        };
1108        Ok(values[member_idx])
1109    }
1110
1111    /// Destructures the pattern into the const value of the variables in scope.
1112    fn destructure_pattern(&mut self, pattern_id: PatternId, value: ConstValueId<'a>) {
1113        let pattern = &self.arenas.patterns[pattern_id];
1114        let db = self.db;
1115        match pattern {
1116            Pattern::Literal(_)
1117            | Pattern::StringLiteral(_)
1118            | Pattern::Otherwise(_)
1119            | Pattern::Missing(_) => {}
1120            Pattern::Variable(pattern) => {
1121                self.vars.insert(VarId::Local(pattern.var.id), value);
1122            }
1123            Pattern::Struct(pattern) => {
1124                if let ConstValue::Struct(inner_values, _) = value.long(db) {
1125                    let member_order = match db.concrete_struct_members(pattern.concrete_struct_id)
1126                    {
1127                        Ok(member_order) => member_order,
1128                        Err(_) => return,
1129                    };
1130                    for (member, inner_value) in zip(member_order.values(), inner_values) {
1131                        if let Some((inner_pattern, _)) =
1132                            pattern.field_patterns.iter().find(|(_, field)| member.id == field.id)
1133                        {
1134                            self.destructure_pattern(*inner_pattern, *inner_value);
1135                        }
1136                    }
1137                }
1138            }
1139            Pattern::Tuple(pattern) => {
1140                if let ConstValue::Struct(inner_values, _) = value.long(db) {
1141                    for (inner_pattern, inner_value) in zip(&pattern.field_patterns, inner_values) {
1142                        self.destructure_pattern(*inner_pattern, *inner_value);
1143                    }
1144                }
1145            }
1146            Pattern::FixedSizeArray(pattern) => {
1147                if let ConstValue::Struct(inner_values, _) = value.long(db) {
1148                    for (inner_pattern, inner_value) in
1149                        zip(&pattern.elements_patterns, inner_values)
1150                    {
1151                        self.destructure_pattern(*inner_pattern, *inner_value);
1152                    }
1153                }
1154            }
1155            Pattern::EnumVariant(pattern) => {
1156                if let ConstValue::Enum(variant, inner_value) = value.long(db)
1157                    && pattern.variant == *variant
1158                    && let Some(inner_pattern) = pattern.inner_pattern
1159                {
1160                    self.destructure_pattern(inner_pattern, *inner_value);
1161                }
1162            }
1163        }
1164    }
1165}
1166
1167impl<'db, 'r> std::ops::Deref for ConstantEvaluateContext<'db, 'r, '_> {
1168    type Target = ConstCalcInfo<'db>;
1169    fn deref(&self) -> &Self::Target {
1170        self.info
1171    }
1172}
1173
1174/// Helper for the arguments info.
1175struct NumericArg<'db> {
1176    /// The arg's integer value.
1177    v: BigInt,
1178    /// The arg's type.
1179    ty: TypeId<'db>,
1180}
1181impl<'db> NumericArg<'db> {
1182    fn try_new(db: &'db dyn Database, arg: ConstValueId<'db>) -> Option<Self> {
1183        Some(Self { ty: arg.ty(db).ok()?, v: numeric_arg_value(db, arg)? })
1184    }
1185}
1186
1187/// Helper for creating a `NumericArg` value.
1188/// This includes unwrapping of `NonZero` values and struct of 2 values as a `u256`.
1189fn numeric_arg_value<'db>(db: &'db dyn Database, value: ConstValueId<'db>) -> Option<BigInt> {
1190    match value.long(db) {
1191        ConstValue::Int(value, _) => Some(value.clone()),
1192        ConstValue::Struct(v, _) => {
1193            if let [low, high] = &v[..] {
1194                Some(low.long(db).to_int()? + (high.long(db).to_int()? << 128))
1195            } else {
1196                None
1197            }
1198        }
1199        ConstValue::NonZero(const_value) => numeric_arg_value(db, *const_value),
1200        _ => None,
1201    }
1202}
1203
1204/// Query implementation of [ConstantSemantic::const_calc_info].
1205fn const_calc_info<'db>(db: &'db dyn Database) -> Arc<ConstCalcInfo<'db>> {
1206    Arc::new(ConstCalcInfo::new(db))
1207}
1208
1209/// Implementation of [ConstantSemantic::const_calc_info].
1210#[salsa::tracked]
1211fn const_calc_info_tracked<'db>(db: &'db dyn Database) -> Arc<ConstCalcInfo<'db>> {
1212    const_calc_info(db)
1213}
1214
1215/// Holds static information about extern functions required for const calculations.
1216#[derive(Debug, PartialEq, Eq, salsa::Update)]
1217pub struct ConstCalcInfo<'db> {
1218    /// Traits that are allowed for consts if their impls is in the corelib.
1219    const_traits: UnorderedHashSet<TraitId<'db>>,
1220    /// The const value for the unit type `()`.
1221    unit_const: ConstValueId<'db>,
1222    /// The const value for `true`.
1223    true_const: ConstValueId<'db>,
1224    /// The const value for `false`.
1225    false_const: ConstValueId<'db>,
1226    /// The function for panicking with a felt252.
1227    panic_with_felt252: FunctionId<'db>,
1228    /// The integer `upcast` style functions.
1229    pub upcast_fns: UnorderedHashSet<ExternFunctionId<'db>>,
1230    /// The integer `downcast` style functions, mapping to whether it returns a reversed Option
1231    /// enum.
1232    pub downcast_fns: UnorderedHashMap<ExternFunctionId<'db>, bool>,
1233    /// The `felt252` into `u128` words libfunc.
1234    pub u128s_from_felt252: ExternFunctionId<'db>,
1235    /// The `unwrap_non_zero` function.
1236    unwrap_non_zero: ExternFunctionId<'db>,
1237    /// The `is_zero` style functions.
1238    pub nz_fns: UnorderedHashSet<ExternFunctionId<'db>>,
1239    /// The range of values of a numeric type.
1240    pub type_value_ranges: UnorderedHashMap<TypeId<'db>, TypeRange>,
1241
1242    core_info: Arc<CoreInfo<'db>>,
1243}
1244
1245impl<'db> std::ops::Deref for ConstCalcInfo<'db> {
1246    type Target = CoreInfo<'db>;
1247    fn deref(&self) -> &CoreInfo<'db> {
1248        &self.core_info
1249    }
1250}
1251
1252impl<'db> ConstCalcInfo<'db> {
1253    /// Creates a new ConstCalcInfo.
1254    fn new(db: &'db dyn Database) -> Self {
1255        let core_info = db.core_info();
1256        let unit_const = ConstValue::Struct(vec![], unit_ty(db)).intern(db);
1257        let core = ModuleHelper::core(db);
1258        let bounded_int = core.submodule("internal").submodule("bounded_int");
1259        let integer = core.submodule("integer");
1260        let zeroable = core.submodule("zeroable");
1261        let starknet = core.submodule("starknet");
1262        let class_hash_module = starknet.submodule("class_hash");
1263        let class_hash_ty = class_hash_module.ty("ClassHash", vec![]);
1264        let contract_address_module = starknet.submodule("contract_address");
1265        let contract_address_ty = contract_address_module.ty("ContractAddress", vec![]);
1266        Self {
1267            const_traits: FromIterator::from_iter([
1268                core_info.neg_trt,
1269                core_info.add_trt,
1270                core_info.sub_trt,
1271                core_info.mul_trt,
1272                core_info.div_trt,
1273                core_info.rem_trt,
1274                core_info.div_rem_trt,
1275                core_info.bitand_trt,
1276                core_info.bitor_trt,
1277                core_info.bitxor_trt,
1278                core_info.partialeq_trt,
1279                core_info.partialord_trt,
1280                core_info.not_trt,
1281            ]),
1282            true_const: ConstValue::Enum(true_variant(db), unit_const).intern(db),
1283            false_const: ConstValue::Enum(false_variant(db), unit_const).intern(db),
1284            unit_const,
1285            panic_with_felt252: core.function_id("panic_with_felt252", vec![]),
1286            upcast_fns: FromIterator::from_iter([
1287                bounded_int.extern_function_id("upcast"),
1288                integer.extern_function_id("u8_to_felt252"),
1289                integer.extern_function_id("u16_to_felt252"),
1290                integer.extern_function_id("u32_to_felt252"),
1291                integer.extern_function_id("u64_to_felt252"),
1292                integer.extern_function_id("u128_to_felt252"),
1293                integer.extern_function_id("i8_to_felt252"),
1294                integer.extern_function_id("i16_to_felt252"),
1295                integer.extern_function_id("i32_to_felt252"),
1296                integer.extern_function_id("i64_to_felt252"),
1297                integer.extern_function_id("i128_to_felt252"),
1298                class_hash_module.extern_function_id("class_hash_to_felt252"),
1299                contract_address_module.extern_function_id("contract_address_to_felt252"),
1300            ]),
1301            downcast_fns: FromIterator::from_iter([
1302                (bounded_int.extern_function_id("downcast"), false),
1303                (bounded_int.extern_function_id("bounded_int_trim_min"), true),
1304                (bounded_int.extern_function_id("bounded_int_trim_max"), true),
1305                (integer.extern_function_id("u8_try_from_felt252"), false),
1306                (integer.extern_function_id("u16_try_from_felt252"), false),
1307                (integer.extern_function_id("u32_try_from_felt252"), false),
1308                (integer.extern_function_id("u64_try_from_felt252"), false),
1309                (integer.extern_function_id("i8_try_from_felt252"), false),
1310                (integer.extern_function_id("i16_try_from_felt252"), false),
1311                (integer.extern_function_id("i32_try_from_felt252"), false),
1312                (integer.extern_function_id("i64_try_from_felt252"), false),
1313                (integer.extern_function_id("i128_try_from_felt252"), false),
1314                (class_hash_module.extern_function_id("class_hash_try_from_felt252"), false),
1315                (
1316                    contract_address_module.extern_function_id("contract_address_try_from_felt252"),
1317                    false,
1318                ),
1319            ]),
1320            u128s_from_felt252: integer.extern_function_id("u128s_from_felt252"),
1321            unwrap_non_zero: zeroable.extern_function_id("unwrap_non_zero"),
1322            nz_fns: FromIterator::from_iter([
1323                core.extern_function_id("felt252_is_zero"),
1324                bounded_int.extern_function_id("bounded_int_is_zero"),
1325                integer.extern_function_id("u8_is_zero"),
1326                integer.extern_function_id("u16_is_zero"),
1327                integer.extern_function_id("u32_is_zero"),
1328                integer.extern_function_id("u64_is_zero"),
1329                integer.extern_function_id("u128_is_zero"),
1330                integer.extern_function_id("u256_is_zero"),
1331            ]),
1332            type_value_ranges: FromIterator::from_iter([
1333                (core_info.u8, TypeRange::new(u8::MIN, u8::MAX)),
1334                (core_info.u16, TypeRange::new(u16::MIN, u16::MAX)),
1335                (core_info.u32, TypeRange::new(u32::MIN, u32::MAX)),
1336                (core_info.u64, TypeRange::new(u64::MIN, u64::MAX)),
1337                (core_info.u128, TypeRange::new(u128::MIN, u128::MAX)),
1338                (core_info.u256, TypeRange::new(BigInt::ZERO, BigInt::from(1) << 256)),
1339                (core_info.i8, TypeRange::new(i8::MIN, i8::MAX)),
1340                (core_info.i16, TypeRange::new(i16::MIN, i16::MAX)),
1341                (core_info.i32, TypeRange::new(i32::MIN, i32::MAX)),
1342                (core_info.i64, TypeRange::new(i64::MIN, i64::MAX)),
1343                (core_info.i128, TypeRange::new(i128::MIN, i128::MAX)),
1344                (class_hash_ty, TypeRange::new(BigInt::ZERO, BigInt::from(1) << 251)),
1345                (contract_address_ty, TypeRange::new(BigInt::ZERO, BigInt::from(1) << 251)),
1346            ]),
1347            core_info,
1348        }
1349    }
1350}
1351
1352/// Trait for constant-related semantic queries.
1353pub trait ConstantSemantic<'db>: Database {
1354    /// Returns the semantic diagnostics of a constant definition.
1355    fn constant_semantic_diagnostics(
1356        &'db self,
1357        const_id: ConstantId<'db>,
1358    ) -> Diagnostics<'db, SemanticDiagnostic<'db>> {
1359        let db = self.as_dyn_database();
1360        constant_semantic_data(db, const_id, false)
1361            .as_ref()
1362            .map(|data| data.diagnostics.clone())
1363            .unwrap_or_default()
1364    }
1365    /// Returns the semantic data of a constant definition.
1366    fn constant_semantic_data(&'db self, use_id: ConstantId<'db>) -> Maybe<Constant<'db>> {
1367        let db = self.as_dyn_database();
1368        constant_semantic_data(db, use_id, false).maybe_as_ref()?.constant.clone()
1369    }
1370    /// Returns the resolver data of a constant definition.
1371    fn constant_resolver_data(&'db self, use_id: ConstantId<'db>) -> Maybe<Arc<ResolverData<'db>>> {
1372        let db = self.as_dyn_database();
1373        Ok(constant_semantic_data(db, use_id, false).maybe_as_ref()?.resolver_data.clone())
1374    }
1375    /// Returns the const value of a constant definition.
1376    fn constant_const_value(&'db self, const_id: ConstantId<'db>) -> Maybe<ConstValueId<'db>> {
1377        let db = self.as_dyn_database();
1378        Ok(constant_semantic_data(db, const_id, false).maybe_as_ref()?.const_value)
1379    }
1380    /// Returns information required for const calculations.
1381    fn const_calc_info(&'db self) -> Arc<ConstCalcInfo<'db>> {
1382        const_calc_info_tracked(self.as_dyn_database())
1383    }
1384}
1385impl<'db, T: Database + ?Sized> ConstantSemantic<'db> for T {}
1386
1387/// A range of values of a numeric type.
1388#[derive(Clone, Debug, PartialEq, Eq, salsa::Update)]
1389pub struct TypeRange {
1390    /// The minimum value of the range.
1391    pub min: BigInt,
1392    /// The maximum value of the range.
1393    pub max: BigInt,
1394}
1395impl TypeRange {
1396    pub fn new(min: impl Into<BigInt>, max: impl Into<BigInt>) -> Self {
1397        Self { min: min.into(), max: max.into() }
1398    }
1399}