cairo_lang_starknet_classes/
contract_segmentation.rs1#[cfg(test)]
2#[path = "contract_segmentation_test.rs"]
3mod test;
4
5use cairo_lang_sierra::program::{Program, Statement, StatementIdx};
6use cairo_lang_sierra_to_casm::compiler::CairoProgram;
7use cairo_lang_utils::require;
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10
11#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(untagged)]
21pub enum NestedIntList {
22 Leaf(usize),
23 Node(Vec<NestedIntList>),
24}
25
26#[derive(Error, Debug, Eq, PartialEq)]
27pub enum SegmentationError {
28 #[error("Expected a function start at index 0.")]
29 NoFunctionStartAtZero,
30 #[error("Jump outside of function boundaries.")]
31 JumpOutsideFunction(StatementIdx),
32}
33
34pub fn compute_bytecode_segment_lengths(
36 program: &Program,
37 cairo_program: &CairoProgram,
38 bytecode_len: usize,
39) -> Result<NestedIntList, SegmentationError> {
40 if bytecode_len == 0 {
41 return Ok(NestedIntList::Leaf(0));
42 }
43 let functions_segment_start_statements = find_functions_segments(program)?;
44 let mut segment_start_offsets =
45 functions_statement_ids_to_offsets(cairo_program, &functions_segment_start_statements);
46 segment_start_offsets.extend(consts_segments_offsets(cairo_program, bytecode_len));
47
48 Ok(NestedIntList::Node(
49 get_segment_lengths(&segment_start_offsets, bytecode_len)
50 .iter()
51 .map(|length| NestedIntList::Leaf(*length))
52 .collect(),
53 ))
54}
55
56fn find_functions_segments(program: &Program) -> Result<Vec<usize>, SegmentationError> {
58 let mut function_statement_ids: Vec<usize> =
60 program.funcs.iter().map(|func| func.entry_point.0).collect();
61 function_statement_ids.sort();
62 require(matches!(function_statement_ids.first(), Some(0)))
63 .ok_or(SegmentationError::NoFunctionStartAtZero)?;
64
65 let mut current_function = FunctionInfo::new(0);
67 let mut next_function_idx = 1;
68
69 for (idx, statement) in program.statements.iter().enumerate() {
71 if function_statement_ids.get(next_function_idx) == Some(&idx) {
73 current_function.finalize(idx)?;
74 current_function = FunctionInfo::new(idx);
75 next_function_idx += 1;
76 }
77 current_function.visit_statement(idx, statement)?;
78 }
79
80 current_function.finalize(program.statements.len())?;
81
82 Ok(function_statement_ids)
83}
84
85fn functions_statement_ids_to_offsets(
87 cairo_program: &CairoProgram,
88 segment_starts_statements: &[usize],
89) -> Vec<usize> {
90 let statement_to_offset = |statement_id: usize| {
91 cairo_program
92 .debug_info
93 .sierra_statement_info
94 .get(statement_id)
95 .unwrap_or_else(|| panic!("Missing bytecode offset for statement id {statement_id}."))
96 .start_offset
97 };
98 segment_starts_statements.iter().map(|start| statement_to_offset(*start)).collect()
99}
100
101fn get_segment_lengths(segment_starts_offsets: &[usize], bytecode_len: usize) -> Vec<usize> {
103 let mut segment_lengths = vec![];
105 for i in 1..segment_starts_offsets.len() {
106 let segment_size = segment_starts_offsets[i] - segment_starts_offsets[i - 1];
107 if segment_size > 0 {
108 segment_lengths.push(segment_size);
109 }
110 }
111
112 let last_offset =
114 segment_starts_offsets.last().expect("Segmentation error: No function found.");
115 let segment_size = bytecode_len - last_offset;
116 if segment_size > 0 {
117 segment_lengths.push(segment_size);
118 }
119
120 segment_lengths
121}
122
123struct FunctionInfo {
126 entry_point: usize,
127 max_jump_in_function: usize,
129 max_jump_in_function_src: usize,
131}
132impl FunctionInfo {
133 fn new(entry_point: usize) -> Self {
135 Self {
136 entry_point,
137 max_jump_in_function: entry_point,
138 max_jump_in_function_src: entry_point,
139 }
140 }
141
142 fn finalize(self, function_end: usize) -> Result<(), SegmentationError> {
148 if self.max_jump_in_function >= function_end {
150 return Err(SegmentationError::JumpOutsideFunction(StatementIdx(
151 self.max_jump_in_function_src,
152 )));
153 }
154 Ok(())
155 }
156
157 fn visit_statement(
159 &mut self,
160 idx: usize,
161 statement: &Statement,
162 ) -> Result<(), SegmentationError> {
163 match statement {
164 Statement::Invocation(invocation) => {
165 for branch in &invocation.branches {
166 let next_statement_idx = StatementIdx(idx).next(&branch.target).0;
167 if next_statement_idx < self.entry_point {
168 return Err(SegmentationError::JumpOutsideFunction(StatementIdx(idx)));
169 }
170 if next_statement_idx > self.max_jump_in_function {
171 self.max_jump_in_function = next_statement_idx;
172 self.max_jump_in_function_src = idx;
173 }
174 }
175 }
176 Statement::Return(_) => {}
177 }
178 Ok(())
179 }
180}
181
182fn consts_segments_offsets(cairo_program: &CairoProgram, bytecode_len: usize) -> Vec<usize> {
184 let const_segments_start_offset = bytecode_len - cairo_program.consts_info.total_segments_size;
185 cairo_program
186 .consts_info
187 .segments
188 .values()
189 .map(|segment| const_segments_start_offset + segment.segment_offset)
190 .collect()
191}