cairo_lang_starknet_classes/
contract_segmentation.rs

1#[cfg(test)]
2#[path = "contract_segmentation_test.rs"]
3mod test;
4
5use cairo_lang_sierra::program::{Program, Statement, StatementIdx};
6use cairo_lang_sierra_to_casm::compiler::CairoProgram;
7use cairo_lang_utils::require;
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10
11/// NestedIntList is either a list of NestedIntList or an integer.
12/// E.g., `[0, [1, 2], [3, [4]]]`.
13///
14/// Used to represent the lengths of the segments in a contract, which are in the form of a tree.
15///
16/// For example, the contract may be segmented by functions, where each function is segmented by
17/// its branches. It is also possible to have the inner segmentation only for some of the functions,
18/// while others are kept as non-segmented leaves in the tree.
19#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(untagged)]
21pub enum NestedIntList {
22    Leaf(usize),
23    Node(Vec<NestedIntList>),
24}
25
26#[derive(Error, Debug, Eq, PartialEq)]
27pub enum SegmentationError {
28    #[error("Expected a function start at index 0.")]
29    NoFunctionStartAtZero,
30    #[error("Jump outside of function boundaries.")]
31    JumpOutsideFunction(StatementIdx),
32}
33
34/// Computes the bytecode_segment_length for the given contract.
35pub fn compute_bytecode_segment_lengths(
36    program: &Program,
37    cairo_program: &CairoProgram,
38    bytecode_len: usize,
39) -> Result<NestedIntList, SegmentationError> {
40    if bytecode_len == 0 {
41        return Ok(NestedIntList::Leaf(0));
42    }
43    let functions_segment_start_statements = find_functions_segments(program)?;
44    let mut segment_start_offsets =
45        functions_statement_ids_to_offsets(cairo_program, &functions_segment_start_statements);
46    segment_start_offsets.extend(consts_segments_offsets(cairo_program, bytecode_len));
47
48    Ok(NestedIntList::Node(
49        get_segment_lengths(&segment_start_offsets, bytecode_len)
50            .iter()
51            .map(|length| NestedIntList::Leaf(*length))
52            .collect(),
53    ))
54}
55
56/// Returns a vector that contains the starts (as statement indices) of the functions.
57fn find_functions_segments(program: &Program) -> Result<Vec<usize>, SegmentationError> {
58    // Get the set of function entry points.
59    let mut function_statement_ids: Vec<usize> =
60        program.funcs.iter().map(|func| func.entry_point.0).collect();
61    function_statement_ids.sort();
62    require(matches!(function_statement_ids.first(), Some(0)))
63        .ok_or(SegmentationError::NoFunctionStartAtZero)?;
64
65    // Sanity check: go over the statements and check that there are no jumps outside of functions.
66    let mut current_function = FunctionInfo::new(0);
67    let mut next_function_idx = 1;
68
69    // Iterate over all statements and collect the segments' starts.
70    for (idx, statement) in program.statements.iter().enumerate() {
71        // Check if this is the beginning of a new function.
72        if function_statement_ids.get(next_function_idx) == Some(&idx) {
73            current_function.finalize(idx)?;
74            current_function = FunctionInfo::new(idx);
75            next_function_idx += 1;
76        }
77        current_function.visit_statement(idx, statement)?;
78    }
79
80    current_function.finalize(program.statements.len())?;
81
82    Ok(function_statement_ids)
83}
84
85/// Converts the result of [find_functions_segments] from statement ids to bytecode offsets.
86fn functions_statement_ids_to_offsets(
87    cairo_program: &CairoProgram,
88    segment_starts_statements: &[usize],
89) -> Vec<usize> {
90    let statement_to_offset = |statement_id: usize| {
91        cairo_program
92            .debug_info
93            .sierra_statement_info
94            .get(statement_id)
95            .unwrap_or_else(|| panic!("Missing bytecode offset for statement id {statement_id}."))
96            .start_offset
97    };
98    segment_starts_statements.iter().map(|start| statement_to_offset(*start)).collect()
99}
100
101/// Returns a vector that contains the lengths of the segments.
102fn get_segment_lengths(segment_starts_offsets: &[usize], bytecode_len: usize) -> Vec<usize> {
103    // Compute the lengths of the segments from their start points.
104    let mut segment_lengths = vec![];
105    for i in 1..segment_starts_offsets.len() {
106        let segment_size = segment_starts_offsets[i] - segment_starts_offsets[i - 1];
107        if segment_size > 0 {
108            segment_lengths.push(segment_size);
109        }
110    }
111
112    // Handle the last segment.
113    let last_offset =
114        segment_starts_offsets.last().expect("Segmentation error: No function found.");
115    let segment_size = bytecode_len - last_offset;
116    if segment_size > 0 {
117        segment_lengths.push(segment_size);
118    }
119
120    segment_lengths
121}
122
123/// Helper struct for [find_functions_segments].
124/// Represents a single function and its segments.
125struct FunctionInfo {
126    entry_point: usize,
127    /// The maximal StatementIdx which we saw a jump to in the function.
128    max_jump_in_function: usize,
129    /// The statement that performed the jump to max_jump_in_function.
130    max_jump_in_function_src: usize,
131}
132impl FunctionInfo {
133    /// Creates a new FunctionInfo, for a function with the given entry point.
134    fn new(entry_point: usize) -> Self {
135        Self {
136            entry_point,
137            max_jump_in_function: entry_point,
138            max_jump_in_function_src: entry_point,
139        }
140    }
141
142    /// Finalizes the current function handling.
143    ///
144    /// `function_end` is the statement index following the last statement in the function.
145    ///
146    /// Returns the segment starts for the function.
147    fn finalize(self, function_end: usize) -> Result<(), SegmentationError> {
148        // Check that we did not see a jump after the function's end.
149        if self.max_jump_in_function >= function_end {
150            return Err(SegmentationError::JumpOutsideFunction(StatementIdx(
151                self.max_jump_in_function_src,
152            )));
153        }
154        Ok(())
155    }
156
157    /// Visits a statement inside the function and updates the [FunctionInfo] accordingly.
158    fn visit_statement(
159        &mut self,
160        idx: usize,
161        statement: &Statement,
162    ) -> Result<(), SegmentationError> {
163        match statement {
164            Statement::Invocation(invocation) => {
165                for branch in &invocation.branches {
166                    let next_statement_idx = StatementIdx(idx).next(&branch.target).0;
167                    if next_statement_idx < self.entry_point {
168                        return Err(SegmentationError::JumpOutsideFunction(StatementIdx(idx)));
169                    }
170                    if next_statement_idx > self.max_jump_in_function {
171                        self.max_jump_in_function = next_statement_idx;
172                        self.max_jump_in_function_src = idx;
173                    }
174                }
175            }
176            Statement::Return(_) => {}
177        }
178        Ok(())
179    }
180}
181
182/// Returns the offsets of the const segments.
183fn consts_segments_offsets(cairo_program: &CairoProgram, bytecode_len: usize) -> Vec<usize> {
184    let const_segments_start_offset = bytecode_len - cairo_program.consts_info.total_segments_size;
185    cairo_program
186        .consts_info
187        .segments
188        .values()
189        .map(|segment| const_segments_start_offset + segment.segment_offset)
190        .collect()
191}