cairo_lang_sierra_generator/
utils.rs

1use cairo_lang_debug::DebugWithDb;
2use cairo_lang_defs::ids::NamedLanguageElementId;
3use cairo_lang_diagnostics::Maybe;
4use cairo_lang_filesystem::ids::Tracked;
5use cairo_lang_semantic::items::constant::ConstValueId;
6use cairo_lang_sierra::extensions::const_type::{
7    ConstAsBoxLibfunc, ConstAsImmediateLibfunc, ConstType,
8};
9use cairo_lang_sierra::extensions::core::CoreLibfunc;
10use cairo_lang_sierra::extensions::lib_func::LibfuncSignature;
11use cairo_lang_sierra::extensions::snapshot::SnapshotType;
12use cairo_lang_sierra::extensions::{
13    ExtensionError, GenericLibfuncEx, NamedLibfunc, NamedType, SpecializationError,
14};
15use cairo_lang_sierra::ids::{ConcreteLibfuncId, ConcreteTypeId, GenericLibfuncId};
16use cairo_lang_sierra::program::{self, FunctionSignature, GenericArg};
17use cairo_lang_utils::extract_matches;
18use salsa::Database;
19use semantic::items::constant::ConstValue;
20use semantic::items::functions::GenericFunctionId;
21use smol_str::SmolStr;
22use {cairo_lang_defs as defs, cairo_lang_lowering as lowering, cairo_lang_semantic as semantic};
23
24use crate::db::{SierraGenGroup, SierraGeneratorTypeLongId};
25use crate::pre_sierra;
26use crate::replace_ids::{DebugReplacer, SierraIdReplacer};
27use crate::specialization_context::SierraSignatureSpecializationContext;
28
29pub fn simple_basic_statement<'db>(
30    libfunc_id: ConcreteLibfuncId,
31    args: &[cairo_lang_sierra::ids::VarId],
32    results: &[cairo_lang_sierra::ids::VarId],
33) -> pre_sierra::Statement<'db> {
34    pre_sierra::Statement::Sierra(program::GenStatement::Invocation(program::GenInvocation {
35        libfunc_id,
36        args: args.into(),
37        branches: vec![program::GenBranchInfo {
38            target: program::GenBranchTarget::Fallthrough,
39            results: results.into(),
40        }],
41    }))
42}
43
44pub fn simple_statement<'db>(
45    libfunc_id: ConcreteLibfuncId,
46    args: &[cairo_lang_sierra::ids::VarId],
47    results: &[cairo_lang_sierra::ids::VarId],
48) -> pre_sierra::StatementWithLocation<'db> {
49    simple_basic_statement(libfunc_id, args, results).into_statement_without_location()
50}
51
52pub fn jump_statement(
53    jump: ConcreteLibfuncId,
54    label: pre_sierra::LabelId<'_>,
55) -> pre_sierra::Statement<'_> {
56    pre_sierra::Statement::Sierra(program::GenStatement::Invocation(program::GenInvocation {
57        libfunc_id: jump,
58        args: vec![],
59        branches: vec![program::GenBranchInfo {
60            target: program::GenBranchTarget::Statement(label),
61            results: vec![],
62        }],
63    }))
64}
65
66pub fn return_statement<'db>(
67    res: Vec<cairo_lang_sierra::ids::VarId>,
68) -> pre_sierra::Statement<'db> {
69    pre_sierra::Statement::Sierra(program::GenStatement::Return(res))
70}
71
72pub fn get_libfunc_id_with_generic_arg(
73    db: &dyn Database,
74    name: impl Into<SmolStr>,
75    ty: cairo_lang_sierra::ids::ConcreteTypeId,
76) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
77    db.intern_concrete_lib_func(cairo_lang_sierra::program::ConcreteLibfuncLongId {
78        generic_id: cairo_lang_sierra::ids::GenericLibfuncId::from_string(name),
79        generic_args: vec![GenericArg::Type(ty)],
80    })
81}
82
83/// Returns the [cairo_lang_sierra::program::ConcreteLibfuncLongId] associated with `store_temp`.
84pub fn store_temp_libfunc_id(
85    db: &dyn Database,
86    ty: cairo_lang_sierra::ids::ConcreteTypeId,
87) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
88    get_libfunc_id_with_generic_arg(db, "store_temp", ty)
89}
90
91/// Returns the [cairo_lang_sierra::program::ConcreteLibfuncLongId] associated with `store_local`.
92pub fn store_local_libfunc_id(
93    db: &dyn Database,
94    ty: cairo_lang_sierra::ids::ConcreteTypeId,
95) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
96    get_libfunc_id_with_generic_arg(db, "store_local", ty)
97}
98
99pub fn struct_construct_libfunc_id(
100    db: &dyn Database,
101    ty: cairo_lang_sierra::ids::ConcreteTypeId,
102) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
103    get_libfunc_id_with_generic_arg(db, "struct_construct", ty)
104}
105
106pub fn struct_deconstruct_libfunc_id(
107    db: &dyn Database,
108    ty: cairo_lang_sierra::ids::ConcreteTypeId,
109) -> Maybe<cairo_lang_sierra::ids::ConcreteLibfuncId> {
110    let long_id = &db.get_type_info(ty.clone())?.long_id;
111    let is_snapshot = long_id.generic_id == SnapshotType::id();
112    Ok(if is_snapshot {
113        let concrete_enum_type =
114            extract_matches!(&long_id.generic_args[0], GenericArg::Type).clone();
115        get_libfunc_id_with_generic_arg(db, "struct_snapshot_deconstruct", concrete_enum_type)
116    } else {
117        get_libfunc_id_with_generic_arg(db, "struct_deconstruct", ty)
118    })
119}
120
121pub fn enum_init_libfunc_id(
122    db: &dyn Database,
123    ty: cairo_lang_sierra::ids::ConcreteTypeId,
124    variant_idx: usize,
125) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
126    db.intern_concrete_lib_func(cairo_lang_sierra::program::ConcreteLibfuncLongId {
127        generic_id: cairo_lang_sierra::ids::GenericLibfuncId::from_string("enum_init"),
128        generic_args: vec![GenericArg::Type(ty), GenericArg::Value(variant_idx.into())],
129    })
130}
131
132/// Returns the [cairo_lang_sierra::program::ConcreteLibfuncLongId] associated with `snapshot_take`.
133pub fn snapshot_take_libfunc_id(
134    db: &dyn Database,
135    ty: cairo_lang_sierra::ids::ConcreteTypeId,
136) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
137    db.intern_concrete_lib_func(cairo_lang_sierra::program::ConcreteLibfuncLongId {
138        generic_id: cairo_lang_sierra::ids::GenericLibfuncId::from_string("snapshot_take"),
139        generic_args: vec![GenericArg::Type(ty)],
140    })
141}
142
143/// Returns the [cairo_lang_sierra::program::ConcreteLibfuncLongId] associated with `rename`.
144pub fn rename_libfunc_id(
145    db: &dyn Database,
146    ty: cairo_lang_sierra::ids::ConcreteTypeId,
147) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
148    db.intern_concrete_lib_func(cairo_lang_sierra::program::ConcreteLibfuncLongId {
149        generic_id: cairo_lang_sierra::ids::GenericLibfuncId::from_string("rename"),
150        generic_args: vec![GenericArg::Type(ty)],
151    })
152}
153
154fn get_libfunc_id_without_generics(
155    db: &dyn Database,
156    name: impl Into<SmolStr>,
157) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
158    db.intern_concrete_lib_func(cairo_lang_sierra::program::ConcreteLibfuncLongId {
159        generic_id: cairo_lang_sierra::ids::GenericLibfuncId::from_string(name),
160        generic_args: vec![],
161    })
162}
163
164pub fn const_libfunc_id_by_type(
165    db: &dyn Database,
166    value: ConstValueId<'_>,
167    boxed: bool,
168) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
169    let const_ty_arg = GenericArg::Type(const_type_id(db, value));
170    if boxed {
171        db.intern_concrete_lib_func(cairo_lang_sierra::program::ConcreteLibfuncLongId {
172            generic_id: cairo_lang_sierra::ids::GenericLibfuncId::from_string(
173                ConstAsBoxLibfunc::STR_ID,
174            ),
175            generic_args: vec![const_ty_arg, GenericArg::Value(0.into())],
176        })
177    } else {
178        db.intern_concrete_lib_func(cairo_lang_sierra::program::ConcreteLibfuncLongId {
179            generic_id: cairo_lang_sierra::ids::GenericLibfuncId::from_string(
180                ConstAsImmediateLibfunc::STR_ID,
181            ),
182            generic_args: vec![const_ty_arg],
183        })
184    }
185}
186
187/// Returns the [cairo_lang_sierra::ids::ConcreteTypeId] for the given `value`.
188fn const_type_id(
189    db: &dyn Database,
190    value: ConstValueId<'_>,
191) -> cairo_lang_sierra::ids::ConcreteTypeId {
192    let ty = value.ty(db).unwrap();
193    let first_arg = GenericArg::Type(db.get_concrete_type_id(ty).unwrap().clone());
194    db.intern_concrete_type(SierraGeneratorTypeLongId::Regular(
195        cairo_lang_sierra::program::ConcreteTypeLongId {
196            generic_id: ConstType::ID,
197            generic_args: match value.long(db) {
198                ConstValue::Int(v, _) => vec![first_arg, GenericArg::Value(v.clone())],
199                ConstValue::Struct(tys, _) => {
200                    let mut args = vec![first_arg];
201                    for value in tys {
202                        args.push(GenericArg::Type(const_type_id(db, *value)));
203                    }
204                    args
205                }
206                ConstValue::Enum(variant, inner) => {
207                    vec![
208                        first_arg,
209                        GenericArg::Value(variant.idx.into()),
210                        GenericArg::Type(const_type_id(db, *inner)),
211                    ]
212                }
213                ConstValue::NonZero(value) => {
214                    vec![first_arg, GenericArg::Type(const_type_id(db, *value))]
215                }
216                ConstValue::Generic(_)
217                | ConstValue::Var(_, _)
218                | ConstValue::Missing(_)
219                | ConstValue::ImplConstant(_) => {
220                    unreachable!("Should be caught by the lowering.")
221                }
222            },
223        }
224        .into(),
225    ))
226}
227
228pub fn match_enum_libfunc_id(
229    db: &dyn Database,
230    ty: cairo_lang_sierra::ids::ConcreteTypeId,
231) -> Maybe<cairo_lang_sierra::ids::ConcreteLibfuncId> {
232    let long_id = &db.get_type_info(ty.clone())?.long_id;
233    let is_snapshot = long_id.generic_id == SnapshotType::id();
234    Ok(if is_snapshot {
235        let concrete_enum_type =
236            extract_matches!(&long_id.generic_args[0], GenericArg::Type).clone();
237        get_libfunc_id_with_generic_arg(db, "enum_snapshot_match", concrete_enum_type)
238    } else {
239        get_libfunc_id_with_generic_arg(db, "enum_match", ty)
240    })
241}
242
243pub fn enum_from_bounded_int_libfunc_id(
244    db: &dyn Database,
245    ty: cairo_lang_sierra::ids::ConcreteTypeId,
246) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
247    get_libfunc_id_with_generic_arg(db, "enum_from_bounded_int", ty)
248}
249
250pub fn drop_libfunc_id(
251    db: &dyn Database,
252    ty: cairo_lang_sierra::ids::ConcreteTypeId,
253) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
254    get_libfunc_id_with_generic_arg(db, "drop", ty)
255}
256
257pub fn dup_libfunc_id(
258    db: &dyn Database,
259    ty: cairo_lang_sierra::ids::ConcreteTypeId,
260) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
261    get_libfunc_id_with_generic_arg(db, "dup", ty)
262}
263
264pub fn branch_align_libfunc_id(db: &dyn Database) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
265    get_libfunc_id_without_generics(db, "branch_align")
266}
267
268pub fn jump_libfunc_id(db: &dyn Database) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
269    get_libfunc_id_without_generics(db, "jump")
270}
271
272pub fn revoke_ap_tracking_libfunc_id(
273    db: &dyn Database,
274) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
275    get_libfunc_id_without_generics(db, "revoke_ap_tracking")
276}
277
278pub fn enable_ap_tracking_libfunc_id(
279    db: &dyn Database,
280) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
281    get_libfunc_id_without_generics(db, "enable_ap_tracking")
282}
283
284pub fn disable_ap_tracking_libfunc_id(
285    db: &dyn Database,
286) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
287    get_libfunc_id_without_generics(db, "disable_ap_tracking")
288}
289
290pub fn alloc_local_libfunc_id(
291    db: &dyn Database,
292    ty: cairo_lang_sierra::ids::ConcreteTypeId,
293) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
294    get_libfunc_id_with_generic_arg(db, "alloc_local", ty)
295}
296
297pub fn finalize_locals_libfunc_id(db: &dyn Database) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
298    get_libfunc_id_without_generics(db, "finalize_locals")
299}
300
301/// Returns the [LibfuncSignature] of the given function.
302pub fn get_libfunc_signature<'db>(
303    db: &'db dyn Database,
304    concrete_lib_func_id: &ConcreteLibfuncId,
305) -> &'db LibfuncSignature {
306    get_libfunc_signature_tracked(db, (), concrete_lib_func_id.id)
307}
308
309/// Implementation of [get_libfunc_signature] that is tracked.
310#[salsa::tracked(returns(ref))]
311fn get_libfunc_signature_tracked(
312    db: &dyn Database,
313    _tracked: Tracked,
314    concrete_lib_func_id: u64,
315) -> LibfuncSignature {
316    let libfunc_long_id = db.lookup_concrete_lib_func(&concrete_lib_func_id.into());
317    CoreLibfunc::specialize_signature_by_id(
318        &SierraSignatureSpecializationContext(db),
319        &libfunc_long_id.generic_id,
320        &libfunc_long_id.generic_args,
321    )
322    .unwrap_or_else(|err| {
323        if let ExtensionError::LibfuncSpecialization {
324            error: SpecializationError::MissingFunction(function),
325            ..
326        } = err
327        {
328            let function = db.lookup_sierra_function(&function);
329            panic!("Missing function {:?}", function.debug(db));
330        }
331        // If panic happens here, make sure the specified libfunc name is in one of the STR_IDs of
332        // the libfuncs in the [`CoreLibfunc`] structured enum.
333        panic!(
334            "Failed to specialize: `{}`. Error: {err}",
335            DebugReplacer { db }.replace_libfunc_id(&concrete_lib_func_id.into())
336        )
337    })
338}
339
340/// Returns the [ConcreteLibfuncId] for calling a user-defined function.
341pub fn function_call_libfunc_id(
342    db: &dyn Database,
343    func: lowering::ids::FunctionId<'_>,
344) -> ConcreteLibfuncId {
345    db.intern_concrete_lib_func(cairo_lang_sierra::program::ConcreteLibfuncLongId {
346        generic_id: GenericLibfuncId::from_string("function_call"),
347        generic_args: vec![GenericArg::UserFunc(db.intern_sierra_function(func))],
348    })
349}
350
351/// Returns the [ConcreteLibfuncId] for calling a user-defined function, given a coupon for that
352/// function.
353pub fn coupon_call_libfunc_id(
354    db: &dyn Database,
355    func: lowering::ids::FunctionId<'_>,
356) -> ConcreteLibfuncId {
357    db.intern_concrete_lib_func(cairo_lang_sierra::program::ConcreteLibfuncLongId {
358        generic_id: GenericLibfuncId::from_string("coupon_call"),
359        generic_args: vec![GenericArg::UserFunc(db.intern_sierra_function(func))],
360    })
361}
362
363/// Returns the [ConcreteLibfuncId] used for calling a libfunc.
364pub fn generic_libfunc_id(
365    db: &dyn Database,
366    extern_id: defs::ids::ExternFunctionId<'_>,
367    generic_args: Vec<GenericArg>,
368) -> ConcreteLibfuncId {
369    db.intern_concrete_lib_func(cairo_lang_sierra::program::ConcreteLibfuncLongId {
370        generic_id: GenericLibfuncId::from_string(extern_id.name(db).long(db).clone()),
371        generic_args,
372    })
373}
374
375/// Returns the [ConcreteLibfuncId] used for calling a function (either user-defined or libfunc).
376pub fn get_concrete_libfunc_id<'db>(
377    db: &'db dyn Database,
378    function: lowering::ids::FunctionId<'db>,
379    with_coupon: bool,
380) -> (Option<lowering::ids::ConcreteFunctionWithBodyId<'db>>, ConcreteLibfuncId) {
381    // Check if this is a user-defined function or a libfunc.
382    if let Some(body) = function.body(db).expect("No diagnostics at this stage.") {
383        if with_coupon {
384            return (Some(body), coupon_call_libfunc_id(db, function));
385        } else {
386            return (Some(body), function_call_libfunc_id(db, function));
387        }
388    }
389
390    assert!(!with_coupon, "Coupon cannot be used with extern functions.");
391
392    let semantic = extract_matches!(function.long(db), lowering::ids::FunctionLongId::Semantic);
393    let concrete_function = semantic.long(db).function.clone();
394    let GenericFunctionId::Extern(extern_id) = concrete_function.generic_function else {
395        panic!("Expected an extern function, found: {:?}", concrete_function.full_path(db));
396    };
397
398    let mut generic_args = vec![];
399    for generic_arg in &concrete_function.generic_args {
400        match generic_arg {
401            semantic::GenericArgumentId::Type(ty) => {
402                // TODO(lior): How should the following unwrap() be handled?
403                generic_args.push(GenericArg::Type(
404                    db.get_concrete_type_id(*ty)
405                        .unwrap_or_else(|_| {
406                            panic!(
407                                "Failed to obtain concrete type id for generic type argument: \
408                                 {ty:?}"
409                            )
410                        })
411                        .clone(),
412                ))
413            }
414            semantic::GenericArgumentId::Constant(value_id) => {
415                generic_args.push(GenericArg::Value(
416                    value_id.long(db).to_int().expect("Expected ConstValue::Int for size").clone(),
417                ));
418            }
419            semantic::GenericArgumentId::Impl(_) | semantic::GenericArgumentId::NegImpl(_) => {
420                // Everything after an impl generic is ignored as it does not exist in Sierra.
421                // This may still be used in high level code for getting type information that is
422                // otherwise concluded by the sierra-to-casm compiler, or addition of `where` clause
423                // style blocks.
424                break;
425            }
426        };
427    }
428
429    (None, generic_libfunc_id(db, extern_id, generic_args))
430}
431
432// Given a function id, generates a dummy function call libfunc id.
433pub fn dummy_call_libfunc_id(
434    db: &dyn Database,
435    function_id: cairo_lang_sierra::ids::FunctionId,
436    sierra_signature: &FunctionSignature,
437) -> ConcreteLibfuncId {
438    let ap_change = db
439        .get_ap_change(db.lookup_sierra_function(&function_id).body(db).unwrap().unwrap())
440        .unwrap();
441
442    let mut gargs = vec![];
443
444    gargs.push(GenericArg::UserFunc(function_id.clone()));
445
446    gargs.push(GenericArg::Value(
447        match ap_change {
448            cairo_lang_sierra::extensions::lib_func::SierraApChange::Unknown => 1,
449            cairo_lang_sierra::extensions::lib_func::SierraApChange::Known { new_vars_only: _ } => {
450                0
451            }
452            cairo_lang_sierra::extensions::lib_func::SierraApChange::BranchAlign
453            | cairo_lang_sierra::extensions::lib_func::SierraApChange::FunctionCall(_) => {
454                unreachable!("should never happen")
455            }
456        }
457        .into(),
458    ));
459
460    let as_generic_arg = |ty: &ConcreteTypeId| GenericArg::Type(ty.clone());
461
462    gargs.push(GenericArg::Value(sierra_signature.param_types.len().into()));
463    gargs.extend(sierra_signature.param_types.iter().map(as_generic_arg));
464
465    gargs.push(GenericArg::Value(sierra_signature.ret_types.len().into()));
466    gargs.extend(sierra_signature.ret_types.iter().map(as_generic_arg));
467
468    db.intern_concrete_lib_func(cairo_lang_sierra::program::ConcreteLibfuncLongId {
469        generic_id: cairo_lang_sierra::ids::GenericLibfuncId::from_string("dummy_function_call"),
470        generic_args: gargs,
471    })
472}