cairo_lang_sierra_to_casm/invocations/
mod.rs

1use assert_matches::assert_matches;
2use cairo_lang_casm::ap_change::ApChange;
3use cairo_lang_casm::builder::{CasmBuildResult, CasmBuilder, Var};
4use cairo_lang_casm::cell_expression::CellExpression;
5use cairo_lang_casm::instructions::Instruction;
6use cairo_lang_casm::operand::{CellRef, Register};
7use cairo_lang_sierra::extensions::circuit::CircuitInfo;
8use cairo_lang_sierra::extensions::core::CoreConcreteLibfunc::{self, *};
9use cairo_lang_sierra::extensions::coupon::CouponConcreteLibfunc;
10use cairo_lang_sierra::extensions::gas::{CostTokenMap, CostTokenType};
11use cairo_lang_sierra::extensions::lib_func::{BranchSignature, OutputVarInfo, SierraApChange};
12use cairo_lang_sierra::extensions::{ConcreteLibfunc, OutputVarReferenceInfo};
13use cairo_lang_sierra::ids::ConcreteTypeId;
14use cairo_lang_sierra::program::{BranchInfo, BranchTarget, Invocation, StatementIdx};
15use cairo_lang_sierra_ap_change::core_libfunc_ap_change::{
16    InvocationApChangeInfoProvider, core_libfunc_ap_change,
17};
18use cairo_lang_sierra_gas::core_libfunc_cost::{InvocationCostInfoProvider, core_libfunc_cost};
19use cairo_lang_sierra_gas::objects::ConstCost;
20use cairo_lang_sierra_type_size::TypeSizeMap;
21use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
22use itertools::{Itertools, chain, zip_eq};
23use num_bigint::BigInt;
24use thiserror::Error;
25
26use crate::circuit::CircuitsInfo;
27use crate::environment::Environment;
28use crate::environment::frame_state::{FrameState, FrameStateError};
29use crate::metadata::Metadata;
30use crate::references::{
31    OutputReferenceValue, OutputReferenceValueIntroductionPoint, ReferenceExpression,
32    ReferenceValue,
33};
34use crate::relocations::{InstructionsWithRelocations, Relocation, RelocationEntry};
35
36mod array;
37mod bitwise;
38mod blake;
39mod boolean;
40mod boxing;
41mod bytes31;
42mod casts;
43mod circuit;
44mod const_type;
45mod debug;
46mod ec;
47pub mod enm;
48mod felt252;
49mod felt252_dict;
50mod function_call;
51mod gas;
52mod gas_reserve;
53mod int;
54mod mem;
55mod misc;
56mod nullable;
57mod pedersen;
58mod poseidon;
59mod qm31;
60mod range;
61mod range_reduction;
62mod squashed_felt252_dict;
63mod starknet;
64mod structure;
65mod trace;
66mod unsafe_panic;
67
68#[cfg(test)]
69mod test_utils;
70
71#[derive(Error, Debug, Eq, PartialEq)]
72pub enum InvocationError {
73    #[error("One of the arguments does not satisfy the requirements of the libfunc.")]
74    InvalidReferenceExpressionForArgument,
75    #[error("Unexpected error - an unregistered type id used.")]
76    UnknownTypeId(ConcreteTypeId),
77    #[error("Expected a different number of arguments.")]
78    WrongNumberOfArguments { expected: usize, actual: usize },
79    #[error("The requested functionality is not implemented yet.")]
80    NotImplemented(Invocation),
81    #[error("The requested functionality is not implemented yet: {message}")]
82    NotImplementedStr { invocation: Invocation, message: String },
83    #[error("The functionality is supported only for sized types.")]
84    NotSized(Invocation),
85    #[error("Expected type data not found.")]
86    UnknownTypeData,
87    #[error("Expected variable data for statement not found.")]
88    UnknownVariableData,
89    #[error("Invalid generic argument for libfunc.")]
90    InvalidGenericArg,
91    #[error("An integer overflow occurred.")]
92    IntegerOverflow,
93    #[error(transparent)]
94    FrameStateError(#[from] FrameStateError),
95    // TODO(lior): Remove this error once not used.
96    #[error("This libfunc does not support pre-cost metadata yet.")]
97    PreCostMetadataNotSupported,
98    #[error("{output_ty} is not a contained in the circuit {circuit_ty}.")]
99    InvalidCircuitOutput { output_ty: ConcreteTypeId, circuit_ty: ConcreteTypeId },
100}
101
102/// Describes a simple change in the ap tracking itself.
103#[derive(Clone, Debug, Eq, PartialEq)]
104pub enum ApTrackingChange {
105    /// Enables the tracking if not already enabled.
106    Enable,
107    /// Disables the tracking.
108    Disable,
109    /// No changes.
110    None,
111}
112
113/// Describes the changes to the set of references at a single branch target, as well as changes to
114/// the environment.
115#[derive(Clone, Debug, Eq, PartialEq)]
116pub struct BranchChanges {
117    /// New references defined at a given branch.
118    /// should correspond to BranchInfo.results.
119    pub refs: Vec<OutputReferenceValue>,
120    /// The change to AP caused by the libfunc in the branch.
121    pub ap_change: ApChange,
122    /// A change to the ap tracking status.
123    pub ap_tracking_change: ApTrackingChange,
124    /// The change to the remaining gas value in the wallet.
125    pub gas_change: CostTokenMap<i64>,
126    /// Should the stack be cleared due to a gap between stack items.
127    pub clear_old_stack: bool,
128    /// The expected size of the known stack after the change.
129    pub new_stack_size: usize,
130}
131impl BranchChanges {
132    /// Creates a `BranchChanges` object.
133    /// `param_ref` is used to fetch the reference value of a param of the libfunc.
134    fn new<'a, ParamRef: Fn(usize) -> &'a ReferenceValue>(
135        ap_change: ApChange,
136        ap_tracking_change: ApTrackingChange,
137        gas_change: CostTokenMap<i64>,
138        expressions: impl ExactSizeIterator<Item = ReferenceExpression>,
139        branch_signature: &BranchSignature,
140        prev_env: &Environment,
141        param_ref: ParamRef,
142    ) -> Self {
143        assert_eq!(
144            expressions.len(),
145            branch_signature.vars.len(),
146            "The number of expressions does not match the number of expected results in the \
147             branch."
148        );
149        let clear_old_stack =
150            !matches!(&branch_signature.ap_change, SierraApChange::Known { new_vars_only: true });
151        let stack_base = if clear_old_stack { 0 } else { prev_env.stack_size };
152        let mut new_stack_size = stack_base;
153
154        let refs: Vec<_> = zip_eq(expressions, &branch_signature.vars)
155            .enumerate()
156            .map(|(output_idx, (expression, OutputVarInfo { ref_info, ty }))| {
157                validate_output_var_refs(ref_info, &expression);
158                let stack_idx =
159                    calc_output_var_stack_idx(ref_info, stack_base, clear_old_stack, &param_ref);
160                if let Some(stack_idx) = stack_idx {
161                    new_stack_size = new_stack_size.max(stack_idx + 1);
162                }
163                let introduction_point =
164                    if let OutputVarReferenceInfo::SameAsParam { param_idx } = ref_info {
165                        OutputReferenceValueIntroductionPoint::Existing(
166                            param_ref(*param_idx).introduction_point.clone(),
167                        )
168                    } else {
169                        // Marking the statement as unknown to be fixed later.
170                        OutputReferenceValueIntroductionPoint::New(output_idx)
171                    };
172                OutputReferenceValue { expression, ty: ty.clone(), stack_idx, introduction_point }
173            })
174            .collect();
175        validate_stack_top(ap_change, branch_signature, &refs);
176        Self { refs, ap_change, ap_tracking_change, gas_change, clear_old_stack, new_stack_size }
177    }
178}
179
180/// Validates that a new temp or local var have valid references in their matching expression.
181fn validate_output_var_refs(ref_info: &OutputVarReferenceInfo, expression: &ReferenceExpression) {
182    match ref_info {
183        OutputVarReferenceInfo::SameAsParam { .. } => {}
184        _ if expression.cells.is_empty() => {
185            assert_matches!(ref_info, OutputVarReferenceInfo::ZeroSized);
186        }
187        OutputVarReferenceInfo::ZeroSized => {
188            unreachable!("Non empty ReferenceExpression for zero sized variable.")
189        }
190        OutputVarReferenceInfo::NewTempVar { .. } => {
191            expression.cells.iter().for_each(|cell| {
192                assert_matches!(cell, CellExpression::Deref(CellRef { register: Register::AP, .. }))
193            });
194        }
195        OutputVarReferenceInfo::NewLocalVar => {
196            expression.cells.iter().for_each(|cell| {
197                assert_matches!(cell, CellExpression::Deref(CellRef { register: Register::FP, .. }))
198            });
199        }
200        OutputVarReferenceInfo::SimpleDerefs => {
201            expression
202                .cells
203                .iter()
204                .for_each(|cell| assert_matches!(cell, CellExpression::Deref(_)));
205        }
206        OutputVarReferenceInfo::PartialParam { .. } | OutputVarReferenceInfo::Deferred(_) => {}
207    };
208}
209
210/// Validates that the variables that are now on the top of the stack are contiguous and that if the
211/// stack was not broken the size of all the variables is consistent with the ap change.
212fn validate_stack_top(
213    ap_change: ApChange,
214    branch_signature: &BranchSignature,
215    refs: &[OutputReferenceValue],
216) {
217    // A mapping for the new temp vars allocated on the top of the stack from their index on the
218    // top of the stack to their index in the `refs` vector.
219    let stack_top_vars = UnorderedHashMap::<usize, usize>::from_iter(
220        branch_signature.vars.iter().enumerate().filter_map(|(arg_idx, var)| {
221            if let OutputVarReferenceInfo::NewTempVar { idx: stack_idx } = var.ref_info {
222                Some((stack_idx, arg_idx))
223            } else {
224                None
225            }
226        }),
227    );
228    let mut prev_ap_offset = None;
229    let mut stack_top_size = 0;
230    for i in 0..stack_top_vars.len() {
231        let Some(arg) = stack_top_vars.get(&i) else {
232            panic!("Missing top stack var #{i} out of {}.", stack_top_vars.len());
233        };
234        let cells = &refs[*arg].expression.cells;
235        stack_top_size += cells.len();
236        for cell in cells {
237            let ap_offset = match cell {
238                CellExpression::Deref(CellRef { register: Register::AP, offset }) => *offset,
239                _ => unreachable!("Tested in `validate_output_var_refs`."),
240            };
241            if let Some(prev_ap_offset) = prev_ap_offset {
242                assert_eq!(ap_offset, prev_ap_offset + 1, "Top stack vars are not contiguous.");
243            }
244            prev_ap_offset = Some(ap_offset);
245        }
246    }
247    if matches!(branch_signature.ap_change, SierraApChange::Known { new_vars_only: true }) {
248        assert_eq!(
249            ap_change,
250            ApChange::Known(stack_top_size),
251            "New tempvar variables are not contiguous with the old stack."
252        );
253    }
254    // TODO(orizi): Add assertion for the non-new_vars_only case, that it is optimal.
255}
256
257/// Calculates the continuous stack index for an output var of a branch.
258/// `param_ref` is used to fetch the reference value of a param of the libfunc.
259fn calc_output_var_stack_idx<'a, ParamRef: Fn(usize) -> &'a ReferenceValue>(
260    ref_info: &OutputVarReferenceInfo,
261    stack_base: usize,
262    clear_old_stack: bool,
263    param_ref: &ParamRef,
264) -> Option<usize> {
265    match ref_info {
266        OutputVarReferenceInfo::NewTempVar { idx } => Some(stack_base + idx),
267        OutputVarReferenceInfo::SameAsParam { param_idx } if !clear_old_stack => {
268            param_ref(*param_idx).stack_idx
269        }
270        OutputVarReferenceInfo::SameAsParam { .. }
271        | OutputVarReferenceInfo::SimpleDerefs
272        | OutputVarReferenceInfo::NewLocalVar
273        | OutputVarReferenceInfo::PartialParam { .. }
274        | OutputVarReferenceInfo::Deferred(_)
275        | OutputVarReferenceInfo::ZeroSized => None,
276    }
277}
278
279/// The result from a compilation of a single invocation statement.
280#[derive(Debug)]
281pub struct CompiledInvocation {
282    /// A vector of instructions that implement the invocation.
283    pub instructions: Vec<Instruction>,
284    /// A vector of static relocations.
285    pub relocations: Vec<RelocationEntry>,
286    /// A vector of BranchRefChanges, should correspond to the branches of the invocation
287    /// statement.
288    pub results: Vec<BranchChanges>,
289    /// The environment after the invocation statement.
290    pub environment: Environment,
291}
292
293/// Checks that the list of references is contiguous on the stack and ends at ap - 1.
294/// This is the requirement for function call and return statements.
295pub fn check_references_on_stack(refs: &[ReferenceValue]) -> Result<(), InvocationError> {
296    let mut expected_offset: i16 = -1;
297    for reference in refs.iter().rev() {
298        for cell_expr in reference.expression.cells.iter().rev() {
299            match cell_expr {
300                CellExpression::Deref(CellRef { register: Register::AP, offset })
301                    if *offset == expected_offset =>
302                {
303                    expected_offset -= 1;
304                }
305                _ => return Err(InvocationError::InvalidReferenceExpressionForArgument),
306            }
307        }
308    }
309    Ok(())
310}
311
312/// The cells per returned Sierra variables, in CASM-builder vars.
313type VarCells = [Var];
314/// The configuration for all Sierra variables returned from a libfunc.
315type AllVars<'a> = [&'a VarCells];
316
317impl InvocationApChangeInfoProvider for CompiledInvocationBuilder<'_> {
318    fn type_size(&self, ty: &ConcreteTypeId) -> usize {
319        self.program_info.type_sizes[ty] as usize
320    }
321
322    fn token_usages(&self, token_type: CostTokenType) -> usize {
323        self.program_info
324            .metadata
325            .gas_info
326            .variable_values
327            .get(&(self.idx, token_type))
328            .copied()
329            .unwrap_or(0) as usize
330    }
331}
332
333impl InvocationCostInfoProvider for CompiledInvocationBuilder<'_> {
334    fn type_size(&self, ty: &ConcreteTypeId) -> usize {
335        self.program_info.type_sizes[ty] as usize
336    }
337
338    fn ap_change_var_value(&self) -> usize {
339        self.program_info
340            .metadata
341            .ap_change_info
342            .variable_values
343            .get(&self.idx)
344            .copied()
345            .unwrap_or_default()
346    }
347
348    fn token_usages(&self, token_type: CostTokenType) -> usize {
349        InvocationApChangeInfoProvider::token_usages(self, token_type)
350    }
351
352    fn circuit_info(&self, ty: &ConcreteTypeId) -> &CircuitInfo {
353        self.program_info.circuits_info.circuits.get(ty).unwrap()
354    }
355}
356
357/// Cost validation info for a builtin.
358struct BuiltinInfo {
359    /// The cost token type associated with the builtin.
360    cost_token_ty: CostTokenType,
361    /// The builtin pointer at the start of the libfunc.
362    start: Var,
363    /// The builtin pointer at the end of all the libfunc branches.
364    end: Var,
365}
366
367/// Information required for validating libfunc cost.
368#[derive(Default)]
369struct CostValidationInfo<const BRANCH_COUNT: usize> {
370    /// infos about builtin usage.
371    pub builtin_infos: Vec<BuiltinInfo>,
372    /// Possible extra cost per branch.
373    /// Useful for amortized costs, as well as gas withdrawal libfuncs.
374    pub extra_costs: Option<[i32; BRANCH_COUNT]>,
375}
376
377/// Helper for building compiled invocations.
378pub struct CompiledInvocationBuilder<'a> {
379    pub program_info: ProgramInfo<'a>,
380    pub invocation: &'a Invocation,
381    pub libfunc: &'a CoreConcreteLibfunc,
382    pub idx: StatementIdx,
383    /// The arguments of the libfunc.
384    pub refs: &'a [ReferenceValue],
385    pub environment: Environment,
386}
387impl CompiledInvocationBuilder<'_> {
388    /// Creates a new invocation.
389    fn build(
390        self,
391        instructions: Vec<Instruction>,
392        relocations: Vec<RelocationEntry>,
393        output_expressions: impl ExactSizeIterator<
394            Item = impl ExactSizeIterator<Item = ReferenceExpression>,
395        >,
396    ) -> CompiledInvocation {
397        let gas_changes =
398            core_libfunc_cost(&self.program_info.metadata.gas_info, &self.idx, self.libfunc, &self);
399
400        let branch_signatures = self.libfunc.branch_signatures();
401        assert_eq!(
402            branch_signatures.len(),
403            output_expressions.len(),
404            "The number of output expressions does not match signature."
405        );
406        let ap_changes = core_libfunc_ap_change(self.libfunc, &self);
407        assert_eq!(
408            branch_signatures.len(),
409            ap_changes.len(),
410            "The number of ap changes does not match signature."
411        );
412        assert_eq!(
413            branch_signatures.len(),
414            gas_changes.len(),
415            "The number of gas changes does not match signature."
416        );
417
418        CompiledInvocation {
419            instructions,
420            relocations,
421            results: zip_eq(
422                zip_eq(branch_signatures, gas_changes),
423                zip_eq(output_expressions, ap_changes),
424            )
425            .map(|((branch_signature, gas_change), (expressions, ap_change))| {
426                let ap_tracking_change = match ap_change {
427                    cairo_lang_sierra_ap_change::ApChange::EnableApTracking => {
428                        ApTrackingChange::Enable
429                    }
430                    cairo_lang_sierra_ap_change::ApChange::DisableApTracking => {
431                        ApTrackingChange::Disable
432                    }
433                    _ => ApTrackingChange::None,
434                };
435                let ap_change = match ap_change {
436                    cairo_lang_sierra_ap_change::ApChange::Known(x) => ApChange::Known(x),
437                    cairo_lang_sierra_ap_change::ApChange::AtLocalsFinalization(_)
438                    | cairo_lang_sierra_ap_change::ApChange::EnableApTracking
439                    | cairo_lang_sierra_ap_change::ApChange::DisableApTracking => {
440                        ApChange::Known(0)
441                    }
442                    cairo_lang_sierra_ap_change::ApChange::FinalizeLocals => {
443                        if let FrameState::Finalized { allocated } = self.environment.frame_state {
444                            ApChange::Known(allocated)
445                        } else {
446                            panic!("Unexpected frame state.")
447                        }
448                    }
449                    cairo_lang_sierra_ap_change::ApChange::FunctionCall(id) => self
450                        .program_info
451                        .metadata
452                        .ap_change_info
453                        .function_ap_change
454                        .get(&id)
455                        .map_or(ApChange::Unknown, |x| ApChange::Known(x + 2)),
456                    cairo_lang_sierra_ap_change::ApChange::FromMetadata => ApChange::Known(
457                        *self
458                            .program_info
459                            .metadata
460                            .ap_change_info
461                            .variable_values
462                            .get(&self.idx)
463                            .unwrap_or(&0),
464                    ),
465                    cairo_lang_sierra_ap_change::ApChange::Unknown => ApChange::Unknown,
466                };
467
468                BranchChanges::new(
469                    ap_change,
470                    ap_tracking_change,
471                    gas_change.iter().map(|(token_type, val)| (*token_type, -val)).collect(),
472                    expressions,
473                    branch_signature,
474                    &self.environment,
475                    |idx| &self.refs[idx],
476                )
477            })
478            .collect(),
479            environment: self.environment,
480        }
481    }
482
483    /// Builds a `CompiledInvocation` from a CASM builder and branch extractions.
484    /// Per branch requires `(name, result_variables, target_statement_id)`.
485    fn build_from_casm_builder<const BRANCH_COUNT: usize>(
486        self,
487        casm_builder: CasmBuilder,
488        branch_extractions: [(&str, &AllVars<'_>, Option<StatementIdx>); BRANCH_COUNT],
489        cost_validation: CostValidationInfo<BRANCH_COUNT>,
490    ) -> CompiledInvocation {
491        self.build_from_casm_builder_ex(
492            casm_builder,
493            branch_extractions,
494            cost_validation,
495            Default::default(),
496        )
497    }
498
499    /// Builds a `CompiledInvocation` from a CASM builder and branch extractions.
500    /// Per branch requires `(name, result_variables, target_statement_id)`.
501    ///
502    /// `pre_instructions` - Instructions to execute before the ones created by the builder.
503    fn build_from_casm_builder_ex<const BRANCH_COUNT: usize>(
504        self,
505        casm_builder: CasmBuilder,
506        branch_extractions: [(&str, &AllVars<'_>, Option<StatementIdx>); BRANCH_COUNT],
507        cost_validation: CostValidationInfo<BRANCH_COUNT>,
508        pre_instructions: InstructionsWithRelocations,
509    ) -> CompiledInvocation {
510        let CasmBuildResult { instructions, branches } =
511            casm_builder.build(branch_extractions.map(|(name, _, _)| name));
512        let expected_ap_changes = core_libfunc_ap_change(self.libfunc, &self);
513        let actual_ap_changes = branches
514            .iter()
515            .map(|(state, _)| cairo_lang_sierra_ap_change::ApChange::Known(state.ap_change));
516        if !itertools::equal(expected_ap_changes.iter().cloned(), actual_ap_changes.clone()) {
517            panic!(
518                "Wrong ap changes for {}. Expected: {expected_ap_changes:?}, actual: {:?}.",
519                self.invocation,
520                actual_ap_changes.collect_vec(),
521            );
522        }
523        let gas_changes =
524            core_libfunc_cost(&self.program_info.metadata.gas_info, &self.idx, self.libfunc, &self)
525                .into_iter()
526                .map(|costs| costs.get(&CostTokenType::Const).copied().unwrap_or_default());
527        let mut final_costs: [ConstCost; BRANCH_COUNT] =
528            std::array::from_fn(|_| Default::default());
529        for (cost, (state, _)) in final_costs.iter_mut().zip(branches.iter()) {
530            cost.steps += state.steps as i32;
531        }
532
533        for BuiltinInfo { cost_token_ty, start, end } in cost_validation.builtin_infos {
534            for (cost, (state, _)) in final_costs.iter_mut().zip(branches.iter()) {
535                let (start_base, start_offset) =
536                    state.get_adjusted(start).to_deref_with_offset().unwrap();
537                let (end_base, end_offset) =
538                    state.get_adjusted(end).to_deref_with_offset().unwrap();
539                assert_eq!(start_base, end_base);
540                let diff = end_offset - start_offset;
541                match cost_token_ty {
542                    CostTokenType::RangeCheck => {
543                        cost.range_checks += diff;
544                    }
545                    CostTokenType::RangeCheck96 => {
546                        cost.range_checks96 += diff;
547                    }
548                    _ => panic!("Cost token type not supported."),
549                }
550            }
551        }
552
553        let extra_costs =
554            cost_validation.extra_costs.unwrap_or(std::array::from_fn(|_| Default::default()));
555        let final_costs_with_extra =
556            final_costs.iter().zip(extra_costs).map(|(final_cost, extra)| {
557                (final_cost.cost() + extra + pre_instructions.cost.cost()) as i64
558            });
559        if !itertools::equal(gas_changes.clone(), final_costs_with_extra.clone()) {
560            panic!(
561                "Wrong costs for {}. Expected: {:?}, actual: {:?}, Costs from casm_builder: {:?}.",
562                self.invocation,
563                gas_changes.collect_vec(),
564                final_costs_with_extra.collect_vec(),
565                final_costs,
566            );
567        }
568        let branch_relocations = branches.iter().zip_eq(branch_extractions.iter()).flat_map(
569            |((_, relocations), (_, _, target))| {
570                assert_eq!(
571                    relocations.is_empty(),
572                    target.is_none(),
573                    "No relocations if nowhere to relocate to."
574                );
575                relocations.iter().map(|idx| RelocationEntry {
576                    instruction_idx: pre_instructions.instructions.len() + *idx,
577                    relocation: Relocation::RelativeStatementId(target.unwrap()),
578                })
579            },
580        );
581        let relocations = chain!(pre_instructions.relocations, branch_relocations).collect();
582        let output_expressions =
583            zip_eq(branches, branch_extractions).map(|((state, _), (_, vars, _))| {
584                vars.iter().map(move |var_cells| ReferenceExpression {
585                    cells: var_cells.iter().map(|cell| state.get_adjusted(*cell)).collect(),
586                })
587            });
588        self.build(
589            chain!(pre_instructions.instructions, instructions).collect(),
590            relocations,
591            output_expressions,
592        )
593    }
594
595    /// Creates a new invocation with only reference changes.
596    fn build_only_reference_changes(
597        self,
598        output_expressions: impl ExactSizeIterator<Item = ReferenceExpression>,
599    ) -> CompiledInvocation {
600        self.build(vec![], vec![], [output_expressions].into_iter())
601    }
602
603    /// Returns the reference expressions if the size is correct.
604    pub fn try_get_refs<const COUNT: usize>(
605        &self,
606    ) -> Result<[&ReferenceExpression; COUNT], InvocationError> {
607        if self.refs.len() == COUNT {
608            Ok(core::array::from_fn(|i| &self.refs[i].expression))
609        } else {
610            Err(InvocationError::WrongNumberOfArguments {
611                expected: COUNT,
612                actual: self.refs.len(),
613            })
614        }
615    }
616
617    /// Returns the reference expressions, assuming all contains one cell if the size is correct.
618    pub fn try_get_single_cells<const COUNT: usize>(
619        &self,
620    ) -> Result<[&CellExpression; COUNT], InvocationError> {
621        let refs = self.try_get_refs::<COUNT>()?;
622        let mut last_err = None;
623        const FAKE_CELL: CellExpression =
624            CellExpression::Deref(CellRef { register: Register::AP, offset: 0 });
625        // TODO(orizi): Use `refs.try_map` once it is a stable feature.
626        let result = refs.map(|r| match r.try_unpack_single() {
627            Ok(cell) => cell,
628            Err(err) => {
629                last_err = Some(err);
630                &FAKE_CELL
631            }
632        });
633        if let Some(err) = last_err { Err(err) } else { Ok(result) }
634    }
635}
636
637/// Information in the program level required for compiling an invocation.
638pub struct ProgramInfo<'a> {
639    pub metadata: &'a Metadata,
640    pub type_sizes: &'a TypeSizeMap,
641    /// Information about the circuits in the program.
642    pub circuits_info: &'a CircuitsInfo,
643    /// Returns the given a const type returns a vector of cells value representing it.
644    pub const_data_values: &'a dyn Fn(&ConcreteTypeId) -> Vec<BigInt>,
645}
646
647/// Given a Sierra invocation statement and concrete libfunc, creates a compiled CASM representation
648/// of the Sierra statement.
649pub fn compile_invocation(
650    program_info: ProgramInfo<'_>,
651    invocation: &Invocation,
652    libfunc: &CoreConcreteLibfunc,
653    idx: StatementIdx,
654    refs: &[ReferenceValue],
655    environment: Environment,
656) -> Result<CompiledInvocation, InvocationError> {
657    let builder =
658        CompiledInvocationBuilder { program_info, invocation, libfunc, idx, refs, environment };
659    match libfunc {
660        Felt252(libfunc) => felt252::build(libfunc, builder),
661        Felt252SquashedDict(libfunc) => squashed_felt252_dict::build(libfunc, builder),
662        Bool(libfunc) => boolean::build(libfunc, builder),
663        Cast(libfunc) => casts::build(libfunc, builder),
664        Ec(libfunc) => ec::build(libfunc, builder),
665        Uint8(libfunc) => int::unsigned::build_uint::<_, 0x100>(libfunc, builder),
666        Uint16(libfunc) => int::unsigned::build_uint::<_, 0x10000>(libfunc, builder),
667        Uint32(libfunc) => int::unsigned::build_uint::<_, 0x100000000>(libfunc, builder),
668        Uint64(libfunc) => int::unsigned::build_uint::<_, 0x10000000000000000>(libfunc, builder),
669        Uint128(libfunc) => int::unsigned128::build(libfunc, builder),
670        Uint256(libfunc) => int::unsigned256::build(libfunc, builder),
671        Uint512(libfunc) => int::unsigned512::build(libfunc, builder),
672        Sint8(libfunc) => {
673            int::signed::build_sint::<_, { i8::MIN as i128 }, { i8::MAX as i128 }>(libfunc, builder)
674        }
675        Sint16(libfunc) => {
676            int::signed::build_sint::<_, { i16::MIN as i128 }, { i16::MAX as i128 }>(
677                libfunc, builder,
678            )
679        }
680        Sint32(libfunc) => {
681            int::signed::build_sint::<_, { i32::MIN as i128 }, { i32::MAX as i128 }>(
682                libfunc, builder,
683            )
684        }
685        Sint64(libfunc) => {
686            int::signed::build_sint::<_, { i64::MIN as i128 }, { i64::MAX as i128 }>(
687                libfunc, builder,
688            )
689        }
690        Sint128(libfunc) => int::signed128::build(libfunc, builder),
691        Gas(libfunc) => gas::build(libfunc, builder),
692        GasReserve(libfunc) => gas_reserve::build(libfunc, builder),
693        BranchAlign(_) => misc::build_branch_align(builder),
694        Array(libfunc) => array::build(libfunc, builder),
695        Drop(_) => misc::build_drop(builder),
696        Dup(_) => misc::build_dup(builder),
697        Mem(libfunc) => mem::build(libfunc, builder),
698        UnwrapNonZero(_) => misc::build_identity(builder),
699        FunctionCall(libfunc) | CouponCall(libfunc) => function_call::build(libfunc, builder, true),
700        DummyFunctionCall(libfunc) => function_call::build(libfunc, builder, false),
701        UnconditionalJump(_) => misc::build_jump(builder),
702        ApTracking(_) => misc::build_update_ap_tracking(builder),
703        Box(libfunc) => boxing::build(libfunc, builder),
704        Enum(libfunc) => enm::build(libfunc, builder),
705        Struct(libfunc) => structure::build(libfunc, builder),
706        Felt252Dict(libfunc) => felt252_dict::build_dict(libfunc, builder),
707        Pedersen(libfunc) => pedersen::build(libfunc, builder),
708        Poseidon(libfunc) => poseidon::build(libfunc, builder),
709        Starknet(libfunc) => starknet::build(libfunc, builder),
710        Nullable(libfunc) => nullable::build(libfunc, builder),
711        Debug(libfunc) => debug::build(libfunc, builder),
712        SnapshotTake(_) => misc::build_dup(builder),
713        Felt252DictEntry(libfunc) => felt252_dict::build_entry(libfunc, builder),
714        Bytes31(libfunc) => bytes31::build(libfunc, builder),
715        Const(libfunc) => const_type::build(libfunc, builder),
716        Coupon(libfunc) => match libfunc {
717            CouponConcreteLibfunc::Buy(_) => Ok(builder
718                .build_only_reference_changes([ReferenceExpression::zero_sized()].into_iter())),
719            CouponConcreteLibfunc::Refund(_) => {
720                Ok(builder.build_only_reference_changes([].into_iter()))
721            }
722        },
723        BoundedInt(libfunc) => int::bounded::build(libfunc, builder),
724        Circuit(libfunc) => circuit::build(libfunc, builder),
725        IntRange(libfunc) => range::build(libfunc, builder),
726        Blake(libfunc) => blake::build(libfunc, builder),
727        Trace(libfunc) => trace::build(libfunc, builder),
728        QM31(libfunc) => qm31::build(libfunc, builder),
729        UnsafePanic(_) => unsafe_panic::build(builder),
730    }
731}
732
733/// A trait for views of the Complex ReferenceExpressions as specific data structures (e.g.
734/// enum/array).
735trait ReferenceExpressionView: Sized {
736    type Error;
737    /// Extracts the specific view from the reference expressions. Can include validations and thus
738    /// returns a result.
739    /// `concrete_type_id` - the concrete type this view should represent.
740    fn try_get_view(
741        expr: &ReferenceExpression,
742        program_info: &ProgramInfo<'_>,
743        concrete_type_id: &ConcreteTypeId,
744    ) -> Result<Self, Self::Error>;
745    /// Converts the view into a ReferenceExpression.
746    fn to_reference_expression(self) -> ReferenceExpression;
747}
748
749/// Fetches the non-fallthrough jump target of the invocation, assuming this invocation is a
750/// conditional jump.
751pub fn get_non_fallthrough_statement_id(builder: &CompiledInvocationBuilder<'_>) -> StatementIdx {
752    match builder.invocation.branches.as_slice() {
753        [
754            BranchInfo { target: BranchTarget::Fallthrough, results: _ },
755            BranchInfo { target: BranchTarget::Statement(target_statement_id), results: _ },
756        ] => *target_statement_id,
757        _ => panic!("malformed invocation"),
758    }
759}
760
761/// Adds input variables into the builder while validating their type.
762macro_rules! add_input_variables {
763    ($casm_builder:ident,) => {};
764    ($casm_builder:ident, deref $var:ident; $($tok:tt)*) => {
765        let $var = $casm_builder.add_var(cairo_lang_casm::cell_expression::CellExpression::Deref(
766            $var.to_deref().ok_or(InvocationError::InvalidReferenceExpressionForArgument)?,
767        ));
768        $crate::invocations::add_input_variables!($casm_builder, $($tok)*)
769    };
770    ($casm_builder:ident, deref_or_immediate $var:ident; $($tok:tt)*) => {
771        let $var = $casm_builder.add_var(
772            match $var
773                .to_deref_or_immediate()
774                .ok_or(InvocationError::InvalidReferenceExpressionForArgument)?
775            {
776                cairo_lang_casm::operand::DerefOrImmediate::Deref(cell) => {
777                    cairo_lang_casm::cell_expression::CellExpression::Deref(cell)
778                }
779                cairo_lang_casm::operand::DerefOrImmediate::Immediate(cell) => {
780                    cairo_lang_casm::cell_expression::CellExpression::Immediate(cell.value)
781                }
782            },
783        );
784        $crate::invocations::add_input_variables!($casm_builder, $($tok)*)
785    };
786    ($casm_builder:ident, buffer($slack:expr) $var:ident; $($tok:tt)*) => {
787        let $var = $casm_builder.add_var(
788            $var.to_buffer($slack).ok_or(InvocationError::InvalidReferenceExpressionForArgument)?,
789        );
790        $crate::invocations::add_input_variables!($casm_builder, $($tok)*)
791    };
792}
793use add_input_variables;