cairo_lang_sierra/
program_registry.rs

1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use itertools::{chain, izip};
5use thiserror::Error;
6
7use crate::extensions::lib_func::{
8    SierraApChange, SignatureSpecializationContext, SpecializationContext,
9};
10use crate::extensions::type_specialization_context::TypeSpecializationContext;
11use crate::extensions::types::TypeInfo;
12use crate::extensions::{
13    ConcreteLibfunc, ConcreteType, ExtensionError, GenericLibfunc, GenericLibfuncEx, GenericType,
14    GenericTypeEx,
15};
16use crate::ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId, GenericTypeId};
17use crate::program::{
18    BranchTarget, DeclaredTypeInfo, Function, FunctionSignature, GenericArg, Program, Statement,
19    StatementIdx, TypeDeclaration,
20};
21
22#[cfg(test)]
23#[path = "program_registry_test.rs"]
24mod test;
25
26/// Errors encountered in the program registry.
27#[derive(Error, Debug, Eq, PartialEq)]
28pub enum ProgramRegistryError {
29    #[error("used the same function id twice `{0}`.")]
30    FunctionIdAlreadyExists(FunctionId),
31    #[error("Could not find the requested function `{0}`.")]
32    MissingFunction(FunctionId),
33    #[error("Error during type specialization of `{concrete_id}`: {error}")]
34    TypeSpecialization { concrete_id: ConcreteTypeId, error: ExtensionError },
35    #[error("Used concrete type id `{0}` twice")]
36    TypeConcreteIdAlreadyExists(ConcreteTypeId),
37    #[error("Declared concrete type `{0}` twice")]
38    TypeAlreadyDeclared(Box<TypeDeclaration>),
39    #[error("Could not find requested type `{0}`.")]
40    MissingType(ConcreteTypeId),
41    #[error("Error during libfunc specialization of {concrete_id}: {error}")]
42    LibfuncSpecialization { concrete_id: ConcreteLibfuncId, error: ExtensionError },
43    #[error("Used concrete libfunc id `{0}` twice.")]
44    LibfuncConcreteIdAlreadyExists(ConcreteLibfuncId),
45    #[error("Could not find requested libfunc `{0}`.")]
46    MissingLibfunc(ConcreteLibfuncId),
47    #[error("Type info declaration mismatch for `{0}`.")]
48    TypeInfoDeclarationMismatch(ConcreteTypeId),
49    #[error("Function `{func_id}`'s parameter type `{ty}` is not storable.")]
50    FunctionWithUnstorableType { func_id: FunctionId, ty: ConcreteTypeId },
51    #[error("Function `{0}` points to non existing entry point statement.")]
52    FunctionNonExistingEntryPoint(FunctionId),
53    #[error("#{0}: Libfunc invocation input count mismatch")]
54    LibfuncInvocationInputCountMismatch(StatementIdx),
55    #[error("#{0}: Libfunc invocation branch count mismatch")]
56    LibfuncInvocationBranchCountMismatch(StatementIdx),
57    #[error("#{0}: Libfunc invocation branch #{1} result count mismatch")]
58    LibfuncInvocationBranchResultCountMismatch(StatementIdx, usize),
59    #[error("#{0}: Libfunc invocation branch #{1} target mismatch")]
60    LibfuncInvocationBranchTargetMismatch(StatementIdx, usize),
61    #[error("#{src}: Branch jump backwards to {dst}")]
62    BranchBackwards { src: StatementIdx, dst: StatementIdx },
63    #[error("#{src}: Branch jump to a non-branch align statement #{dst}")]
64    BranchNotToBranchAlign { src: StatementIdx, dst: StatementIdx },
65    #[error("#{src1}, #{src2}: Jump to the same statement #{dst}")]
66    MultipleJumpsToSameStatement { src1: StatementIdx, src2: StatementIdx, dst: StatementIdx },
67    #[error("#{0}: Jump out of range")]
68    JumpOutOfRange(StatementIdx),
69    #[error("Type size computation failed for `{ty}`: missing size information for `{dep}`")]
70    TypeSizeDependencyMissing { ty: ConcreteTypeId, dep: ConcreteTypeId },
71    #[error("Type size computation failed for `{0}`: size overflow.")]
72    TypeSizeOverflow(ConcreteTypeId),
73}
74
75type TypeMap<TType> = HashMap<ConcreteTypeId, TType>;
76type LibfuncMap<TLibfunc> = HashMap<ConcreteLibfuncId, TLibfunc>;
77type FunctionMap = HashMap<FunctionId, Function>;
78/// Mapping from the arguments for generating a concrete type (the generic-id and the arguments) to
79/// the concrete-id that points to it.
80type ConcreteTypeIdMap<'a> = HashMap<(GenericTypeId, &'a [GenericArg]), ConcreteTypeId>;
81
82/// Registry for the data of the compiler, for all program specific data.
83pub struct ProgramRegistry<TType: GenericType, TLibfunc: GenericLibfunc> {
84    /// Mapping ids to the corresponding user function declaration from the program.
85    functions: FunctionMap,
86    /// Mapping ids to the concrete types represented by them.
87    concrete_types: TypeMap<TType::Concrete>,
88    /// Mapping ids to the concrete libfuncs represented by them.
89    concrete_libfuncs: LibfuncMap<TLibfunc::Concrete>,
90}
91impl<TType: GenericType, TLibfunc: GenericLibfunc> ProgramRegistry<TType, TLibfunc> {
92    /// Create a registry for the program.
93    pub fn new(
94        program: &Program,
95    ) -> Result<ProgramRegistry<TType, TLibfunc>, Box<ProgramRegistryError>> {
96        let functions = get_functions(program)?;
97        let (concrete_types, concrete_type_ids) = get_concrete_types_maps::<TType>(program)?;
98        let concrete_libfuncs = get_concrete_libfuncs::<TType, TLibfunc>(
99            program,
100            &SpecializationContextForRegistry {
101                functions: &functions,
102                concrete_type_ids: &concrete_type_ids,
103                concrete_types: &concrete_types,
104            },
105        )?;
106        let registry = ProgramRegistry { functions, concrete_types, concrete_libfuncs };
107        registry.validate(program)?;
108        Ok(registry)
109    }
110
111    /// Gets a function from the input program.
112    pub fn get_function<'a>(
113        &'a self,
114        id: &FunctionId,
115    ) -> Result<&'a Function, Box<ProgramRegistryError>> {
116        self.functions
117            .get(id)
118            .ok_or_else(|| Box::new(ProgramRegistryError::MissingFunction(id.clone())))
119    }
120    /// Gets a type from the input program.
121    pub fn get_type<'a>(
122        &'a self,
123        id: &ConcreteTypeId,
124    ) -> Result<&'a TType::Concrete, Box<ProgramRegistryError>> {
125        self.concrete_types
126            .get(id)
127            .ok_or_else(|| Box::new(ProgramRegistryError::MissingType(id.clone())))
128    }
129    /// Gets a libfunc from the input program.
130    pub fn get_libfunc<'a>(
131        &'a self,
132        id: &ConcreteLibfuncId,
133    ) -> Result<&'a TLibfunc::Concrete, Box<ProgramRegistryError>> {
134        self.concrete_libfuncs
135            .get(id)
136            .ok_or_else(|| Box::new(ProgramRegistryError::MissingLibfunc(id.clone())))
137    }
138
139    /// Checks the validity of the [ProgramRegistry] and runs validations on the program.
140    ///
141    /// Later compilation stages may perform more validations as well as repeat these validations.
142    fn validate(&self, program: &Program) -> Result<(), Box<ProgramRegistryError>> {
143        // Check that all the parameter and return types are storable.
144        for func in self.functions.values() {
145            for ty in chain!(func.signature.param_types.iter(), func.signature.ret_types.iter()) {
146                if !self.get_type(ty)?.info().storable {
147                    return Err(Box::new(ProgramRegistryError::FunctionWithUnstorableType {
148                        func_id: func.id.clone(),
149                        ty: ty.clone(),
150                    }));
151                }
152            }
153            if func.entry_point.0 >= program.statements.len() {
154                return Err(Box::new(ProgramRegistryError::FunctionNonExistingEntryPoint(
155                    func.id.clone(),
156                )));
157            }
158        }
159        // A branch map, mapping from a destination statement to the statement that jumps to it.
160        // A branch is considered a branch only if it has more than one target.
161        // Assuming branches into branch alignments only, this should be a bijection.
162        let mut branches: HashMap<StatementIdx, StatementIdx> =
163            HashMap::<StatementIdx, StatementIdx>::default();
164        for (i, statement) in program.statements.iter().enumerate() {
165            self.validate_statement(program, StatementIdx(i), statement, &mut branches)?;
166        }
167        Ok(())
168    }
169
170    /// Checks the validity of a statement.
171    fn validate_statement(
172        &self,
173        program: &Program,
174        index: StatementIdx,
175        statement: &Statement,
176        branches: &mut HashMap<StatementIdx, StatementIdx>,
177    ) -> Result<(), Box<ProgramRegistryError>> {
178        let Statement::Invocation(invocation) = statement else {
179            return Ok(());
180        };
181        let libfunc = self.get_libfunc(&invocation.libfunc_id)?;
182        if invocation.args.len() != libfunc.param_signatures().len() {
183            return Err(Box::new(ProgramRegistryError::LibfuncInvocationInputCountMismatch(index)));
184        }
185        let libfunc_branches = libfunc.branch_signatures();
186        if invocation.branches.len() != libfunc_branches.len() {
187            return Err(Box::new(ProgramRegistryError::LibfuncInvocationBranchCountMismatch(
188                index,
189            )));
190        }
191        let libfunc_fallthrough = libfunc.fallthrough();
192        for (branch_index, (invocation_branch, libfunc_branch)) in
193            izip!(&invocation.branches, libfunc_branches).enumerate()
194        {
195            if invocation_branch.results.len() != libfunc_branch.vars.len() {
196                return Err(Box::new(
197                    ProgramRegistryError::LibfuncInvocationBranchResultCountMismatch(
198                        index,
199                        branch_index,
200                    ),
201                ));
202            }
203            if matches!(libfunc_fallthrough, Some(target) if target == branch_index)
204                != (invocation_branch.target == BranchTarget::Fallthrough)
205            {
206                return Err(Box::new(ProgramRegistryError::LibfuncInvocationBranchTargetMismatch(
207                    index,
208                    branch_index,
209                )));
210            }
211            if !matches!(libfunc_branch.ap_change, SierraApChange::BranchAlign)
212                && let Some(prev) = branches.get(&index)
213            {
214                return Err(Box::new(ProgramRegistryError::BranchNotToBranchAlign {
215                    src: *prev,
216                    dst: index,
217                }));
218            }
219            let next = index.next(&invocation_branch.target);
220            if next.0 >= program.statements.len() {
221                return Err(Box::new(ProgramRegistryError::JumpOutOfRange(index)));
222            }
223            if libfunc_branches.len() > 1 {
224                if next.0 < index.0 {
225                    return Err(Box::new(ProgramRegistryError::BranchBackwards {
226                        src: index,
227                        dst: next,
228                    }));
229                }
230                match branches.entry(next) {
231                    Entry::Occupied(e) => {
232                        return Err(Box::new(ProgramRegistryError::MultipleJumpsToSameStatement {
233                            src1: *e.get(),
234                            src2: index,
235                            dst: next,
236                        }));
237                    }
238                    Entry::Vacant(e) => {
239                        e.insert(index);
240                    }
241                }
242            }
243        }
244        Ok(())
245    }
246}
247
248/// Creates the functions map.
249fn get_functions(program: &Program) -> Result<FunctionMap, Box<ProgramRegistryError>> {
250    let mut functions = FunctionMap::new();
251    for func in &program.funcs {
252        match functions.entry(func.id.clone()) {
253            Entry::Occupied(_) => {
254                Err(ProgramRegistryError::FunctionIdAlreadyExists(func.id.clone()))
255            }
256            Entry::Vacant(entry) => Ok(entry.insert(func.clone())),
257        }?;
258    }
259    Ok(functions)
260}
261
262struct TypeSpecializationContextForRegistry<'a, TType: GenericType> {
263    pub concrete_types: &'a TypeMap<TType::Concrete>,
264    pub declared_type_info: &'a TypeMap<TypeInfo>,
265}
266impl<TType: GenericType> TypeSpecializationContext
267    for TypeSpecializationContextForRegistry<'_, TType>
268{
269    fn try_get_type_info(&self, id: ConcreteTypeId) -> Option<TypeInfo> {
270        self.declared_type_info
271            .get(&id)
272            .or_else(|| self.concrete_types.get(&id).map(|ty| ty.info()))
273            .cloned()
274    }
275}
276
277/// Creates the type-id to concrete type map, and the reverse map from generic-id and arguments to
278/// concrete-id.
279fn get_concrete_types_maps<TType: GenericType>(
280    program: &Program,
281) -> Result<(TypeMap<TType::Concrete>, ConcreteTypeIdMap<'_>), Box<ProgramRegistryError>> {
282    let mut concrete_types = HashMap::new();
283    let mut concrete_type_ids = HashMap::<(GenericTypeId, &[GenericArg]), ConcreteTypeId>::new();
284    let declared_type_info = program
285        .type_declarations
286        .iter()
287        .filter_map(|declaration| {
288            let TypeDeclaration { id, long_id, declared_type_info } = declaration;
289            let DeclaredTypeInfo { storable, droppable, duplicatable, zero_sized } =
290                declared_type_info.as_ref().cloned()?;
291            Some((
292                id.clone(),
293                TypeInfo {
294                    long_id: long_id.clone(),
295                    storable,
296                    droppable,
297                    duplicatable,
298                    zero_sized,
299                },
300            ))
301        })
302        .collect();
303    for declaration in &program.type_declarations {
304        let concrete_type = TType::specialize_by_id(
305            &TypeSpecializationContextForRegistry::<TType> {
306                concrete_types: &concrete_types,
307                declared_type_info: &declared_type_info,
308            },
309            &declaration.long_id.generic_id,
310            &declaration.long_id.generic_args,
311        )
312        .map_err(|error| {
313            Box::new(ProgramRegistryError::TypeSpecialization {
314                concrete_id: declaration.id.clone(),
315                error,
316            })
317        })?;
318        // Check that the info is consistent with declaration.
319        if let Some(declared_info) = declared_type_info.get(&declaration.id)
320            && concrete_type.info() != declared_info
321        {
322            return Err(Box::new(ProgramRegistryError::TypeInfoDeclarationMismatch(
323                declaration.id.clone(),
324            )));
325        }
326
327        match concrete_types.entry(declaration.id.clone()) {
328            Entry::Occupied(_) => Err(Box::new(ProgramRegistryError::TypeConcreteIdAlreadyExists(
329                declaration.id.clone(),
330            ))),
331            Entry::Vacant(entry) => Ok(entry.insert(concrete_type)),
332        }?;
333        match concrete_type_ids
334            .entry((declaration.long_id.generic_id.clone(), &declaration.long_id.generic_args[..]))
335        {
336            Entry::Occupied(_) => Err(Box::new(ProgramRegistryError::TypeAlreadyDeclared(
337                Box::new(declaration.clone()),
338            ))),
339            Entry::Vacant(entry) => Ok(entry.insert(declaration.id.clone())),
340        }?;
341    }
342    Ok((concrete_types, concrete_type_ids))
343}
344
345/// Context required for specialization process.
346pub struct SpecializationContextForRegistry<'a, TType: GenericType> {
347    pub functions: &'a FunctionMap,
348    pub concrete_type_ids: &'a ConcreteTypeIdMap<'a>,
349    pub concrete_types: &'a TypeMap<TType::Concrete>,
350}
351impl<TType: GenericType> TypeSpecializationContext for SpecializationContextForRegistry<'_, TType> {
352    fn try_get_type_info(&self, id: ConcreteTypeId) -> Option<TypeInfo> {
353        self.concrete_types.get(&id).map(|ty| ty.info().clone())
354    }
355}
356impl<TType: GenericType> SignatureSpecializationContext
357    for SpecializationContextForRegistry<'_, TType>
358{
359    fn try_get_concrete_type(
360        &self,
361        id: GenericTypeId,
362        generic_args: &[GenericArg],
363    ) -> Option<ConcreteTypeId> {
364        self.concrete_type_ids.get(&(id, generic_args)).cloned()
365    }
366
367    fn try_get_function_signature(&self, function_id: &FunctionId) -> Option<FunctionSignature> {
368        self.try_get_function(function_id).map(|f| f.signature)
369    }
370}
371impl<TType: GenericType> SpecializationContext for SpecializationContextForRegistry<'_, TType> {
372    fn try_get_function(&self, function_id: &FunctionId) -> Option<Function> {
373        self.functions.get(function_id).cloned()
374    }
375}
376
377/// Creates the libfuncs map.
378fn get_concrete_libfuncs<TType: GenericType, TLibfunc: GenericLibfunc>(
379    program: &Program,
380    context: &SpecializationContextForRegistry<'_, TType>,
381) -> Result<LibfuncMap<TLibfunc::Concrete>, Box<ProgramRegistryError>> {
382    let mut concrete_libfuncs = HashMap::new();
383    for declaration in &program.libfunc_declarations {
384        let concrete_libfunc = TLibfunc::specialize_by_id(
385            context,
386            &declaration.long_id.generic_id,
387            &declaration.long_id.generic_args,
388        )
389        .map_err(|error| ProgramRegistryError::LibfuncSpecialization {
390            concrete_id: declaration.id.clone(),
391            error,
392        })?;
393        match concrete_libfuncs.entry(declaration.id.clone()) {
394            Entry::Occupied(_) => {
395                Err(ProgramRegistryError::LibfuncConcreteIdAlreadyExists(declaration.id.clone()))
396            }
397            Entry::Vacant(entry) => Ok(entry.insert(concrete_libfunc)),
398        }?;
399    }
400    Ok(concrete_libfuncs)
401}