1use cairo_lang_casm::ap_change::{ApChangeError, ApplyApChange};
2use cairo_lang_sierra::edit_state::{put_results, take_args};
3use cairo_lang_sierra::ids::{ConcreteTypeId, FunctionId, VarId};
4use cairo_lang_sierra::program::{BranchInfo, Function, StatementIdx};
5use cairo_lang_sierra_type_size::TypeSizeMap;
6use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
7use itertools::{chain, zip_eq};
8use thiserror::Error;
9
10use crate::environment::ap_tracking::update_ap_tracking;
11use crate::environment::frame_state::FrameStateError;
12use crate::environment::gas_wallet::{GasWallet, GasWalletError};
13use crate::environment::{
14 ApTracking, ApTrackingBase, Environment, EnvironmentError, validate_environment_equality,
15 validate_final_environment,
16};
17use crate::invocations::{ApTrackingChange, BranchChanges};
18use crate::metadata::Metadata;
19use crate::references::{
20 IntroductionPoint, OutputReferenceValueIntroductionPoint, ReferenceExpression, ReferenceValue,
21 ReferencesError, StatementRefs, build_function_parameters_refs, check_types_match,
22};
23
24#[derive(Error, Debug, Eq, PartialEq)]
25pub enum AnnotationError {
26 #[error("#{statement_idx}: Inconsistent references annotations: {error}")]
27 InconsistentReferencesAnnotation {
28 statement_idx: StatementIdx,
29 error: InconsistentReferenceError,
30 },
31 #[error("#{source_statement_idx}->#{destination_statement_idx}: Annotation was already set.")]
32 AnnotationAlreadySet {
33 source_statement_idx: StatementIdx,
34 destination_statement_idx: StatementIdx,
35 },
36 #[error("#{statement_idx}: {error}")]
37 InconsistentEnvironments { statement_idx: StatementIdx, error: EnvironmentError },
38 #[error("#{statement_idx}: Belongs to two different functions.")]
39 InconsistentFunctionId { statement_idx: StatementIdx },
40 #[error("#{statement_idx}: Invalid convergence.")]
41 InvalidConvergence { statement_idx: StatementIdx },
42 #[error("InvalidStatementIdx")]
43 InvalidStatementIdx,
44 #[error("MissingAnnotationsForStatement")]
45 MissingAnnotationsForStatement(StatementIdx),
46 #[error("#{statement_idx}: {var_id} is undefined.")]
47 MissingReferenceError { statement_idx: StatementIdx, var_id: VarId },
48 #[error("#{source_statement_idx}->#{destination_statement_idx}: {var_id} was overridden.")]
49 OverrideReferenceError {
50 source_statement_idx: StatementIdx,
51 destination_statement_idx: StatementIdx,
52 var_id: VarId,
53 },
54 #[error(transparent)]
55 FrameStateError(#[from] FrameStateError),
56 #[error("#{source_statement_idx}->#{destination_statement_idx}: {error}")]
57 GasWalletError {
58 source_statement_idx: StatementIdx,
59 destination_statement_idx: StatementIdx,
60 error: GasWalletError,
61 },
62 #[error("#{statement_idx}: {error}")]
63 ReferencesError { statement_idx: StatementIdx, error: ReferencesError },
64 #[error("#{statement_idx}: Attempting to enable ap tracking when already enabled.")]
65 ApTrackingAlreadyEnabled { statement_idx: StatementIdx },
66 #[error(
67 "#{source_statement_idx}->#{destination_statement_idx}: Got '{error}' error while moving \
68 {var_id} introduced at {introduction_point}."
69 )]
70 ApChangeError {
71 var_id: VarId,
72 source_statement_idx: StatementIdx,
73 destination_statement_idx: StatementIdx,
74 introduction_point: IntroductionPoint,
75 error: ApChangeError,
76 },
77 #[error("#{source_statement_idx} -> #{destination_statement_idx}: Ap tracking error")]
78 ApTrackingError {
79 source_statement_idx: StatementIdx,
80 destination_statement_idx: StatementIdx,
81 error: ApChangeError,
82 },
83 #[error(
84 "#{statement_idx}: Invalid function ap change annotation. Expected ap tracking: \
85 {expected:?}, got: {actual:?}."
86 )]
87 InvalidFunctionApChange {
88 statement_idx: StatementIdx,
89 expected: ApTracking,
90 actual: ApTracking,
91 },
92}
93
94impl AnnotationError {
95 pub fn stmt_indices(&self) -> Vec<StatementIdx> {
96 match self {
97 AnnotationError::ApChangeError {
98 source_statement_idx,
99 destination_statement_idx,
100 introduction_point,
101 ..
102 } => chain!(
103 [source_statement_idx, destination_statement_idx],
104 &introduction_point.source_statement_idx,
105 [&introduction_point.destination_statement_idx]
106 )
107 .cloned()
108 .collect(),
109 _ => vec![],
110 }
111 }
112}
113
114#[derive(Error, Debug, Eq, PartialEq)]
116pub enum InconsistentReferenceError {
117 #[error("Variable {var} type mismatch. Expected `{expected}`, got `{actual}`.")]
118 TypeMismatch { var: VarId, expected: ConcreteTypeId, actual: ConcreteTypeId },
119 #[error("Variable {var} expression mismatch. Expected `{expected}`, got `{actual}`.")]
120 ExpressionMismatch { var: VarId, expected: ReferenceExpression, actual: ReferenceExpression },
121 #[error("Variable {var} stack index mismatch. Expected `{expected:?}`, got `{actual:?}`.")]
122 StackIndexMismatch { var: VarId, expected: Option<usize>, actual: Option<usize> },
123 #[error("Variable {var} introduction point mismatch. Expected `{expected}`, got `{actual}`.")]
124 IntroductionPointMismatch { var: VarId, expected: IntroductionPoint, actual: IntroductionPoint },
125 #[error("Variable count mismatch.")]
126 VariableCountMismatch,
127 #[error("Missing expected variable {0}.")]
128 VariableMissing(VarId),
129 #[error("Ap tracking is disabled while trying to merge {0}.")]
130 ApTrackingDisabled(VarId),
131}
132
133#[derive(Clone, Debug)]
135pub struct StatementAnnotations {
136 pub refs: StatementRefs,
137 pub function_id: FunctionId,
139 pub convergence_allowed: bool,
141 pub environment: Environment,
142}
143
144pub struct ProgramAnnotations {
147 per_statement_annotations: Vec<Option<StatementAnnotations>>,
149 backwards_jump_indices: UnorderedHashSet<StatementIdx>,
151}
152impl ProgramAnnotations {
153 fn new(n_statements: usize, backwards_jump_indices: UnorderedHashSet<StatementIdx>) -> Self {
154 ProgramAnnotations {
155 per_statement_annotations: vec![None; n_statements],
156 backwards_jump_indices,
157 }
158 }
159
160 pub fn create(
163 n_statements: usize,
164 backwards_jump_indices: UnorderedHashSet<StatementIdx>,
165 functions: &[Function],
166 metadata: &Metadata,
167 gas_usage_check: bool,
168 type_sizes: &TypeSizeMap,
169 ) -> Result<Self, AnnotationError> {
170 let mut annotations = ProgramAnnotations::new(n_statements, backwards_jump_indices);
171 for func in functions {
172 annotations.set_or_assert(
173 func.entry_point,
174 StatementAnnotations {
175 refs: build_function_parameters_refs(func, type_sizes).map_err(|error| {
176 AnnotationError::ReferencesError { statement_idx: func.entry_point, error }
177 })?,
178 function_id: func.id.clone(),
179 convergence_allowed: false,
180 environment: Environment::new(if gas_usage_check {
181 GasWallet::Value(metadata.gas_info.function_costs[&func.id].clone())
182 } else {
183 GasWallet::Disabled
184 }),
185 },
186 )?
187 }
188
189 Ok(annotations)
190 }
191
192 pub fn set_or_assert(
197 &mut self,
198 statement_idx: StatementIdx,
199 annotations: StatementAnnotations,
200 ) -> Result<(), AnnotationError> {
201 let idx = statement_idx.0;
202 match self.per_statement_annotations.get(idx).ok_or(AnnotationError::InvalidStatementIdx)? {
203 None => self.per_statement_annotations[idx] = Some(annotations),
204 Some(expected_annotations) => {
205 if expected_annotations.function_id != annotations.function_id {
206 return Err(AnnotationError::InconsistentFunctionId { statement_idx });
207 }
208 validate_environment_equality(
209 &expected_annotations.environment,
210 &annotations.environment,
211 )
212 .map_err(|error| AnnotationError::InconsistentEnvironments {
213 statement_idx,
214 error,
215 })?;
216 self.test_references_consistency(&annotations, expected_annotations).map_err(
217 |error| AnnotationError::InconsistentReferencesAnnotation {
218 statement_idx,
219 error,
220 },
221 )?;
222
223 if !expected_annotations.convergence_allowed {
226 return Err(AnnotationError::InvalidConvergence { statement_idx });
227 }
228 }
229 };
230 Ok(())
231 }
232
233 fn test_references_consistency(
236 &self,
237 actual: &StatementAnnotations,
238 expected: &StatementAnnotations,
239 ) -> Result<(), InconsistentReferenceError> {
240 if actual.refs.len() != expected.refs.len() {
242 return Err(InconsistentReferenceError::VariableCountMismatch);
243 }
244 let ap_tracking_enabled =
245 matches!(actual.environment.ap_tracking, ApTracking::Enabled { .. });
246 for (var_id, actual_ref) in actual.refs.iter() {
247 let Some(expected_ref) = expected.refs.get(var_id) else {
249 return Err(InconsistentReferenceError::VariableMissing(var_id.clone()));
250 };
251 if actual_ref.ty != expected_ref.ty {
253 return Err(InconsistentReferenceError::TypeMismatch {
254 var: var_id.clone(),
255 expected: expected_ref.ty.clone(),
256 actual: actual_ref.ty.clone(),
257 });
258 }
259 if actual_ref.expression != expected_ref.expression {
260 return Err(InconsistentReferenceError::ExpressionMismatch {
261 var: var_id.clone(),
262 expected: expected_ref.expression.clone(),
263 actual: actual_ref.expression.clone(),
264 });
265 }
266 if actual_ref.stack_idx != expected_ref.stack_idx {
267 return Err(InconsistentReferenceError::StackIndexMismatch {
268 var: var_id.clone(),
269 expected: expected_ref.stack_idx,
270 actual: actual_ref.stack_idx,
271 });
272 }
273 test_var_consistency(var_id, actual_ref, expected_ref, ap_tracking_enabled)?;
274 }
275 Ok(())
276 }
277
278 pub fn get_annotations_after_take_args<'a>(
282 &mut self,
283 statement_idx: StatementIdx,
284 ref_ids: impl Iterator<Item = &'a VarId>,
285 ) -> Result<(StatementAnnotations, Vec<ReferenceValue>), AnnotationError> {
286 let existing = self.per_statement_annotations[statement_idx.0]
287 .as_mut()
288 .ok_or(AnnotationError::MissingAnnotationsForStatement(statement_idx))?;
289 let mut updated = if self.backwards_jump_indices.contains(&statement_idx) {
290 existing.clone()
291 } else {
292 std::mem::replace(
293 existing,
294 StatementAnnotations {
295 refs: Default::default(),
296 function_id: existing.function_id.clone(),
297 convergence_allowed: false,
299 environment: existing.environment.clone(),
300 },
301 )
302 };
303 let refs = std::mem::take(&mut updated.refs);
304 let (statement_refs, taken_refs) = take_args(refs, ref_ids).map_err(|error| {
305 AnnotationError::MissingReferenceError { statement_idx, var_id: error.var_id() }
306 })?;
307 updated.refs = statement_refs;
308 Ok((updated, taken_refs))
309 }
310
311 pub fn propagate_annotations(
317 &mut self,
318 source_statement_idx: StatementIdx,
319 destination_statement_idx: StatementIdx,
320 mut annotations: StatementAnnotations,
321 branch_info: &BranchInfo,
322 branch_changes: BranchChanges,
323 must_set: bool,
324 ) -> Result<(), AnnotationError> {
325 if must_set && self.per_statement_annotations[destination_statement_idx.0].is_some() {
326 return Err(AnnotationError::AnnotationAlreadySet {
327 source_statement_idx,
328 destination_statement_idx,
329 });
330 }
331
332 for (var_id, ref_value) in annotations.refs.iter_mut() {
333 if branch_changes.clear_old_stack {
334 ref_value.stack_idx = None;
335 }
336 ref_value.expression =
337 std::mem::replace(&mut ref_value.expression, ReferenceExpression::zero_sized())
338 .apply_ap_change(branch_changes.ap_change)
339 .map_err(|error| AnnotationError::ApChangeError {
340 var_id: var_id.clone(),
341 source_statement_idx,
342 destination_statement_idx,
343 introduction_point: ref_value.introduction_point.clone(),
344 error,
345 })?;
346 }
347 let mut refs = put_results(
348 annotations.refs,
349 zip_eq(
350 &branch_info.results,
351 branch_changes.refs.into_iter().map(|value| ReferenceValue {
352 expression: value.expression,
353 ty: value.ty,
354 stack_idx: value.stack_idx,
355 introduction_point: match value.introduction_point {
356 OutputReferenceValueIntroductionPoint::New(output_idx) => {
357 IntroductionPoint {
358 source_statement_idx: Some(source_statement_idx),
359 destination_statement_idx,
360 output_idx,
361 }
362 }
363 OutputReferenceValueIntroductionPoint::Existing(introduction_point) => {
364 introduction_point
365 }
366 },
367 }),
368 ),
369 )
370 .map_err(|error| AnnotationError::OverrideReferenceError {
371 source_statement_idx,
372 destination_statement_idx,
373 var_id: error.var_id(),
374 })?;
375
376 let available_stack_indices: UnorderedHashSet<_> =
380 refs.values().flat_map(|r| r.stack_idx).collect();
381 let new_stack_size_opt = (0..branch_changes.new_stack_size)
382 .find(|i| !available_stack_indices.contains(&(branch_changes.new_stack_size - 1 - i)));
383 let stack_size = if let Some(new_stack_size) = new_stack_size_opt {
384 let stack_removal = branch_changes.new_stack_size - new_stack_size;
386 for (_, r) in refs.iter_mut() {
387 r.stack_idx =
390 r.stack_idx.and_then(|stack_idx| stack_idx.checked_sub(stack_removal));
391 }
392 new_stack_size
393 } else {
394 branch_changes.new_stack_size
395 };
396
397 let ap_tracking = match branch_changes.ap_tracking_change {
398 ApTrackingChange::Disable => ApTracking::Disabled,
399 ApTrackingChange::Enable => {
400 if !matches!(annotations.environment.ap_tracking, ApTracking::Disabled) {
401 return Err(AnnotationError::ApTrackingAlreadyEnabled {
402 statement_idx: source_statement_idx,
403 });
404 }
405 ApTracking::Enabled {
406 ap_change: 0,
407 base: ApTrackingBase::Statement(destination_statement_idx),
408 }
409 }
410 ApTrackingChange::None => {
411 update_ap_tracking(annotations.environment.ap_tracking, branch_changes.ap_change)
412 .map_err(|error| AnnotationError::ApTrackingError {
413 source_statement_idx,
414 destination_statement_idx,
415 error,
416 })?
417 }
418 };
419
420 self.set_or_assert(
421 destination_statement_idx,
422 StatementAnnotations {
423 refs,
424 function_id: annotations.function_id,
425 convergence_allowed: !must_set,
426 environment: Environment {
427 ap_tracking,
428 stack_size,
429 frame_state: annotations.environment.frame_state,
430 gas_wallet: annotations
431 .environment
432 .gas_wallet
433 .update(branch_changes.gas_change)
434 .map_err(|error| AnnotationError::GasWalletError {
435 source_statement_idx,
436 destination_statement_idx,
437 error,
438 })?,
439 },
440 },
441 )
442 }
443
444 pub fn validate_return_properties(
446 &self,
447 statement_idx: StatementIdx,
448 annotations: &StatementAnnotations,
449 functions: &[Function],
450 metadata: &Metadata,
451 return_refs: &[ReferenceValue],
452 ) -> Result<(), AnnotationError> {
453 let func = &functions.iter().find(|func| func.id == annotations.function_id).unwrap();
455
456 let expected_ap_tracking = match metadata.ap_change_info.function_ap_change.get(&func.id) {
457 Some(x) => ApTracking::Enabled { ap_change: *x, base: ApTrackingBase::FunctionStart },
458 None => ApTracking::Disabled,
459 };
460 if annotations.environment.ap_tracking != expected_ap_tracking {
461 return Err(AnnotationError::InvalidFunctionApChange {
462 statement_idx,
463 expected: expected_ap_tracking,
464 actual: annotations.environment.ap_tracking,
465 });
466 }
467
468 check_types_match(return_refs, &func.signature.ret_types)
470 .map_err(|error| AnnotationError::ReferencesError { statement_idx, error })?;
471 Ok(())
472 }
473
474 pub fn validate_final_annotations(
476 &self,
477 statement_idx: StatementIdx,
478 annotations: &StatementAnnotations,
479 functions: &[Function],
480 metadata: &Metadata,
481 return_refs: &[ReferenceValue],
482 ) -> Result<(), AnnotationError> {
483 self.validate_return_properties(
484 statement_idx,
485 annotations,
486 functions,
487 metadata,
488 return_refs,
489 )?;
490 validate_final_environment(&annotations.environment)
491 .map_err(|error| AnnotationError::InconsistentEnvironments { statement_idx, error })
492 }
493}
494
495fn test_var_consistency(
499 var_id: &VarId,
500 actual: &ReferenceValue,
501 expected: &ReferenceValue,
502 ap_tracking_enabled: bool,
503) -> Result<(), InconsistentReferenceError> {
504 if actual.stack_idx.is_some() {
506 return Ok(());
507 }
508 if actual.expression.can_apply_unknown() {
511 return Ok(());
512 }
513 if !ap_tracking_enabled {
515 return Err(InconsistentReferenceError::ApTrackingDisabled(var_id.clone()));
516 }
517 if actual.introduction_point == expected.introduction_point {
519 Ok(())
520 } else {
521 Err(InconsistentReferenceError::IntroductionPointMismatch {
522 var: var_id.clone(),
523 expected: expected.introduction_point.clone(),
524 actual: actual.introduction_point.clone(),
525 })
526 }
527}