1use {
2 anyhow::{anyhow, Result},
3 std::{
4 collections::BTreeMap,
5 path::{Path, PathBuf},
6 },
7 syn::{
8 parse::{Error as ParseError, Result as ParseResult},
9 Ident, ImplItem, ImplItemConst, Type, TypePath,
10 },
11};
12
13pub struct CrateContext {
17 modules: BTreeMap<String, ParsedModule>,
18}
19
20impl CrateContext {
21 pub fn parse(root: impl AsRef<Path>) -> Result<Self> {
22 Ok(CrateContext {
23 modules: ParsedModule::parse_recursive(root.as_ref())?,
24 })
25 }
26
27 pub fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
28 self.modules.iter().flat_map(|(_, ctx)| ctx.consts())
29 }
30
31 pub fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &syn::ImplItemConst)> {
32 self.modules.iter().flat_map(|(_, ctx)| ctx.impl_consts())
33 }
34
35 pub fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
36 self.modules.iter().flat_map(|(_, ctx)| ctx.structs())
37 }
38
39 pub fn enums(&self) -> impl Iterator<Item = &syn::ItemEnum> {
40 self.modules.iter().flat_map(|(_, ctx)| ctx.enums())
41 }
42
43 pub fn type_aliases(&self) -> impl Iterator<Item = &syn::ItemType> {
44 self.modules.iter().flat_map(|(_, ctx)| ctx.type_aliases())
45 }
46
47 pub fn modules(&self) -> impl Iterator<Item = ModuleContext<'_>> {
48 self.modules.values().map(|detail| ModuleContext { detail })
49 }
50
51 pub fn root_module(&self) -> ModuleContext<'_> {
52 #[allow(
53 clippy::unwrap_used,
54 reason = "\"crate\" module is always inserted during parsing"
55 )]
56 let detail = self.modules.get("crate").unwrap();
57 ModuleContext { detail }
58 }
59
60 pub fn safety_checks(&self) -> Result<()> {
62 for ctx in self.modules.values() {
64 for unsafe_field in ctx.unsafe_struct_fields() {
65 let is_documented = unsafe_field.attrs.iter().any(|attr| {
67 if let syn::Meta::NameValue(syn::MetaNameValue {
68 value:
69 syn::Expr::Lit(syn::ExprLit {
70 lit: syn::Lit::Str(s),
71 ..
72 }),
73 ..
74 }) = &attr.meta
75 {
76 s.value().contains("CHECK")
77 } else {
78 false
79 }
80 });
81 if !is_documented {
82 #[allow(
83 clippy::unwrap_used,
84 reason = "unsafe fields always have idents (named fields only)"
85 )]
86 let ident = unsafe_field.ident.as_ref().unwrap();
87 let span = ident.span();
88 #[allow(
90 clippy::unwrap_used,
91 reason = "file paths are always valid during compilation"
92 )]
93 let canonical = ctx.file.canonicalize().unwrap();
94 return Err(anyhow!(
95 r#"
96 {}:{}:{}
97 Struct field "{}" is unsafe, but is not documented.
98 Please add a `/// CHECK:` doc comment explaining why no checks through types are necessary.
99 Alternatively, for reasons like quick prototyping, you may disable the safety checks
100 by using the `skip-lint` option.
101 See https://www.anchor-lang.com/docs/basics/program-structure#account-validation for more information.
102 "#,
103 canonical.display(),
104 span.start().line,
105 span.start().column,
106 ident
107 ));
108 };
109 }
110 }
111 Ok(())
112 }
113}
114
115#[derive(Copy, Clone)]
119pub struct ModuleContext<'krate> {
120 detail: &'krate ParsedModule,
121}
122
123impl ModuleContext<'_> {
124 pub fn items(&self) -> impl Iterator<Item = &syn::Item> {
125 self.detail.items.iter()
126 }
127}
128struct ParsedModule {
129 name: String,
130 file: PathBuf,
131 path: String,
132 items: Vec<syn::Item>,
133}
134
135struct UnparsedModule {
136 file: PathBuf,
137 path: String,
138 name: String,
139 item: syn::ItemMod,
140}
141
142impl ParsedModule {
143 fn parse_recursive(root: &Path) -> Result<BTreeMap<String, ParsedModule>> {
144 let mut modules = BTreeMap::new();
145
146 let root_content = std::fs::read_to_string(root)?;
147 let root_file = syn::parse_file(&root_content)?;
148 let root_mod = Self::new(
149 String::new(),
150 root.to_owned(),
151 "crate".to_owned(),
152 root_file.items,
153 );
154
155 let mut unparsed = root_mod.unparsed_submodules();
156 while let Some(to_parse) = unparsed.pop() {
157 let path = format!("{}::{}", to_parse.path, to_parse.name);
158 let module = Self::from_item_mod(&to_parse.file, &path, to_parse.item)?;
159
160 unparsed.extend(module.unparsed_submodules());
161 modules.insert(format!("{}{}", module.path, to_parse.name), module);
162 }
163
164 modules.insert(root_mod.name.clone(), root_mod);
165
166 Ok(modules)
167 }
168
169 fn from_item_mod(
170 parent_file: &Path,
171 parent_path: &str,
172 item: syn::ItemMod,
173 ) -> ParseResult<Self> {
174 Ok(match item.content {
175 Some((_, items)) => {
176 Self::new(
178 parent_path.to_owned(),
179 parent_file.to_owned(),
180 item.ident.to_string(),
181 items,
182 )
183 }
184 None => {
185 #[allow(
188 clippy::unwrap_used,
189 reason = "file paths always have parent directories during compilation"
190 )]
191 let parent_dir = parent_file.parent().unwrap();
192 #[allow(
193 clippy::unwrap_used,
194 reason = "file stems are always valid UTF-8 Rust identifiers"
195 )]
196 let parent_filename = parent_file.file_stem().unwrap().to_str().unwrap();
197 let parent_mod_dir = parent_dir.join(parent_filename);
198
199 let possible_file_paths = vec![
200 parent_dir.join(format!("{}.rs", item.ident)),
201 parent_dir.join(format!("{}/mod.rs", item.ident)),
202 parent_mod_dir.join(format!("{}.rs", item.ident)),
203 parent_mod_dir.join(format!("{}/mod.rs", item.ident)),
204 ];
205
206 let mod_file_path = possible_file_paths
207 .into_iter()
208 .find(|p| p.exists())
209 .ok_or_else(|| ParseError::new_spanned(&item, "could not find file"))?;
210 let mod_file_content = std::fs::read_to_string(&mod_file_path)
211 .map_err(|_| ParseError::new_spanned(&item, "could not read file"))?;
212 let mod_file = syn::parse_file(&mod_file_content)?;
213
214 Self::new(
215 parent_path.to_owned(),
216 mod_file_path,
217 item.ident.to_string(),
218 mod_file.items,
219 )
220 }
221 })
222 }
223
224 fn new(path: String, file: PathBuf, name: String, items: Vec<syn::Item>) -> Self {
225 Self {
226 name,
227 file,
228 path,
229 items,
230 }
231 }
232
233 fn unparsed_submodules(&self) -> Vec<UnparsedModule> {
234 self.submodules()
235 .map(|item| UnparsedModule {
236 file: self.file.clone(),
237 path: self.path.clone(),
238 name: item.ident.to_string(),
239 item: item.clone(),
240 })
241 .collect()
242 }
243
244 fn submodules(&self) -> impl Iterator<Item = &syn::ItemMod> {
245 self.items.iter().filter_map(|i| match i {
246 syn::Item::Mod(item) => Some(item),
247 _ => None,
248 })
249 }
250
251 fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
252 self.items.iter().filter_map(|i| match i {
253 syn::Item::Struct(item) => Some(item),
254 _ => None,
255 })
256 }
257
258 fn unsafe_struct_fields(&self) -> impl Iterator<Item = &syn::Field> {
259 let accounts_filter = |item_struct: &&syn::ItemStruct| {
260 item_struct.attrs.iter().any(|attr| {
261 attr.path().is_ident("derive")
262 && attr
263 .parse_args_with(
264 syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated,
265 )
266 .ok()
267 .is_some_and(|args| {
268 args.iter()
269 .any(|m| matches!(m, syn::Meta::Path(p) if p.is_ident("Accounts")))
270 })
271 })
272 };
273
274 self.structs()
275 .filter(accounts_filter)
276 .flat_map(|s| &s.fields)
277 .filter(|f| match &f.ty {
278 syn::Type::Path(syn::TypePath {
279 path: syn::Path { segments, .. },
280 ..
281 }) => {
282 segments.len() == 1 && segments[0].ident == "UncheckedAccount"
283 || segments[0].ident == "AccountInfo"
284 }
285 _ => false,
286 })
287 }
288
289 fn enums(&self) -> impl Iterator<Item = &syn::ItemEnum> {
290 self.items.iter().filter_map(|i| match i {
291 syn::Item::Enum(item) => Some(item),
292 _ => None,
293 })
294 }
295
296 fn type_aliases(&self) -> impl Iterator<Item = &syn::ItemType> {
297 self.items.iter().filter_map(|i| match i {
298 syn::Item::Type(item) => Some(item),
299 _ => None,
300 })
301 }
302
303 fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
304 self.items.iter().filter_map(|i| match i {
305 syn::Item::Const(item) => Some(item),
306 _ => None,
307 })
308 }
309
310 fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &ImplItemConst)> {
311 self.items
312 .iter()
313 .filter_map(|i| match i {
314 syn::Item::Impl(syn::ItemImpl {
315 self_ty: ty, items, ..
316 }) => {
317 if let Type::Path(TypePath {
318 qself: None,
319 path: p,
320 }) = ty.as_ref()
321 {
322 if let Some(ident) = p.get_ident() {
323 let mut to_return = Vec::new();
324 items.iter().for_each(|item| {
325 if let ImplItem::Const(item) = item {
326 to_return.push((ident, item));
327 }
328 });
329 Some(to_return)
330 } else {
331 None
332 }
333 } else {
334 None
335 }
336 }
337 _ => None,
338 })
339 .flatten()
340 }
341}