cairo_lang_sierra_generator/
program_generator.rs

1use std::collections::VecDeque;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_defs::db::DefsGroup;
5use cairo_lang_defs::diagnostic_utils::StableLocation;
6use cairo_lang_diagnostics::{Maybe, get_location_marks};
7use cairo_lang_filesystem::ids::{CrateId, Tracked};
8use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId;
9use cairo_lang_sierra::extensions::GenericLibfuncEx;
10use cairo_lang_sierra::extensions::core::CoreLibfunc;
11use cairo_lang_sierra::ids::{ConcreteLibfuncId, ConcreteTypeId};
12use cairo_lang_sierra::program::{self, DeclaredTypeInfo, Program, StatementIdx};
13use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
14use cairo_lang_utils::try_extract_matches;
15use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
16use itertools::{Itertools, chain};
17use salsa::Database;
18
19use crate::db::{SierraGenGroup, sierra_concrete_long_id};
20use crate::extra_sierra_info::type_has_const_size;
21use crate::pre_sierra;
22use crate::replace_ids::{DebugReplacer, SierraIdReplacer};
23use crate::resolve_labels::{LabelReplacer, resolve_labels_and_extract_locations};
24use crate::specialization_context::SierraSignatureSpecializationContext;
25use crate::statements_locations::StatementsLocations;
26
27#[cfg(test)]
28#[path = "program_generator_test.rs"]
29mod test;
30
31/// Generates the list of [cairo_lang_sierra::program::LibfuncDeclaration] for the given list of
32/// [pre_sierra::StatementWithLocation].
33fn collect_and_generate_libfunc_declarations<'db>(
34    db: &dyn Database,
35    statements: &[pre_sierra::StatementWithLocation<'db>],
36) -> Vec<program::LibfuncDeclaration> {
37    let mut declared_libfuncs = UnorderedHashSet::<ConcreteLibfuncId>::default();
38    statements
39        .iter()
40        .filter_map(|statement| match &statement.statement {
41            pre_sierra::Statement::Sierra(program::GenStatement::Invocation(invocation)) => {
42                declared_libfuncs.insert(invocation.libfunc_id.clone()).then(|| {
43                    program::LibfuncDeclaration {
44                        id: invocation.libfunc_id.clone(),
45                        long_id: db.lookup_concrete_lib_func(&invocation.libfunc_id),
46                    }
47                })
48            }
49            pre_sierra::Statement::Sierra(program::GenStatement::Return(_))
50            | pre_sierra::Statement::Label(_) => None,
51            pre_sierra::Statement::PushValues(_) => {
52                panic!("Unexpected pre_sierra::Statement::PushValues in collect_used_libfuncs().")
53            }
54        })
55        .collect()
56}
57
58/// Generates the list of [cairo_lang_sierra::program::TypeDeclaration] for the given list of
59/// [ConcreteTypeId].
60fn generate_type_declarations(
61    db: &dyn Database,
62    libfunc_declarations: &[program::LibfuncDeclaration],
63    functions: &[program::Function],
64) -> Vec<program::TypeDeclaration> {
65    let mut declarations = vec![];
66    let mut already_declared = UnorderedHashSet::default();
67    let mut remaining_types = collect_used_types(db, libfunc_declarations, functions);
68    while let Some(ty) = remaining_types.iter().next().cloned() {
69        remaining_types.swap_remove(&ty);
70        generate_type_declarations_helper(
71            db,
72            &ty,
73            &mut declarations,
74            &mut remaining_types,
75            &mut already_declared,
76        );
77    }
78    declarations
79}
80
81/// Helper to ensure declaring types ordered in such a way that no type appears before types it
82/// depends on for knowing its size.
83/// `remaining_types` are types that will later be checked.
84/// We may add types to there if we are not sure their dependencies are already declared.
85fn generate_type_declarations_helper(
86    db: &dyn Database,
87    ty: &ConcreteTypeId,
88    declarations: &mut Vec<program::TypeDeclaration>,
89    remaining_types: &mut OrderedHashSet<ConcreteTypeId>,
90    already_declared: &mut UnorderedHashSet<ConcreteTypeId>,
91) {
92    if already_declared.contains(ty) {
93        return;
94    }
95    let long_id = sierra_concrete_long_id(db, ty.clone()).unwrap();
96    already_declared.insert(ty.clone());
97    let inner_tys = long_id
98        .generic_args
99        .iter()
100        .filter_map(|arg| try_extract_matches!(arg, program::GenericArg::Type));
101    // Making sure we order the types such that types that require others to know their size are
102    // after the required types. Types that always have a known size would be first.
103    if type_has_const_size(&long_id.generic_id) {
104        remaining_types.extend(inner_tys.cloned());
105    } else {
106        for inner_ty in inner_tys {
107            generate_type_declarations_helper(
108                db,
109                inner_ty,
110                declarations,
111                remaining_types,
112                already_declared,
113            );
114        }
115    }
116
117    let type_info = db.get_type_info(ty.clone()).unwrap();
118    declarations.push(program::TypeDeclaration {
119        id: ty.clone(),
120        long_id: long_id.as_ref().clone(),
121        declared_type_info: Some(DeclaredTypeInfo {
122            storable: type_info.storable,
123            droppable: type_info.droppable,
124            duplicatable: type_info.duplicatable,
125            zero_sized: type_info.zero_sized,
126        }),
127    });
128}
129
130/// Collects the set of all [ConcreteTypeId] that are used in the given lists of
131/// [program::LibfuncDeclaration] and user functions.
132fn collect_used_types(
133    db: &dyn Database,
134    libfunc_declarations: &[program::LibfuncDeclaration],
135    functions: &[program::Function],
136) -> OrderedHashSet<ConcreteTypeId> {
137    let mut all_types = OrderedHashSet::default();
138    // Collect types that appear in libfuncs.
139    for libfunc in libfunc_declarations {
140        let types = db.priv_libfunc_dependencies(libfunc.id.clone());
141        all_types.extend(types.iter().cloned());
142    }
143
144    // Gather types used in user-defined functions.
145    // This is necessary for types that are used as entry point arguments but do not appear in any
146    // libfunc. For instance, if an entry point takes and returns an empty struct and no
147    // libfuncs are involved, we still need to declare that struct.
148    // Additionally, we include the return types of functions, since with unsafe panic enabled,
149    // a function that always panics might declare a return type that does not appear in anywhere
150    // else in the program.
151    all_types.extend(
152        functions.iter().flat_map(|func| {
153            chain!(&func.signature.param_types, &func.signature.ret_types).cloned()
154        }),
155    );
156    all_types
157}
158
159/// Query implementation of [SierraGenGroup::priv_libfunc_dependencies].
160#[salsa::tracked(returns(ref))]
161pub fn priv_libfunc_dependencies(
162    db: &dyn Database,
163    _tracked: Tracked,
164    libfunc_id: ConcreteLibfuncId,
165) -> Vec<ConcreteTypeId> {
166    let long_id = db.lookup_concrete_lib_func(&libfunc_id);
167    let signature = CoreLibfunc::specialize_signature_by_id(
168        &SierraSignatureSpecializationContext(db),
169        &long_id.generic_id,
170        &long_id.generic_args,
171    )
172    // If panic happens here, make sure the specified libfunc name is in one of the STR_IDs of
173    // the libfuncs in the [`CoreLibfunc`] structured enum.
174    .unwrap_or_else(|err| panic!("Failed to specialize: `{}`. Error: {err}",
175        DebugReplacer { db }.replace_libfunc_id(&libfunc_id)));
176    // Collecting types as a vector since the set should be very small.
177    let mut all_types = vec![];
178    let mut add_ty = |ty: ConcreteTypeId| {
179        if !all_types.contains(&ty) {
180            all_types.push(ty);
181        }
182    };
183    for param_signature in signature.param_signatures {
184        add_ty(param_signature.ty);
185    }
186    for info in signature.branch_signatures {
187        for var in info.vars {
188            add_ty(var.ty);
189        }
190    }
191    for arg in long_id.generic_args {
192        if let program::GenericArg::Type(ty) = arg {
193            add_ty(ty);
194        }
195    }
196    all_types
197}
198
199#[derive(Clone, Debug, Eq, PartialEq)]
200pub struct SierraProgramWithDebug<'db> {
201    pub program: cairo_lang_sierra::program::Program,
202    pub debug_info: SierraProgramDebugInfo<'db>,
203}
204
205unsafe impl<'db> salsa::Update for SierraProgramWithDebug<'db> {
206    unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
207        let old_value = unsafe { &mut *old_pointer };
208        if old_value == &new_value {
209            return false;
210        }
211        *old_value = new_value;
212        true
213    }
214}
215/// Implementation for a debug print of a Sierra program with all locations.
216/// The print is a valid textual Sierra program.
217impl<'db> DebugWithDb<'db> for SierraProgramWithDebug<'db> {
218    type Db = dyn Database;
219
220    fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &dyn Database) -> std::fmt::Result {
221        let sierra_program = DebugReplacer { db }.apply(&self.program);
222        for declaration in &sierra_program.type_declarations {
223            writeln!(f, "{declaration};")?;
224        }
225        writeln!(f)?;
226        for declaration in &sierra_program.libfunc_declarations {
227            writeln!(f, "{declaration};")?;
228        }
229        writeln!(f)?;
230        let mut funcs = sierra_program.funcs.iter().peekable();
231        while let Some(func) = funcs.next() {
232            let start = func.entry_point.0;
233            let end = funcs
234                .peek()
235                .map(|f| f.entry_point.0)
236                .unwrap_or_else(|| sierra_program.statements.len());
237            writeln!(f, "// {}:", func.id)?;
238            for param in &func.params {
239                writeln!(f, "//   {param}")?;
240            }
241            for i in start..end {
242                writeln!(f, "{}; // {i}", sierra_program.statements[i])?;
243                if let Some(loc) =
244                    &self.debug_info.statements_locations.locations.get(&StatementIdx(i))
245                {
246                    let loc =
247                        get_location_marks(db, &loc.first().unwrap().diagnostic_location(db), true);
248                    println!("{}", loc.split('\n').map(|l| format!("// {l}")).join("\n"));
249                }
250            }
251        }
252        writeln!(f)?;
253        for func in &sierra_program.funcs {
254            writeln!(f, "{func};")?;
255        }
256        Ok(())
257    }
258}
259
260#[derive(Clone, Debug, Eq, PartialEq)]
261pub struct SierraProgramDebugInfo<'db> {
262    pub statements_locations: StatementsLocations<'db>,
263}
264
265#[salsa::tracked(returns(ref))]
266pub fn get_sierra_program_for_functions<'db>(
267    db: &'db dyn Database,
268    _tracked: Tracked,
269    requested_function_ids: Vec<ConcreteFunctionWithBodyId<'db>>,
270) -> Maybe<SierraProgramWithDebug<'db>> {
271    let mut functions: Vec<&'db pre_sierra::Function<'_>> = vec![];
272    let mut statements: Vec<pre_sierra::StatementWithLocation<'_>> = vec![];
273    let mut processed_function_ids = UnorderedHashSet::<ConcreteFunctionWithBodyId<'_>>::default();
274    let mut function_id_queue: VecDeque<ConcreteFunctionWithBodyId<'_>> =
275        requested_function_ids.into_iter().collect();
276    while let Some(function_id) = function_id_queue.pop_front() {
277        if !processed_function_ids.insert(function_id) {
278            continue;
279        }
280        let function = db.function_with_body_sierra(function_id)?;
281        functions.push(function);
282        statements.extend_from_slice(&function.body);
283
284        for statement in &function.body {
285            if let Some(related_function_id) = try_get_function_with_body_id(db, statement) {
286                function_id_queue.push_back(related_function_id);
287            }
288        }
289    }
290
291    let (program, statements_locations) = assemble_program(db, functions, statements);
292    Ok(SierraProgramWithDebug {
293        program,
294        debug_info: SierraProgramDebugInfo {
295            statements_locations: StatementsLocations::from_locations_vec(&statements_locations),
296        },
297    })
298}
299
300/// Given a list of functions and statements, generates a Sierra program.
301/// Returns the program and the locations of the statements in the program.
302pub fn assemble_program<'db>(
303    db: &dyn Database,
304    functions: Vec<&'db pre_sierra::Function<'db>>,
305    statements: Vec<pre_sierra::StatementWithLocation<'db>>,
306) -> (program::Program, Vec<Vec<StableLocation<'db>>>) {
307    let label_replacer = LabelReplacer::from_statements(&statements);
308    let funcs = functions
309        .into_iter()
310        .map(|function| {
311            let sierra_signature = db.get_function_signature(function.id.clone()).unwrap();
312            program::Function::new(
313                function.id.clone(),
314                function.parameters.clone(),
315                sierra_signature.ret_types.clone(),
316                label_replacer.handle_label_id(function.entry_point),
317            )
318        })
319        .collect_vec();
320
321    let libfunc_declarations = collect_and_generate_libfunc_declarations(db, &statements);
322    let type_declarations = generate_type_declarations(db, &libfunc_declarations, &funcs);
323    // Resolve labels.
324    let (resolved_statements, statements_locations) =
325        resolve_labels_and_extract_locations(statements, &label_replacer);
326    let program = program::Program {
327        type_declarations,
328        libfunc_declarations,
329        statements: resolved_statements,
330        funcs,
331    };
332    (program, statements_locations)
333}
334
335/// Tries extracting a ConcreteFunctionWithBodyId from a pre-Sierra statement.
336pub fn try_get_function_with_body_id<'db>(
337    db: &'db dyn Database,
338    statement: &pre_sierra::StatementWithLocation<'db>,
339) -> Option<ConcreteFunctionWithBodyId<'db>> {
340    let invc = try_extract_matches!(
341        try_extract_matches!(&statement.statement, pre_sierra::Statement::Sierra)?,
342        program::GenStatement::Invocation
343    )?;
344    let libfunc = db.lookup_concrete_lib_func(&invc.libfunc_id);
345    let inner_function = if libfunc.generic_id == "function_call".into()
346        || libfunc.generic_id == "coupon_call".into()
347    {
348        libfunc.generic_args.first()?.clone()
349    } else if libfunc.generic_id == "coupon_buy".into()
350        || libfunc.generic_id == "coupon_refund".into()
351    {
352        // TODO(lior): Instead of this code, unused coupons should be replaced with the unit type
353        //   or with a zero-valued coupon. Currently, the code is not optimal (since the coupon
354        //   costs more than it should) and some programs may not compile (if the coupon is
355        //   mentioned as a type but not mentioned in any libfuncs).
356        let coupon_ty = try_extract_matches!(
357            libfunc.generic_args.first()?,
358            cairo_lang_sierra::program::GenericArg::Type
359        )?;
360        let coupon_long_id = sierra_concrete_long_id(db, coupon_ty.clone()).unwrap();
361        coupon_long_id.generic_args.first()?.clone()
362    } else {
363        return None;
364    };
365
366    db.lookup_sierra_function(&try_extract_matches!(
367        inner_function,
368        cairo_lang_sierra::program::GenericArg::UserFunc
369    )?)
370    .body(db)
371    .expect("No diagnostics at this stage.")
372}
373
374#[salsa::tracked(returns(ref))]
375pub fn get_sierra_program<'db>(
376    db: &'db dyn Database,
377    _tracked: Tracked,
378    requested_crate_ids: Vec<CrateId<'db>>,
379) -> Maybe<SierraProgramWithDebug<'db>> {
380    let requested_function_ids = find_all_free_function_ids(db, requested_crate_ids)?;
381    db.get_sierra_program_for_functions(requested_function_ids).cloned()
382}
383
384/// Return [`ConcreteFunctionWithBodyId`] for all free functions in the given list of crates.
385pub fn find_all_free_function_ids<'db>(
386    db: &'db dyn Database,
387    requested_crate_ids: Vec<CrateId<'db>>,
388) -> Maybe<Vec<ConcreteFunctionWithBodyId<'db>>> {
389    let mut requested_function_ids = vec![];
390    for crate_id in requested_crate_ids {
391        for module_id in db.crate_modules(crate_id).iter() {
392            for (free_func_id, _) in module_id.module_data(db)?.free_functions(db).iter() {
393                // TODO(spapini): Search Impl functions.
394                if let Some(function) =
395                    ConcreteFunctionWithBodyId::from_no_generics_free(db, *free_func_id)
396                {
397                    requested_function_ids.push(function)
398                }
399            }
400        }
401    }
402    Ok(requested_function_ids)
403}
404
405/// Given `function_id` generates a dummy program with the body of the relevant function
406/// and dummy helper functions that allows the program to be compiled to CASM.
407/// The generated program is not valid, but it can be used to estimate the size of the
408/// relevant function.
409pub fn get_dummy_program_for_size_estimation(
410    db: &dyn Database,
411    function_id: ConcreteFunctionWithBodyId<'_>,
412) -> Maybe<Program> {
413    let function = db.function_with_body_sierra(function_id)?;
414
415    let mut processed_function_ids =
416        UnorderedHashSet::<ConcreteFunctionWithBodyId<'_>>::from_iter([function_id]);
417
418    let mut functions = vec![function];
419
420    for statement in &function.body {
421        if let Some(function_id) = try_get_function_with_body_id(db, statement) {
422            if processed_function_ids.insert(function_id) {
423                functions.push(db.priv_get_dummy_function(function_id)?);
424            }
425        }
426    }
427    // Since we are not interested in the locations, we can remove them from the statements.
428    let statements = functions
429        .iter()
430        .flat_map(|f| f.body.iter())
431        .map(|s| s.statement.clone().into_statement_without_location())
432        .collect();
433
434    let (program, _statements_locations) = assemble_program(db, functions, statements);
435
436    Ok(program)
437}