1use std::{collections::HashSet, env, fmt::Display, fs, path::Path};
2
3use error::Tokens;
4use itertools::Itertools;
5use proc_macro::TokenStream;
6use proc_macro2::Span;
7use quote::quote;
8use ra_ap_syntax::{
9 ast::{self, HasModuleItem, HasName, Type},
10 AstNode, NodeOrToken, SourceFile,
11};
12use syn::{
13 bracketed,
14 parse::{Parse, ParseStream},
15 parse_macro_input,
16 token::Comma,
17 Error, Ident, LitStr, Result, Token,
18};
19
20mod error;
21
22macro_rules! call_site_error {
23 ($($message:tt)*) => {
24 Err($crate::error::call_site(format!($($message)*)))
25 }
26}
27
28macro_rules! error {
29 ($span:expr, $($message:tt)*) => {
30 Err($crate::Error::new($span, format!($($message)*)))
31 }
32}
33
34#[proc_macro]
35pub fn source_file(input: TokenStream) -> TokenStream {
36 let file: LitStr = parse_macro_input!(input);
37
38 doc_function_body(file, Ident::new("main", Span::call_site()), None).tokens()
39}
40
41#[proc_macro]
42pub fn function_body(input: TokenStream) -> TokenStream {
43 let args: FunctionBodyArgs = parse_macro_input!(input);
44 let function_body = args.function_body.to_string();
45 let mut dependencies = HashSet::new();
46
47 dependencies.extend(args.dependencies.iter().map(Ident::to_string));
48
49 if dependencies.contains(&function_body) {
50 return Error::new(
51 args.function_body.span(),
52 "Function body can't be in dependencies",
53 )
54 .into_compile_error()
55 .into();
56 }
57
58 doc_function_body(args.file, args.function_body, Some(&dependencies)).tokens()
59}
60
61struct FunctionBodyArgs {
62 file: LitStr,
63 function_body: Ident,
64 dependencies: Vec<Ident>,
65}
66
67impl Parse for FunctionBodyArgs {
68 fn parse(input: ParseStream) -> Result<Self> {
69 let file = input.parse()?;
70 input.parse::<Comma>()?;
71 let function_body = input.parse()?;
72 input.parse::<Comma>()?;
73 let dependencies;
74 bracketed!(dependencies in input);
75 let dependencies = dependencies
76 .parse_terminated(Ident::parse, Token![,])?
77 .into_iter()
78 .collect();
79
80 Ok(Self {
81 file,
82 function_body,
83 dependencies,
84 })
85 }
86}
87
88fn doc_function_body(
89 file: LitStr,
90 function_body_ident: Ident,
91 deps: Option<&HashSet<String>>,
92) -> Result<proc_macro2::TokenStream> {
93 let source = parse_file(&file)?;
94
95 let mut found_body = false;
96 let function_body = function_body_ident.to_string();
97 let mut track_deps = HashSet::new();
98
99 let parts = source.items().filter_map(|item| match item {
100 ast::Item::Use(use_item) => Some(hide_in_doc(use_item)),
101 ast::Item::Fn(function) => function.name().and_then(|name| {
102 let name = name.text();
103
104 if name.as_str() == function_body {
105 found_body = true;
106 extract_function_body(&function)
107 } else if is_dependency(&name, deps, &mut track_deps) {
108 include_always(&function)
109 } else {
110 None
111 }
112 }),
113 ast::Item::Const(item) => include_if_dependency(&item, deps, &mut track_deps),
114 ast::Item::Enum(item) => include_if_dependency(&item, deps, &mut track_deps),
115 ast::Item::ExternBlock(item) => include_always(&item),
116 ast::Item::ExternCrate(item) => include_always(&item),
117 ast::Item::Impl(item) => {
118 if is_type_dependency(&item.self_ty(), deps, &mut track_deps)
119 || is_type_dependency(&item.trait_(), deps, &mut track_deps)
120 {
121 include_always(&item)
122 } else {
123 None
124 }
125 }
126 ast::Item::MacroCall(item) => include_always(&item),
127 ast::Item::MacroRules(item) => include_if_dependency(&item, deps, &mut track_deps),
128 ast::Item::MacroDef(item) => include_if_dependency(&item, deps, &mut track_deps),
129 ast::Item::Module(item) => include_if_dependency(&item, deps, &mut track_deps),
130 ast::Item::Static(item) => include_if_dependency(&item, deps, &mut track_deps),
131 ast::Item::Struct(item) => include_if_dependency(&item, deps, &mut track_deps),
132 ast::Item::Trait(item) => include_if_dependency(&item, deps, &mut track_deps),
133 ast::Item::TypeAlias(item) => include_if_dependency(&item, deps, &mut track_deps),
134 ast::Item::Union(item) => include_if_dependency(&item, deps, &mut track_deps),
135 ast::Item::TraitAlias(item) => include_if_dependency(&item, deps, &mut track_deps),
136 });
137
138 let doc = parts.collect::<Vec<String>>().join("\n");
139
140 if let Some(deps) = deps {
141 let missing_deps = deps.difference(&track_deps).join(", ");
142
143 if !missing_deps.is_empty() {
144 call_site_error!("Not all dependencies were found: [{missing_deps}]")?;
145 }
146 }
147
148 if !found_body {
149 error!(function_body_ident.span(), "{function_body} not found")?;
150 }
151
152 Ok(quote!(#doc))
153}
154
155fn include_always<T: Display>(node: &T) -> Option<String> {
156 Some(format!("{node}\n"))
157}
158
159fn include_if_dependency<T: HasName + Display>(
160 node: &T,
161 dependencies: Option<&HashSet<String>>,
162 dependency_tracker: &mut HashSet<String>,
163) -> Option<String> {
164 node.name().and_then(|name| {
165 let name = name.text();
166
167 if is_dependency(&name, dependencies, dependency_tracker) {
168 Some(format!("{node}\n"))
169 } else {
170 None
171 }
172 })
173}
174
175fn is_type_dependency(
176 ty: &Option<Type>,
177 dependencies: Option<&HashSet<String>>,
178 dependency_tracker: &mut HashSet<String>,
179) -> bool {
180 let Some(ty) = ty else {
181 return false;
182 };
183
184 ty.syntax()
185 .descendants_with_tokens()
186 .any(|token| match token {
187 NodeOrToken::Node(_) => false,
188 NodeOrToken::Token(token) => {
189 is_dependency(token.text(), dependencies, dependency_tracker)
190 }
191 })
192}
193
194fn is_dependency(
195 name: impl AsRef<str>,
196 dependencies: Option<&HashSet<String>>,
197 dependency_tracker: &mut HashSet<String>,
198) -> bool {
199 dependencies.map(|deps| {
200 let name = name.as_ref();
201 let is_dep = deps.contains(name);
202
203 if is_dep {
204 dependency_tracker.insert(name.to_string());
205 }
206
207 is_dep
208 }) != Some(false)
209}
210
211fn extract_function_body(function: &ast::Fn) -> Option<String> {
212 function.body().map(|body| {
213 if function.async_token().is_some() {
214 format!("async {body};\n")
215 } else {
216 remove_indent(
217 body.to_string()
218 .trim()
219 .trim_start_matches('{')
220 .trim_end_matches('}'),
221 ) + "\n"
222 }
223 })
224}
225
226fn remove_indent(text: &str) -> String {
227 let min_indent = text.lines().filter_map(indent_size).min().unwrap_or(0);
228
229 text.lines()
230 .map(|line| {
231 if line.len() > min_indent {
232 &line[min_indent..]
233 } else {
234 ""
235 }
236 })
237 .join("\n")
238 .trim_matches('\n')
239 .to_string()
240}
241
242fn indent_size(text: &str) -> Option<usize> {
243 if text.trim().is_empty() {
244 None
245 } else {
246 text.find(|c: char| c != ' ' && c != '\t')
247 }
248}
249
250fn parse_file(file_expr: &LitStr) -> Result<SourceFile> {
251 let source_code = read_file(file_expr)?;
252 let parse = SourceFile::parse(&source_code);
253 let source = parse.tree();
254
255 if !parse.errors().is_empty() {
256 error!(file_expr.span(), "Errors in source file")?;
257 }
258
259 Ok(source)
260}
261
262fn read_file(file_expr: &LitStr) -> Result<String> {
263 let file = file_expr.value();
264
265 let dir = env::var("CARGO_MANIFEST_DIR").map_err(error::call_site)?;
266 let path = Path::new(&dir).join(file);
267 fs::read_to_string(path).map_err(|e| Error::new(file_expr.span(), e))
268}
269
270fn hide_in_doc(item: impl Display) -> String {
271 format!("# {}\n", item.to_string().lines().format("")) + "#"
279}