Skip to main content

anchor_syn/parser/
context.rs

1use {
2    anyhow::{anyhow, Result},
3    std::{
4        collections::BTreeMap,
5        path::{Path, PathBuf},
6    },
7    syn::{
8        parse::{Error as ParseError, Result as ParseResult},
9        Ident, ImplItem, ImplItemConst, Type, TypePath,
10    },
11};
12
13/// Crate parse context
14///
15/// Keeps track of modules defined within a crate.
16pub struct CrateContext {
17    modules: BTreeMap<String, ParsedModule>,
18}
19
20impl CrateContext {
21    pub fn parse(root: impl AsRef<Path>) -> Result<Self> {
22        Ok(CrateContext {
23            modules: ParsedModule::parse_recursive(root.as_ref())?,
24        })
25    }
26
27    pub fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
28        self.modules.iter().flat_map(|(_, ctx)| ctx.consts())
29    }
30
31    pub fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &syn::ImplItemConst)> {
32        self.modules.iter().flat_map(|(_, ctx)| ctx.impl_consts())
33    }
34
35    pub fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
36        self.modules.iter().flat_map(|(_, ctx)| ctx.structs())
37    }
38
39    pub fn enums(&self) -> impl Iterator<Item = &syn::ItemEnum> {
40        self.modules.iter().flat_map(|(_, ctx)| ctx.enums())
41    }
42
43    pub fn type_aliases(&self) -> impl Iterator<Item = &syn::ItemType> {
44        self.modules.iter().flat_map(|(_, ctx)| ctx.type_aliases())
45    }
46
47    pub fn modules(&self) -> impl Iterator<Item = ModuleContext<'_>> {
48        self.modules.values().map(|detail| ModuleContext { detail })
49    }
50
51    pub fn root_module(&self) -> ModuleContext<'_> {
52        ModuleContext {
53            detail: self.modules.get("crate").unwrap(),
54        }
55    }
56
57    // Perform Anchor safety checks on the parsed create
58    pub fn safety_checks(&self) -> Result<()> {
59        // Check all structs for unsafe field types, i.e. AccountInfo and UncheckedAccount.
60        for ctx in self.modules.values() {
61            for unsafe_field in ctx.unsafe_struct_fields() {
62                // Check if unsafe field type has been documented with a /// SAFETY: doc string.
63                let is_documented = unsafe_field.attrs.iter().any(|attr| {
64                    attr.tokens.clone().into_iter().any(|token| match token {
65                        // Check for doc comments containing CHECK
66                        proc_macro2::TokenTree::Literal(s) => s.to_string().contains("CHECK"),
67                        _ => false,
68                    })
69                });
70                if !is_documented {
71                    let ident = unsafe_field.ident.as_ref().unwrap();
72                    let span = ident.span();
73                    // Error if undocumented.
74                    return Err(anyhow!(
75                        r#"
76        {}:{}:{}
77        Struct field "{}" is unsafe, but is not documented.
78        Please add a `/// CHECK:` doc comment explaining why no checks through types are necessary.
79        Alternatively, for reasons like quick prototyping, you may disable the safety checks
80        by using the `skip-lint` option.
81        See https://www.anchor-lang.com/docs/basics/program-structure#account-validation for more information.
82                    "#,
83                        ctx.file.canonicalize().unwrap().display(),
84                        span.start().line,
85                        span.start().column,
86                        ident
87                    ));
88                };
89            }
90        }
91        Ok(())
92    }
93}
94
95/// Module parse context
96///
97/// Keeps track of items defined within a module.
98#[derive(Copy, Clone)]
99pub struct ModuleContext<'krate> {
100    detail: &'krate ParsedModule,
101}
102
103impl ModuleContext<'_> {
104    pub fn items(&self) -> impl Iterator<Item = &syn::Item> {
105        self.detail.items.iter()
106    }
107}
108struct ParsedModule {
109    name: String,
110    file: PathBuf,
111    path: String,
112    items: Vec<syn::Item>,
113}
114
115struct UnparsedModule {
116    file: PathBuf,
117    path: String,
118    name: String,
119    item: syn::ItemMod,
120}
121
122impl ParsedModule {
123    fn parse_recursive(root: &Path) -> Result<BTreeMap<String, ParsedModule>> {
124        let mut modules = BTreeMap::new();
125
126        let root_content = std::fs::read_to_string(root)?;
127        let root_file = syn::parse_file(&root_content)?;
128        let root_mod = Self::new(
129            String::new(),
130            root.to_owned(),
131            "crate".to_owned(),
132            root_file.items,
133        );
134
135        let mut unparsed = root_mod.unparsed_submodules();
136        while let Some(to_parse) = unparsed.pop() {
137            let path = format!("{}::{}", to_parse.path, to_parse.name);
138            let module = Self::from_item_mod(&to_parse.file, &path, to_parse.item)?;
139
140            unparsed.extend(module.unparsed_submodules());
141            modules.insert(format!("{}{}", module.path, to_parse.name), module);
142        }
143
144        modules.insert(root_mod.name.clone(), root_mod);
145
146        Ok(modules)
147    }
148
149    fn from_item_mod(
150        parent_file: &Path,
151        parent_path: &str,
152        item: syn::ItemMod,
153    ) -> ParseResult<Self> {
154        Ok(match item.content {
155            Some((_, items)) => {
156                // The module content is within the parent file being parsed
157                Self::new(
158                    parent_path.to_owned(),
159                    parent_file.to_owned(),
160                    item.ident.to_string(),
161                    items,
162                )
163            }
164            None => {
165                // The module is referencing some other file, so we need to load that
166                // to parse the items it has.
167                let parent_dir = parent_file.parent().unwrap();
168                let parent_filename = parent_file.file_stem().unwrap().to_str().unwrap();
169                let parent_mod_dir = parent_dir.join(parent_filename);
170
171                let possible_file_paths = vec![
172                    parent_dir.join(format!("{}.rs", item.ident)),
173                    parent_dir.join(format!("{}/mod.rs", item.ident)),
174                    parent_mod_dir.join(format!("{}.rs", item.ident)),
175                    parent_mod_dir.join(format!("{}/mod.rs", item.ident)),
176                ];
177
178                let mod_file_path = possible_file_paths
179                    .into_iter()
180                    .find(|p| p.exists())
181                    .ok_or_else(|| ParseError::new_spanned(&item, "could not find file"))?;
182                let mod_file_content = std::fs::read_to_string(&mod_file_path)
183                    .map_err(|_| ParseError::new_spanned(&item, "could not read file"))?;
184                let mod_file = syn::parse_file(&mod_file_content)?;
185
186                Self::new(
187                    parent_path.to_owned(),
188                    mod_file_path,
189                    item.ident.to_string(),
190                    mod_file.items,
191                )
192            }
193        })
194    }
195
196    fn new(path: String, file: PathBuf, name: String, items: Vec<syn::Item>) -> Self {
197        Self {
198            name,
199            file,
200            path,
201            items,
202        }
203    }
204
205    fn unparsed_submodules(&self) -> Vec<UnparsedModule> {
206        self.submodules()
207            .map(|item| UnparsedModule {
208                file: self.file.clone(),
209                path: self.path.clone(),
210                name: item.ident.to_string(),
211                item: item.clone(),
212            })
213            .collect()
214    }
215
216    fn submodules(&self) -> impl Iterator<Item = &syn::ItemMod> {
217        self.items.iter().filter_map(|i| match i {
218            syn::Item::Mod(item) => Some(item),
219            _ => None,
220        })
221    }
222
223    fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
224        self.items.iter().filter_map(|i| match i {
225            syn::Item::Struct(item) => Some(item),
226            _ => None,
227        })
228    }
229
230    fn unsafe_struct_fields(&self) -> impl Iterator<Item = &syn::Field> {
231        let accounts_filter = |item_struct: &&syn::ItemStruct| {
232            item_struct.attrs.iter().any(|attr| {
233                match attr.parse_meta() {
234                    Ok(syn::Meta::List(syn::MetaList{path, nested, ..})) => {
235                        path.is_ident("derive") && nested.iter().any(|nested| {
236                            matches!(nested, syn::NestedMeta::Meta(syn::Meta::Path(path)) if path.is_ident("Accounts"))
237                        })
238                    }
239                    _ => false
240                }
241            })
242        };
243
244        self.structs()
245            .filter(accounts_filter)
246            .flat_map(|s| &s.fields)
247            .filter(|f| match &f.ty {
248                syn::Type::Path(syn::TypePath {
249                    path: syn::Path { segments, .. },
250                    ..
251                }) => {
252                    segments.len() == 1 && segments[0].ident == "UncheckedAccount"
253                        || segments[0].ident == "AccountInfo"
254                }
255                _ => false,
256            })
257    }
258
259    fn enums(&self) -> impl Iterator<Item = &syn::ItemEnum> {
260        self.items.iter().filter_map(|i| match i {
261            syn::Item::Enum(item) => Some(item),
262            _ => None,
263        })
264    }
265
266    fn type_aliases(&self) -> impl Iterator<Item = &syn::ItemType> {
267        self.items.iter().filter_map(|i| match i {
268            syn::Item::Type(item) => Some(item),
269            _ => None,
270        })
271    }
272
273    fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
274        self.items.iter().filter_map(|i| match i {
275            syn::Item::Const(item) => Some(item),
276            _ => None,
277        })
278    }
279
280    fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &ImplItemConst)> {
281        self.items
282            .iter()
283            .filter_map(|i| match i {
284                syn::Item::Impl(syn::ItemImpl {
285                    self_ty: ty, items, ..
286                }) => {
287                    if let Type::Path(TypePath {
288                        qself: None,
289                        path: p,
290                    }) = ty.as_ref()
291                    {
292                        if let Some(ident) = p.get_ident() {
293                            let mut to_return = Vec::new();
294                            items.iter().for_each(|item| {
295                                if let ImplItem::Const(item) = item {
296                                    to_return.push((ident, item));
297                                }
298                            });
299                            Some(to_return)
300                        } else {
301                            None
302                        }
303                    } else {
304                        None
305                    }
306                }
307                _ => None,
308            })
309            .flatten()
310    }
311}