1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
5use itertools::{chain, izip};
6use thiserror::Error;
7
8use crate::extensions::lib_func::{
9 SierraApChange, SignatureSpecializationContext, SpecializationContext,
10};
11use crate::extensions::type_specialization_context::TypeSpecializationContext;
12use crate::extensions::types::TypeInfo;
13use crate::extensions::{
14 ConcreteLibfunc, ConcreteType, ExtensionError, GenericLibfunc, GenericLibfuncEx, GenericType,
15 GenericTypeEx,
16};
17use crate::ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId, GenericTypeId};
18use crate::program::{
19 BranchTarget, DeclaredTypeInfo, Function, FunctionSignature, GenericArg, Program, Statement,
20 StatementIdx, TypeDeclaration,
21};
22
23#[cfg(test)]
24#[path = "program_registry_test.rs"]
25mod test;
26
27#[derive(Error, Debug, Eq, PartialEq)]
29pub enum ProgramRegistryError {
30 #[error("used the same function id twice `{0}`.")]
31 FunctionIdAlreadyExists(FunctionId),
32 #[error("Could not find the requested function `{0}`.")]
33 MissingFunction(FunctionId),
34 #[error("Error during type specialization of `{concrete_id}`: {error}")]
35 TypeSpecialization { concrete_id: ConcreteTypeId, error: ExtensionError },
36 #[error("Used concrete type id `{0}` twice")]
37 TypeConcreteIdAlreadyExists(ConcreteTypeId),
38 #[error("Declared concrete type `{0}` twice")]
39 TypeAlreadyDeclared(Box<TypeDeclaration>),
40 #[error("Could not find requested type `{0}`.")]
41 MissingType(ConcreteTypeId),
42 #[error("Error during libfunc specialization of {concrete_id}: {error}")]
43 LibfuncSpecialization { concrete_id: ConcreteLibfuncId, error: ExtensionError },
44 #[error("Used concrete libfunc id `{0}` twice.")]
45 LibfuncConcreteIdAlreadyExists(ConcreteLibfuncId),
46 #[error("Could not find requested libfunc `{0}`.")]
47 MissingLibfunc(ConcreteLibfuncId),
48 #[error("Type info declaration mismatch for `{0}`.")]
49 TypeInfoDeclarationMismatch(ConcreteTypeId),
50 #[error("Function `{func_id}`'s parameter type `{ty}` is not storable.")]
51 FunctionWithUnstorableType { func_id: FunctionId, ty: ConcreteTypeId },
52 #[error("Function `{0}` points to non existing entry point statement.")]
53 FunctionNonExistingEntryPoint(FunctionId),
54 #[error("#{0}: Libfunc invocation input count mismatch")]
55 LibfuncInvocationInputCountMismatch(StatementIdx),
56 #[error("#{0}: Libfunc invocation branch count mismatch")]
57 LibfuncInvocationBranchCountMismatch(StatementIdx),
58 #[error("#{0}: Libfunc invocation branch #{1} result count mismatch")]
59 LibfuncInvocationBranchResultCountMismatch(StatementIdx, usize),
60 #[error("#{0}: Libfunc invocation branch #{1} target mismatch")]
61 LibfuncInvocationBranchTargetMismatch(StatementIdx, usize),
62 #[error("#{src}: Branch jump backwards to {dst}")]
63 BranchBackwards { src: StatementIdx, dst: StatementIdx },
64 #[error("#{src}: Branch jump to a non-branch align statement #{dst}")]
65 BranchNotToBranchAlign { src: StatementIdx, dst: StatementIdx },
66 #[error("#{src1}, #{src2}: Jump to the same statement #{dst}")]
67 MultipleJumpsToSameStatement { src1: StatementIdx, src2: StatementIdx, dst: StatementIdx },
68 #[error("#{0}: Jump out of range")]
69 JumpOutOfRange(StatementIdx),
70}
71
72type TypeMap<TType> = HashMap<ConcreteTypeId, TType>;
73type LibfuncMap<TLibfunc> = HashMap<ConcreteLibfuncId, TLibfunc>;
74type FunctionMap = HashMap<FunctionId, Function>;
75type ConcreteTypeIdMap<'a> = HashMap<(GenericTypeId, &'a [GenericArg]), ConcreteTypeId>;
78
79pub struct ProgramRegistry<TType: GenericType, TLibfunc: GenericLibfunc> {
81 functions: FunctionMap,
83 concrete_types: TypeMap<TType::Concrete>,
85 concrete_libfuncs: LibfuncMap<TLibfunc::Concrete>,
87}
88impl<TType: GenericType, TLibfunc: GenericLibfunc> ProgramRegistry<TType, TLibfunc> {
89 pub fn new_with_ap_change(
91 program: &Program,
92 function_ap_change: OrderedHashMap<FunctionId, usize>,
93 ) -> Result<ProgramRegistry<TType, TLibfunc>, Box<ProgramRegistryError>> {
94 let functions = get_functions(program)?;
95 let (concrete_types, concrete_type_ids) = get_concrete_types_maps::<TType>(program)?;
96 let concrete_libfuncs = get_concrete_libfuncs::<TType, TLibfunc>(
97 program,
98 &SpecializationContextForRegistry {
99 functions: &functions,
100 concrete_type_ids: &concrete_type_ids,
101 concrete_types: &concrete_types,
102 function_ap_change,
103 },
104 )?;
105 let registry = ProgramRegistry { functions, concrete_types, concrete_libfuncs };
106 registry.validate(program)?;
107 Ok(registry)
108 }
109
110 pub fn new(
111 program: &Program,
112 ) -> Result<ProgramRegistry<TType, TLibfunc>, Box<ProgramRegistryError>> {
113 Self::new_with_ap_change(program, Default::default())
114 }
115 pub fn get_function<'a>(
117 &'a self,
118 id: &FunctionId,
119 ) -> Result<&'a Function, Box<ProgramRegistryError>> {
120 self.functions
121 .get(id)
122 .ok_or_else(|| Box::new(ProgramRegistryError::MissingFunction(id.clone())))
123 }
124 pub fn get_type<'a>(
126 &'a self,
127 id: &ConcreteTypeId,
128 ) -> Result<&'a TType::Concrete, Box<ProgramRegistryError>> {
129 self.concrete_types
130 .get(id)
131 .ok_or_else(|| Box::new(ProgramRegistryError::MissingType(id.clone())))
132 }
133 pub fn get_libfunc<'a>(
135 &'a self,
136 id: &ConcreteLibfuncId,
137 ) -> Result<&'a TLibfunc::Concrete, Box<ProgramRegistryError>> {
138 self.concrete_libfuncs
139 .get(id)
140 .ok_or_else(|| Box::new(ProgramRegistryError::MissingLibfunc(id.clone())))
141 }
142
143 fn validate(&self, program: &Program) -> Result<(), Box<ProgramRegistryError>> {
147 for func in self.functions.values() {
149 for ty in chain!(func.signature.param_types.iter(), func.signature.ret_types.iter()) {
150 if !self.get_type(ty)?.info().storable {
151 return Err(Box::new(ProgramRegistryError::FunctionWithUnstorableType {
152 func_id: func.id.clone(),
153 ty: ty.clone(),
154 }));
155 }
156 }
157 if func.entry_point.0 >= program.statements.len() {
158 return Err(Box::new(ProgramRegistryError::FunctionNonExistingEntryPoint(
159 func.id.clone(),
160 )));
161 }
162 }
163 let mut branches: HashMap<StatementIdx, StatementIdx> =
167 HashMap::<StatementIdx, StatementIdx>::default();
168 for (i, statement) in program.statements.iter().enumerate() {
169 self.validate_statement(program, StatementIdx(i), statement, &mut branches)?;
170 }
171 Ok(())
172 }
173
174 fn validate_statement(
176 &self,
177 program: &Program,
178 index: StatementIdx,
179 statement: &Statement,
180 branches: &mut HashMap<StatementIdx, StatementIdx>,
181 ) -> Result<(), Box<ProgramRegistryError>> {
182 let Statement::Invocation(invocation) = statement else {
183 return Ok(());
184 };
185 let libfunc = self.get_libfunc(&invocation.libfunc_id)?;
186 if invocation.args.len() != libfunc.param_signatures().len() {
187 return Err(Box::new(ProgramRegistryError::LibfuncInvocationInputCountMismatch(index)));
188 }
189 let libfunc_branches = libfunc.branch_signatures();
190 if invocation.branches.len() != libfunc_branches.len() {
191 return Err(Box::new(ProgramRegistryError::LibfuncInvocationBranchCountMismatch(
192 index,
193 )));
194 }
195 let libfunc_fallthrough = libfunc.fallthrough();
196 for (branch_index, (invocation_branch, libfunc_branch)) in
197 izip!(&invocation.branches, libfunc_branches).enumerate()
198 {
199 if invocation_branch.results.len() != libfunc_branch.vars.len() {
200 return Err(Box::new(
201 ProgramRegistryError::LibfuncInvocationBranchResultCountMismatch(
202 index,
203 branch_index,
204 ),
205 ));
206 }
207 if matches!(libfunc_fallthrough, Some(target) if target == branch_index)
208 != (invocation_branch.target == BranchTarget::Fallthrough)
209 {
210 return Err(Box::new(ProgramRegistryError::LibfuncInvocationBranchTargetMismatch(
211 index,
212 branch_index,
213 )));
214 }
215 if !matches!(libfunc_branch.ap_change, SierraApChange::BranchAlign) {
216 if let Some(prev) = branches.get(&index) {
217 return Err(Box::new(ProgramRegistryError::BranchNotToBranchAlign {
218 src: *prev,
219 dst: index,
220 }));
221 }
222 }
223 let next = index.next(&invocation_branch.target);
224 if next.0 >= program.statements.len() {
225 return Err(Box::new(ProgramRegistryError::JumpOutOfRange(index)));
226 }
227 if libfunc_branches.len() > 1 {
228 if next.0 < index.0 {
229 return Err(Box::new(ProgramRegistryError::BranchBackwards {
230 src: index,
231 dst: next,
232 }));
233 }
234 match branches.entry(next) {
235 Entry::Occupied(e) => {
236 return Err(Box::new(ProgramRegistryError::MultipleJumpsToSameStatement {
237 src1: *e.get(),
238 src2: index,
239 dst: next,
240 }));
241 }
242 Entry::Vacant(e) => {
243 e.insert(index);
244 }
245 }
246 }
247 }
248 Ok(())
249 }
250}
251
252fn get_functions(program: &Program) -> Result<FunctionMap, Box<ProgramRegistryError>> {
254 let mut functions = FunctionMap::new();
255 for func in &program.funcs {
256 match functions.entry(func.id.clone()) {
257 Entry::Occupied(_) => {
258 Err(ProgramRegistryError::FunctionIdAlreadyExists(func.id.clone()))
259 }
260 Entry::Vacant(entry) => Ok(entry.insert(func.clone())),
261 }?;
262 }
263 Ok(functions)
264}
265
266struct TypeSpecializationContextForRegistry<'a, TType: GenericType> {
267 pub concrete_types: &'a TypeMap<TType::Concrete>,
268 pub declared_type_info: &'a TypeMap<TypeInfo>,
269}
270impl<TType: GenericType> TypeSpecializationContext
271 for TypeSpecializationContextForRegistry<'_, TType>
272{
273 fn try_get_type_info(&self, id: ConcreteTypeId) -> Option<TypeInfo> {
274 self.declared_type_info
275 .get(&id)
276 .or_else(|| self.concrete_types.get(&id).map(|ty| ty.info()))
277 .cloned()
278 }
279}
280
281fn get_concrete_types_maps<TType: GenericType>(
284 program: &Program,
285) -> Result<(TypeMap<TType::Concrete>, ConcreteTypeIdMap<'_>), Box<ProgramRegistryError>> {
286 let mut concrete_types = HashMap::new();
287 let mut concrete_type_ids = HashMap::<(GenericTypeId, &[GenericArg]), ConcreteTypeId>::new();
288 let declared_type_info = program
289 .type_declarations
290 .iter()
291 .filter_map(|declaration| {
292 let TypeDeclaration { id, long_id, declared_type_info } = declaration;
293 let DeclaredTypeInfo { storable, droppable, duplicatable, zero_sized } =
294 declared_type_info.as_ref().cloned()?;
295 Some((
296 id.clone(),
297 TypeInfo {
298 long_id: long_id.clone(),
299 storable,
300 droppable,
301 duplicatable,
302 zero_sized,
303 },
304 ))
305 })
306 .collect();
307 for declaration in &program.type_declarations {
308 let concrete_type = TType::specialize_by_id(
309 &TypeSpecializationContextForRegistry::<TType> {
310 concrete_types: &concrete_types,
311 declared_type_info: &declared_type_info,
312 },
313 &declaration.long_id.generic_id,
314 &declaration.long_id.generic_args,
315 )
316 .map_err(|error| {
317 Box::new(ProgramRegistryError::TypeSpecialization {
318 concrete_id: declaration.id.clone(),
319 error,
320 })
321 })?;
322 if let Some(declared_info) = declared_type_info.get(&declaration.id) {
324 if concrete_type.info() != declared_info {
325 return Err(Box::new(ProgramRegistryError::TypeInfoDeclarationMismatch(
326 declaration.id.clone(),
327 )));
328 }
329 }
330
331 match concrete_types.entry(declaration.id.clone()) {
332 Entry::Occupied(_) => Err(Box::new(ProgramRegistryError::TypeConcreteIdAlreadyExists(
333 declaration.id.clone(),
334 ))),
335 Entry::Vacant(entry) => Ok(entry.insert(concrete_type)),
336 }?;
337 match concrete_type_ids
338 .entry((declaration.long_id.generic_id.clone(), &declaration.long_id.generic_args[..]))
339 {
340 Entry::Occupied(_) => Err(Box::new(ProgramRegistryError::TypeAlreadyDeclared(
341 Box::new(declaration.clone()),
342 ))),
343 Entry::Vacant(entry) => Ok(entry.insert(declaration.id.clone())),
344 }?;
345 }
346 Ok((concrete_types, concrete_type_ids))
347}
348
349pub struct SpecializationContextForRegistry<'a, TType: GenericType> {
351 pub functions: &'a FunctionMap,
352 pub concrete_type_ids: &'a ConcreteTypeIdMap<'a>,
353 pub concrete_types: &'a TypeMap<TType::Concrete>,
354 pub function_ap_change: OrderedHashMap<FunctionId, usize>,
356}
357impl<TType: GenericType> TypeSpecializationContext for SpecializationContextForRegistry<'_, TType> {
358 fn try_get_type_info(&self, id: ConcreteTypeId) -> Option<TypeInfo> {
359 self.concrete_types.get(&id).map(|ty| ty.info().clone())
360 }
361}
362impl<TType: GenericType> SignatureSpecializationContext
363 for SpecializationContextForRegistry<'_, TType>
364{
365 fn try_get_concrete_type(
366 &self,
367 id: GenericTypeId,
368 generic_args: &[GenericArg],
369 ) -> Option<ConcreteTypeId> {
370 self.concrete_type_ids.get(&(id, generic_args)).cloned()
371 }
372
373 fn try_get_function_signature(&self, function_id: &FunctionId) -> Option<FunctionSignature> {
374 self.try_get_function(function_id).map(|f| f.signature)
375 }
376
377 fn as_type_specialization_context(&self) -> &dyn TypeSpecializationContext {
378 self
379 }
380
381 fn try_get_function_ap_change(&self, function_id: &FunctionId) -> Option<SierraApChange> {
382 Some(if self.function_ap_change.contains_key(function_id) {
383 SierraApChange::Known { new_vars_only: false }
384 } else {
385 SierraApChange::Unknown
386 })
387 }
388}
389impl<TType: GenericType> SpecializationContext for SpecializationContextForRegistry<'_, TType> {
390 fn try_get_function(&self, function_id: &FunctionId) -> Option<Function> {
391 self.functions.get(function_id).cloned()
392 }
393
394 fn upcast(&self) -> &dyn SignatureSpecializationContext {
395 self
396 }
397}
398
399fn get_concrete_libfuncs<TType: GenericType, TLibfunc: GenericLibfunc>(
401 program: &Program,
402 context: &SpecializationContextForRegistry<'_, TType>,
403) -> Result<LibfuncMap<TLibfunc::Concrete>, Box<ProgramRegistryError>> {
404 let mut concrete_libfuncs = HashMap::new();
405 for declaration in &program.libfunc_declarations {
406 let concrete_libfunc = TLibfunc::specialize_by_id(
407 context,
408 &declaration.long_id.generic_id,
409 &declaration.long_id.generic_args,
410 )
411 .map_err(|error| ProgramRegistryError::LibfuncSpecialization {
412 concrete_id: declaration.id.clone(),
413 error,
414 })?;
415 match concrete_libfuncs.entry(declaration.id.clone()) {
416 Entry::Occupied(_) => {
417 Err(ProgramRegistryError::LibfuncConcreteIdAlreadyExists(declaration.id.clone()))
418 }
419 Entry::Vacant(entry) => Ok(entry.insert(concrete_libfunc)),
420 }?;
421 }
422 Ok(concrete_libfuncs)
423}