mainstay_syn/parser/
context.rs

1use anyhow::{anyhow, Result};
2use std::collections::BTreeMap;
3use std::path::{Path, PathBuf};
4use syn::parse::{Error as ParseError, Result as ParseResult};
5use syn::{Ident, ImplItem, ImplItemConst, Type, TypePath};
6
7/// Crate parse context
8///
9/// Keeps track of modules defined within a crate.
10pub struct CrateContext {
11    modules: BTreeMap<String, ParsedModule>,
12}
13
14impl CrateContext {
15    pub fn parse(root: impl AsRef<Path>) -> Result<Self> {
16        Ok(CrateContext {
17            modules: ParsedModule::parse_recursive(root.as_ref())?,
18        })
19    }
20
21    pub fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
22        self.modules.iter().flat_map(|(_, ctx)| ctx.consts())
23    }
24
25    pub fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &syn::ImplItemConst)> {
26        self.modules.iter().flat_map(|(_, ctx)| ctx.impl_consts())
27    }
28
29    pub fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
30        self.modules.iter().flat_map(|(_, ctx)| ctx.structs())
31    }
32
33    pub fn enums(&self) -> impl Iterator<Item = &syn::ItemEnum> {
34        self.modules.iter().flat_map(|(_, ctx)| ctx.enums())
35    }
36
37    pub fn type_aliases(&self) -> impl Iterator<Item = &syn::ItemType> {
38        self.modules.iter().flat_map(|(_, ctx)| ctx.type_aliases())
39    }
40
41    pub fn modules(&self) -> impl Iterator<Item = ModuleContext> {
42        self.modules.values().map(|detail| ModuleContext { detail })
43    }
44
45    pub fn root_module(&self) -> ModuleContext {
46        ModuleContext {
47            detail: self.modules.get("crate").unwrap(),
48        }
49    }
50
51    // Perform Mainstay safety checks on the parsed create
52    pub fn safety_checks(&self) -> Result<()> {
53        // Check all structs for unsafe field types, i.e. AccountInfo and UncheckedAccount.
54        for ctx in self.modules.values() {
55            for unsafe_field in ctx.unsafe_struct_fields() {
56                // Check if unsafe field type has been documented with a /// SAFETY: doc string.
57                let is_documented = unsafe_field.attrs.iter().any(|attr| {
58                    attr.tokens.clone().into_iter().any(|token| match token {
59                        // Check for doc comments containing CHECK
60                        proc_macro2::TokenTree::Literal(s) => s.to_string().contains("CHECK"),
61                        _ => false,
62                    })
63                });
64                if !is_documented {
65                    let ident = unsafe_field.ident.as_ref().unwrap();
66                    let span = ident.span();
67                    // Error if undocumented.
68                    return Err(anyhow!(
69                        r#"
70        {}:{}:{}
71        Struct field "{}" is unsafe, but is not documented.
72        Please add a `/// CHECK:` doc comment explaining why no checks through types are necessary.
73        See https://www.mainstay-lang.com/docs/the-accounts-struct#safety-checks for more information.
74                    "#,
75                        ctx.file.canonicalize().unwrap().display(),
76                        span.start().line,
77                        span.start().column,
78                        ident.to_string()
79                    ));
80                };
81            }
82        }
83        Ok(())
84    }
85}
86
87/// Module parse context
88///
89/// Keeps track of items defined within a module.
90#[derive(Copy, Clone)]
91pub struct ModuleContext<'krate> {
92    detail: &'krate ParsedModule,
93}
94
95impl<'krate> ModuleContext<'krate> {
96    pub fn items(&self) -> impl Iterator<Item = &syn::Item> {
97        self.detail.items.iter()
98    }
99}
100struct ParsedModule {
101    name: String,
102    file: PathBuf,
103    path: String,
104    items: Vec<syn::Item>,
105}
106
107struct UnparsedModule {
108    file: PathBuf,
109    path: String,
110    name: String,
111    item: syn::ItemMod,
112}
113
114impl ParsedModule {
115    fn parse_recursive(root: &Path) -> Result<BTreeMap<String, ParsedModule>> {
116        let mut modules = BTreeMap::new();
117
118        let root_content = std::fs::read_to_string(root)?;
119        let root_file = syn::parse_file(&root_content)?;
120        let root_mod = Self::new(
121            String::new(),
122            root.to_owned(),
123            "crate".to_owned(),
124            root_file.items,
125        );
126
127        let mut unparsed = root_mod.unparsed_submodules();
128        while let Some(to_parse) = unparsed.pop() {
129            let path = format!("{}::{}", to_parse.path, to_parse.name);
130            let module = Self::from_item_mod(&to_parse.file, &path, to_parse.item)?;
131
132            unparsed.extend(module.unparsed_submodules());
133            modules.insert(format!("{}{}", module.path, to_parse.name), module);
134        }
135
136        modules.insert(root_mod.name.clone(), root_mod);
137
138        Ok(modules)
139    }
140
141    fn from_item_mod(
142        parent_file: &Path,
143        parent_path: &str,
144        item: syn::ItemMod,
145    ) -> ParseResult<Self> {
146        Ok(match item.content {
147            Some((_, items)) => {
148                // The module content is within the parent file being parsed
149                Self::new(
150                    parent_path.to_owned(),
151                    parent_file.to_owned(),
152                    item.ident.to_string(),
153                    items,
154                )
155            }
156            None => {
157                // The module is referencing some other file, so we need to load that
158                // to parse the items it has.
159                let parent_dir = parent_file.parent().unwrap();
160                let parent_filename = parent_file.file_stem().unwrap().to_str().unwrap();
161                let parent_mod_dir = parent_dir.join(parent_filename);
162
163                let possible_file_paths = vec![
164                    parent_dir.join(format!("{}.rs", item.ident)),
165                    parent_dir.join(format!("{}/mod.rs", item.ident)),
166                    parent_mod_dir.join(format!("{}.rs", item.ident)),
167                    parent_mod_dir.join(format!("{}/mod.rs", item.ident)),
168                ];
169
170                let mod_file_path = possible_file_paths
171                    .into_iter()
172                    .find(|p| p.exists())
173                    .ok_or_else(|| ParseError::new_spanned(&item, "could not find file"))?;
174                let mod_file_content = std::fs::read_to_string(&mod_file_path)
175                    .map_err(|_| ParseError::new_spanned(&item, "could not read file"))?;
176                let mod_file = syn::parse_file(&mod_file_content)?;
177
178                Self::new(
179                    parent_path.to_owned(),
180                    mod_file_path,
181                    item.ident.to_string(),
182                    mod_file.items,
183                )
184            }
185        })
186    }
187
188    fn new(path: String, file: PathBuf, name: String, items: Vec<syn::Item>) -> Self {
189        Self {
190            name,
191            file,
192            path,
193            items,
194        }
195    }
196
197    fn unparsed_submodules(&self) -> Vec<UnparsedModule> {
198        self.submodules()
199            .map(|item| UnparsedModule {
200                file: self.file.clone(),
201                path: self.path.clone(),
202                name: item.ident.to_string(),
203                item: item.clone(),
204            })
205            .collect()
206    }
207
208    fn submodules(&self) -> impl Iterator<Item = &syn::ItemMod> {
209        self.items.iter().filter_map(|i| match i {
210            syn::Item::Mod(item) => Some(item),
211            _ => None,
212        })
213    }
214
215    fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
216        self.items.iter().filter_map(|i| match i {
217            syn::Item::Struct(item) => Some(item),
218            _ => None,
219        })
220    }
221
222    fn unsafe_struct_fields(&self) -> impl Iterator<Item = &syn::Field> {
223        let accounts_filter = |item_struct: &&syn::ItemStruct| {
224            item_struct.attrs.iter().any(|attr| {
225                match attr.parse_meta() {
226                    Ok(syn::Meta::List(syn::MetaList{path, nested, ..})) => {
227                        path.is_ident("derive") && nested.iter().any(|nested| {
228                            matches!(nested, syn::NestedMeta::Meta(syn::Meta::Path(path)) if path.is_ident("Accounts"))
229                        })
230                    }
231                    _ => false
232                }
233            })
234        };
235
236        self.structs()
237            .filter(accounts_filter)
238            .flat_map(|s| &s.fields)
239            .filter(|f| match &f.ty {
240                syn::Type::Path(syn::TypePath {
241                    path: syn::Path { segments, .. },
242                    ..
243                }) => {
244                    segments.len() == 1 && segments[0].ident == "UncheckedAccount"
245                        || segments[0].ident == "AccountInfo"
246                }
247                _ => false,
248            })
249    }
250
251    fn enums(&self) -> impl Iterator<Item = &syn::ItemEnum> {
252        self.items.iter().filter_map(|i| match i {
253            syn::Item::Enum(item) => Some(item),
254            _ => None,
255        })
256    }
257
258    fn type_aliases(&self) -> impl Iterator<Item = &syn::ItemType> {
259        self.items.iter().filter_map(|i| match i {
260            syn::Item::Type(item) => Some(item),
261            _ => None,
262        })
263    }
264
265    fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
266        self.items.iter().filter_map(|i| match i {
267            syn::Item::Const(item) => Some(item),
268            _ => None,
269        })
270    }
271
272    fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &ImplItemConst)> {
273        self.items
274            .iter()
275            .filter_map(|i| match i {
276                syn::Item::Impl(syn::ItemImpl {
277                    self_ty: ty, items, ..
278                }) => {
279                    if let Type::Path(TypePath {
280                        qself: None,
281                        path: p,
282                    }) = ty.as_ref()
283                    {
284                        if let Some(ident) = p.get_ident() {
285                            let mut to_return = Vec::new();
286                            items.iter().for_each(|item| {
287                                if let ImplItem::Const(item) = item {
288                                    to_return.push((ident, item));
289                                }
290                            });
291                            Some(to_return)
292                        } else {
293                            None
294                        }
295                    } else {
296                        None
297                    }
298                }
299                _ => None,
300            })
301            .flatten()
302    }
303}