cairo_lang_sierra_to_casm/
annotations.rs

1use cairo_lang_casm::ap_change::{ApChangeError, ApplyApChange};
2use cairo_lang_sierra::edit_state::{put_results, take_args};
3use cairo_lang_sierra::ids::{ConcreteTypeId, FunctionId, VarId};
4use cairo_lang_sierra::program::{BranchInfo, Function, StatementIdx};
5use cairo_lang_sierra_type_size::TypeSizeMap;
6use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
7use itertools::{chain, zip_eq};
8use thiserror::Error;
9
10use crate::environment::ap_tracking::update_ap_tracking;
11use crate::environment::frame_state::FrameStateError;
12use crate::environment::gas_wallet::{GasWallet, GasWalletError};
13use crate::environment::{
14    ApTracking, ApTrackingBase, Environment, EnvironmentError, validate_environment_equality,
15    validate_final_environment,
16};
17use crate::invocations::{ApTrackingChange, BranchChanges};
18use crate::metadata::Metadata;
19use crate::references::{
20    IntroductionPoint, OutputReferenceValueIntroductionPoint, ReferenceExpression, ReferenceValue,
21    ReferencesError, StatementRefs, build_function_parameters_refs, check_types_match,
22};
23
24#[derive(Error, Debug, Eq, PartialEq)]
25pub enum AnnotationError {
26    #[error("#{statement_idx}: Inconsistent references annotations: {error}")]
27    InconsistentReferencesAnnotation {
28        statement_idx: StatementIdx,
29        error: InconsistentReferenceError,
30    },
31    #[error("#{source_statement_idx}->#{destination_statement_idx}: Annotation was already set.")]
32    AnnotationAlreadySet {
33        source_statement_idx: StatementIdx,
34        destination_statement_idx: StatementIdx,
35    },
36    #[error("#{statement_idx}: {error}")]
37    InconsistentEnvironments { statement_idx: StatementIdx, error: EnvironmentError },
38    #[error("#{statement_idx}: Belongs to two different functions.")]
39    InconsistentFunctionId { statement_idx: StatementIdx },
40    #[error("#{statement_idx}: Invalid convergence.")]
41    InvalidConvergence { statement_idx: StatementIdx },
42    #[error("InvalidStatementIdx")]
43    InvalidStatementIdx,
44    #[error("MissingAnnotationsForStatement")]
45    MissingAnnotationsForStatement(StatementIdx),
46    #[error("#{statement_idx}: {var_id} is undefined.")]
47    MissingReferenceError { statement_idx: StatementIdx, var_id: VarId },
48    #[error("#{source_statement_idx}->#{destination_statement_idx}: {var_id} was overridden.")]
49    OverrideReferenceError {
50        source_statement_idx: StatementIdx,
51        destination_statement_idx: StatementIdx,
52        var_id: VarId,
53    },
54    #[error(transparent)]
55    FrameStateError(#[from] FrameStateError),
56    #[error("#{source_statement_idx}->#{destination_statement_idx}: {error}")]
57    GasWalletError {
58        source_statement_idx: StatementIdx,
59        destination_statement_idx: StatementIdx,
60        error: GasWalletError,
61    },
62    #[error("#{statement_idx}: {error}")]
63    ReferencesError { statement_idx: StatementIdx, error: ReferencesError },
64    #[error("#{statement_idx}: Attempting to enable ap tracking when already enabled.")]
65    ApTrackingAlreadyEnabled { statement_idx: StatementIdx },
66    #[error(
67        "#{source_statement_idx}->#{destination_statement_idx}: Got '{error}' error while moving \
68         {var_id} introduced at {introduction_point}."
69    )]
70    ApChangeError {
71        var_id: VarId,
72        source_statement_idx: StatementIdx,
73        destination_statement_idx: StatementIdx,
74        introduction_point: IntroductionPoint,
75        error: ApChangeError,
76    },
77    #[error("#{source_statement_idx} -> #{destination_statement_idx}: Ap tracking error")]
78    ApTrackingError {
79        source_statement_idx: StatementIdx,
80        destination_statement_idx: StatementIdx,
81        error: ApChangeError,
82    },
83    #[error(
84        "#{statement_idx}: Invalid function ap change annotation. Expected ap tracking: \
85         {expected:?}, got: {actual:?}."
86    )]
87    InvalidFunctionApChange {
88        statement_idx: StatementIdx,
89        expected: ApTracking,
90        actual: ApTracking,
91    },
92}
93
94impl AnnotationError {
95    pub fn stmt_indices(&self) -> Vec<StatementIdx> {
96        match self {
97            AnnotationError::ApChangeError {
98                source_statement_idx,
99                destination_statement_idx,
100                introduction_point,
101                ..
102            } => chain!(
103                [source_statement_idx, destination_statement_idx],
104                &introduction_point.source_statement_idx,
105                [&introduction_point.destination_statement_idx]
106            )
107            .cloned()
108            .collect(),
109            _ => vec![],
110        }
111    }
112}
113
114/// Error representing an inconsistency in the references annotations.
115#[derive(Error, Debug, Eq, PartialEq)]
116pub enum InconsistentReferenceError {
117    #[error("Variable {var} type mismatch. Expected `{expected}`, got `{actual}`.")]
118    TypeMismatch { var: VarId, expected: ConcreteTypeId, actual: ConcreteTypeId },
119    #[error("Variable {var} expression mismatch. Expected `{expected}`, got `{actual}`.")]
120    ExpressionMismatch { var: VarId, expected: ReferenceExpression, actual: ReferenceExpression },
121    #[error("Variable {var} stack index mismatch. Expected `{expected:?}`, got `{actual:?}`.")]
122    StackIndexMismatch { var: VarId, expected: Option<usize>, actual: Option<usize> },
123    #[error("Variable {var} introduction point mismatch. Expected `{expected}`, got `{actual}`.")]
124    IntroductionPointMismatch { var: VarId, expected: IntroductionPoint, actual: IntroductionPoint },
125    #[error("Variable count mismatch.")]
126    VariableCountMismatch,
127    #[error("Missing expected variable {0}.")]
128    VariableMissing(VarId),
129    #[error("Ap tracking is disabled while trying to merge {0}.")]
130    ApTrackingDisabled(VarId),
131}
132
133/// Annotation that represents the state at each program statement.
134#[derive(Clone, Debug)]
135pub struct StatementAnnotations {
136    pub refs: StatementRefs,
137    /// The function id that the statement belongs to.
138    pub function_id: FunctionId,
139    /// Indicates whether convergence is allowed in the given statement.
140    pub convergence_allowed: bool,
141    pub environment: Environment,
142}
143
144/// Annotations of the program statements.
145/// See StatementAnnotations.
146pub struct ProgramAnnotations {
147    /// Optional per statement annotation.
148    per_statement_annotations: Vec<Option<StatementAnnotations>>,
149    /// The indices of the statements that are the targets of backwards jumps.
150    backwards_jump_indices: UnorderedHashSet<StatementIdx>,
151}
152impl ProgramAnnotations {
153    fn new(n_statements: usize, backwards_jump_indices: UnorderedHashSet<StatementIdx>) -> Self {
154        ProgramAnnotations {
155            per_statement_annotations: vec![None; n_statements],
156            backwards_jump_indices,
157        }
158    }
159
160    /// Creates a ProgramAnnotations object based on 'n_statements', a given functions list
161    /// and metadata for the program.
162    pub fn create(
163        n_statements: usize,
164        backwards_jump_indices: UnorderedHashSet<StatementIdx>,
165        functions: &[Function],
166        metadata: &Metadata,
167        gas_usage_check: bool,
168        type_sizes: &TypeSizeMap,
169    ) -> Result<Self, AnnotationError> {
170        let mut annotations = ProgramAnnotations::new(n_statements, backwards_jump_indices);
171        for func in functions {
172            annotations.set_or_assert(
173                func.entry_point,
174                StatementAnnotations {
175                    refs: build_function_parameters_refs(func, type_sizes).map_err(|error| {
176                        AnnotationError::ReferencesError { statement_idx: func.entry_point, error }
177                    })?,
178                    function_id: func.id.clone(),
179                    convergence_allowed: false,
180                    environment: Environment::new(if gas_usage_check {
181                        GasWallet::Value(metadata.gas_info.function_costs[&func.id].clone())
182                    } else {
183                        GasWallet::Disabled
184                    }),
185                },
186            )?
187        }
188
189        Ok(annotations)
190    }
191
192    /// Sets the annotations at 'statement_idx' to 'annotations'
193    /// If the annotations for this statement were set previously asserts that the previous
194    /// assignment is consistent with the new assignment and verifies that convergence_allowed
195    /// is true.
196    pub fn set_or_assert(
197        &mut self,
198        statement_idx: StatementIdx,
199        annotations: StatementAnnotations,
200    ) -> Result<(), AnnotationError> {
201        let idx = statement_idx.0;
202        match self.per_statement_annotations.get(idx).ok_or(AnnotationError::InvalidStatementIdx)? {
203            None => self.per_statement_annotations[idx] = Some(annotations),
204            Some(expected_annotations) => {
205                if expected_annotations.function_id != annotations.function_id {
206                    return Err(AnnotationError::InconsistentFunctionId { statement_idx });
207                }
208                validate_environment_equality(
209                    &expected_annotations.environment,
210                    &annotations.environment,
211                )
212                .map_err(|error| AnnotationError::InconsistentEnvironments {
213                    statement_idx,
214                    error,
215                })?;
216                self.test_references_consistency(&annotations, expected_annotations).map_err(
217                    |error| AnnotationError::InconsistentReferencesAnnotation {
218                        statement_idx,
219                        error,
220                    },
221                )?;
222
223                // Note that we ignore annotations here.
224                // A flow cannot converge with a branch target.
225                if !expected_annotations.convergence_allowed {
226                    return Err(AnnotationError::InvalidConvergence { statement_idx });
227                }
228            }
229        };
230        Ok(())
231    }
232
233    /// Checks whether or not `actual` and `expected` references are consistent.
234    /// Returns an error representing the inconsistency.
235    fn test_references_consistency(
236        &self,
237        actual: &StatementAnnotations,
238        expected: &StatementAnnotations,
239    ) -> Result<(), InconsistentReferenceError> {
240        // Check if there is a mismatch in the number of variables.
241        if actual.refs.len() != expected.refs.len() {
242            return Err(InconsistentReferenceError::VariableCountMismatch);
243        }
244        let ap_tracking_enabled =
245            matches!(actual.environment.ap_tracking, ApTracking::Enabled { .. });
246        for (var_id, actual_ref) in actual.refs.iter() {
247            // Check if the variable exists in just one of the branches.
248            let Some(expected_ref) = expected.refs.get(var_id) else {
249                return Err(InconsistentReferenceError::VariableMissing(var_id.clone()));
250            };
251            // Check if the variable doesn't match on type, expression or stack information.
252            if actual_ref.ty != expected_ref.ty {
253                return Err(InconsistentReferenceError::TypeMismatch {
254                    var: var_id.clone(),
255                    expected: expected_ref.ty.clone(),
256                    actual: actual_ref.ty.clone(),
257                });
258            }
259            if actual_ref.expression != expected_ref.expression {
260                return Err(InconsistentReferenceError::ExpressionMismatch {
261                    var: var_id.clone(),
262                    expected: expected_ref.expression.clone(),
263                    actual: actual_ref.expression.clone(),
264                });
265            }
266            if actual_ref.stack_idx != expected_ref.stack_idx {
267                return Err(InconsistentReferenceError::StackIndexMismatch {
268                    var: var_id.clone(),
269                    expected: expected_ref.stack_idx,
270                    actual: actual_ref.stack_idx,
271                });
272            }
273            test_var_consistency(var_id, actual_ref, expected_ref, ap_tracking_enabled)?;
274        }
275        Ok(())
276    }
277
278    /// Returns the result of applying take_args to the StatementAnnotations at statement_idx.
279    /// Can be called only once per item, the item is removed from the annotations, and can no
280    /// longer be used for merges.
281    pub fn get_annotations_after_take_args<'a>(
282        &mut self,
283        statement_idx: StatementIdx,
284        ref_ids: impl Iterator<Item = &'a VarId>,
285    ) -> Result<(StatementAnnotations, Vec<ReferenceValue>), AnnotationError> {
286        let existing = self.per_statement_annotations[statement_idx.0]
287            .as_mut()
288            .ok_or(AnnotationError::MissingAnnotationsForStatement(statement_idx))?;
289        let mut updated = if self.backwards_jump_indices.contains(&statement_idx) {
290            existing.clone()
291        } else {
292            std::mem::replace(
293                existing,
294                StatementAnnotations {
295                    refs: Default::default(),
296                    function_id: existing.function_id.clone(),
297                    // Merging with this data is no longer allowed.
298                    convergence_allowed: false,
299                    environment: existing.environment.clone(),
300                },
301            )
302        };
303        let refs = std::mem::take(&mut updated.refs);
304        let (statement_refs, taken_refs) = take_args(refs, ref_ids).map_err(|error| {
305            AnnotationError::MissingReferenceError { statement_idx, var_id: error.var_id() }
306        })?;
307        updated.refs = statement_refs;
308        Ok((updated, taken_refs))
309    }
310
311    /// Propagates the annotations from `statement_idx` to 'destination_statement_idx'.
312    ///
313    /// `annotations` is the result of calling get_annotations_after_take_args at
314    /// `source_statement_idx` and `branch_changes` are the reference changes at each branch.
315    ///  if `must_set` is true, asserts that destination_statement_idx wasn't annotated before.
316    pub fn propagate_annotations(
317        &mut self,
318        source_statement_idx: StatementIdx,
319        destination_statement_idx: StatementIdx,
320        mut annotations: StatementAnnotations,
321        branch_info: &BranchInfo,
322        branch_changes: BranchChanges,
323        must_set: bool,
324    ) -> Result<(), AnnotationError> {
325        if must_set && self.per_statement_annotations[destination_statement_idx.0].is_some() {
326            return Err(AnnotationError::AnnotationAlreadySet {
327                source_statement_idx,
328                destination_statement_idx,
329            });
330        }
331
332        for (var_id, ref_value) in annotations.refs.iter_mut() {
333            if branch_changes.clear_old_stack {
334                ref_value.stack_idx = None;
335            }
336            ref_value.expression =
337                std::mem::replace(&mut ref_value.expression, ReferenceExpression::zero_sized())
338                    .apply_ap_change(branch_changes.ap_change)
339                    .map_err(|error| AnnotationError::ApChangeError {
340                        var_id: var_id.clone(),
341                        source_statement_idx,
342                        destination_statement_idx,
343                        introduction_point: ref_value.introduction_point.clone(),
344                        error,
345                    })?;
346        }
347        let mut refs = put_results(
348            annotations.refs,
349            zip_eq(
350                &branch_info.results,
351                branch_changes.refs.into_iter().map(|value| ReferenceValue {
352                    expression: value.expression,
353                    ty: value.ty,
354                    stack_idx: value.stack_idx,
355                    introduction_point: match value.introduction_point {
356                        OutputReferenceValueIntroductionPoint::New(output_idx) => {
357                            IntroductionPoint {
358                                source_statement_idx: Some(source_statement_idx),
359                                destination_statement_idx,
360                                output_idx,
361                            }
362                        }
363                        OutputReferenceValueIntroductionPoint::Existing(introduction_point) => {
364                            introduction_point
365                        }
366                    },
367                }),
368            ),
369        )
370        .map_err(|error| AnnotationError::OverrideReferenceError {
371            source_statement_idx,
372            destination_statement_idx,
373            var_id: error.var_id(),
374        })?;
375
376        // Since some variables on the stack may have been consumed by the libfunc, we need to
377        // find the new stack size. This is done by searching from the bottom of the stack until we
378        // find a missing variable.
379        let available_stack_indices: UnorderedHashSet<_> =
380            refs.values().flat_map(|r| r.stack_idx).collect();
381        let new_stack_size_opt = (0..branch_changes.new_stack_size)
382            .find(|i| !available_stack_indices.contains(&(branch_changes.new_stack_size - 1 - i)));
383        let stack_size = if let Some(new_stack_size) = new_stack_size_opt {
384            // The number of stack elements which were removed.
385            let stack_removal = branch_changes.new_stack_size - new_stack_size;
386            for (_, r) in refs.iter_mut() {
387                // Subtract the number of stack elements removed. If the result is negative,
388                // `stack_idx` is set to `None` and the variable is removed from the stack.
389                r.stack_idx =
390                    r.stack_idx.and_then(|stack_idx| stack_idx.checked_sub(stack_removal));
391            }
392            new_stack_size
393        } else {
394            branch_changes.new_stack_size
395        };
396
397        let ap_tracking = match branch_changes.ap_tracking_change {
398            ApTrackingChange::Disable => ApTracking::Disabled,
399            ApTrackingChange::Enable => {
400                if !matches!(annotations.environment.ap_tracking, ApTracking::Disabled) {
401                    return Err(AnnotationError::ApTrackingAlreadyEnabled {
402                        statement_idx: source_statement_idx,
403                    });
404                }
405                ApTracking::Enabled {
406                    ap_change: 0,
407                    base: ApTrackingBase::Statement(destination_statement_idx),
408                }
409            }
410            ApTrackingChange::None => {
411                update_ap_tracking(annotations.environment.ap_tracking, branch_changes.ap_change)
412                    .map_err(|error| AnnotationError::ApTrackingError {
413                        source_statement_idx,
414                        destination_statement_idx,
415                        error,
416                    })?
417            }
418        };
419
420        self.set_or_assert(
421            destination_statement_idx,
422            StatementAnnotations {
423                refs,
424                function_id: annotations.function_id,
425                convergence_allowed: !must_set,
426                environment: Environment {
427                    ap_tracking,
428                    stack_size,
429                    frame_state: annotations.environment.frame_state,
430                    gas_wallet: annotations
431                        .environment
432                        .gas_wallet
433                        .update(branch_changes.gas_change)
434                        .map_err(|error| AnnotationError::GasWalletError {
435                            source_statement_idx,
436                            destination_statement_idx,
437                            error,
438                        })?,
439                },
440            },
441        )
442    }
443
444    /// Validates the ap change and return types in a return statement.
445    pub fn validate_return_properties(
446        &self,
447        statement_idx: StatementIdx,
448        annotations: &StatementAnnotations,
449        functions: &[Function],
450        metadata: &Metadata,
451        return_refs: &[ReferenceValue],
452    ) -> Result<(), AnnotationError> {
453        // TODO(ilya): Don't use linear search.
454        let func = &functions.iter().find(|func| func.id == annotations.function_id).unwrap();
455
456        let expected_ap_tracking = match metadata.ap_change_info.function_ap_change.get(&func.id) {
457            Some(x) => ApTracking::Enabled { ap_change: *x, base: ApTrackingBase::FunctionStart },
458            None => ApTracking::Disabled,
459        };
460        if annotations.environment.ap_tracking != expected_ap_tracking {
461            return Err(AnnotationError::InvalidFunctionApChange {
462                statement_idx,
463                expected: expected_ap_tracking,
464                actual: annotations.environment.ap_tracking,
465            });
466        }
467
468        // Checks that the list of return reference contains has the expected types.
469        check_types_match(return_refs, &func.signature.ret_types)
470            .map_err(|error| AnnotationError::ReferencesError { statement_idx, error })?;
471        Ok(())
472    }
473
474    /// Validates the final annotation in a return statement.
475    pub fn validate_final_annotations(
476        &self,
477        statement_idx: StatementIdx,
478        annotations: &StatementAnnotations,
479        functions: &[Function],
480        metadata: &Metadata,
481        return_refs: &[ReferenceValue],
482    ) -> Result<(), AnnotationError> {
483        self.validate_return_properties(
484            statement_idx,
485            annotations,
486            functions,
487            metadata,
488            return_refs,
489        )?;
490        validate_final_environment(&annotations.environment)
491            .map_err(|error| AnnotationError::InconsistentEnvironments { statement_idx, error })
492    }
493}
494
495/// Checks whether or not the references `actual` and `expected` are consistent and can be merged
496/// in a way that will be re-compilable.
497/// Returns an error representing the inconsistency.
498fn test_var_consistency(
499    var_id: &VarId,
500    actual: &ReferenceValue,
501    expected: &ReferenceValue,
502    ap_tracking_enabled: bool,
503) -> Result<(), InconsistentReferenceError> {
504    // If the variable is on the stack, it can always be merged.
505    if actual.stack_idx.is_some() {
506        return Ok(());
507    }
508    // If the variable is not ap-dependent it can always be merged.
509    // Note: This makes the assumption that empty variables are always mergeable.
510    if actual.expression.can_apply_unknown() {
511        return Ok(());
512    }
513    // Ap tracking must be enabled when merging non-stack ap-dependent variables.
514    if !ap_tracking_enabled {
515        return Err(InconsistentReferenceError::ApTrackingDisabled(var_id.clone()));
516    }
517    // Merged variables must have the same introduction point.
518    if actual.introduction_point == expected.introduction_point {
519        Ok(())
520    } else {
521        Err(InconsistentReferenceError::IntroductionPointMismatch {
522            var: var_id.clone(),
523            expected: expected.introduction_point.clone(),
524            actual: actual.introduction_point.clone(),
525        })
526    }
527}