1use anyhow::{anyhow, Result};
2use std::collections::BTreeMap;
3use std::path::{Path, PathBuf};
4use syn::parse::{Error as ParseError, Result as ParseResult};
5use syn::{Ident, ImplItem, ImplItemConst, Type, TypePath};
6
7pub struct CrateContext {
11 modules: BTreeMap<String, ParsedModule>,
12}
13
14impl CrateContext {
15 pub fn parse(root: impl AsRef<Path>) -> Result<Self> {
16 Ok(CrateContext {
17 modules: ParsedModule::parse_recursive(root.as_ref())?,
18 })
19 }
20
21 pub fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
22 self.modules.iter().flat_map(|(_, ctx)| ctx.consts())
23 }
24
25 pub fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &syn::ImplItemConst)> {
26 self.modules.iter().flat_map(|(_, ctx)| ctx.impl_consts())
27 }
28
29 pub fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
30 self.modules.iter().flat_map(|(_, ctx)| ctx.structs())
31 }
32
33 pub fn enums(&self) -> impl Iterator<Item = &syn::ItemEnum> {
34 self.modules.iter().flat_map(|(_, ctx)| ctx.enums())
35 }
36
37 pub fn type_aliases(&self) -> impl Iterator<Item = &syn::ItemType> {
38 self.modules.iter().flat_map(|(_, ctx)| ctx.type_aliases())
39 }
40
41 pub fn modules(&self) -> impl Iterator<Item = ModuleContext> {
42 self.modules.values().map(|detail| ModuleContext { detail })
43 }
44
45 pub fn root_module(&self) -> ModuleContext {
46 ModuleContext {
47 detail: self.modules.get("crate").unwrap(),
48 }
49 }
50
51 pub fn safety_checks(&self) -> Result<()> {
53 for ctx in self.modules.values() {
55 for unsafe_field in ctx.unsafe_struct_fields() {
56 let is_documented = unsafe_field.attrs.iter().any(|attr| {
58 attr.tokens.clone().into_iter().any(|token| match token {
59 proc_macro2::TokenTree::Literal(s) => s.to_string().contains("CHECK"),
61 _ => false,
62 })
63 });
64 if !is_documented {
65 let ident = unsafe_field.ident.as_ref().unwrap();
66 let span = ident.span();
67 return Err(anyhow!(
69 r#"
70 {}:{}:{}
71 Struct field "{}" is unsafe, but is not documented.
72 Please add a `/// CHECK:` doc comment explaining why no checks through types are necessary.
73 See https://www.mainstay-lang.com/docs/the-accounts-struct#safety-checks for more information.
74 "#,
75 ctx.file.canonicalize().unwrap().display(),
76 span.start().line,
77 span.start().column,
78 ident.to_string()
79 ));
80 };
81 }
82 }
83 Ok(())
84 }
85}
86
87#[derive(Copy, Clone)]
91pub struct ModuleContext<'krate> {
92 detail: &'krate ParsedModule,
93}
94
95impl<'krate> ModuleContext<'krate> {
96 pub fn items(&self) -> impl Iterator<Item = &syn::Item> {
97 self.detail.items.iter()
98 }
99}
100struct ParsedModule {
101 name: String,
102 file: PathBuf,
103 path: String,
104 items: Vec<syn::Item>,
105}
106
107struct UnparsedModule {
108 file: PathBuf,
109 path: String,
110 name: String,
111 item: syn::ItemMod,
112}
113
114impl ParsedModule {
115 fn parse_recursive(root: &Path) -> Result<BTreeMap<String, ParsedModule>> {
116 let mut modules = BTreeMap::new();
117
118 let root_content = std::fs::read_to_string(root)?;
119 let root_file = syn::parse_file(&root_content)?;
120 let root_mod = Self::new(
121 String::new(),
122 root.to_owned(),
123 "crate".to_owned(),
124 root_file.items,
125 );
126
127 let mut unparsed = root_mod.unparsed_submodules();
128 while let Some(to_parse) = unparsed.pop() {
129 let path = format!("{}::{}", to_parse.path, to_parse.name);
130 let module = Self::from_item_mod(&to_parse.file, &path, to_parse.item)?;
131
132 unparsed.extend(module.unparsed_submodules());
133 modules.insert(format!("{}{}", module.path, to_parse.name), module);
134 }
135
136 modules.insert(root_mod.name.clone(), root_mod);
137
138 Ok(modules)
139 }
140
141 fn from_item_mod(
142 parent_file: &Path,
143 parent_path: &str,
144 item: syn::ItemMod,
145 ) -> ParseResult<Self> {
146 Ok(match item.content {
147 Some((_, items)) => {
148 Self::new(
150 parent_path.to_owned(),
151 parent_file.to_owned(),
152 item.ident.to_string(),
153 items,
154 )
155 }
156 None => {
157 let parent_dir = parent_file.parent().unwrap();
160 let parent_filename = parent_file.file_stem().unwrap().to_str().unwrap();
161 let parent_mod_dir = parent_dir.join(parent_filename);
162
163 let possible_file_paths = vec![
164 parent_dir.join(format!("{}.rs", item.ident)),
165 parent_dir.join(format!("{}/mod.rs", item.ident)),
166 parent_mod_dir.join(format!("{}.rs", item.ident)),
167 parent_mod_dir.join(format!("{}/mod.rs", item.ident)),
168 ];
169
170 let mod_file_path = possible_file_paths
171 .into_iter()
172 .find(|p| p.exists())
173 .ok_or_else(|| ParseError::new_spanned(&item, "could not find file"))?;
174 let mod_file_content = std::fs::read_to_string(&mod_file_path)
175 .map_err(|_| ParseError::new_spanned(&item, "could not read file"))?;
176 let mod_file = syn::parse_file(&mod_file_content)?;
177
178 Self::new(
179 parent_path.to_owned(),
180 mod_file_path,
181 item.ident.to_string(),
182 mod_file.items,
183 )
184 }
185 })
186 }
187
188 fn new(path: String, file: PathBuf, name: String, items: Vec<syn::Item>) -> Self {
189 Self {
190 name,
191 file,
192 path,
193 items,
194 }
195 }
196
197 fn unparsed_submodules(&self) -> Vec<UnparsedModule> {
198 self.submodules()
199 .map(|item| UnparsedModule {
200 file: self.file.clone(),
201 path: self.path.clone(),
202 name: item.ident.to_string(),
203 item: item.clone(),
204 })
205 .collect()
206 }
207
208 fn submodules(&self) -> impl Iterator<Item = &syn::ItemMod> {
209 self.items.iter().filter_map(|i| match i {
210 syn::Item::Mod(item) => Some(item),
211 _ => None,
212 })
213 }
214
215 fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
216 self.items.iter().filter_map(|i| match i {
217 syn::Item::Struct(item) => Some(item),
218 _ => None,
219 })
220 }
221
222 fn unsafe_struct_fields(&self) -> impl Iterator<Item = &syn::Field> {
223 let accounts_filter = |item_struct: &&syn::ItemStruct| {
224 item_struct.attrs.iter().any(|attr| {
225 match attr.parse_meta() {
226 Ok(syn::Meta::List(syn::MetaList{path, nested, ..})) => {
227 path.is_ident("derive") && nested.iter().any(|nested| {
228 matches!(nested, syn::NestedMeta::Meta(syn::Meta::Path(path)) if path.is_ident("Accounts"))
229 })
230 }
231 _ => false
232 }
233 })
234 };
235
236 self.structs()
237 .filter(accounts_filter)
238 .flat_map(|s| &s.fields)
239 .filter(|f| match &f.ty {
240 syn::Type::Path(syn::TypePath {
241 path: syn::Path { segments, .. },
242 ..
243 }) => {
244 segments.len() == 1 && segments[0].ident == "UncheckedAccount"
245 || segments[0].ident == "AccountInfo"
246 }
247 _ => false,
248 })
249 }
250
251 fn enums(&self) -> impl Iterator<Item = &syn::ItemEnum> {
252 self.items.iter().filter_map(|i| match i {
253 syn::Item::Enum(item) => Some(item),
254 _ => None,
255 })
256 }
257
258 fn type_aliases(&self) -> impl Iterator<Item = &syn::ItemType> {
259 self.items.iter().filter_map(|i| match i {
260 syn::Item::Type(item) => Some(item),
261 _ => None,
262 })
263 }
264
265 fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
266 self.items.iter().filter_map(|i| match i {
267 syn::Item::Const(item) => Some(item),
268 _ => None,
269 })
270 }
271
272 fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &ImplItemConst)> {
273 self.items
274 .iter()
275 .filter_map(|i| match i {
276 syn::Item::Impl(syn::ItemImpl {
277 self_ty: ty, items, ..
278 }) => {
279 if let Type::Path(TypePath {
280 qself: None,
281 path: p,
282 }) = ty.as_ref()
283 {
284 if let Some(ident) = p.get_ident() {
285 let mut to_return = Vec::new();
286 items.iter().for_each(|item| {
287 if let ImplItem::Const(item) = item {
288 to_return.push((ident, item));
289 }
290 });
291 Some(to_return)
292 } else {
293 None
294 }
295 } else {
296 None
297 }
298 }
299 _ => None,
300 })
301 .flatten()
302 }
303}