cairo_lang_sierra_to_casm/invocations/
mod.rs

1use assert_matches::assert_matches;
2use cairo_lang_casm::ap_change::ApChange;
3use cairo_lang_casm::builder::{CasmBuildResult, CasmBuilder, Var};
4use cairo_lang_casm::cell_expression::CellExpression;
5use cairo_lang_casm::instructions::Instruction;
6use cairo_lang_casm::operand::{CellRef, Register};
7use cairo_lang_sierra::extensions::circuit::CircuitInfo;
8use cairo_lang_sierra::extensions::core::CoreConcreteLibfunc::{self, *};
9use cairo_lang_sierra::extensions::coupon::CouponConcreteLibfunc;
10use cairo_lang_sierra::extensions::gas::CostTokenType;
11use cairo_lang_sierra::extensions::lib_func::{BranchSignature, OutputVarInfo, SierraApChange};
12use cairo_lang_sierra::extensions::{ConcreteLibfunc, OutputVarReferenceInfo};
13use cairo_lang_sierra::ids::ConcreteTypeId;
14use cairo_lang_sierra::program::{BranchInfo, BranchTarget, Invocation, StatementIdx};
15use cairo_lang_sierra_ap_change::core_libfunc_ap_change::{
16    InvocationApChangeInfoProvider, core_libfunc_ap_change,
17};
18use cairo_lang_sierra_gas::core_libfunc_cost::{InvocationCostInfoProvider, core_libfunc_cost};
19use cairo_lang_sierra_gas::objects::ConstCost;
20use cairo_lang_sierra_type_size::TypeSizeMap;
21use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
22use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
23use itertools::{Itertools, chain, zip_eq};
24use num_bigint::BigInt;
25use thiserror::Error;
26
27use crate::circuit::CircuitsInfo;
28use crate::environment::Environment;
29use crate::environment::frame_state::{FrameState, FrameStateError};
30use crate::metadata::Metadata;
31use crate::references::{
32    OutputReferenceValue, OutputReferenceValueIntroductionPoint, ReferenceExpression,
33    ReferenceValue,
34};
35use crate::relocations::{InstructionsWithRelocations, Relocation, RelocationEntry};
36
37mod array;
38mod bitwise;
39mod blake;
40mod boolean;
41mod boxing;
42mod bytes31;
43mod casts;
44mod circuit;
45mod const_type;
46mod debug;
47mod ec;
48pub mod enm;
49mod felt252;
50mod felt252_dict;
51mod function_call;
52mod gas;
53mod int;
54mod mem;
55mod misc;
56mod nullable;
57mod pedersen;
58mod poseidon;
59mod qm31;
60mod range;
61mod range_reduction;
62mod squashed_felt252_dict;
63mod starknet;
64mod structure;
65mod trace;
66
67#[cfg(test)]
68mod test_utils;
69
70#[derive(Error, Debug, Eq, PartialEq)]
71pub enum InvocationError {
72    #[error("One of the arguments does not satisfy the requirements of the libfunc.")]
73    InvalidReferenceExpressionForArgument,
74    #[error("Unexpected error - an unregistered type id used.")]
75    UnknownTypeId(ConcreteTypeId),
76    #[error("Expected a different number of arguments.")]
77    WrongNumberOfArguments { expected: usize, actual: usize },
78    #[error("The requested functionality is not implemented yet.")]
79    NotImplemented(Invocation),
80    #[error("The requested functionality is not implemented yet: {message}")]
81    NotImplementedStr { invocation: Invocation, message: String },
82    #[error("The functionality is supported only for sized types.")]
83    NotSized(Invocation),
84    #[error("Expected type data not found.")]
85    UnknownTypeData,
86    #[error("Expected variable data for statement not found.")]
87    UnknownVariableData,
88    #[error("An integer overflow occurred.")]
89    InvalidGenericArg,
90    #[error("Invalid generic argument for libfunc.")]
91    IntegerOverflow,
92    #[error(transparent)]
93    FrameStateError(#[from] FrameStateError),
94    // TODO(lior): Remove this error once not used.
95    #[error("This libfunc does not support pre-cost metadata yet.")]
96    PreCostMetadataNotSupported,
97    #[error("{output_ty} is not a contained in the circuit {circuit_ty}.")]
98    InvalidCircuitOutput { output_ty: ConcreteTypeId, circuit_ty: ConcreteTypeId },
99}
100
101/// Describes a simple change in the ap tracking itself.
102#[derive(Clone, Debug, Eq, PartialEq)]
103pub enum ApTrackingChange {
104    /// Enables the tracking if not already enabled.
105    Enable,
106    /// Disables the tracking.
107    Disable,
108    /// No changes.
109    None,
110}
111
112/// Describes the changes to the set of references at a single branch target, as well as changes to
113/// the environment.
114#[derive(Clone, Debug, Eq, PartialEq)]
115pub struct BranchChanges {
116    /// New references defined at a given branch.
117    /// should correspond to BranchInfo.results.
118    pub refs: Vec<OutputReferenceValue>,
119    /// The change to AP caused by the libfunc in the branch.
120    pub ap_change: ApChange,
121    /// A change to the ap tracking status.
122    pub ap_tracking_change: ApTrackingChange,
123    /// The change to the remaining gas value in the wallet.
124    pub gas_change: OrderedHashMap<CostTokenType, i64>,
125    /// Should the stack be cleared due to a gap between stack items.
126    pub clear_old_stack: bool,
127    /// The expected size of the known stack after the change.
128    pub new_stack_size: usize,
129}
130impl BranchChanges {
131    /// Creates a `BranchChanges` object.
132    /// `param_ref` is used to fetch the reference value of a param of the libfunc.
133    fn new<'a, ParamRef: Fn(usize) -> &'a ReferenceValue>(
134        ap_change: ApChange,
135        ap_tracking_change: ApTrackingChange,
136        gas_change: OrderedHashMap<CostTokenType, i64>,
137        expressions: impl ExactSizeIterator<Item = ReferenceExpression>,
138        branch_signature: &BranchSignature,
139        prev_env: &Environment,
140        param_ref: ParamRef,
141    ) -> Self {
142        assert_eq!(
143            expressions.len(),
144            branch_signature.vars.len(),
145            "The number of expressions does not match the number of expected results in the \
146             branch."
147        );
148        let clear_old_stack =
149            !matches!(&branch_signature.ap_change, SierraApChange::Known { new_vars_only: true });
150        let stack_base = if clear_old_stack { 0 } else { prev_env.stack_size };
151        let mut new_stack_size = stack_base;
152
153        let refs: Vec<_> = zip_eq(expressions, &branch_signature.vars)
154            .enumerate()
155            .map(|(output_idx, (expression, OutputVarInfo { ref_info, ty }))| {
156                validate_output_var_refs(ref_info, &expression);
157                let stack_idx =
158                    calc_output_var_stack_idx(ref_info, stack_base, clear_old_stack, &param_ref);
159                if let Some(stack_idx) = stack_idx {
160                    new_stack_size = new_stack_size.max(stack_idx + 1);
161                }
162                let introduction_point =
163                    if let OutputVarReferenceInfo::SameAsParam { param_idx } = ref_info {
164                        OutputReferenceValueIntroductionPoint::Existing(
165                            param_ref(*param_idx).introduction_point.clone(),
166                        )
167                    } else {
168                        // Marking the statement as unknown to be fixed later.
169                        OutputReferenceValueIntroductionPoint::New(output_idx)
170                    };
171                OutputReferenceValue { expression, ty: ty.clone(), stack_idx, introduction_point }
172            })
173            .collect();
174        validate_stack_top(ap_change, branch_signature, &refs);
175        Self { refs, ap_change, ap_tracking_change, gas_change, clear_old_stack, new_stack_size }
176    }
177}
178
179/// Validates that a new temp or local var have valid references in their matching expression.
180fn validate_output_var_refs(ref_info: &OutputVarReferenceInfo, expression: &ReferenceExpression) {
181    match ref_info {
182        OutputVarReferenceInfo::SameAsParam { .. } => {}
183        _ if expression.cells.is_empty() => {
184            assert_matches!(ref_info, OutputVarReferenceInfo::ZeroSized);
185        }
186        OutputVarReferenceInfo::ZeroSized => {
187            unreachable!("Non empty ReferenceExpression for zero sized variable.")
188        }
189        OutputVarReferenceInfo::NewTempVar { .. } => {
190            expression.cells.iter().for_each(|cell| {
191                assert_matches!(cell, CellExpression::Deref(CellRef { register: Register::AP, .. }))
192            });
193        }
194        OutputVarReferenceInfo::NewLocalVar => {
195            expression.cells.iter().for_each(|cell| {
196                assert_matches!(cell, CellExpression::Deref(CellRef { register: Register::FP, .. }))
197            });
198        }
199        OutputVarReferenceInfo::SimpleDerefs => {
200            expression
201                .cells
202                .iter()
203                .for_each(|cell| assert_matches!(cell, CellExpression::Deref(_)));
204        }
205        OutputVarReferenceInfo::PartialParam { .. } | OutputVarReferenceInfo::Deferred(_) => {}
206    };
207}
208
209/// Validates that the variables that are now on the top of the stack are contiguous and that if the
210/// stack was not broken the size of all the variables is consistent with the ap change.
211fn validate_stack_top(
212    ap_change: ApChange,
213    branch_signature: &BranchSignature,
214    refs: &[OutputReferenceValue],
215) {
216    // A mapping for the new temp vars allocated on the top of the stack from their index on the
217    // top of the stack to their index in the `refs` vector.
218    let stack_top_vars = UnorderedHashMap::<usize, usize>::from_iter(
219        branch_signature.vars.iter().enumerate().filter_map(|(arg_idx, var)| {
220            if let OutputVarReferenceInfo::NewTempVar { idx: stack_idx } = var.ref_info {
221                Some((stack_idx, arg_idx))
222            } else {
223                None
224            }
225        }),
226    );
227    let mut prev_ap_offset = None;
228    let mut stack_top_size = 0;
229    for i in 0..stack_top_vars.len() {
230        let Some(arg) = stack_top_vars.get(&i) else {
231            panic!("Missing top stack var #{i} out of {}.", stack_top_vars.len());
232        };
233        let cells = &refs[*arg].expression.cells;
234        stack_top_size += cells.len();
235        for cell in cells {
236            let ap_offset = match cell {
237                CellExpression::Deref(CellRef { register: Register::AP, offset }) => *offset,
238                _ => unreachable!("Tested in `validate_output_var_refs`."),
239            };
240            if let Some(prev_ap_offset) = prev_ap_offset {
241                assert_eq!(ap_offset, prev_ap_offset + 1, "Top stack vars are not contiguous.");
242            }
243            prev_ap_offset = Some(ap_offset);
244        }
245    }
246    if matches!(branch_signature.ap_change, SierraApChange::Known { new_vars_only: true }) {
247        assert_eq!(
248            ap_change,
249            ApChange::Known(stack_top_size),
250            "New tempvar variables are not contiguous with the old stack."
251        );
252    }
253    // TODO(orizi): Add assertion for the non-new_vars_only case, that it is optimal.
254}
255
256/// Calculates the continuous stack index for an output var of a branch.
257/// `param_ref` is used to fetch the reference value of a param of the libfunc.
258fn calc_output_var_stack_idx<'a, ParamRef: Fn(usize) -> &'a ReferenceValue>(
259    ref_info: &OutputVarReferenceInfo,
260    stack_base: usize,
261    clear_old_stack: bool,
262    param_ref: &ParamRef,
263) -> Option<usize> {
264    match ref_info {
265        OutputVarReferenceInfo::NewTempVar { idx } => Some(stack_base + idx),
266        OutputVarReferenceInfo::SameAsParam { param_idx } if !clear_old_stack => {
267            param_ref(*param_idx).stack_idx
268        }
269        OutputVarReferenceInfo::SameAsParam { .. }
270        | OutputVarReferenceInfo::SimpleDerefs
271        | OutputVarReferenceInfo::NewLocalVar
272        | OutputVarReferenceInfo::PartialParam { .. }
273        | OutputVarReferenceInfo::Deferred(_)
274        | OutputVarReferenceInfo::ZeroSized => None,
275    }
276}
277
278/// The result from a compilation of a single invocation statement.
279#[derive(Debug)]
280pub struct CompiledInvocation {
281    /// A vector of instructions that implement the invocation.
282    pub instructions: Vec<Instruction>,
283    /// A vector of static relocations.
284    pub relocations: Vec<RelocationEntry>,
285    /// A vector of BranchRefChanges, should correspond to the branches of the invocation
286    /// statement.
287    pub results: Vec<BranchChanges>,
288    /// The environment after the invocation statement.
289    pub environment: Environment,
290}
291
292/// Checks that the list of references is contiguous on the stack and ends at ap - 1.
293/// This is the requirement for function call and return statements.
294pub fn check_references_on_stack(refs: &[ReferenceValue]) -> Result<(), InvocationError> {
295    let mut expected_offset: i16 = -1;
296    for reference in refs.iter().rev() {
297        for cell_expr in reference.expression.cells.iter().rev() {
298            match cell_expr {
299                CellExpression::Deref(CellRef { register: Register::AP, offset })
300                    if *offset == expected_offset =>
301                {
302                    expected_offset -= 1;
303                }
304                _ => return Err(InvocationError::InvalidReferenceExpressionForArgument),
305            }
306        }
307    }
308    Ok(())
309}
310
311/// The cells per returned Sierra variables, in casm-builder vars.
312type VarCells = [Var];
313/// The configuration for all Sierra variables returned from a libfunc.
314type AllVars<'a> = [&'a VarCells];
315
316impl InvocationApChangeInfoProvider for CompiledInvocationBuilder<'_> {
317    fn type_size(&self, ty: &ConcreteTypeId) -> usize {
318        self.program_info.type_sizes[ty] as usize
319    }
320
321    fn token_usages(&self, token_type: CostTokenType) -> usize {
322        self.program_info
323            .metadata
324            .gas_info
325            .variable_values
326            .get(&(self.idx, token_type))
327            .copied()
328            .unwrap_or(0) as usize
329    }
330}
331
332impl InvocationCostInfoProvider for CompiledInvocationBuilder<'_> {
333    fn type_size(&self, ty: &ConcreteTypeId) -> usize {
334        self.program_info.type_sizes[ty] as usize
335    }
336
337    fn ap_change_var_value(&self) -> usize {
338        self.program_info
339            .metadata
340            .ap_change_info
341            .variable_values
342            .get(&self.idx)
343            .copied()
344            .unwrap_or_default()
345    }
346
347    fn token_usages(&self, token_type: CostTokenType) -> usize {
348        InvocationApChangeInfoProvider::token_usages(self, token_type)
349    }
350
351    fn circuit_info(&self, ty: &ConcreteTypeId) -> &CircuitInfo {
352        self.program_info.circuits_info.circuits.get(ty).unwrap()
353    }
354}
355
356/// Cost validation info for a builtin.
357struct BuiltinInfo {
358    /// The cost token type associated with the builtin.
359    cost_token_ty: CostTokenType,
360    /// The builtin pointer at the start of the libfunc.
361    start: Var,
362    /// The builtin pointer at the end of all the libfunc branches.
363    end: Var,
364}
365
366/// Information required for validating libfunc cost.
367#[derive(Default)]
368struct CostValidationInfo<const BRANCH_COUNT: usize> {
369    /// infos about builtin usage.
370    pub builtin_infos: Vec<BuiltinInfo>,
371    /// Possible extra cost per branch.
372    /// Useful for amortized costs, as well as gas withdrawal libfuncs.
373    pub extra_costs: Option<[i32; BRANCH_COUNT]>,
374}
375
376/// Helper for building compiled invocations.
377pub struct CompiledInvocationBuilder<'a> {
378    pub program_info: ProgramInfo<'a>,
379    pub invocation: &'a Invocation,
380    pub libfunc: &'a CoreConcreteLibfunc,
381    pub idx: StatementIdx,
382    /// The arguments of the libfunc.
383    pub refs: &'a [ReferenceValue],
384    pub environment: Environment,
385}
386impl CompiledInvocationBuilder<'_> {
387    /// Creates a new invocation.
388    fn build(
389        self,
390        instructions: Vec<Instruction>,
391        relocations: Vec<RelocationEntry>,
392        output_expressions: impl ExactSizeIterator<
393            Item = impl ExactSizeIterator<Item = ReferenceExpression>,
394        >,
395    ) -> CompiledInvocation {
396        let gas_changes =
397            core_libfunc_cost(&self.program_info.metadata.gas_info, &self.idx, self.libfunc, &self);
398
399        let branch_signatures = self.libfunc.branch_signatures();
400        assert_eq!(
401            branch_signatures.len(),
402            output_expressions.len(),
403            "The number of output expressions does not match signature."
404        );
405        let ap_changes = core_libfunc_ap_change(self.libfunc, &self);
406        assert_eq!(
407            branch_signatures.len(),
408            ap_changes.len(),
409            "The number of ap changes does not match signature."
410        );
411        assert_eq!(
412            branch_signatures.len(),
413            gas_changes.len(),
414            "The number of gas changes does not match signature."
415        );
416
417        CompiledInvocation {
418            instructions,
419            relocations,
420            results: zip_eq(
421                zip_eq(branch_signatures, gas_changes),
422                zip_eq(output_expressions, ap_changes),
423            )
424            .map(|((branch_signature, gas_change), (expressions, ap_change))| {
425                let ap_tracking_change = match ap_change {
426                    cairo_lang_sierra_ap_change::ApChange::EnableApTracking => {
427                        ApTrackingChange::Enable
428                    }
429                    cairo_lang_sierra_ap_change::ApChange::DisableApTracking => {
430                        ApTrackingChange::Disable
431                    }
432                    _ => ApTrackingChange::None,
433                };
434                let ap_change = match ap_change {
435                    cairo_lang_sierra_ap_change::ApChange::Known(x) => ApChange::Known(x),
436                    cairo_lang_sierra_ap_change::ApChange::AtLocalsFinalization(_)
437                    | cairo_lang_sierra_ap_change::ApChange::EnableApTracking
438                    | cairo_lang_sierra_ap_change::ApChange::DisableApTracking => {
439                        ApChange::Known(0)
440                    }
441                    cairo_lang_sierra_ap_change::ApChange::FinalizeLocals => {
442                        if let FrameState::Finalized { allocated } = self.environment.frame_state {
443                            ApChange::Known(allocated)
444                        } else {
445                            panic!("Unexpected frame state.")
446                        }
447                    }
448                    cairo_lang_sierra_ap_change::ApChange::FunctionCall(id) => self
449                        .program_info
450                        .metadata
451                        .ap_change_info
452                        .function_ap_change
453                        .get(&id)
454                        .map_or(ApChange::Unknown, |x| ApChange::Known(x + 2)),
455                    cairo_lang_sierra_ap_change::ApChange::FromMetadata => ApChange::Known(
456                        *self
457                            .program_info
458                            .metadata
459                            .ap_change_info
460                            .variable_values
461                            .get(&self.idx)
462                            .unwrap_or(&0),
463                    ),
464                    cairo_lang_sierra_ap_change::ApChange::Unknown => ApChange::Unknown,
465                };
466
467                BranchChanges::new(
468                    ap_change,
469                    ap_tracking_change,
470                    gas_change.iter().map(|(token_type, val)| (*token_type, -val)).collect(),
471                    expressions,
472                    branch_signature,
473                    &self.environment,
474                    |idx| &self.refs[idx],
475                )
476            })
477            .collect(),
478            environment: self.environment,
479        }
480    }
481
482    /// Builds a `CompiledInvocation` from a casm builder and branch extractions.
483    /// Per branch requires `(name, result_variables, target_statement_id)`.
484    fn build_from_casm_builder<const BRANCH_COUNT: usize>(
485        self,
486        casm_builder: CasmBuilder,
487        branch_extractions: [(&str, &AllVars<'_>, Option<StatementIdx>); BRANCH_COUNT],
488        cost_validation: CostValidationInfo<BRANCH_COUNT>,
489    ) -> CompiledInvocation {
490        self.build_from_casm_builder_ex(
491            casm_builder,
492            branch_extractions,
493            cost_validation,
494            Default::default(),
495        )
496    }
497
498    /// Builds a `CompiledInvocation` from a casm builder and branch extractions.
499    /// Per branch requires `(name, result_variables, target_statement_id)`.
500    ///
501    /// `pre_instructions` - Instructions to execute before the ones created by the builder.
502    fn build_from_casm_builder_ex<const BRANCH_COUNT: usize>(
503        self,
504        casm_builder: CasmBuilder,
505        branch_extractions: [(&str, &AllVars<'_>, Option<StatementIdx>); BRANCH_COUNT],
506        cost_validation: CostValidationInfo<BRANCH_COUNT>,
507        pre_instructions: InstructionsWithRelocations,
508    ) -> CompiledInvocation {
509        let CasmBuildResult { instructions, branches } =
510            casm_builder.build(branch_extractions.map(|(name, _, _)| name));
511        let expected_ap_changes = core_libfunc_ap_change(self.libfunc, &self);
512        let actual_ap_changes = branches
513            .iter()
514            .map(|(state, _)| cairo_lang_sierra_ap_change::ApChange::Known(state.ap_change));
515        if !itertools::equal(expected_ap_changes.iter().cloned(), actual_ap_changes.clone()) {
516            panic!(
517                "Wrong ap changes for {}. Expected: {expected_ap_changes:?}, actual: {:?}.",
518                self.invocation,
519                actual_ap_changes.collect_vec(),
520            );
521        }
522        let gas_changes =
523            core_libfunc_cost(&self.program_info.metadata.gas_info, &self.idx, self.libfunc, &self)
524                .into_iter()
525                .map(|costs| costs.get(&CostTokenType::Const).copied().unwrap_or_default());
526        let mut final_costs: [ConstCost; BRANCH_COUNT] =
527            std::array::from_fn(|_| Default::default());
528        for (cost, (state, _)) in final_costs.iter_mut().zip(branches.iter()) {
529            cost.steps += state.steps as i32;
530        }
531
532        for BuiltinInfo { cost_token_ty, start, end } in cost_validation.builtin_infos {
533            for (cost, (state, _)) in final_costs.iter_mut().zip(branches.iter()) {
534                let (start_base, start_offset) =
535                    state.get_adjusted(start).to_deref_with_offset().unwrap();
536                let (end_base, end_offset) =
537                    state.get_adjusted(end).to_deref_with_offset().unwrap();
538                assert_eq!(start_base, end_base);
539                let diff = end_offset - start_offset;
540                match cost_token_ty {
541                    CostTokenType::RangeCheck => {
542                        cost.range_checks += diff;
543                    }
544                    CostTokenType::RangeCheck96 => {
545                        cost.range_checks96 += diff;
546                    }
547                    _ => panic!("Cost token type not supported."),
548                }
549            }
550        }
551
552        let extra_costs =
553            cost_validation.extra_costs.unwrap_or(std::array::from_fn(|_| Default::default()));
554        let final_costs_with_extra =
555            final_costs.iter().zip(extra_costs).map(|(final_cost, extra)| {
556                (final_cost.cost() + extra + pre_instructions.cost.cost()) as i64
557            });
558        if !itertools::equal(gas_changes.clone(), final_costs_with_extra.clone()) {
559            panic!(
560                "Wrong costs for {}. Expected: {:?}, actual: {:?}, Costs from casm_builder: {:?}.",
561                self.invocation,
562                gas_changes.collect_vec(),
563                final_costs_with_extra.collect_vec(),
564                final_costs,
565            );
566        }
567        let branch_relocations = branches.iter().zip_eq(branch_extractions.iter()).flat_map(
568            |((_, relocations), (_, _, target))| {
569                assert_eq!(
570                    relocations.is_empty(),
571                    target.is_none(),
572                    "No relocations if nowhere to relocate to."
573                );
574                relocations.iter().map(|idx| RelocationEntry {
575                    instruction_idx: pre_instructions.instructions.len() + *idx,
576                    relocation: Relocation::RelativeStatementId(target.unwrap()),
577                })
578            },
579        );
580        let relocations = chain!(pre_instructions.relocations, branch_relocations).collect();
581        let output_expressions =
582            zip_eq(branches, branch_extractions).map(|((state, _), (_, vars, _))| {
583                vars.iter().map(move |var_cells| ReferenceExpression {
584                    cells: var_cells.iter().map(|cell| state.get_adjusted(*cell)).collect(),
585                })
586            });
587        self.build(
588            chain!(pre_instructions.instructions, instructions).collect(),
589            relocations,
590            output_expressions,
591        )
592    }
593
594    /// Creates a new invocation with only reference changes.
595    fn build_only_reference_changes(
596        self,
597        output_expressions: impl ExactSizeIterator<Item = ReferenceExpression>,
598    ) -> CompiledInvocation {
599        self.build(vec![], vec![], [output_expressions].into_iter())
600    }
601
602    /// Returns the reference expressions if the size is correct.
603    pub fn try_get_refs<const COUNT: usize>(
604        &self,
605    ) -> Result<[&ReferenceExpression; COUNT], InvocationError> {
606        if self.refs.len() == COUNT {
607            Ok(core::array::from_fn(|i| &self.refs[i].expression))
608        } else {
609            Err(InvocationError::WrongNumberOfArguments {
610                expected: COUNT,
611                actual: self.refs.len(),
612            })
613        }
614    }
615
616    /// Returns the reference expressions, assuming all contains one cell if the size is correct.
617    pub fn try_get_single_cells<const COUNT: usize>(
618        &self,
619    ) -> Result<[&CellExpression; COUNT], InvocationError> {
620        let refs = self.try_get_refs::<COUNT>()?;
621        let mut last_err = None;
622        const FAKE_CELL: CellExpression =
623            CellExpression::Deref(CellRef { register: Register::AP, offset: 0 });
624        // TODO(orizi): Use `refs.try_map` once it is a stable feature.
625        let result = refs.map(|r| match r.try_unpack_single() {
626            Ok(cell) => cell,
627            Err(err) => {
628                last_err = Some(err);
629                &FAKE_CELL
630            }
631        });
632        if let Some(err) = last_err { Err(err) } else { Ok(result) }
633    }
634}
635
636/// Information in the program level required for compiling an invocation.
637pub struct ProgramInfo<'a> {
638    pub metadata: &'a Metadata,
639    pub type_sizes: &'a TypeSizeMap,
640    /// Information about the circuits in the program.
641    pub circuits_info: &'a CircuitsInfo,
642    /// Returns the given a const type returns a vector of cells value representing it.
643    pub const_data_values: &'a dyn Fn(&ConcreteTypeId) -> Vec<BigInt>,
644}
645
646/// Given a Sierra invocation statement and concrete libfunc, creates a compiled casm representation
647/// of the Sierra statement.
648pub fn compile_invocation(
649    program_info: ProgramInfo<'_>,
650    invocation: &Invocation,
651    libfunc: &CoreConcreteLibfunc,
652    idx: StatementIdx,
653    refs: &[ReferenceValue],
654    environment: Environment,
655) -> Result<CompiledInvocation, InvocationError> {
656    let builder =
657        CompiledInvocationBuilder { program_info, invocation, libfunc, idx, refs, environment };
658    match libfunc {
659        Felt252(libfunc) => felt252::build(libfunc, builder),
660        Felt252SquashedDict(libfunc) => squashed_felt252_dict::build(libfunc, builder),
661        Bool(libfunc) => boolean::build(libfunc, builder),
662        Cast(libfunc) => casts::build(libfunc, builder),
663        Ec(libfunc) => ec::build(libfunc, builder),
664        Uint8(libfunc) => int::unsigned::build_uint::<_, 0x100>(libfunc, builder),
665        Uint16(libfunc) => int::unsigned::build_uint::<_, 0x10000>(libfunc, builder),
666        Uint32(libfunc) => int::unsigned::build_uint::<_, 0x100000000>(libfunc, builder),
667        Uint64(libfunc) => int::unsigned::build_uint::<_, 0x10000000000000000>(libfunc, builder),
668        Uint128(libfunc) => int::unsigned128::build(libfunc, builder),
669        Uint256(libfunc) => int::unsigned256::build(libfunc, builder),
670        Uint512(libfunc) => int::unsigned512::build(libfunc, builder),
671        Sint8(libfunc) => {
672            int::signed::build_sint::<_, { i8::MIN as i128 }, { i8::MAX as i128 }>(libfunc, builder)
673        }
674        Sint16(libfunc) => {
675            int::signed::build_sint::<_, { i16::MIN as i128 }, { i16::MAX as i128 }>(
676                libfunc, builder,
677            )
678        }
679        Sint32(libfunc) => {
680            int::signed::build_sint::<_, { i32::MIN as i128 }, { i32::MAX as i128 }>(
681                libfunc, builder,
682            )
683        }
684        Sint64(libfunc) => {
685            int::signed::build_sint::<_, { i64::MIN as i128 }, { i64::MAX as i128 }>(
686                libfunc, builder,
687            )
688        }
689        Sint128(libfunc) => int::signed128::build(libfunc, builder),
690        Gas(libfunc) => gas::build(libfunc, builder),
691        BranchAlign(_) => misc::build_branch_align(builder),
692        Array(libfunc) => array::build(libfunc, builder),
693        Drop(_) => misc::build_drop(builder),
694        Dup(_) => misc::build_dup(builder),
695        Mem(libfunc) => mem::build(libfunc, builder),
696        UnwrapNonZero(_) => misc::build_identity(builder),
697        FunctionCall(libfunc) | CouponCall(libfunc) => function_call::build(libfunc, builder),
698        UnconditionalJump(_) => misc::build_jump(builder),
699        ApTracking(_) => misc::build_update_ap_tracking(builder),
700        Box(libfunc) => boxing::build(libfunc, builder),
701        Enum(libfunc) => enm::build(libfunc, builder),
702        Struct(libfunc) => structure::build(libfunc, builder),
703        Felt252Dict(libfunc) => felt252_dict::build_dict(libfunc, builder),
704        Pedersen(libfunc) => pedersen::build(libfunc, builder),
705        Poseidon(libfunc) => poseidon::build(libfunc, builder),
706        Starknet(libfunc) => starknet::build(libfunc, builder),
707        Nullable(libfunc) => nullable::build(libfunc, builder),
708        Debug(libfunc) => debug::build(libfunc, builder),
709        SnapshotTake(_) => misc::build_dup(builder),
710        Felt252DictEntry(libfunc) => felt252_dict::build_entry(libfunc, builder),
711        Bytes31(libfunc) => bytes31::build(libfunc, builder),
712        Const(libfunc) => const_type::build(libfunc, builder),
713        Coupon(libfunc) => match libfunc {
714            CouponConcreteLibfunc::Buy(_) => Ok(builder
715                .build_only_reference_changes([ReferenceExpression::zero_sized()].into_iter())),
716            CouponConcreteLibfunc::Refund(_) => {
717                Ok(builder.build_only_reference_changes([].into_iter()))
718            }
719        },
720        BoundedInt(libfunc) => int::bounded::build(libfunc, builder),
721        Circuit(libfunc) => circuit::build(libfunc, builder),
722        IntRange(libfunc) => range::build(libfunc, builder),
723        Blake(libfunc) => blake::build(libfunc, builder),
724        Trace(libfunc) => trace::build(libfunc, builder),
725        QM31(libfunc) => qm31::build(libfunc, builder),
726    }
727}
728
729/// A trait for views of the Complex ReferenceExpressions as specific data structures (e.g.
730/// enum/array).
731trait ReferenceExpressionView: Sized {
732    type Error;
733    /// Extracts the specific view from the reference expressions. Can include validations and thus
734    /// returns a result.
735    /// `concrete_type_id` - the concrete type this view should represent.
736    fn try_get_view(
737        expr: &ReferenceExpression,
738        program_info: &ProgramInfo<'_>,
739        concrete_type_id: &ConcreteTypeId,
740    ) -> Result<Self, Self::Error>;
741    /// Converts the view into a ReferenceExpression.
742    fn to_reference_expression(self) -> ReferenceExpression;
743}
744
745/// Fetches the non-fallthrough jump target of the invocation, assuming this invocation is a
746/// conditional jump.
747pub fn get_non_fallthrough_statement_id(builder: &CompiledInvocationBuilder<'_>) -> StatementIdx {
748    match builder.invocation.branches.as_slice() {
749        [
750            BranchInfo { target: BranchTarget::Fallthrough, results: _ },
751            BranchInfo { target: BranchTarget::Statement(target_statement_id), results: _ },
752        ] => *target_statement_id,
753        _ => panic!("malformed invocation"),
754    }
755}
756
757/// Adds input variables into the builder while validating their type.
758macro_rules! add_input_variables {
759    ($casm_builder:ident,) => {};
760    ($casm_builder:ident, deref $var:ident; $($tok:tt)*) => {
761        let $var = $casm_builder.add_var(cairo_lang_casm::cell_expression::CellExpression::Deref(
762            $var.to_deref().ok_or(InvocationError::InvalidReferenceExpressionForArgument)?,
763        ));
764        $crate::invocations::add_input_variables!($casm_builder, $($tok)*)
765    };
766    ($casm_builder:ident, deref_or_immediate $var:ident; $($tok:tt)*) => {
767        let $var = $casm_builder.add_var(
768            match $var
769                .to_deref_or_immediate()
770                .ok_or(InvocationError::InvalidReferenceExpressionForArgument)?
771            {
772                cairo_lang_casm::operand::DerefOrImmediate::Deref(cell) => {
773                    cairo_lang_casm::cell_expression::CellExpression::Deref(cell)
774                }
775                cairo_lang_casm::operand::DerefOrImmediate::Immediate(cell) => {
776                    cairo_lang_casm::cell_expression::CellExpression::Immediate(cell.value)
777                }
778            },
779        );
780        $crate::invocations::add_input_variables!($casm_builder, $($tok)*)
781    };
782    ($casm_builder:ident, buffer($slack:expr) $var:ident; $($tok:tt)*) => {
783        let $var = $casm_builder.add_var(
784            $var.to_buffer($slack).ok_or(InvocationError::InvalidReferenceExpressionForArgument)?,
785        );
786        $crate::invocations::add_input_variables!($casm_builder, $($tok)*)
787    };
788}
789use add_input_variables;