1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
5use itertools::{chain, izip};
6use thiserror::Error;
7
8use crate::extensions::lib_func::{
9    SierraApChange, SignatureSpecializationContext, SpecializationContext,
10};
11use crate::extensions::type_specialization_context::TypeSpecializationContext;
12use crate::extensions::types::TypeInfo;
13use crate::extensions::{
14    ConcreteLibfunc, ConcreteType, ExtensionError, GenericLibfunc, GenericLibfuncEx, GenericType,
15    GenericTypeEx,
16};
17use crate::ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId, GenericTypeId};
18use crate::program::{
19    BranchTarget, DeclaredTypeInfo, Function, FunctionSignature, GenericArg, Program, Statement,
20    StatementIdx, TypeDeclaration,
21};
22
23#[cfg(test)]
24#[path = "program_registry_test.rs"]
25mod test;
26
27#[derive(Error, Debug, Eq, PartialEq)]
29pub enum ProgramRegistryError {
30    #[error("used the same function id twice `{0}`.")]
31    FunctionIdAlreadyExists(FunctionId),
32    #[error("Could not find the requested function `{0}`.")]
33    MissingFunction(FunctionId),
34    #[error("Error during type specialization of `{concrete_id}`: {error}")]
35    TypeSpecialization { concrete_id: ConcreteTypeId, error: ExtensionError },
36    #[error("Used concrete type id `{0}` twice")]
37    TypeConcreteIdAlreadyExists(ConcreteTypeId),
38    #[error("Declared concrete type `{0}` twice")]
39    TypeAlreadyDeclared(Box<TypeDeclaration>),
40    #[error("Could not find requested type `{0}`.")]
41    MissingType(ConcreteTypeId),
42    #[error("Error during libfunc specialization of {concrete_id}: {error}")]
43    LibfuncSpecialization { concrete_id: ConcreteLibfuncId, error: ExtensionError },
44    #[error("Used concrete libfunc id `{0}` twice.")]
45    LibfuncConcreteIdAlreadyExists(ConcreteLibfuncId),
46    #[error("Could not find requested libfunc `{0}`.")]
47    MissingLibfunc(ConcreteLibfuncId),
48    #[error("Type info declaration mismatch for `{0}`.")]
49    TypeInfoDeclarationMismatch(ConcreteTypeId),
50    #[error("Function `{func_id}`'s parameter type `{ty}` is not storable.")]
51    FunctionWithUnstorableType { func_id: FunctionId, ty: ConcreteTypeId },
52    #[error("Function `{0}` points to non existing entry point statement.")]
53    FunctionNonExistingEntryPoint(FunctionId),
54    #[error("#{0}: Libfunc invocation input count mismatch")]
55    LibfuncInvocationInputCountMismatch(StatementIdx),
56    #[error("#{0}: Libfunc invocation branch count mismatch")]
57    LibfuncInvocationBranchCountMismatch(StatementIdx),
58    #[error("#{0}: Libfunc invocation branch #{1} result count mismatch")]
59    LibfuncInvocationBranchResultCountMismatch(StatementIdx, usize),
60    #[error("#{0}: Libfunc invocation branch #{1} target mismatch")]
61    LibfuncInvocationBranchTargetMismatch(StatementIdx, usize),
62    #[error("#{src}: Branch jump backwards to {dst}")]
63    BranchBackwards { src: StatementIdx, dst: StatementIdx },
64    #[error("#{src}: Branch jump to a non-branch align statement #{dst}")]
65    BranchNotToBranchAlign { src: StatementIdx, dst: StatementIdx },
66    #[error("#{src1}, #{src2}: Jump to the same statement #{dst}")]
67    MultipleJumpsToSameStatement { src1: StatementIdx, src2: StatementIdx, dst: StatementIdx },
68    #[error("#{0}: Jump out of range")]
69    JumpOutOfRange(StatementIdx),
70}
71
72type TypeMap<TType> = HashMap<ConcreteTypeId, TType>;
73type LibfuncMap<TLibfunc> = HashMap<ConcreteLibfuncId, TLibfunc>;
74type FunctionMap = HashMap<FunctionId, Function>;
75type ConcreteTypeIdMap<'a> = HashMap<(GenericTypeId, &'a [GenericArg]), ConcreteTypeId>;
78
79pub struct ProgramRegistry<TType: GenericType, TLibfunc: GenericLibfunc> {
81    functions: FunctionMap,
83    concrete_types: TypeMap<TType::Concrete>,
85    concrete_libfuncs: LibfuncMap<TLibfunc::Concrete>,
87}
88impl<TType: GenericType, TLibfunc: GenericLibfunc> ProgramRegistry<TType, TLibfunc> {
89    pub fn new_with_ap_change(
91        program: &Program,
92        function_ap_change: OrderedHashMap<FunctionId, usize>,
93    ) -> Result<ProgramRegistry<TType, TLibfunc>, Box<ProgramRegistryError>> {
94        let functions = get_functions(program)?;
95        let (concrete_types, concrete_type_ids) = get_concrete_types_maps::<TType>(program)?;
96        let concrete_libfuncs = get_concrete_libfuncs::<TType, TLibfunc>(
97            program,
98            &SpecializationContextForRegistry {
99                functions: &functions,
100                concrete_type_ids: &concrete_type_ids,
101                concrete_types: &concrete_types,
102                function_ap_change,
103            },
104        )?;
105        let registry = ProgramRegistry { functions, concrete_types, concrete_libfuncs };
106        registry.validate(program)?;
107        Ok(registry)
108    }
109
110    pub fn new(
111        program: &Program,
112    ) -> Result<ProgramRegistry<TType, TLibfunc>, Box<ProgramRegistryError>> {
113        Self::new_with_ap_change(program, Default::default())
114    }
115    pub fn get_function<'a>(
117        &'a self,
118        id: &FunctionId,
119    ) -> Result<&'a Function, Box<ProgramRegistryError>> {
120        self.functions
121            .get(id)
122            .ok_or_else(|| Box::new(ProgramRegistryError::MissingFunction(id.clone())))
123    }
124    pub fn get_type<'a>(
126        &'a self,
127        id: &ConcreteTypeId,
128    ) -> Result<&'a TType::Concrete, Box<ProgramRegistryError>> {
129        self.concrete_types
130            .get(id)
131            .ok_or_else(|| Box::new(ProgramRegistryError::MissingType(id.clone())))
132    }
133    pub fn get_libfunc<'a>(
135        &'a self,
136        id: &ConcreteLibfuncId,
137    ) -> Result<&'a TLibfunc::Concrete, Box<ProgramRegistryError>> {
138        self.concrete_libfuncs
139            .get(id)
140            .ok_or_else(|| Box::new(ProgramRegistryError::MissingLibfunc(id.clone())))
141    }
142
143    fn validate(&self, program: &Program) -> Result<(), Box<ProgramRegistryError>> {
147        for func in self.functions.values() {
149            for ty in chain!(func.signature.param_types.iter(), func.signature.ret_types.iter()) {
150                if !self.get_type(ty)?.info().storable {
151                    return Err(Box::new(ProgramRegistryError::FunctionWithUnstorableType {
152                        func_id: func.id.clone(),
153                        ty: ty.clone(),
154                    }));
155                }
156            }
157            if func.entry_point.0 >= program.statements.len() {
158                return Err(Box::new(ProgramRegistryError::FunctionNonExistingEntryPoint(
159                    func.id.clone(),
160                )));
161            }
162        }
163        let mut branches: HashMap<StatementIdx, StatementIdx> =
167            HashMap::<StatementIdx, StatementIdx>::default();
168        for (i, statement) in program.statements.iter().enumerate() {
169            self.validate_statement(program, StatementIdx(i), statement, &mut branches)?;
170        }
171        Ok(())
172    }
173
174    fn validate_statement(
176        &self,
177        program: &Program,
178        index: StatementIdx,
179        statement: &Statement,
180        branches: &mut HashMap<StatementIdx, StatementIdx>,
181    ) -> Result<(), Box<ProgramRegistryError>> {
182        let Statement::Invocation(invocation) = statement else {
183            return Ok(());
184        };
185        let libfunc = self.get_libfunc(&invocation.libfunc_id)?;
186        if invocation.args.len() != libfunc.param_signatures().len() {
187            return Err(Box::new(ProgramRegistryError::LibfuncInvocationInputCountMismatch(index)));
188        }
189        let libfunc_branches = libfunc.branch_signatures();
190        if invocation.branches.len() != libfunc_branches.len() {
191            return Err(Box::new(ProgramRegistryError::LibfuncInvocationBranchCountMismatch(
192                index,
193            )));
194        }
195        let libfunc_fallthrough = libfunc.fallthrough();
196        for (branch_index, (invocation_branch, libfunc_branch)) in
197            izip!(&invocation.branches, libfunc_branches).enumerate()
198        {
199            if invocation_branch.results.len() != libfunc_branch.vars.len() {
200                return Err(Box::new(
201                    ProgramRegistryError::LibfuncInvocationBranchResultCountMismatch(
202                        index,
203                        branch_index,
204                    ),
205                ));
206            }
207            if matches!(libfunc_fallthrough, Some(target) if target == branch_index)
208                != (invocation_branch.target == BranchTarget::Fallthrough)
209            {
210                return Err(Box::new(ProgramRegistryError::LibfuncInvocationBranchTargetMismatch(
211                    index,
212                    branch_index,
213                )));
214            }
215            if !matches!(libfunc_branch.ap_change, SierraApChange::BranchAlign) {
216                if let Some(prev) = branches.get(&index) {
217                    return Err(Box::new(ProgramRegistryError::BranchNotToBranchAlign {
218                        src: *prev,
219                        dst: index,
220                    }));
221                }
222            }
223            let next = index.next(&invocation_branch.target);
224            if next.0 >= program.statements.len() {
225                return Err(Box::new(ProgramRegistryError::JumpOutOfRange(index)));
226            }
227            if libfunc_branches.len() > 1 {
228                if next.0 < index.0 {
229                    return Err(Box::new(ProgramRegistryError::BranchBackwards {
230                        src: index,
231                        dst: next,
232                    }));
233                }
234                match branches.entry(next) {
235                    Entry::Occupied(e) => {
236                        return Err(Box::new(ProgramRegistryError::MultipleJumpsToSameStatement {
237                            src1: *e.get(),
238                            src2: index,
239                            dst: next,
240                        }));
241                    }
242                    Entry::Vacant(e) => {
243                        e.insert(index);
244                    }
245                }
246            }
247        }
248        Ok(())
249    }
250}
251
252fn get_functions(program: &Program) -> Result<FunctionMap, Box<ProgramRegistryError>> {
254    let mut functions = FunctionMap::new();
255    for func in &program.funcs {
256        match functions.entry(func.id.clone()) {
257            Entry::Occupied(_) => {
258                Err(ProgramRegistryError::FunctionIdAlreadyExists(func.id.clone()))
259            }
260            Entry::Vacant(entry) => Ok(entry.insert(func.clone())),
261        }?;
262    }
263    Ok(functions)
264}
265
266struct TypeSpecializationContextForRegistry<'a, TType: GenericType> {
267    pub concrete_types: &'a TypeMap<TType::Concrete>,
268    pub declared_type_info: &'a TypeMap<TypeInfo>,
269}
270impl<TType: GenericType> TypeSpecializationContext
271    for TypeSpecializationContextForRegistry<'_, TType>
272{
273    fn try_get_type_info(&self, id: ConcreteTypeId) -> Option<TypeInfo> {
274        self.declared_type_info
275            .get(&id)
276            .or_else(|| self.concrete_types.get(&id).map(|ty| ty.info()))
277            .cloned()
278    }
279}
280
281fn get_concrete_types_maps<TType: GenericType>(
284    program: &Program,
285) -> Result<(TypeMap<TType::Concrete>, ConcreteTypeIdMap<'_>), Box<ProgramRegistryError>> {
286    let mut concrete_types = HashMap::new();
287    let mut concrete_type_ids = HashMap::<(GenericTypeId, &[GenericArg]), ConcreteTypeId>::new();
288    let declared_type_info = program
289        .type_declarations
290        .iter()
291        .filter_map(|declaration| {
292            let TypeDeclaration { id, long_id, declared_type_info } = declaration;
293            let DeclaredTypeInfo { storable, droppable, duplicatable, zero_sized } =
294                declared_type_info.as_ref().cloned()?;
295            Some((
296                id.clone(),
297                TypeInfo {
298                    long_id: long_id.clone(),
299                    storable,
300                    droppable,
301                    duplicatable,
302                    zero_sized,
303                },
304            ))
305        })
306        .collect();
307    for declaration in &program.type_declarations {
308        let concrete_type = TType::specialize_by_id(
309            &TypeSpecializationContextForRegistry::<TType> {
310                concrete_types: &concrete_types,
311                declared_type_info: &declared_type_info,
312            },
313            &declaration.long_id.generic_id,
314            &declaration.long_id.generic_args,
315        )
316        .map_err(|error| {
317            Box::new(ProgramRegistryError::TypeSpecialization {
318                concrete_id: declaration.id.clone(),
319                error,
320            })
321        })?;
322        if let Some(declared_info) = declared_type_info.get(&declaration.id) {
324            if concrete_type.info() != declared_info {
325                return Err(Box::new(ProgramRegistryError::TypeInfoDeclarationMismatch(
326                    declaration.id.clone(),
327                )));
328            }
329        }
330
331        match concrete_types.entry(declaration.id.clone()) {
332            Entry::Occupied(_) => Err(Box::new(ProgramRegistryError::TypeConcreteIdAlreadyExists(
333                declaration.id.clone(),
334            ))),
335            Entry::Vacant(entry) => Ok(entry.insert(concrete_type)),
336        }?;
337        match concrete_type_ids
338            .entry((declaration.long_id.generic_id.clone(), &declaration.long_id.generic_args[..]))
339        {
340            Entry::Occupied(_) => Err(Box::new(ProgramRegistryError::TypeAlreadyDeclared(
341                Box::new(declaration.clone()),
342            ))),
343            Entry::Vacant(entry) => Ok(entry.insert(declaration.id.clone())),
344        }?;
345    }
346    Ok((concrete_types, concrete_type_ids))
347}
348
349pub struct SpecializationContextForRegistry<'a, TType: GenericType> {
351    pub functions: &'a FunctionMap,
352    pub concrete_type_ids: &'a ConcreteTypeIdMap<'a>,
353    pub concrete_types: &'a TypeMap<TType::Concrete>,
354    pub function_ap_change: OrderedHashMap<FunctionId, usize>,
356}
357impl<TType: GenericType> TypeSpecializationContext for SpecializationContextForRegistry<'_, TType> {
358    fn try_get_type_info(&self, id: ConcreteTypeId) -> Option<TypeInfo> {
359        self.concrete_types.get(&id).map(|ty| ty.info().clone())
360    }
361}
362impl<TType: GenericType> SignatureSpecializationContext
363    for SpecializationContextForRegistry<'_, TType>
364{
365    fn try_get_concrete_type(
366        &self,
367        id: GenericTypeId,
368        generic_args: &[GenericArg],
369    ) -> Option<ConcreteTypeId> {
370        self.concrete_type_ids.get(&(id, generic_args)).cloned()
371    }
372
373    fn try_get_function_signature(&self, function_id: &FunctionId) -> Option<FunctionSignature> {
374        self.try_get_function(function_id).map(|f| f.signature)
375    }
376
377    fn try_get_function_ap_change(&self, function_id: &FunctionId) -> Option<SierraApChange> {
378        Some(if self.function_ap_change.contains_key(function_id) {
379            SierraApChange::Known { new_vars_only: false }
380        } else {
381            SierraApChange::Unknown
382        })
383    }
384}
385impl<TType: GenericType> SpecializationContext for SpecializationContextForRegistry<'_, TType> {
386    fn try_get_function(&self, function_id: &FunctionId) -> Option<Function> {
387        self.functions.get(function_id).cloned()
388    }
389}
390
391fn get_concrete_libfuncs<TType: GenericType, TLibfunc: GenericLibfunc>(
393    program: &Program,
394    context: &SpecializationContextForRegistry<'_, TType>,
395) -> Result<LibfuncMap<TLibfunc::Concrete>, Box<ProgramRegistryError>> {
396    let mut concrete_libfuncs = HashMap::new();
397    for declaration in &program.libfunc_declarations {
398        let concrete_libfunc = TLibfunc::specialize_by_id(
399            context,
400            &declaration.long_id.generic_id,
401            &declaration.long_id.generic_args,
402        )
403        .map_err(|error| ProgramRegistryError::LibfuncSpecialization {
404            concrete_id: declaration.id.clone(),
405            error,
406        })?;
407        match concrete_libfuncs.entry(declaration.id.clone()) {
408            Entry::Occupied(_) => {
409                Err(ProgramRegistryError::LibfuncConcreteIdAlreadyExists(declaration.id.clone()))
410            }
411            Entry::Vacant(entry) => Ok(entry.insert(concrete_libfunc)),
412        }?;
413    }
414    Ok(concrete_libfuncs)
415}