cairo_lang_sierra_generator/
program_generator.rs

1use std::collections::VecDeque;
2use std::sync::Arc;
3
4use cairo_lang_debug::DebugWithDb;
5use cairo_lang_defs::diagnostic_utils::StableLocation;
6use cairo_lang_diagnostics::{Maybe, get_location_marks};
7use cairo_lang_filesystem::ids::CrateId;
8use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId;
9use cairo_lang_sierra::extensions::GenericLibfuncEx;
10use cairo_lang_sierra::extensions::core::CoreLibfunc;
11use cairo_lang_sierra::ids::{ConcreteLibfuncId, ConcreteTypeId};
12use cairo_lang_sierra::program::{self, DeclaredTypeInfo, Program, StatementIdx};
13use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
14use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
15use cairo_lang_utils::{LookupIntern, try_extract_matches};
16use itertools::{Itertools, chain};
17
18use crate::db::{SierraGenGroup, sierra_concrete_long_id};
19use crate::extra_sierra_info::type_has_const_size;
20use crate::pre_sierra;
21use crate::replace_ids::{DebugReplacer, SierraIdReplacer};
22use crate::resolve_labels::{LabelReplacer, resolve_labels_and_extract_locations};
23use crate::specialization_context::SierraSignatureSpecializationContext;
24use crate::statements_locations::StatementsLocations;
25
26#[cfg(test)]
27#[path = "program_generator_test.rs"]
28mod test;
29
30/// Generates the list of [cairo_lang_sierra::program::LibfuncDeclaration] for the given list of
31/// [ConcreteLibfuncId].
32fn generate_libfunc_declarations<'a>(
33    db: &dyn SierraGenGroup,
34    libfuncs: impl Iterator<Item = &'a ConcreteLibfuncId>,
35) -> Vec<program::LibfuncDeclaration> {
36    libfuncs
37        .into_iter()
38        .map(|libfunc_id| program::LibfuncDeclaration {
39            id: libfunc_id.clone(),
40            long_id: libfunc_id.lookup_intern(db),
41        })
42        .collect()
43}
44
45/// Collects the set of all [ConcreteLibfuncId] used in the given list of [pre_sierra::Statement].
46fn collect_used_libfuncs(
47    statements: &[pre_sierra::StatementWithLocation],
48) -> OrderedHashSet<ConcreteLibfuncId> {
49    statements
50        .iter()
51        .filter_map(|statement| match &statement.statement {
52            pre_sierra::Statement::Sierra(program::GenStatement::Invocation(invocation)) => {
53                Some(invocation.libfunc_id.clone())
54            }
55            pre_sierra::Statement::Sierra(program::GenStatement::Return(_))
56            | pre_sierra::Statement::Label(_) => None,
57            pre_sierra::Statement::PushValues(_) => {
58                panic!("Unexpected pre_sierra::Statement::PushValues in collect_used_libfuncs().")
59            }
60        })
61        .collect()
62}
63
64/// Generates the list of [cairo_lang_sierra::program::TypeDeclaration] for the given list of
65/// [ConcreteTypeId].
66fn generate_type_declarations(
67    db: &dyn SierraGenGroup,
68    mut remaining_types: OrderedHashSet<ConcreteTypeId>,
69) -> Vec<program::TypeDeclaration> {
70    let mut declarations = vec![];
71    let mut already_declared = UnorderedHashSet::default();
72    while let Some(ty) = remaining_types.iter().next().cloned() {
73        remaining_types.swap_remove(&ty);
74        generate_type_declarations_helper(
75            db,
76            &ty,
77            &mut declarations,
78            &mut remaining_types,
79            &mut already_declared,
80        );
81    }
82    declarations
83}
84
85/// Helper to ensure declaring types ordered in such a way that no type appears before types it
86/// depends on for knowing its size.
87/// `remaining_types` are types that will later be checked.
88/// We may add types to there if we are not sure their dependencies are already declared.
89fn generate_type_declarations_helper(
90    db: &dyn SierraGenGroup,
91    ty: &ConcreteTypeId,
92    declarations: &mut Vec<program::TypeDeclaration>,
93    remaining_types: &mut OrderedHashSet<ConcreteTypeId>,
94    already_declared: &mut UnorderedHashSet<ConcreteTypeId>,
95) {
96    if already_declared.contains(ty) {
97        return;
98    }
99    let long_id = sierra_concrete_long_id(db, ty.clone()).unwrap();
100    already_declared.insert(ty.clone());
101    let inner_tys = long_id
102        .generic_args
103        .iter()
104        .filter_map(|arg| try_extract_matches!(arg, program::GenericArg::Type));
105    // Making sure we order the types such that types that require others to know their size are
106    // after the required types. Types that always have a known size would be first.
107    if type_has_const_size(&long_id.generic_id) {
108        remaining_types.extend(inner_tys.cloned());
109    } else {
110        for inner_ty in inner_tys {
111            generate_type_declarations_helper(
112                db,
113                inner_ty,
114                declarations,
115                remaining_types,
116                already_declared,
117            );
118        }
119    }
120
121    let type_info = db.get_type_info(ty.clone()).unwrap();
122    declarations.push(program::TypeDeclaration {
123        id: ty.clone(),
124        long_id: long_id.as_ref().clone(),
125        declared_type_info: Some(DeclaredTypeInfo {
126            storable: type_info.storable,
127            droppable: type_info.droppable,
128            duplicatable: type_info.duplicatable,
129            zero_sized: type_info.zero_sized,
130        }),
131    });
132}
133
134/// Collects the set of all [ConcreteTypeId] that are used in the given lists of
135/// [program::LibfuncDeclaration] and user functions.
136fn collect_used_types(
137    db: &dyn SierraGenGroup,
138    libfunc_declarations: &[program::LibfuncDeclaration],
139    functions: &[program::Function],
140) -> OrderedHashSet<ConcreteTypeId> {
141    // Collect types that appear in libfuncs.
142    let types_in_libfuncs = libfunc_declarations.iter().flat_map(|libfunc| {
143        // TODO(orizi): replace expect() with a diagnostic (unless this can never happen).
144        let signature = CoreLibfunc::specialize_signature_by_id(
145                &SierraSignatureSpecializationContext(db),
146                &libfunc.long_id.generic_id,
147                &libfunc.long_id.generic_args,
148            )
149            // If panic happens here, make sure the specified libfunc name is in one of the STR_IDs of
150            // the libfuncs in the [`CoreLibfunc`] structured enum.
151            .unwrap_or_else(|err| panic!("Failed to specialize: `{}`. Error: {err}",
152                DebugReplacer { db }.replace_libfunc_id(&libfunc.id)));
153        chain!(
154            signature.param_signatures.into_iter().map(|param_signature| param_signature.ty),
155            signature.branch_signatures.into_iter().flat_map(|info| info.vars).map(|var| var.ty),
156            libfunc.long_id.generic_args.iter().filter_map(|arg| match arg {
157                program::GenericArg::Type(ty) => Some(ty.clone()),
158                _ => None,
159            })
160        )
161    });
162
163    // Gather types used in user-defined functions.
164    // This is necessary for types that are used as entry point arguments but do not appear in any
165    // libfunc. For instance, if an entry point takes and returns an empty struct and no
166    // libfuncs are involved, we still need to declare that struct.
167    // Additionally, we include the return types of functions, since with unsafe panic enabled,
168    // a function that always panics might declare a return type that does not appear in anywhere
169    // else in the program.
170    let types_in_user_functions = functions
171        .iter()
172        .flat_map(|func| chain!(&func.signature.param_types, &func.signature.ret_types).cloned());
173
174    chain!(types_in_libfuncs, types_in_user_functions).collect()
175}
176
177#[derive(Clone, Debug, Eq, PartialEq)]
178pub struct SierraProgramWithDebug {
179    pub program: cairo_lang_sierra::program::Program,
180    pub debug_info: SierraProgramDebugInfo,
181}
182/// Implementation for a debug print of a Sierra program with all locations.
183/// The print is a valid textual Sierra program.
184impl DebugWithDb<dyn SierraGenGroup> for SierraProgramWithDebug {
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &dyn SierraGenGroup) -> std::fmt::Result {
186        let sierra_program = DebugReplacer { db }.apply(&self.program);
187        for declaration in &sierra_program.type_declarations {
188            writeln!(f, "{declaration};")?;
189        }
190        writeln!(f)?;
191        for declaration in &sierra_program.libfunc_declarations {
192            writeln!(f, "{declaration};")?;
193        }
194        writeln!(f)?;
195        let mut funcs = sierra_program.funcs.iter().peekable();
196        while let Some(func) = funcs.next() {
197            let start = func.entry_point.0;
198            let end = funcs
199                .peek()
200                .map(|f| f.entry_point.0)
201                .unwrap_or_else(|| sierra_program.statements.len());
202            writeln!(f, "// {}:", func.id)?;
203            for param in &func.params {
204                writeln!(f, "//   {param}")?;
205            }
206            for i in start..end {
207                writeln!(f, "{}; // {i}", sierra_program.statements[i])?;
208                if let Some(loc) =
209                    &self.debug_info.statements_locations.locations.get(&StatementIdx(i))
210                {
211                    let loc =
212                        get_location_marks(db, &loc.first().unwrap().diagnostic_location(db), true);
213                    println!("{}", loc.split('\n').map(|l| format!("// {l}")).join("\n"));
214                }
215            }
216        }
217        writeln!(f)?;
218        for func in &sierra_program.funcs {
219            writeln!(f, "{func};")?;
220        }
221        Ok(())
222    }
223}
224
225#[derive(Clone, Debug, Eq, PartialEq)]
226pub struct SierraProgramDebugInfo {
227    pub statements_locations: StatementsLocations,
228}
229
230pub fn get_sierra_program_for_functions(
231    db: &dyn SierraGenGroup,
232    requested_function_ids: Vec<ConcreteFunctionWithBodyId>,
233) -> Maybe<Arc<SierraProgramWithDebug>> {
234    let mut functions: Vec<Arc<pre_sierra::Function>> = vec![];
235    let mut statements: Vec<pre_sierra::StatementWithLocation> = vec![];
236    let mut processed_function_ids = UnorderedHashSet::<ConcreteFunctionWithBodyId>::default();
237    let mut function_id_queue: VecDeque<ConcreteFunctionWithBodyId> =
238        requested_function_ids.into_iter().collect();
239    while let Some(function_id) = function_id_queue.pop_front() {
240        if !processed_function_ids.insert(function_id) {
241            continue;
242        }
243        let function: Arc<pre_sierra::Function> = db.function_with_body_sierra(function_id)?;
244        functions.push(function.clone());
245        statements.extend_from_slice(&function.body);
246
247        for statement in &function.body {
248            if let Some(related_function_id) = try_get_function_with_body_id(db, statement) {
249                function_id_queue.push_back(related_function_id);
250            }
251        }
252    }
253
254    let (program, statements_locations) = assemble_program(db, functions, statements);
255    Ok(Arc::new(SierraProgramWithDebug {
256        program,
257        debug_info: SierraProgramDebugInfo {
258            statements_locations: StatementsLocations::from_locations_vec(&statements_locations),
259        },
260    }))
261}
262
263/// Given a list of functions and statements, generates a Sierra program.
264/// Returns the program and the locations of the statements in the program.
265pub fn assemble_program(
266    db: &dyn SierraGenGroup,
267    functions: Vec<Arc<pre_sierra::Function>>,
268    statements: Vec<pre_sierra::StatementWithLocation>,
269) -> (program::Program, Vec<Vec<StableLocation>>) {
270    let label_replacer = LabelReplacer::from_statements(&statements);
271    let funcs = functions
272        .into_iter()
273        .map(|function| {
274            let sierra_signature = db.get_function_signature(function.id.clone()).unwrap();
275            program::Function::new(
276                function.id.clone(),
277                function.parameters.clone(),
278                sierra_signature.ret_types.clone(),
279                label_replacer.handle_label_id(function.entry_point),
280            )
281        })
282        .collect_vec();
283
284    let libfunc_declarations =
285        generate_libfunc_declarations(db, collect_used_libfuncs(&statements).iter());
286    let type_declarations =
287        generate_type_declarations(db, collect_used_types(db, &libfunc_declarations, &funcs));
288    // Resolve labels.
289    let (resolved_statements, statements_locations) =
290        resolve_labels_and_extract_locations(statements, &label_replacer);
291    let program = program::Program {
292        type_declarations,
293        libfunc_declarations,
294        statements: resolved_statements,
295        funcs,
296    };
297    (program, statements_locations)
298}
299
300/// Tries extracting a ConcreteFunctionWithBodyId from a pre-Sierra statement.
301pub fn try_get_function_with_body_id(
302    db: &dyn SierraGenGroup,
303    statement: &pre_sierra::StatementWithLocation,
304) -> Option<ConcreteFunctionWithBodyId> {
305    let invc = try_extract_matches!(
306        try_extract_matches!(&statement.statement, pre_sierra::Statement::Sierra)?,
307        program::GenStatement::Invocation
308    )?;
309    let libfunc = invc.libfunc_id.lookup_intern(db);
310    let inner_function = if libfunc.generic_id == "function_call".into()
311        || libfunc.generic_id == "coupon_call".into()
312    {
313        libfunc.generic_args.first()?.clone()
314    } else if libfunc.generic_id == "coupon_buy".into()
315        || libfunc.generic_id == "coupon_refund".into()
316    {
317        // TODO(lior): Instead of this code, unused coupons should be replaced with the unit type
318        //   or with a zero-valued coupon. Currently, the code is not optimal (since the coupon
319        //   costs more than it should) and some programs may not compile (if the coupon is
320        //   mentioned as a type but not mentioned in any libfuncs).
321        let coupon_ty = try_extract_matches!(
322            libfunc.generic_args.first()?,
323            cairo_lang_sierra::program::GenericArg::Type
324        )?;
325        let coupon_long_id = sierra_concrete_long_id(db, coupon_ty.clone()).unwrap();
326        coupon_long_id.generic_args.first()?.clone()
327    } else {
328        return None;
329    };
330
331    try_extract_matches!(inner_function, cairo_lang_sierra::program::GenericArg::UserFunc)?
332        .lookup_intern(db)
333        .body(db)
334        .expect("No diagnostics at this stage.")
335}
336
337pub fn get_sierra_program(
338    db: &dyn SierraGenGroup,
339    requested_crate_ids: Vec<CrateId>,
340) -> Maybe<Arc<SierraProgramWithDebug>> {
341    let mut requested_function_ids = vec![];
342    for crate_id in requested_crate_ids {
343        for module_id in db.crate_modules(crate_id).iter() {
344            for (free_func_id, _) in db.module_free_functions(*module_id)?.iter() {
345                // TODO(spapini): Search Impl functions.
346                if let Some(function) =
347                    ConcreteFunctionWithBodyId::from_no_generics_free(db, *free_func_id)
348                {
349                    requested_function_ids.push(function)
350                }
351            }
352        }
353    }
354    db.get_sierra_program_for_functions(requested_function_ids)
355}
356
357/// Given `function_id` generates a dummy program with the body of the relevant function
358/// and dummy helper functions that allows the program to be compiled to casm.
359/// The generated program is not valid, but it can be used to estimate the size of the
360/// relevant function.
361pub fn get_dummy_program_for_size_estimation(
362    db: &dyn SierraGenGroup,
363    function_id: ConcreteFunctionWithBodyId,
364) -> Maybe<Program> {
365    let function: Arc<pre_sierra::Function> = db.function_with_body_sierra(function_id)?;
366
367    let mut processed_function_ids =
368        UnorderedHashSet::<ConcreteFunctionWithBodyId>::from_iter([function_id]);
369
370    let mut functions = vec![function.clone()];
371
372    for statement in &function.body {
373        if let Some(function_id) = try_get_function_with_body_id(db, statement) {
374            if processed_function_ids.insert(function_id) {
375                functions.push(db.priv_get_dummy_function(function_id)?);
376            }
377        }
378    }
379    // Since we are not interested in the locations, we can remove them from the statements.
380    let statements = functions
381        .iter()
382        .flat_map(|f| f.body.iter())
383        .map(|s| s.statement.clone().into_statement_without_location())
384        .collect();
385
386    let (program, _statements_locations) = assemble_program(db, functions, statements);
387
388    Ok(program)
389}