1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use itertools::{chain, izip};
5use thiserror::Error;
6
7use crate::extensions::lib_func::{
8 SierraApChange, SignatureSpecializationContext, SpecializationContext,
9};
10use crate::extensions::type_specialization_context::TypeSpecializationContext;
11use crate::extensions::types::TypeInfo;
12use crate::extensions::{
13 ConcreteLibfunc, ConcreteType, ExtensionError, GenericLibfunc, GenericLibfuncEx, GenericType,
14 GenericTypeEx,
15};
16use crate::ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId, GenericTypeId};
17use crate::program::{
18 BranchTarget, DeclaredTypeInfo, Function, FunctionSignature, GenericArg, Program, Statement,
19 StatementIdx, TypeDeclaration,
20};
21
22#[cfg(test)]
23#[path = "program_registry_test.rs"]
24mod test;
25
26#[derive(Error, Debug, Eq, PartialEq)]
28pub enum ProgramRegistryError {
29 #[error("used the same function id twice `{0}`.")]
30 FunctionIdAlreadyExists(FunctionId),
31 #[error("Could not find the requested function `{0}`.")]
32 MissingFunction(FunctionId),
33 #[error("Error during type specialization of `{concrete_id}`: {error}")]
34 TypeSpecialization { concrete_id: ConcreteTypeId, error: ExtensionError },
35 #[error("Used concrete type id `{0}` twice")]
36 TypeConcreteIdAlreadyExists(ConcreteTypeId),
37 #[error("Declared concrete type `{0}` twice")]
38 TypeAlreadyDeclared(Box<TypeDeclaration>),
39 #[error("Could not find requested type `{0}`.")]
40 MissingType(ConcreteTypeId),
41 #[error("Error during libfunc specialization of {concrete_id}: {error}")]
42 LibfuncSpecialization { concrete_id: ConcreteLibfuncId, error: ExtensionError },
43 #[error("Used concrete libfunc id `{0}` twice.")]
44 LibfuncConcreteIdAlreadyExists(ConcreteLibfuncId),
45 #[error("Could not find requested libfunc `{0}`.")]
46 MissingLibfunc(ConcreteLibfuncId),
47 #[error("Type info declaration mismatch for `{0}`.")]
48 TypeInfoDeclarationMismatch(ConcreteTypeId),
49 #[error("Function `{func_id}`'s parameter type `{ty}` is not storable.")]
50 FunctionWithUnstorableType { func_id: FunctionId, ty: ConcreteTypeId },
51 #[error("Function `{0}` points to non existing entry point statement.")]
52 FunctionNonExistingEntryPoint(FunctionId),
53 #[error("#{0}: Libfunc invocation input count mismatch")]
54 LibfuncInvocationInputCountMismatch(StatementIdx),
55 #[error("#{0}: Libfunc invocation branch count mismatch")]
56 LibfuncInvocationBranchCountMismatch(StatementIdx),
57 #[error("#{0}: Libfunc invocation branch #{1} result count mismatch")]
58 LibfuncInvocationBranchResultCountMismatch(StatementIdx, usize),
59 #[error("#{0}: Libfunc invocation branch #{1} target mismatch")]
60 LibfuncInvocationBranchTargetMismatch(StatementIdx, usize),
61 #[error("#{src}: Branch jump backwards to {dst}")]
62 BranchBackwards { src: StatementIdx, dst: StatementIdx },
63 #[error("#{src}: Branch jump to a non-branch align statement #{dst}")]
64 BranchNotToBranchAlign { src: StatementIdx, dst: StatementIdx },
65 #[error("#{src1}, #{src2}: Jump to the same statement #{dst}")]
66 MultipleJumpsToSameStatement { src1: StatementIdx, src2: StatementIdx, dst: StatementIdx },
67 #[error("#{0}: Jump out of range")]
68 JumpOutOfRange(StatementIdx),
69 #[error("Type size computation failed for `{ty}`: missing size information for `{dep}`")]
70 TypeSizeDependencyMissing { ty: ConcreteTypeId, dep: ConcreteTypeId },
71 #[error("Type size computation failed for `{0}`: size overflow.")]
72 TypeSizeOverflow(ConcreteTypeId),
73}
74
75type TypeMap<TType> = HashMap<ConcreteTypeId, TType>;
76type LibfuncMap<TLibfunc> = HashMap<ConcreteLibfuncId, TLibfunc>;
77type FunctionMap = HashMap<FunctionId, Function>;
78type ConcreteTypeIdMap<'a> = HashMap<(GenericTypeId, &'a [GenericArg]), ConcreteTypeId>;
81
82pub struct ProgramRegistry<TType: GenericType, TLibfunc: GenericLibfunc> {
84 functions: FunctionMap,
86 concrete_types: TypeMap<TType::Concrete>,
88 concrete_libfuncs: LibfuncMap<TLibfunc::Concrete>,
90}
91impl<TType: GenericType, TLibfunc: GenericLibfunc> ProgramRegistry<TType, TLibfunc> {
92 pub fn new(
94 program: &Program,
95 ) -> Result<ProgramRegistry<TType, TLibfunc>, Box<ProgramRegistryError>> {
96 let functions = get_functions(program)?;
97 let (concrete_types, concrete_type_ids) = get_concrete_types_maps::<TType>(program)?;
98 let concrete_libfuncs = get_concrete_libfuncs::<TType, TLibfunc>(
99 program,
100 &SpecializationContextForRegistry {
101 functions: &functions,
102 concrete_type_ids: &concrete_type_ids,
103 concrete_types: &concrete_types,
104 },
105 )?;
106 let registry = ProgramRegistry { functions, concrete_types, concrete_libfuncs };
107 registry.validate(program)?;
108 Ok(registry)
109 }
110
111 pub fn get_function<'a>(
113 &'a self,
114 id: &FunctionId,
115 ) -> Result<&'a Function, Box<ProgramRegistryError>> {
116 self.functions
117 .get(id)
118 .ok_or_else(|| Box::new(ProgramRegistryError::MissingFunction(id.clone())))
119 }
120 pub fn get_type<'a>(
122 &'a self,
123 id: &ConcreteTypeId,
124 ) -> Result<&'a TType::Concrete, Box<ProgramRegistryError>> {
125 self.concrete_types
126 .get(id)
127 .ok_or_else(|| Box::new(ProgramRegistryError::MissingType(id.clone())))
128 }
129 pub fn get_libfunc<'a>(
131 &'a self,
132 id: &ConcreteLibfuncId,
133 ) -> Result<&'a TLibfunc::Concrete, Box<ProgramRegistryError>> {
134 self.concrete_libfuncs
135 .get(id)
136 .ok_or_else(|| Box::new(ProgramRegistryError::MissingLibfunc(id.clone())))
137 }
138
139 fn validate(&self, program: &Program) -> Result<(), Box<ProgramRegistryError>> {
143 for func in self.functions.values() {
145 for ty in chain!(func.signature.param_types.iter(), func.signature.ret_types.iter()) {
146 if !self.get_type(ty)?.info().storable {
147 return Err(Box::new(ProgramRegistryError::FunctionWithUnstorableType {
148 func_id: func.id.clone(),
149 ty: ty.clone(),
150 }));
151 }
152 }
153 if func.entry_point.0 >= program.statements.len() {
154 return Err(Box::new(ProgramRegistryError::FunctionNonExistingEntryPoint(
155 func.id.clone(),
156 )));
157 }
158 }
159 let mut branches: HashMap<StatementIdx, StatementIdx> =
163 HashMap::<StatementIdx, StatementIdx>::default();
164 for (i, statement) in program.statements.iter().enumerate() {
165 self.validate_statement(program, StatementIdx(i), statement, &mut branches)?;
166 }
167 Ok(())
168 }
169
170 fn validate_statement(
172 &self,
173 program: &Program,
174 index: StatementIdx,
175 statement: &Statement,
176 branches: &mut HashMap<StatementIdx, StatementIdx>,
177 ) -> Result<(), Box<ProgramRegistryError>> {
178 let Statement::Invocation(invocation) = statement else {
179 return Ok(());
180 };
181 let libfunc = self.get_libfunc(&invocation.libfunc_id)?;
182 if invocation.args.len() != libfunc.param_signatures().len() {
183 return Err(Box::new(ProgramRegistryError::LibfuncInvocationInputCountMismatch(index)));
184 }
185 let libfunc_branches = libfunc.branch_signatures();
186 if invocation.branches.len() != libfunc_branches.len() {
187 return Err(Box::new(ProgramRegistryError::LibfuncInvocationBranchCountMismatch(
188 index,
189 )));
190 }
191 let libfunc_fallthrough = libfunc.fallthrough();
192 for (branch_index, (invocation_branch, libfunc_branch)) in
193 izip!(&invocation.branches, libfunc_branches).enumerate()
194 {
195 if invocation_branch.results.len() != libfunc_branch.vars.len() {
196 return Err(Box::new(
197 ProgramRegistryError::LibfuncInvocationBranchResultCountMismatch(
198 index,
199 branch_index,
200 ),
201 ));
202 }
203 if matches!(libfunc_fallthrough, Some(target) if target == branch_index)
204 != (invocation_branch.target == BranchTarget::Fallthrough)
205 {
206 return Err(Box::new(ProgramRegistryError::LibfuncInvocationBranchTargetMismatch(
207 index,
208 branch_index,
209 )));
210 }
211 if !matches!(libfunc_branch.ap_change, SierraApChange::BranchAlign)
212 && let Some(prev) = branches.get(&index)
213 {
214 return Err(Box::new(ProgramRegistryError::BranchNotToBranchAlign {
215 src: *prev,
216 dst: index,
217 }));
218 }
219 let next = index.next(&invocation_branch.target);
220 if next.0 >= program.statements.len() {
221 return Err(Box::new(ProgramRegistryError::JumpOutOfRange(index)));
222 }
223 if libfunc_branches.len() > 1 {
224 if next.0 < index.0 {
225 return Err(Box::new(ProgramRegistryError::BranchBackwards {
226 src: index,
227 dst: next,
228 }));
229 }
230 match branches.entry(next) {
231 Entry::Occupied(e) => {
232 return Err(Box::new(ProgramRegistryError::MultipleJumpsToSameStatement {
233 src1: *e.get(),
234 src2: index,
235 dst: next,
236 }));
237 }
238 Entry::Vacant(e) => {
239 e.insert(index);
240 }
241 }
242 }
243 }
244 Ok(())
245 }
246}
247
248fn get_functions(program: &Program) -> Result<FunctionMap, Box<ProgramRegistryError>> {
250 let mut functions = FunctionMap::new();
251 for func in &program.funcs {
252 match functions.entry(func.id.clone()) {
253 Entry::Occupied(_) => {
254 Err(ProgramRegistryError::FunctionIdAlreadyExists(func.id.clone()))
255 }
256 Entry::Vacant(entry) => Ok(entry.insert(func.clone())),
257 }?;
258 }
259 Ok(functions)
260}
261
262struct TypeSpecializationContextForRegistry<'a, TType: GenericType> {
263 pub concrete_types: &'a TypeMap<TType::Concrete>,
264 pub declared_type_info: &'a TypeMap<TypeInfo>,
265}
266impl<TType: GenericType> TypeSpecializationContext
267 for TypeSpecializationContextForRegistry<'_, TType>
268{
269 fn try_get_type_info(&self, id: ConcreteTypeId) -> Option<TypeInfo> {
270 self.declared_type_info
271 .get(&id)
272 .or_else(|| self.concrete_types.get(&id).map(|ty| ty.info()))
273 .cloned()
274 }
275}
276
277fn get_concrete_types_maps<TType: GenericType>(
280 program: &Program,
281) -> Result<(TypeMap<TType::Concrete>, ConcreteTypeIdMap<'_>), Box<ProgramRegistryError>> {
282 let mut concrete_types = HashMap::new();
283 let mut concrete_type_ids = HashMap::<(GenericTypeId, &[GenericArg]), ConcreteTypeId>::new();
284 let declared_type_info = program
285 .type_declarations
286 .iter()
287 .filter_map(|declaration| {
288 let TypeDeclaration { id, long_id, declared_type_info } = declaration;
289 let DeclaredTypeInfo { storable, droppable, duplicatable, zero_sized } =
290 declared_type_info.as_ref().cloned()?;
291 Some((
292 id.clone(),
293 TypeInfo {
294 long_id: long_id.clone(),
295 storable,
296 droppable,
297 duplicatable,
298 zero_sized,
299 },
300 ))
301 })
302 .collect();
303 for declaration in &program.type_declarations {
304 let concrete_type = TType::specialize_by_id(
305 &TypeSpecializationContextForRegistry::<TType> {
306 concrete_types: &concrete_types,
307 declared_type_info: &declared_type_info,
308 },
309 &declaration.long_id.generic_id,
310 &declaration.long_id.generic_args,
311 )
312 .map_err(|error| {
313 Box::new(ProgramRegistryError::TypeSpecialization {
314 concrete_id: declaration.id.clone(),
315 error,
316 })
317 })?;
318 if let Some(declared_info) = declared_type_info.get(&declaration.id)
320 && concrete_type.info() != declared_info
321 {
322 return Err(Box::new(ProgramRegistryError::TypeInfoDeclarationMismatch(
323 declaration.id.clone(),
324 )));
325 }
326
327 match concrete_types.entry(declaration.id.clone()) {
328 Entry::Occupied(_) => Err(Box::new(ProgramRegistryError::TypeConcreteIdAlreadyExists(
329 declaration.id.clone(),
330 ))),
331 Entry::Vacant(entry) => Ok(entry.insert(concrete_type)),
332 }?;
333 match concrete_type_ids
334 .entry((declaration.long_id.generic_id.clone(), &declaration.long_id.generic_args[..]))
335 {
336 Entry::Occupied(_) => Err(Box::new(ProgramRegistryError::TypeAlreadyDeclared(
337 Box::new(declaration.clone()),
338 ))),
339 Entry::Vacant(entry) => Ok(entry.insert(declaration.id.clone())),
340 }?;
341 }
342 Ok((concrete_types, concrete_type_ids))
343}
344
345pub struct SpecializationContextForRegistry<'a, TType: GenericType> {
347 pub functions: &'a FunctionMap,
348 pub concrete_type_ids: &'a ConcreteTypeIdMap<'a>,
349 pub concrete_types: &'a TypeMap<TType::Concrete>,
350}
351impl<TType: GenericType> TypeSpecializationContext for SpecializationContextForRegistry<'_, TType> {
352 fn try_get_type_info(&self, id: ConcreteTypeId) -> Option<TypeInfo> {
353 self.concrete_types.get(&id).map(|ty| ty.info().clone())
354 }
355}
356impl<TType: GenericType> SignatureSpecializationContext
357 for SpecializationContextForRegistry<'_, TType>
358{
359 fn try_get_concrete_type(
360 &self,
361 id: GenericTypeId,
362 generic_args: &[GenericArg],
363 ) -> Option<ConcreteTypeId> {
364 self.concrete_type_ids.get(&(id, generic_args)).cloned()
365 }
366
367 fn try_get_function_signature(&self, function_id: &FunctionId) -> Option<FunctionSignature> {
368 self.try_get_function(function_id).map(|f| f.signature)
369 }
370}
371impl<TType: GenericType> SpecializationContext for SpecializationContextForRegistry<'_, TType> {
372 fn try_get_function(&self, function_id: &FunctionId) -> Option<Function> {
373 self.functions.get(function_id).cloned()
374 }
375}
376
377fn get_concrete_libfuncs<TType: GenericType, TLibfunc: GenericLibfunc>(
379 program: &Program,
380 context: &SpecializationContextForRegistry<'_, TType>,
381) -> Result<LibfuncMap<TLibfunc::Concrete>, Box<ProgramRegistryError>> {
382 let mut concrete_libfuncs = HashMap::new();
383 for declaration in &program.libfunc_declarations {
384 let concrete_libfunc = TLibfunc::specialize_by_id(
385 context,
386 &declaration.long_id.generic_id,
387 &declaration.long_id.generic_args,
388 )
389 .map_err(|error| ProgramRegistryError::LibfuncSpecialization {
390 concrete_id: declaration.id.clone(),
391 error,
392 })?;
393 match concrete_libfuncs.entry(declaration.id.clone()) {
394 Entry::Occupied(_) => {
395 Err(ProgramRegistryError::LibfuncConcreteIdAlreadyExists(declaration.id.clone()))
396 }
397 Entry::Vacant(entry) => Ok(entry.insert(concrete_libfunc)),
398 }?;
399 }
400 Ok(concrete_libfuncs)
401}