cairo_lang_starknet_classes/
allowed_libfuncs.rs1use std::collections::HashSet;
2use std::fmt::{Display, Formatter};
3use std::fs;
4
5use cairo_lang_sierra::ids::GenericLibfuncId;
6use serde::Deserialize;
7use smol_str::SmolStr;
8use thiserror::Error;
9
10#[cfg(test)]
11#[path = "allowed_libfuncs_test.rs"]
12mod test;
13
14#[derive(Error, Debug, Eq, PartialEq)]
15pub enum AllowedLibfuncsError {
16 #[error("Invalid Sierra program.")]
17 SierraProgramError,
18 #[error("No libfunc list named '{allowed_libfuncs_list_name}' is known.")]
19 UnexpectedAllowedLibfuncsList { allowed_libfuncs_list_name: String },
20 #[error("The allowed libfuncs file '{allowed_libfuncs_list_file}' was not found.")]
21 UnknownAllowedLibfuncsFile { allowed_libfuncs_list_file: String },
22 #[error("Failed to deserialize the allowed libfuncs file '{allowed_libfuncs_list_file}'.")]
23 DeserializationError { allowed_libfuncs_list_file: String },
24 #[error(
25 "Libfunc {invalid_libfunc} is not allowed in the libfuncs list \
26 '{allowed_libfuncs_list_name}'.\n Run with '--allowed-libfuncs-list-name \
27 {BUILTIN_ALL_LIBFUNCS_LIST}' to allow all libfuncs."
28 )]
29 UnsupportedLibfunc { invalid_libfunc: String, allowed_libfuncs_list_name: String },
30}
31
32#[derive(Clone, Debug, Default, Eq, PartialEq)]
34pub enum ListSelector {
35 ListName(String),
37 ListFile(String),
39 #[default]
40 DefaultList,
41}
42
43impl ListSelector {
44 pub fn new(list_name: Option<String>, list_file: Option<String>) -> Option<Self> {
47 match (list_name, list_file) {
48 (Some(_), Some(_)) => None,
50 (Some(list_name), None) => Some(Self::ListName(list_name)),
51 (None, Some(list_file)) => Some(Self::ListFile(list_file)),
52 (None, None) => Some(Self::default()),
53 }
54 }
55}
56
57impl Display for ListSelector {
58 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
59 match self {
60 ListSelector::ListName(s) => write!(f, "{s}"),
61 ListSelector::ListFile(s) => write!(f, "{s}"),
62 ListSelector::DefaultList => write!(f, "Default libfunc list"),
63 }
64 }
65}
66
67#[derive(Debug, PartialEq, Eq, Deserialize)]
69pub struct AllowedLibfuncs {
70 #[serde(deserialize_with = "deserialize_libfuncs_set::<_>")]
71 pub allowed_libfuncs: HashSet<GenericLibfuncId>,
72}
73
74fn deserialize_libfuncs_set<'de, D: serde::Deserializer<'de>>(
75 deserializer: D,
76) -> Result<HashSet<GenericLibfuncId>, D::Error> {
77 Ok(HashSet::from_iter(
78 Vec::<SmolStr>::deserialize(deserializer)?.into_iter().map(GenericLibfuncId::from_string),
79 ))
80}
81
82pub const BUILTIN_AUDITED_LIBFUNCS_LIST: &str = "audited";
85pub const BUILTIN_EXPERIMENTAL_LIBFUNCS_LIST: &str = "experimental";
88pub const BUILTIN_ALL_LIBFUNCS_LIST: &str = "all";
90
91pub fn lookup_allowed_libfuncs_list(
93 list_selector: ListSelector,
94) -> Result<AllowedLibfuncs, AllowedLibfuncsError> {
95 let list_name = list_selector.to_string();
96 let allowed_libfuncs_str: String = match list_selector {
97 ListSelector::ListName(list_name) => match list_name.as_str() {
98 BUILTIN_ALL_LIBFUNCS_LIST => {
99 include_str!("allowed_libfuncs_lists/all.json").to_string()
100 }
101 BUILTIN_EXPERIMENTAL_LIBFUNCS_LIST => {
102 include_str!("allowed_libfuncs_lists/experimental.json").to_string()
103 }
104 BUILTIN_AUDITED_LIBFUNCS_LIST => {
105 include_str!("allowed_libfuncs_lists/audited.json").to_string()
106 }
107 _ => {
108 return Err(AllowedLibfuncsError::UnexpectedAllowedLibfuncsList {
109 allowed_libfuncs_list_name: list_name.to_string(),
110 });
111 }
112 },
113 ListSelector::ListFile(file_path) => fs::read_to_string(&file_path).map_err(|_| {
114 AllowedLibfuncsError::UnknownAllowedLibfuncsFile {
115 allowed_libfuncs_list_file: file_path,
116 }
117 })?,
118 ListSelector::DefaultList => {
119 include_str!("allowed_libfuncs_lists/audited.json").to_string()
120 }
121 };
122 let allowed_libfuncs: Result<AllowedLibfuncs, serde_json::Error> =
123 serde_json::from_str(&allowed_libfuncs_str);
124 allowed_libfuncs.map_err(|_| AllowedLibfuncsError::DeserializationError {
125 allowed_libfuncs_list_file: list_name,
126 })
127}