cairo_lang_sierra_generator/
program_generator.rs

1use std::collections::VecDeque;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_defs::db::DefsGroup;
5use cairo_lang_diagnostics::Maybe;
6use cairo_lang_filesystem::ids::{CrateId, Tracked};
7use cairo_lang_filesystem::location_marks::get_location_marks;
8use cairo_lang_lowering::ids::{ConcreteFunctionWithBodyId, LocationId};
9use cairo_lang_sierra::extensions::GenericLibfuncEx;
10use cairo_lang_sierra::extensions::core::CoreLibfunc;
11use cairo_lang_sierra::ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId, VarId};
12use cairo_lang_sierra::program::{self, DeclaredTypeInfo, Program, StatementIdx};
13use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
14use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
15use cairo_lang_utils::try_extract_matches;
16use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
17use itertools::{Itertools, chain};
18use salsa::Database;
19
20use crate::db::{SierraGenGroup, sierra_concrete_long_id};
21use crate::extra_sierra_info::type_has_const_size;
22use crate::pre_sierra;
23use crate::replace_ids::{DebugReplacer, SierraIdReplacer};
24use crate::resolve_labels::{LabelReplacer, resolve_labels_and_extract_locations};
25use crate::specialization_context::SierraSignatureSpecializationContext;
26use crate::statements_locations::StatementsLocations;
27
28#[cfg(test)]
29#[path = "program_generator_test.rs"]
30mod test;
31
32/// Generates the list of [cairo_lang_sierra::program::LibfuncDeclaration] for the given list of
33/// [pre_sierra::StatementWithLocation].
34fn collect_and_generate_libfunc_declarations<'db>(
35    db: &dyn Database,
36    statements: &[pre_sierra::StatementWithLocation<'db>],
37) -> Vec<program::LibfuncDeclaration> {
38    let mut declared_libfuncs = UnorderedHashSet::<ConcreteLibfuncId>::default();
39    statements
40        .iter()
41        .filter_map(|statement| match &statement.statement {
42            pre_sierra::Statement::Sierra(program::GenStatement::Invocation(invocation)) => {
43                declared_libfuncs.insert(invocation.libfunc_id.clone()).then(|| {
44                    program::LibfuncDeclaration {
45                        id: invocation.libfunc_id.clone(),
46                        long_id: db.lookup_concrete_lib_func(&invocation.libfunc_id),
47                    }
48                })
49            }
50            pre_sierra::Statement::Sierra(program::GenStatement::Return(_))
51            | pre_sierra::Statement::Label(_) => None,
52            pre_sierra::Statement::PushValues(_) => {
53                panic!("Unexpected pre_sierra::Statement::PushValues in collect_used_libfuncs().")
54            }
55        })
56        .collect()
57}
58
59/// Generates the list of [cairo_lang_sierra::program::TypeDeclaration] for the given list of
60/// [ConcreteTypeId].
61fn generate_type_declarations(
62    db: &dyn Database,
63    libfunc_declarations: &[program::LibfuncDeclaration],
64    functions: &[program::Function],
65) -> Vec<program::TypeDeclaration> {
66    let mut declarations = vec![];
67    let mut already_declared = UnorderedHashSet::default();
68    let mut remaining_types = collect_used_types(db, libfunc_declarations, functions);
69    while let Some(ty) = remaining_types.iter().next().cloned() {
70        remaining_types.swap_remove(&ty);
71        generate_type_declarations_helper(
72            db,
73            &ty,
74            &mut declarations,
75            &mut remaining_types,
76            &mut already_declared,
77        );
78    }
79    declarations
80}
81
82/// Helper to ensure declaring types ordered in such a way that no type appears before types it
83/// depends on for knowing its size.
84/// `remaining_types` are types that will later be checked.
85/// We may add types to there if we are not sure their dependencies are already declared.
86fn generate_type_declarations_helper(
87    db: &dyn Database,
88    ty: &ConcreteTypeId,
89    declarations: &mut Vec<program::TypeDeclaration>,
90    remaining_types: &mut OrderedHashSet<ConcreteTypeId>,
91    already_declared: &mut UnorderedHashSet<ConcreteTypeId>,
92) {
93    if already_declared.contains(ty) {
94        return;
95    }
96    let long_id = sierra_concrete_long_id(db, ty.clone()).unwrap();
97    already_declared.insert(ty.clone());
98    let inner_tys = long_id
99        .generic_args
100        .iter()
101        .filter_map(|arg| try_extract_matches!(arg, program::GenericArg::Type));
102    // Making sure we order the types such that types that require others to know their size are
103    // after the required types. Types that always have a known size would be first.
104    if type_has_const_size(&long_id.generic_id) {
105        remaining_types.extend(inner_tys.cloned());
106    } else {
107        for inner_ty in inner_tys {
108            generate_type_declarations_helper(
109                db,
110                inner_ty,
111                declarations,
112                remaining_types,
113                already_declared,
114            );
115        }
116    }
117
118    let type_info = db.get_type_info(ty.clone()).unwrap();
119    declarations.push(program::TypeDeclaration {
120        id: ty.clone(),
121        long_id: long_id.as_ref().clone(),
122        declared_type_info: Some(DeclaredTypeInfo {
123            storable: type_info.storable,
124            droppable: type_info.droppable,
125            duplicatable: type_info.duplicatable,
126            zero_sized: type_info.zero_sized,
127        }),
128    });
129}
130
131/// Collects the set of all [ConcreteTypeId] that are used in the given lists of
132/// [program::LibfuncDeclaration] and user functions.
133fn collect_used_types(
134    db: &dyn Database,
135    libfunc_declarations: &[program::LibfuncDeclaration],
136    functions: &[program::Function],
137) -> OrderedHashSet<ConcreteTypeId> {
138    let mut all_types = OrderedHashSet::default();
139    // Collect types that appear in libfuncs.
140    for libfunc in libfunc_declarations {
141        let types = db.priv_libfunc_dependencies(libfunc.id.clone());
142        all_types.extend(types.iter().cloned());
143    }
144
145    // Gather types used in user-defined functions.
146    // This is necessary for types that are used as entry point arguments but do not appear in any
147    // libfunc. For instance, if an entry point takes and returns an empty struct and no
148    // libfuncs are involved, we still need to declare that struct.
149    // Additionally, we include the return types of functions, since with unsafe panic enabled,
150    // a function that always panics might declare a return type that does not appear in anywhere
151    // else in the program.
152    all_types.extend(
153        functions.iter().flat_map(|func| {
154            chain!(&func.signature.param_types, &func.signature.ret_types).cloned()
155        }),
156    );
157    all_types
158}
159
160/// Query implementation of [SierraGenGroup::priv_libfunc_dependencies].
161#[salsa::tracked(returns(ref))]
162pub fn priv_libfunc_dependencies(
163    db: &dyn Database,
164    _tracked: Tracked,
165    libfunc_id: ConcreteLibfuncId,
166) -> Vec<ConcreteTypeId> {
167    let long_id = db.lookup_concrete_lib_func(&libfunc_id);
168    let signature = CoreLibfunc::specialize_signature_by_id(
169        &SierraSignatureSpecializationContext(db),
170        &long_id.generic_id,
171        &long_id.generic_args,
172    )
173    // If panic happens here, make sure the specified libfunc name is in one of the STR_IDs of
174    // the libfuncs in the [`CoreLibfunc`] structured enum.
175    .unwrap_or_else(|err| panic!("Failed to specialize: `{}`. Error: {err}",
176        DebugReplacer { db }.replace_libfunc_id(&libfunc_id)));
177    // Collecting types as a vector since the set should be very small.
178    let mut all_types = vec![];
179    let mut add_ty = |ty: ConcreteTypeId| {
180        if !all_types.contains(&ty) {
181            all_types.push(ty);
182        }
183    };
184    for param_signature in signature.param_signatures {
185        add_ty(param_signature.ty);
186    }
187    for info in signature.branch_signatures {
188        for var in info.vars {
189            add_ty(var.ty);
190        }
191    }
192    for arg in long_id.generic_args {
193        if let program::GenericArg::Type(ty) = arg {
194            add_ty(ty);
195        }
196    }
197    all_types
198}
199
200#[derive(Clone, Debug, Eq, PartialEq)]
201pub struct SierraProgramWithDebug<'db> {
202    pub program: cairo_lang_sierra::program::Program,
203    pub debug_info: SierraProgramDebugInfo<'db>,
204}
205
206unsafe impl<'db> salsa::Update for SierraProgramWithDebug<'db> {
207    unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
208        let old_value = unsafe { &mut *old_pointer };
209        if old_value == &new_value {
210            return false;
211        }
212        *old_value = new_value;
213        true
214    }
215}
216/// Implementation for a debug print of a Sierra program with all locations.
217/// The print is a valid textual Sierra program.
218impl<'db> DebugWithDb<'db> for SierraProgramWithDebug<'db> {
219    type Db = dyn Database;
220
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &dyn Database) -> std::fmt::Result {
222        let sierra_program = DebugReplacer { db }.apply(&self.program);
223        for declaration in &sierra_program.type_declarations {
224            writeln!(f, "{declaration};")?;
225        }
226        writeln!(f)?;
227        for declaration in &sierra_program.libfunc_declarations {
228            writeln!(f, "{declaration};")?;
229        }
230        writeln!(f)?;
231        let mut funcs = sierra_program.funcs.iter().peekable();
232        while let Some(func) = funcs.next() {
233            let start = func.entry_point.0;
234            let end = funcs
235                .peek()
236                .map(|f| f.entry_point.0)
237                .unwrap_or_else(|| sierra_program.statements.len());
238            writeln!(f, "// {}:", func.id)?;
239            for param in &func.params {
240                writeln!(f, "//   {param}")?;
241            }
242            for i in start..end {
243                writeln!(f, "{}; // {i}", sierra_program.statements[i])?;
244                if let Some(loc) =
245                    &self.debug_info.statements_locations.locations.get(&StatementIdx(i))
246                {
247                    let loc = get_location_marks(db, &loc.first().unwrap().span_in_file(db), true);
248                    println!("{}", loc.split('\n').map(|l| format!("// {l}")).join("\n"));
249                }
250            }
251        }
252        writeln!(f)?;
253        for func in &sierra_program.funcs {
254            writeln!(f, "{func};")?;
255        }
256        Ok(())
257    }
258}
259
260#[derive(Clone, Debug, Eq, PartialEq)]
261pub struct SierraProgramDebugInfo<'db> {
262    pub statements_locations: StatementsLocations<'db>,
263    pub variable_location: OrderedHashMap<FunctionId, OrderedHashMap<VarId, LocationId<'db>>>,
264}
265
266#[salsa::tracked(returns(ref))]
267pub fn get_sierra_program_for_functions<'db>(
268    db: &'db dyn Database,
269    _tracked: Tracked,
270    requested_function_ids: Vec<ConcreteFunctionWithBodyId<'db>>,
271) -> Maybe<SierraProgramWithDebug<'db>> {
272    let mut functions: Vec<&'db pre_sierra::Function<'_>> = vec![];
273    let mut statements: Vec<pre_sierra::StatementWithLocation<'_>> = vec![];
274    let mut processed_function_ids = UnorderedHashSet::<ConcreteFunctionWithBodyId<'_>>::default();
275    let mut function_id_queue: VecDeque<ConcreteFunctionWithBodyId<'_>> =
276        requested_function_ids.into_iter().collect();
277    while let Some(function_id) = function_id_queue.pop_front() {
278        if !processed_function_ids.insert(function_id) {
279            continue;
280        }
281        let function = db.function_with_body_sierra(function_id)?;
282        functions.push(function);
283        statements.extend_from_slice(&function.body);
284
285        for statement in &function.body {
286            if let Some(related_function_id) = try_get_function_with_body_id(db, statement) {
287                function_id_queue.push_back(related_function_id);
288            }
289        }
290    }
291
292    let AssembledProgram { program, statements_locations, variable_location } =
293        assemble_program(db, functions, statements);
294    Ok(SierraProgramWithDebug {
295        program,
296        debug_info: SierraProgramDebugInfo {
297            statements_locations: StatementsLocations::from_locations_vec(db, statements_locations),
298            variable_location,
299        },
300    })
301}
302
303/// Return value of `assemble_program`.
304struct AssembledProgram<'db> {
305    /// The actual program.
306    program: program::Program,
307    /// The locations per statement.
308    statements_locations: Vec<Option<LocationId<'db>>>,
309    /// The locations of variables per function.
310    variable_location: OrderedHashMap<FunctionId, OrderedHashMap<VarId, LocationId<'db>>>,
311}
312
313/// Given a list of functions and statements, generates a Sierra program.
314/// Returns the program and the locations of the statements in the program.
315fn assemble_program<'db>(
316    db: &dyn Database,
317    functions: Vec<&'db pre_sierra::Function<'db>>,
318    statements: Vec<pre_sierra::StatementWithLocation<'db>>,
319) -> AssembledProgram<'db> {
320    let label_replacer = LabelReplacer::from_statements(&statements);
321    let variable_location = functions
322        .iter()
323        .map(|f| (f.id.clone(), f.variable_locations.iter().cloned().collect()))
324        .collect();
325    let funcs = functions
326        .into_iter()
327        .map(|function| {
328            let sierra_signature = db.get_function_signature(function.id.clone()).unwrap();
329            program::Function::new(
330                function.id.clone(),
331                function.parameters.clone(),
332                sierra_signature.ret_types.clone(),
333                label_replacer.handle_label_id(function.entry_point),
334            )
335        })
336        .collect_vec();
337
338    let libfunc_declarations = collect_and_generate_libfunc_declarations(db, &statements);
339    let type_declarations = generate_type_declarations(db, &libfunc_declarations, &funcs);
340    // Resolve labels.
341    let (resolved_statements, statements_locations) =
342        resolve_labels_and_extract_locations(statements, &label_replacer);
343    let program = program::Program {
344        type_declarations,
345        libfunc_declarations,
346        statements: resolved_statements,
347        funcs,
348    };
349    AssembledProgram { program, statements_locations, variable_location }
350}
351
352/// Tries extracting a ConcreteFunctionWithBodyId from a pre-Sierra statement.
353pub fn try_get_function_with_body_id<'db>(
354    db: &'db dyn Database,
355    statement: &pre_sierra::StatementWithLocation<'db>,
356) -> Option<ConcreteFunctionWithBodyId<'db>> {
357    let invc = try_extract_matches!(
358        try_extract_matches!(&statement.statement, pre_sierra::Statement::Sierra)?,
359        program::GenStatement::Invocation
360    )?;
361    let libfunc = db.lookup_concrete_lib_func(&invc.libfunc_id);
362    let inner_function = if libfunc.generic_id == "function_call".into()
363        || libfunc.generic_id == "coupon_call".into()
364    {
365        libfunc.generic_args.first()?.clone()
366    } else if libfunc.generic_id == "coupon_buy".into()
367        || libfunc.generic_id == "coupon_refund".into()
368    {
369        // TODO(lior): Instead of this code, unused coupons should be replaced with the unit type
370        //   or with a zero-valued coupon. Currently, the code is not optimal (since the coupon
371        //   costs more than it should) and some programs may not compile (if the coupon is
372        //   mentioned as a type but not mentioned in any libfuncs).
373        let coupon_ty = try_extract_matches!(
374            libfunc.generic_args.first()?,
375            cairo_lang_sierra::program::GenericArg::Type
376        )?;
377        let coupon_long_id = sierra_concrete_long_id(db, coupon_ty.clone()).unwrap();
378        coupon_long_id.generic_args.first()?.clone()
379    } else {
380        return None;
381    };
382
383    db.lookup_sierra_function(&try_extract_matches!(
384        inner_function,
385        cairo_lang_sierra::program::GenericArg::UserFunc
386    )?)
387    .body(db)
388    .expect("No diagnostics at this stage.")
389}
390
391#[salsa::tracked(returns(ref))]
392pub fn get_sierra_program<'db>(
393    db: &'db dyn Database,
394    _tracked: Tracked,
395    requested_crate_ids: Vec<CrateId<'db>>,
396) -> Maybe<SierraProgramWithDebug<'db>> {
397    let requested_function_ids = find_all_free_function_ids(db, requested_crate_ids)?;
398    db.get_sierra_program_for_functions(requested_function_ids).cloned()
399}
400
401/// Return [`ConcreteFunctionWithBodyId`] for all free functions in the given list of crates.
402pub fn find_all_free_function_ids<'db>(
403    db: &'db dyn Database,
404    requested_crate_ids: Vec<CrateId<'db>>,
405) -> Maybe<Vec<ConcreteFunctionWithBodyId<'db>>> {
406    let mut requested_function_ids = vec![];
407    for crate_id in requested_crate_ids {
408        for module_id in db.crate_modules(crate_id).iter() {
409            for (free_func_id, _) in module_id.module_data(db)?.free_functions(db).iter() {
410                // TODO(spapini): Search Impl functions.
411                if let Some(function) =
412                    ConcreteFunctionWithBodyId::from_no_generics_free(db, *free_func_id)
413                {
414                    requested_function_ids.push(function)
415                }
416            }
417        }
418    }
419    Ok(requested_function_ids)
420}
421
422/// Given `function_id` generates a dummy program with the body of the relevant function
423/// and dummy helper functions that allows the program to be compiled to CASM.
424/// The generated program is not valid, but it can be used to estimate the size of the
425/// relevant function.
426pub fn get_dummy_program_for_size_estimation(
427    db: &dyn Database,
428    function_id: ConcreteFunctionWithBodyId<'_>,
429) -> Maybe<Program> {
430    let function = db.function_with_body_sierra(function_id)?;
431
432    let mut processed_function_ids =
433        UnorderedHashSet::<ConcreteFunctionWithBodyId<'_>>::from_iter([function_id]);
434
435    let mut functions = vec![function];
436
437    for statement in &function.body {
438        if let Some(function_id) = try_get_function_with_body_id(db, statement) {
439            if processed_function_ids.insert(function_id) {
440                functions.push(db.priv_get_dummy_function(function_id)?);
441            }
442        }
443    }
444    // Since we are not interested in the locations, we can remove them from the statements.
445    let statements = functions
446        .iter()
447        .flat_map(|f| f.body.iter())
448        .map(|s| s.statement.clone().into_statement_without_location())
449        .collect();
450
451    Ok(assemble_program(db, functions, statements).program)
452}