Skip to main content

anchor_syn/parser/
context.rs

1use {
2    anyhow::{anyhow, Result},
3    std::{
4        collections::BTreeMap,
5        path::{Path, PathBuf},
6    },
7    syn::{
8        parse::{Error as ParseError, Result as ParseResult},
9        Ident, ImplItem, ImplItemConst, Type, TypePath,
10    },
11};
12
13/// Crate parse context
14///
15/// Keeps track of modules defined within a crate.
16pub struct CrateContext {
17    modules: BTreeMap<String, ParsedModule>,
18}
19
20impl CrateContext {
21    pub fn parse(root: impl AsRef<Path>) -> Result<Self> {
22        Ok(CrateContext {
23            modules: ParsedModule::parse_recursive(root.as_ref())?,
24        })
25    }
26
27    pub fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
28        self.modules.iter().flat_map(|(_, ctx)| ctx.consts())
29    }
30
31    pub fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &syn::ImplItemConst)> {
32        self.modules.iter().flat_map(|(_, ctx)| ctx.impl_consts())
33    }
34
35    pub fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
36        self.modules.iter().flat_map(|(_, ctx)| ctx.structs())
37    }
38
39    pub fn enums(&self) -> impl Iterator<Item = &syn::ItemEnum> {
40        self.modules.iter().flat_map(|(_, ctx)| ctx.enums())
41    }
42
43    pub fn type_aliases(&self) -> impl Iterator<Item = &syn::ItemType> {
44        self.modules.iter().flat_map(|(_, ctx)| ctx.type_aliases())
45    }
46
47    pub fn modules(&self) -> impl Iterator<Item = ModuleContext<'_>> {
48        self.modules.values().map(|detail| ModuleContext { detail })
49    }
50
51    pub fn root_module(&self) -> ModuleContext<'_> {
52        #[allow(
53            clippy::unwrap_used,
54            reason = "\"crate\" module is always inserted during parsing"
55        )]
56        let detail = self.modules.get("crate").unwrap();
57        ModuleContext { detail }
58    }
59
60    // Perform Anchor safety checks on the parsed create
61    pub fn safety_checks(&self) -> Result<()> {
62        // Check all structs for unsafe field types, i.e. AccountInfo and UncheckedAccount.
63        for ctx in self.modules.values() {
64            for unsafe_field in ctx.unsafe_struct_fields() {
65                // Check if unsafe field type has been documented with a /// SAFETY: doc string.
66                let is_documented = unsafe_field.attrs.iter().any(|attr| {
67                    if let syn::Meta::NameValue(syn::MetaNameValue {
68                        value:
69                            syn::Expr::Lit(syn::ExprLit {
70                                lit: syn::Lit::Str(s),
71                                ..
72                            }),
73                        ..
74                    }) = &attr.meta
75                    {
76                        s.value().contains("CHECK")
77                    } else {
78                        false
79                    }
80                });
81                if !is_documented {
82                    #[allow(
83                        clippy::unwrap_used,
84                        reason = "unsafe fields always have idents (named fields only)"
85                    )]
86                    let ident = unsafe_field.ident.as_ref().unwrap();
87                    let span = ident.span();
88                    // Error if undocumented.
89                    #[allow(
90                        clippy::unwrap_used,
91                        reason = "file paths are always valid during compilation"
92                    )]
93                    let canonical = ctx.file.canonicalize().unwrap();
94                    return Err(anyhow!(
95                        r#"
96        {}:{}:{}
97        Struct field "{}" is unsafe, but is not documented.
98        Please add a `/// CHECK:` doc comment explaining why no checks through types are necessary.
99        Alternatively, for reasons like quick prototyping, you may disable the safety checks
100        by using the `skip-lint` option.
101        See https://www.anchor-lang.com/docs/basics/program-structure#account-validation for more information.
102                    "#,
103                        canonical.display(),
104                        span.start().line,
105                        span.start().column,
106                        ident
107                    ));
108                };
109            }
110        }
111        Ok(())
112    }
113}
114
115/// Module parse context
116///
117/// Keeps track of items defined within a module.
118#[derive(Copy, Clone)]
119pub struct ModuleContext<'krate> {
120    detail: &'krate ParsedModule,
121}
122
123impl ModuleContext<'_> {
124    pub fn items(&self) -> impl Iterator<Item = &syn::Item> {
125        self.detail.items.iter()
126    }
127}
128struct ParsedModule {
129    name: String,
130    file: PathBuf,
131    path: String,
132    items: Vec<syn::Item>,
133}
134
135struct UnparsedModule {
136    file: PathBuf,
137    path: String,
138    name: String,
139    item: syn::ItemMod,
140}
141
142impl ParsedModule {
143    fn parse_recursive(root: &Path) -> Result<BTreeMap<String, ParsedModule>> {
144        let mut modules = BTreeMap::new();
145
146        let root_content = std::fs::read_to_string(root)?;
147        let root_file = syn::parse_file(&root_content)?;
148        let root_mod = Self::new(
149            String::new(),
150            root.to_owned(),
151            "crate".to_owned(),
152            root_file.items,
153        );
154
155        let mut unparsed = root_mod.unparsed_submodules();
156        while let Some(to_parse) = unparsed.pop() {
157            let path = format!("{}::{}", to_parse.path, to_parse.name);
158            let module = Self::from_item_mod(&to_parse.file, &path, to_parse.item)?;
159
160            unparsed.extend(module.unparsed_submodules());
161            modules.insert(format!("{}{}", module.path, to_parse.name), module);
162        }
163
164        modules.insert(root_mod.name.clone(), root_mod);
165
166        Ok(modules)
167    }
168
169    fn from_item_mod(
170        parent_file: &Path,
171        parent_path: &str,
172        item: syn::ItemMod,
173    ) -> ParseResult<Self> {
174        Ok(match item.content {
175            Some((_, items)) => {
176                // The module content is within the parent file being parsed
177                Self::new(
178                    parent_path.to_owned(),
179                    parent_file.to_owned(),
180                    item.ident.to_string(),
181                    items,
182                )
183            }
184            None => {
185                // The module is referencing some other file, so we need to load that
186                // to parse the items it has.
187                #[allow(
188                    clippy::unwrap_used,
189                    reason = "file paths always have parent directories during compilation"
190                )]
191                let parent_dir = parent_file.parent().unwrap();
192                #[allow(
193                    clippy::unwrap_used,
194                    reason = "file stems are always valid UTF-8 Rust identifiers"
195                )]
196                let parent_filename = parent_file.file_stem().unwrap().to_str().unwrap();
197                let parent_mod_dir = parent_dir.join(parent_filename);
198
199                let possible_file_paths = vec![
200                    parent_dir.join(format!("{}.rs", item.ident)),
201                    parent_dir.join(format!("{}/mod.rs", item.ident)),
202                    parent_mod_dir.join(format!("{}.rs", item.ident)),
203                    parent_mod_dir.join(format!("{}/mod.rs", item.ident)),
204                ];
205
206                let mod_file_path = possible_file_paths
207                    .into_iter()
208                    .find(|p| p.exists())
209                    .ok_or_else(|| ParseError::new_spanned(&item, "could not find file"))?;
210                let mod_file_content = std::fs::read_to_string(&mod_file_path)
211                    .map_err(|_| ParseError::new_spanned(&item, "could not read file"))?;
212                let mod_file = syn::parse_file(&mod_file_content)?;
213
214                Self::new(
215                    parent_path.to_owned(),
216                    mod_file_path,
217                    item.ident.to_string(),
218                    mod_file.items,
219                )
220            }
221        })
222    }
223
224    fn new(path: String, file: PathBuf, name: String, items: Vec<syn::Item>) -> Self {
225        Self {
226            name,
227            file,
228            path,
229            items,
230        }
231    }
232
233    fn unparsed_submodules(&self) -> Vec<UnparsedModule> {
234        self.submodules()
235            .map(|item| UnparsedModule {
236                file: self.file.clone(),
237                path: self.path.clone(),
238                name: item.ident.to_string(),
239                item: item.clone(),
240            })
241            .collect()
242    }
243
244    fn submodules(&self) -> impl Iterator<Item = &syn::ItemMod> {
245        self.items.iter().filter_map(|i| match i {
246            syn::Item::Mod(item) => Some(item),
247            _ => None,
248        })
249    }
250
251    fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
252        self.items.iter().filter_map(|i| match i {
253            syn::Item::Struct(item) => Some(item),
254            _ => None,
255        })
256    }
257
258    fn unsafe_struct_fields(&self) -> impl Iterator<Item = &syn::Field> {
259        let accounts_filter = |item_struct: &&syn::ItemStruct| {
260            item_struct.attrs.iter().any(|attr| {
261                attr.path().is_ident("derive")
262                    && attr
263                        .parse_args_with(
264                            syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated,
265                        )
266                        .ok()
267                        .is_some_and(|args| {
268                            args.iter()
269                                .any(|m| matches!(m, syn::Meta::Path(p) if p.is_ident("Accounts")))
270                        })
271            })
272        };
273
274        self.structs()
275            .filter(accounts_filter)
276            .flat_map(|s| &s.fields)
277            .filter(|f| match &f.ty {
278                syn::Type::Path(syn::TypePath {
279                    path: syn::Path { segments, .. },
280                    ..
281                }) => {
282                    segments.len() == 1 && segments[0].ident == "UncheckedAccount"
283                        || segments[0].ident == "AccountInfo"
284                }
285                _ => false,
286            })
287    }
288
289    fn enums(&self) -> impl Iterator<Item = &syn::ItemEnum> {
290        self.items.iter().filter_map(|i| match i {
291            syn::Item::Enum(item) => Some(item),
292            _ => None,
293        })
294    }
295
296    fn type_aliases(&self) -> impl Iterator<Item = &syn::ItemType> {
297        self.items.iter().filter_map(|i| match i {
298            syn::Item::Type(item) => Some(item),
299            _ => None,
300        })
301    }
302
303    fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
304        self.items.iter().filter_map(|i| match i {
305            syn::Item::Const(item) => Some(item),
306            _ => None,
307        })
308    }
309
310    fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &ImplItemConst)> {
311        self.items
312            .iter()
313            .filter_map(|i| match i {
314                syn::Item::Impl(syn::ItemImpl {
315                    self_ty: ty, items, ..
316                }) => {
317                    if let Type::Path(TypePath {
318                        qself: None,
319                        path: p,
320                    }) = ty.as_ref()
321                    {
322                        if let Some(ident) = p.get_ident() {
323                            let mut to_return = Vec::new();
324                            items.iter().for_each(|item| {
325                                if let ImplItem::Const(item) = item {
326                                    to_return.push((ident, item));
327                                }
328                            });
329                            Some(to_return)
330                        } else {
331                            None
332                        }
333                    } else {
334                        None
335                    }
336                }
337                _ => None,
338            })
339            .flatten()
340    }
341}