cairo_lang_sierra_generator/
program_generator.rs

1use std::collections::VecDeque;
2use std::sync::Arc;
3
4use cairo_lang_debug::DebugWithDb;
5use cairo_lang_diagnostics::{Maybe, get_location_marks};
6use cairo_lang_filesystem::ids::CrateId;
7use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId;
8use cairo_lang_sierra::extensions::GenericLibfuncEx;
9use cairo_lang_sierra::extensions::core::CoreLibfunc;
10use cairo_lang_sierra::ids::{ConcreteLibfuncId, ConcreteTypeId};
11use cairo_lang_sierra::program::{self, DeclaredTypeInfo, StatementIdx};
12use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
13use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
14use cairo_lang_utils::{LookupIntern, try_extract_matches};
15use itertools::{Itertools, chain};
16
17use crate::db::{SierraGenGroup, sierra_concrete_long_id};
18use crate::extra_sierra_info::type_has_const_size;
19use crate::pre_sierra;
20use crate::replace_ids::{DebugReplacer, SierraIdReplacer};
21use crate::resolve_labels::{LabelReplacer, resolve_labels_and_extract_locations};
22use crate::specialization_context::SierraSignatureSpecializationContext;
23use crate::statements_locations::StatementsLocations;
24
25#[cfg(test)]
26#[path = "program_generator_test.rs"]
27mod test;
28
29/// Generates the list of [cairo_lang_sierra::program::LibfuncDeclaration] for the given list of
30/// [ConcreteLibfuncId].
31fn generate_libfunc_declarations<'a>(
32    db: &dyn SierraGenGroup,
33    libfuncs: impl Iterator<Item = &'a ConcreteLibfuncId>,
34) -> Vec<program::LibfuncDeclaration> {
35    libfuncs
36        .into_iter()
37        .map(|libfunc_id| program::LibfuncDeclaration {
38            id: libfunc_id.clone(),
39            long_id: libfunc_id.lookup_intern(db),
40        })
41        .collect()
42}
43
44/// Collects the set of all [ConcreteLibfuncId] used in the given list of [pre_sierra::Statement].
45fn collect_used_libfuncs(
46    statements: &[pre_sierra::StatementWithLocation],
47) -> OrderedHashSet<ConcreteLibfuncId> {
48    statements
49        .iter()
50        .filter_map(|statement| match &statement.statement {
51            pre_sierra::Statement::Sierra(program::GenStatement::Invocation(invocation)) => {
52                Some(invocation.libfunc_id.clone())
53            }
54            pre_sierra::Statement::Sierra(program::GenStatement::Return(_))
55            | pre_sierra::Statement::Label(_) => None,
56            pre_sierra::Statement::PushValues(_) => {
57                panic!("Unexpected pre_sierra::Statement::PushValues in collect_used_libfuncs().")
58            }
59        })
60        .collect()
61}
62
63/// Generates the list of [cairo_lang_sierra::program::TypeDeclaration] for the given list of
64/// [ConcreteTypeId].
65fn generate_type_declarations(
66    db: &dyn SierraGenGroup,
67    mut remaining_types: OrderedHashSet<ConcreteTypeId>,
68) -> Vec<program::TypeDeclaration> {
69    let mut declarations = vec![];
70    let mut already_declared = UnorderedHashSet::default();
71    while let Some(ty) = remaining_types.iter().next().cloned() {
72        remaining_types.swap_remove(&ty);
73        generate_type_declarations_helper(
74            db,
75            &ty,
76            &mut declarations,
77            &mut remaining_types,
78            &mut already_declared,
79        );
80    }
81    declarations
82}
83
84/// Helper to ensure declaring types ordered in such a way that no type appears before types it
85/// depends on for knowing its size.
86/// `remaining_types` are types that will later be checked.
87/// We may add types to there if we are not sure their dependencies are already declared.
88fn generate_type_declarations_helper(
89    db: &dyn SierraGenGroup,
90    ty: &ConcreteTypeId,
91    declarations: &mut Vec<program::TypeDeclaration>,
92    remaining_types: &mut OrderedHashSet<ConcreteTypeId>,
93    already_declared: &mut UnorderedHashSet<ConcreteTypeId>,
94) {
95    if already_declared.contains(ty) {
96        return;
97    }
98    let long_id = sierra_concrete_long_id(db, ty.clone()).unwrap();
99    already_declared.insert(ty.clone());
100    let inner_tys = long_id
101        .generic_args
102        .iter()
103        .filter_map(|arg| try_extract_matches!(arg, program::GenericArg::Type));
104    // Making sure we order the types such that types that require others to know their size are
105    // after the required types. Types that always have a known size would be first.
106    if type_has_const_size(&long_id.generic_id) {
107        remaining_types.extend(inner_tys.cloned());
108    } else {
109        for inner_ty in inner_tys {
110            generate_type_declarations_helper(
111                db,
112                inner_ty,
113                declarations,
114                remaining_types,
115                already_declared,
116            );
117        }
118    }
119
120    let type_info = db.get_type_info(ty.clone()).unwrap();
121    declarations.push(program::TypeDeclaration {
122        id: ty.clone(),
123        long_id: long_id.as_ref().clone(),
124        declared_type_info: Some(DeclaredTypeInfo {
125            storable: type_info.storable,
126            droppable: type_info.droppable,
127            duplicatable: type_info.duplicatable,
128            zero_sized: type_info.zero_sized,
129        }),
130    });
131}
132
133/// Collects the set of all [ConcreteTypeId] that are used in the given lists of
134/// [program::LibfuncDeclaration] and user functions.
135fn collect_used_types(
136    db: &dyn SierraGenGroup,
137    libfunc_declarations: &[program::LibfuncDeclaration],
138    functions: &[Arc<pre_sierra::Function>],
139) -> OrderedHashSet<ConcreteTypeId> {
140    // Collect types that appear in libfuncs.
141    let types_in_libfuncs = libfunc_declarations.iter().flat_map(|libfunc| {
142        // TODO(orizi): replace expect() with a diagnostic (unless this can never happen).
143        let signature = CoreLibfunc::specialize_signature_by_id(
144                &SierraSignatureSpecializationContext(db),
145                &libfunc.long_id.generic_id,
146                &libfunc.long_id.generic_args,
147            )
148            // If panic happens here, make sure the specified libfunc name is in one of the STR_IDs of
149            // the libfuncs in the [`CoreLibfunc`] structured enum.
150            .unwrap_or_else(|err| panic!("Failed to specialize: `{}`. Error: {err}",
151                DebugReplacer { db }.replace_libfunc_id(&libfunc.id)));
152        chain!(
153            signature.param_signatures.into_iter().map(|param_signature| param_signature.ty),
154            signature.branch_signatures.into_iter().flat_map(|info| info.vars).map(|var| var.ty),
155            libfunc.long_id.generic_args.iter().filter_map(|arg| match arg {
156                program::GenericArg::Type(ty) => Some(ty.clone()),
157                _ => None,
158            })
159        )
160    });
161
162    // Collect types that appear in user functions.
163    // This is only relevant for types that are arguments to entry points and are not used in
164    // any libfunc. For example, an empty entry point that gets and returns an empty struct, will
165    // have no libfuncs, but we still need to declare the struct.
166    let types_in_user_functions = functions.iter().flat_map(|func| {
167        chain!(func.parameters.iter().map(|param| param.ty.clone()), func.ret_types.iter().cloned())
168    });
169
170    chain!(types_in_libfuncs, types_in_user_functions).collect()
171}
172
173#[derive(Clone, Debug, Eq, PartialEq)]
174pub struct SierraProgramWithDebug {
175    pub program: cairo_lang_sierra::program::Program,
176    pub debug_info: SierraProgramDebugInfo,
177}
178/// Implementation for a debug print of a Sierra program with all locations.
179/// The print is a valid textual Sierra program.
180impl DebugWithDb<dyn SierraGenGroup> for SierraProgramWithDebug {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &dyn SierraGenGroup) -> std::fmt::Result {
182        let sierra_program = DebugReplacer { db }.apply(&self.program);
183        for declaration in &sierra_program.type_declarations {
184            writeln!(f, "{declaration};")?;
185        }
186        writeln!(f)?;
187        for declaration in &sierra_program.libfunc_declarations {
188            writeln!(f, "{declaration};")?;
189        }
190        writeln!(f)?;
191        let sierra_program = DebugReplacer { db }.apply(&self.program);
192        let mut funcs = sierra_program.funcs.iter().peekable();
193        while let Some(func) = funcs.next() {
194            let start = func.entry_point.0;
195            let end = funcs
196                .peek()
197                .map(|f| f.entry_point.0)
198                .unwrap_or_else(|| sierra_program.statements.len());
199            writeln!(f, "// {}:", func.id)?;
200            for param in &func.params {
201                writeln!(f, "//   {}", param)?;
202            }
203            for i in start..end {
204                writeln!(f, "{}; // {i}", sierra_program.statements[i])?;
205                if let Some(loc) =
206                    &self.debug_info.statements_locations.locations.get(&StatementIdx(i))
207                {
208                    let loc = get_location_marks(
209                        db.upcast(),
210                        &loc.first().unwrap().diagnostic_location(db.upcast()),
211                        true,
212                    );
213                    println!("{}", loc.split('\n').map(|l| format!("// {l}")).join("\n"));
214                }
215            }
216        }
217        writeln!(f)?;
218        for func in &sierra_program.funcs {
219            writeln!(f, "{func};")?;
220        }
221        Ok(())
222    }
223}
224
225#[derive(Clone, Debug, Eq, PartialEq)]
226pub struct SierraProgramDebugInfo {
227    pub statements_locations: StatementsLocations,
228}
229
230pub fn get_sierra_program_for_functions(
231    db: &dyn SierraGenGroup,
232    requested_function_ids: Vec<ConcreteFunctionWithBodyId>,
233) -> Maybe<Arc<SierraProgramWithDebug>> {
234    let mut functions: Vec<Arc<pre_sierra::Function>> = vec![];
235    let mut statements: Vec<pre_sierra::StatementWithLocation> = vec![];
236    let mut processed_function_ids = UnorderedHashSet::<ConcreteFunctionWithBodyId>::default();
237    let mut function_id_queue: VecDeque<ConcreteFunctionWithBodyId> =
238        requested_function_ids.into_iter().collect();
239    while let Some(function_id) = function_id_queue.pop_front() {
240        if !processed_function_ids.insert(function_id) {
241            continue;
242        }
243        let function: Arc<pre_sierra::Function> = db.function_with_body_sierra(function_id)?;
244        functions.push(function.clone());
245        statements.extend_from_slice(&function.body);
246
247        for statement in &function.body {
248            if let Some(related_function_id) = try_get_function_with_body_id(db, statement) {
249                function_id_queue.push_back(related_function_id);
250            }
251        }
252    }
253
254    let libfunc_declarations =
255        generate_libfunc_declarations(db, collect_used_libfuncs(&statements).iter());
256    let type_declarations =
257        generate_type_declarations(db, collect_used_types(db, &libfunc_declarations, &functions));
258    // Resolve labels.
259    let label_replacer = LabelReplacer::from_statements(&statements);
260    let (resolved_statements, statements_locations) =
261        resolve_labels_and_extract_locations(statements, &label_replacer);
262
263    let program = program::Program {
264        type_declarations,
265        libfunc_declarations,
266        statements: resolved_statements,
267        funcs: functions
268            .into_iter()
269            .map(|function| {
270                let sierra_signature = db.get_function_signature(function.id.clone()).unwrap();
271                program::Function::new(
272                    function.id.clone(),
273                    function.parameters.clone(),
274                    sierra_signature.ret_types.clone(),
275                    label_replacer.handle_label_id(function.entry_point),
276                )
277            })
278            .collect(),
279    };
280    Ok(Arc::new(SierraProgramWithDebug {
281        program,
282        debug_info: SierraProgramDebugInfo {
283            statements_locations: StatementsLocations::from_locations_vec(&statements_locations),
284        },
285    }))
286}
287
288/// Tries extracting a ConcreteFunctionWithBodyId from a pre-Sierra statement.
289pub fn try_get_function_with_body_id(
290    db: &dyn SierraGenGroup,
291    statement: &pre_sierra::StatementWithLocation,
292) -> Option<ConcreteFunctionWithBodyId> {
293    let invc = try_extract_matches!(
294        try_extract_matches!(&statement.statement, pre_sierra::Statement::Sierra)?,
295        program::GenStatement::Invocation
296    )?;
297    let libfunc = invc.libfunc_id.lookup_intern(db);
298    let inner_function = if libfunc.generic_id == "function_call".into()
299        || libfunc.generic_id == "coupon_call".into()
300    {
301        libfunc.generic_args.first()?.clone()
302    } else if libfunc.generic_id == "coupon_buy".into()
303        || libfunc.generic_id == "coupon_refund".into()
304    {
305        // TODO(lior): Instead of this code, unused coupons should be replaced with the unit type
306        //   or with a zero-valued coupon. Currently, the code is not optimal (since the coupon
307        //   costs more than it should) and some programs may not compile (if the coupon is
308        //   mentioned as a type but not mentioned in any libfuncs).
309        let coupon_ty = try_extract_matches!(
310            libfunc.generic_args.first()?,
311            cairo_lang_sierra::program::GenericArg::Type
312        )?;
313        let coupon_long_id = sierra_concrete_long_id(db, coupon_ty.clone()).unwrap();
314        coupon_long_id.generic_args.first()?.clone()
315    } else {
316        return None;
317    };
318
319    try_extract_matches!(inner_function, cairo_lang_sierra::program::GenericArg::UserFunc)?
320        .lookup_intern(db)
321        .body(db.upcast())
322        .expect("No diagnostics at this stage.")
323}
324
325pub fn get_sierra_program(
326    db: &dyn SierraGenGroup,
327    requested_crate_ids: Vec<CrateId>,
328) -> Maybe<Arc<SierraProgramWithDebug>> {
329    let mut requested_function_ids = vec![];
330    for crate_id in requested_crate_ids {
331        for module_id in db.crate_modules(crate_id).iter() {
332            for (free_func_id, _) in db.module_free_functions(*module_id)?.iter() {
333                // TODO(spapini): Search Impl functions.
334                if let Some(function) =
335                    ConcreteFunctionWithBodyId::from_no_generics_free(db.upcast(), *free_func_id)
336                {
337                    requested_function_ids.push(function)
338                }
339            }
340        }
341    }
342    db.get_sierra_program_for_functions(requested_function_ids)
343}