cairo_lang_sierra_generator/
program_generator.rs

1use std::collections::VecDeque;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_defs::db::DefsGroup;
5use cairo_lang_diagnostics::Maybe;
6use cairo_lang_filesystem::ids::{CrateId, Tracked};
7use cairo_lang_filesystem::location_marks::get_location_marks;
8use cairo_lang_lowering::ids::{ConcreteFunctionWithBodyId, LocationId};
9use cairo_lang_sierra::extensions::GenericLibfuncEx;
10use cairo_lang_sierra::extensions::core::CoreLibfunc;
11use cairo_lang_sierra::ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId};
12use cairo_lang_sierra::program::{self, DeclaredTypeInfo, Program, StatementIdx};
13use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
14use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
15use cairo_lang_utils::try_extract_matches;
16use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
17use itertools::{Itertools, chain};
18use salsa::Database;
19
20use crate::db::{SierraGenGroup, sierra_concrete_long_id};
21use crate::debug_info::{
22    AllFunctionsDebugInfo, FunctionDebugInfo, SierraProgramDebugInfo, StatementsLocations,
23};
24use crate::extra_sierra_info::type_has_const_size;
25use crate::pre_sierra;
26use crate::replace_ids::{DebugReplacer, SierraIdReplacer};
27use crate::resolve_labels::{LabelReplacer, resolve_labels_and_extract_locations};
28use crate::specialization_context::SierraSignatureSpecializationContext;
29
30#[cfg(test)]
31#[path = "program_generator_test.rs"]
32mod test;
33
34/// Generates the list of [cairo_lang_sierra::program::LibfuncDeclaration] for the given list of
35/// [pre_sierra::StatementWithLocation].
36fn collect_and_generate_libfunc_declarations<'db>(
37    db: &dyn Database,
38    statements: &[pre_sierra::StatementWithLocation<'db>],
39) -> Vec<program::LibfuncDeclaration> {
40    let mut declared_libfuncs = UnorderedHashSet::<ConcreteLibfuncId>::default();
41    statements
42        .iter()
43        .filter_map(|statement| match &statement.statement {
44            pre_sierra::Statement::Sierra(program::GenStatement::Invocation(invocation)) => {
45                declared_libfuncs.insert(invocation.libfunc_id.clone()).then(|| {
46                    program::LibfuncDeclaration {
47                        id: invocation.libfunc_id.clone(),
48                        long_id: db.lookup_concrete_lib_func(&invocation.libfunc_id),
49                    }
50                })
51            }
52            pre_sierra::Statement::Sierra(program::GenStatement::Return(_))
53            | pre_sierra::Statement::Label(_) => None,
54            pre_sierra::Statement::PushValues(_) => {
55                panic!(
56                    "Unexpected pre_sierra::Statement::PushValues in \
57                     collect_and_generate_libfunc_declarations()."
58                )
59            }
60        })
61        .collect()
62}
63
64/// Generates the list of [cairo_lang_sierra::program::TypeDeclaration] for the given list of
65/// [ConcreteTypeId].
66fn generate_type_declarations(
67    db: &dyn Database,
68    libfunc_declarations: &[program::LibfuncDeclaration],
69    functions: &[program::Function],
70) -> Vec<program::TypeDeclaration> {
71    let mut declarations = vec![];
72    let mut already_declared = UnorderedHashSet::default();
73    let mut remaining_types = collect_used_types(db, libfunc_declarations, functions);
74    while let Some(ty) = remaining_types.iter().next().cloned() {
75        remaining_types.swap_remove(&ty);
76        generate_type_declarations_helper(
77            db,
78            &ty,
79            &mut declarations,
80            &mut remaining_types,
81            &mut already_declared,
82        );
83    }
84    declarations
85}
86
87/// Helper to ensure declaring types ordered in such a way that no type appears before types it
88/// depends on for knowing its size.
89/// `remaining_types` are types that will later be checked.
90/// We may add types to there if we are not sure their dependencies are already declared.
91fn generate_type_declarations_helper(
92    db: &dyn Database,
93    ty: &ConcreteTypeId,
94    declarations: &mut Vec<program::TypeDeclaration>,
95    remaining_types: &mut OrderedHashSet<ConcreteTypeId>,
96    already_declared: &mut UnorderedHashSet<ConcreteTypeId>,
97) {
98    if already_declared.contains(ty) {
99        return;
100    }
101    let long_id = sierra_concrete_long_id(db, ty.clone()).unwrap();
102    already_declared.insert(ty.clone());
103    let inner_tys = long_id
104        .generic_args
105        .iter()
106        .filter_map(|arg| try_extract_matches!(arg, program::GenericArg::Type));
107    // Making sure we order the types such that types that require others to know their size are
108    // after the required types. Types that always have a known size would be first.
109    if type_has_const_size(&long_id.generic_id) {
110        remaining_types.extend(inner_tys.cloned());
111    } else {
112        for inner_ty in inner_tys {
113            generate_type_declarations_helper(
114                db,
115                inner_ty,
116                declarations,
117                remaining_types,
118                already_declared,
119            );
120        }
121    }
122
123    let type_info = db.get_type_info(ty.clone()).unwrap();
124    declarations.push(program::TypeDeclaration {
125        id: ty.clone(),
126        long_id: long_id.as_ref().clone(),
127        declared_type_info: Some(DeclaredTypeInfo {
128            storable: type_info.storable,
129            droppable: type_info.droppable,
130            duplicatable: type_info.duplicatable,
131            zero_sized: type_info.zero_sized,
132        }),
133    });
134}
135
136/// Collects the set of all [ConcreteTypeId] that are used in the given lists of
137/// [program::LibfuncDeclaration] and user functions.
138fn collect_used_types(
139    db: &dyn Database,
140    libfunc_declarations: &[program::LibfuncDeclaration],
141    functions: &[program::Function],
142) -> OrderedHashSet<ConcreteTypeId> {
143    let mut all_types = OrderedHashSet::default();
144    // Collect types that appear in libfuncs.
145    for libfunc in libfunc_declarations {
146        let types = db.priv_libfunc_dependencies(libfunc.id.clone());
147        all_types.extend(types.iter().cloned());
148    }
149
150    // Gather types used in user-defined functions.
151    // This is necessary for types that are used as entry point arguments but do not appear in any
152    // libfunc. For instance, if an entry point takes and returns an empty struct and no
153    // libfuncs are involved, we still need to declare that struct.
154    // Additionally, we include the return types of functions, since with unsafe panic enabled,
155    // a function that always panics might declare a return type that does not appear in anywhere
156    // else in the program.
157    all_types.extend(
158        functions.iter().flat_map(|func| {
159            chain!(&func.signature.param_types, &func.signature.ret_types).cloned()
160        }),
161    );
162    all_types
163}
164
165/// Query implementation of [SierraGenGroup::priv_libfunc_dependencies].
166#[salsa::tracked(returns(ref))]
167pub fn priv_libfunc_dependencies(
168    db: &dyn Database,
169    _tracked: Tracked,
170    libfunc_id: ConcreteLibfuncId,
171) -> Vec<ConcreteTypeId> {
172    let long_id = db.lookup_concrete_lib_func(&libfunc_id);
173    let signature = CoreLibfunc::specialize_signature_by_id(
174        &SierraSignatureSpecializationContext(db),
175        &long_id.generic_id,
176        &long_id.generic_args,
177    )
178    // If panic happens here, make sure the specified libfunc name is in one of the STR_IDs of
179    // the libfuncs in the [`CoreLibfunc`] structured enum.
180    .unwrap_or_else(|err| panic!("Failed to specialize: `{}`. Error: {err}",
181        DebugReplacer { db }.replace_libfunc_id(&libfunc_id)));
182    // Collecting types as a vector since the set should be very small.
183    let mut all_types = vec![];
184    let mut add_ty = |ty: ConcreteTypeId| {
185        if !all_types.contains(&ty) {
186            all_types.push(ty);
187        }
188    };
189    for param_signature in signature.param_signatures {
190        add_ty(param_signature.ty);
191    }
192    for info in signature.branch_signatures {
193        for var in info.vars {
194            add_ty(var.ty);
195        }
196    }
197    for arg in long_id.generic_args {
198        if let program::GenericArg::Type(ty) = arg {
199            add_ty(ty);
200        }
201    }
202    all_types
203}
204
205#[derive(Clone, Debug, Eq, PartialEq)]
206pub struct SierraProgramWithDebug<'db> {
207    pub program: cairo_lang_sierra::program::Program,
208    pub debug_info: SierraProgramDebugInfo<'db>,
209}
210
211unsafe impl<'db> salsa::Update for SierraProgramWithDebug<'db> {
212    unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
213        let old_value = unsafe { &mut *old_pointer };
214        if old_value == &new_value {
215            return false;
216        }
217        *old_value = new_value;
218        true
219    }
220}
221/// Implementation for a debug print of a Sierra program with all locations.
222/// The print is a valid textual Sierra program.
223impl<'db> DebugWithDb<'db> for SierraProgramWithDebug<'db> {
224    type Db = dyn Database;
225
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &dyn Database) -> std::fmt::Result {
227        let sierra_program = DebugReplacer { db }.apply(&self.program);
228        for declaration in &sierra_program.type_declarations {
229            writeln!(f, "{declaration};")?;
230        }
231        writeln!(f)?;
232        for declaration in &sierra_program.libfunc_declarations {
233            writeln!(f, "{declaration};")?;
234        }
235        writeln!(f)?;
236        let mut funcs = sierra_program.funcs.iter().peekable();
237        while let Some(func) = funcs.next() {
238            let start = func.entry_point.0;
239            let end = funcs
240                .peek()
241                .map(|f| f.entry_point.0)
242                .unwrap_or_else(|| sierra_program.statements.len());
243            writeln!(f, "// {}:", func.id)?;
244            for param in &func.params {
245                writeln!(f, "//   {param}")?;
246            }
247            for i in start..end {
248                writeln!(f, "{}; // {i}", sierra_program.statements[i])?;
249                if let Some(loc) =
250                    &self.debug_info.statements_locations.locations.get(&StatementIdx(i))
251                {
252                    let loc = get_location_marks(db, &loc.first().unwrap().span_in_file(db), true);
253                    println!("{}", loc.split('\n').map(|l| format!("// {l}")).join("\n"));
254                }
255            }
256        }
257        writeln!(f)?;
258        for func in &sierra_program.funcs {
259            writeln!(f, "{func};")?;
260        }
261        Ok(())
262    }
263}
264
265#[salsa::tracked(returns(ref))]
266pub fn get_sierra_program_for_functions<'db>(
267    db: &'db dyn Database,
268    _tracked: Tracked,
269    requested_function_ids: Vec<ConcreteFunctionWithBodyId<'db>>,
270) -> Maybe<SierraProgramWithDebug<'db>> {
271    let mut functions: Vec<&'db pre_sierra::Function<'_>> = vec![];
272    let mut statements: Vec<pre_sierra::StatementWithLocation<'_>> = vec![];
273    let mut processed_function_ids = UnorderedHashSet::<ConcreteFunctionWithBodyId<'_>>::default();
274    let mut function_id_queue: VecDeque<ConcreteFunctionWithBodyId<'_>> =
275        requested_function_ids.into_iter().collect();
276    while let Some(function_id) = function_id_queue.pop_front() {
277        if !processed_function_ids.insert(function_id) {
278            continue;
279        }
280        let function = db.function_with_body_sierra(function_id)?;
281        functions.push(function);
282        statements.extend_from_slice(&function.body);
283
284        for statement in &function.body {
285            if let Some(related_function_id) = try_get_function_with_body_id(db, statement) {
286                function_id_queue.push_back(related_function_id);
287            }
288        }
289    }
290
291    let AssembledProgram { program, statements_locations, functions_info } =
292        assemble_program(db, functions, statements);
293    Ok(SierraProgramWithDebug {
294        program,
295        debug_info: SierraProgramDebugInfo {
296            statements_locations: StatementsLocations::from_locations_vec(db, statements_locations),
297            functions_info: AllFunctionsDebugInfo::new(functions_info),
298        },
299    })
300}
301
302/// Return value of `assemble_program`.
303struct AssembledProgram<'db> {
304    /// The actual program.
305    program: program::Program,
306    /// The locations per statement.
307    statements_locations: Vec<Option<LocationId<'db>>>,
308    /// The debug info of sierra functions.
309    functions_info: OrderedHashMap<FunctionId, FunctionDebugInfo<'db>>,
310}
311
312/// Given a list of functions and statements, generates a Sierra program.
313/// Returns the program and the locations of the statements in the program.
314fn assemble_program<'db>(
315    db: &dyn Database,
316    functions: Vec<&'db pre_sierra::Function<'db>>,
317    statements: Vec<pre_sierra::StatementWithLocation<'db>>,
318) -> AssembledProgram<'db> {
319    let label_replacer = LabelReplacer::from_statements(&statements);
320    let functions_info = functions
321        .iter()
322        .map(|f| {
323            (
324                f.id.clone(),
325                FunctionDebugInfo {
326                    signature_location: f.signature_location,
327                    variables_locations: f.variable_locations.iter().cloned().collect(),
328                },
329            )
330        })
331        .collect();
332    let funcs = functions
333        .into_iter()
334        .map(|function| {
335            let sierra_signature = db.get_function_signature(function.id.clone()).unwrap();
336            program::Function::new(
337                function.id.clone(),
338                function.parameters.clone(),
339                sierra_signature.ret_types.clone(),
340                label_replacer.handle_label_id(function.entry_point),
341            )
342        })
343        .collect_vec();
344
345    let libfunc_declarations = collect_and_generate_libfunc_declarations(db, &statements);
346    let type_declarations = generate_type_declarations(db, &libfunc_declarations, &funcs);
347    // Resolve labels.
348    let (resolved_statements, statements_locations) =
349        resolve_labels_and_extract_locations(statements, &label_replacer);
350    let program = program::Program {
351        type_declarations,
352        libfunc_declarations,
353        statements: resolved_statements,
354        funcs,
355    };
356    AssembledProgram { program, statements_locations, functions_info }
357}
358
359/// Tries extracting a ConcreteFunctionWithBodyId from a pre-Sierra statement.
360pub fn try_get_function_with_body_id<'db>(
361    db: &'db dyn Database,
362    statement: &pre_sierra::StatementWithLocation<'db>,
363) -> Option<ConcreteFunctionWithBodyId<'db>> {
364    let invc = try_extract_matches!(
365        try_extract_matches!(&statement.statement, pre_sierra::Statement::Sierra)?,
366        program::GenStatement::Invocation
367    )?;
368    let libfunc = db.lookup_concrete_lib_func(&invc.libfunc_id);
369    let inner_function = if libfunc.generic_id == "function_call".into()
370        || libfunc.generic_id == "coupon_call".into()
371    {
372        libfunc.generic_args.first()?.clone()
373    } else if libfunc.generic_id == "coupon_buy".into()
374        || libfunc.generic_id == "coupon_refund".into()
375    {
376        // TODO(lior): Instead of this code, unused coupons should be replaced with the unit type
377        //   or with a zero-valued coupon. Currently, the code is not optimal (since the coupon
378        //   costs more than it should) and some programs may not compile (if the coupon is
379        //   mentioned as a type but not mentioned in any libfuncs).
380        let coupon_ty = try_extract_matches!(
381            libfunc.generic_args.first()?,
382            cairo_lang_sierra::program::GenericArg::Type
383        )?;
384        let coupon_long_id = sierra_concrete_long_id(db, coupon_ty.clone()).unwrap();
385        coupon_long_id.generic_args.first()?.clone()
386    } else {
387        return None;
388    };
389
390    db.lookup_sierra_function(&try_extract_matches!(
391        inner_function,
392        cairo_lang_sierra::program::GenericArg::UserFunc
393    )?)
394    .body(db)
395    .expect("No diagnostics at this stage.")
396}
397
398#[salsa::tracked(returns(ref))]
399pub fn get_sierra_program<'db>(
400    db: &'db dyn Database,
401    _tracked: Tracked,
402    requested_crate_ids: Vec<CrateId<'db>>,
403) -> Maybe<SierraProgramWithDebug<'db>> {
404    let requested_function_ids = find_all_free_function_ids(db, requested_crate_ids)?;
405    db.get_sierra_program_for_functions(requested_function_ids).cloned()
406}
407
408/// Return [`ConcreteFunctionWithBodyId`] for all free functions in the given list of crates.
409pub fn find_all_free_function_ids<'db>(
410    db: &'db dyn Database,
411    requested_crate_ids: Vec<CrateId<'db>>,
412) -> Maybe<Vec<ConcreteFunctionWithBodyId<'db>>> {
413    let mut requested_function_ids = vec![];
414    for crate_id in requested_crate_ids {
415        for module_id in db.crate_modules(crate_id).iter() {
416            for (free_func_id, _) in module_id.module_data(db)?.free_functions(db).iter() {
417                // TODO(spapini): Search Impl functions.
418                if let Some(function) =
419                    ConcreteFunctionWithBodyId::from_no_generics_free(db, *free_func_id)
420                {
421                    requested_function_ids.push(function)
422                }
423            }
424        }
425    }
426    Ok(requested_function_ids)
427}
428
429/// Given `function_id` generates a dummy program with the body of the relevant function
430/// and dummy helper functions that allows the program to be compiled to CASM.
431/// The generated program is not valid, but it can be used to estimate the size of the
432/// relevant function.
433pub fn get_dummy_program_for_size_estimation(
434    db: &dyn Database,
435    function_id: ConcreteFunctionWithBodyId<'_>,
436) -> Maybe<Program> {
437    let function = db.function_with_body_sierra(function_id)?;
438
439    let mut processed_function_ids =
440        UnorderedHashSet::<ConcreteFunctionWithBodyId<'_>>::from_iter([function_id]);
441
442    let mut functions = vec![function];
443
444    for statement in &function.body {
445        if let Some(function_id) = try_get_function_with_body_id(db, statement) {
446            if processed_function_ids.insert(function_id) {
447                functions.push(db.priv_get_dummy_function(function_id)?);
448            }
449        }
450    }
451    // Since we are not interested in the locations, we can remove them from the statements.
452    let statements = functions
453        .iter()
454        .flat_map(|f| f.body.iter())
455        .map(|s| s.statement.clone().into_statement_without_location())
456        .collect();
457
458    Ok(assemble_program(db, functions, statements).program)
459}