1use std::{collections::HashSet, env, fmt::Display, fs, path::Path};
2
3use itertools::Itertools;
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use proc_macro_error::{abort, abort_call_site, proc_macro_error};
7use quote::quote;
8use ra_ap_syntax::{
9 ast::{self, HasModuleItem, HasName, Type},
10 AstNode, NodeOrToken, SourceFile,
11};
12use syn::{
13 bracketed,
14 parse::{Parse, ParseStream},
15 parse_macro_input,
16 token::Comma,
17 Ident, LitStr, Token,
18};
19
20#[proc_macro]
21#[proc_macro_error]
22pub fn source_file(input: TokenStream) -> TokenStream {
23 let file: LitStr = parse_macro_input!(input);
24
25 doc_function_body(file, Ident::new("main", Span::call_site()), None)
26}
27
28#[proc_macro]
29#[proc_macro_error]
30pub fn function_body(input: TokenStream) -> TokenStream {
31 let args: FunctionBodyArgs = parse_macro_input!(input);
32 let function_body = args.function_body.to_string();
33 let mut dependencies = HashSet::new();
34
35 dependencies.extend(args.dependencies.iter().map(Ident::to_string));
36
37 if dependencies.contains(&function_body) {
38 abort_call_site!("Function body can't be in dependencies");
39 }
40
41 doc_function_body(args.file, args.function_body, Some(&dependencies))
42}
43
44struct FunctionBodyArgs {
45 file: LitStr,
46 function_body: Ident,
47 dependencies: Vec<Ident>,
48}
49
50impl Parse for FunctionBodyArgs {
51 fn parse(input: ParseStream) -> syn::Result<Self> {
52 let file = input.parse()?;
53 input.parse::<Comma>()?;
54 let function_body = input.parse()?;
55 input.parse::<Comma>()?;
56 let dependencies;
57 bracketed!(dependencies in input);
58 let dependencies = dependencies
59 .parse_terminated(Ident::parse, Token![,])?
60 .into_iter()
61 .collect();
62
63 Ok(Self {
64 file,
65 function_body,
66 dependencies,
67 })
68 }
69}
70
71fn doc_function_body(
72 file: LitStr,
73 function_body_ident: Ident,
74 deps: Option<&HashSet<String>>,
75) -> TokenStream {
76 let source = parse_file(&file);
77
78 let mut found_body = false;
79 let function_body = function_body_ident.to_string();
80 let mut track_deps = HashSet::new();
81
82 let parts = source.items().filter_map(|item| match item {
83 ast::Item::Use(use_item) => Some(hide_in_doc(use_item)),
84 ast::Item::Fn(function) => function.name().and_then(|name| {
85 let name = name.text();
86
87 if name.as_str() == function_body {
88 found_body = true;
89 extract_function_body(&function)
90 } else if is_dependency(&name, deps, &mut track_deps) {
91 include_always(&function)
92 } else {
93 None
94 }
95 }),
96 ast::Item::Const(item) => include_if_dependency(&item, deps, &mut track_deps),
97 ast::Item::Enum(item) => include_if_dependency(&item, deps, &mut track_deps),
98 ast::Item::ExternBlock(item) => include_always(&item),
99 ast::Item::ExternCrate(item) => include_always(&item),
100 ast::Item::Impl(item) => {
101 if is_type_dependency(&item.self_ty(), deps, &mut track_deps)
102 || is_type_dependency(&item.trait_(), deps, &mut track_deps)
103 {
104 include_always(&item)
105 } else {
106 None
107 }
108 }
109 ast::Item::MacroCall(item) => include_always(&item),
110 ast::Item::MacroRules(item) => include_if_dependency(&item, deps, &mut track_deps),
111 ast::Item::MacroDef(item) => include_if_dependency(&item, deps, &mut track_deps),
112 ast::Item::Module(item) => include_if_dependency(&item, deps, &mut track_deps),
113 ast::Item::Static(item) => include_if_dependency(&item, deps, &mut track_deps),
114 ast::Item::Struct(item) => include_if_dependency(&item, deps, &mut track_deps),
115 ast::Item::Trait(item) => include_if_dependency(&item, deps, &mut track_deps),
116 ast::Item::TypeAlias(item) => include_if_dependency(&item, deps, &mut track_deps),
117 ast::Item::Union(item) => include_if_dependency(&item, deps, &mut track_deps),
118 ast::Item::TraitAlias(item) => include_if_dependency(&item, deps, &mut track_deps),
119 });
120
121 let doc = parts.collect::<Vec<String>>().join("\n");
122
123 if let Some(deps) = deps {
124 let missing_deps = deps.difference(&track_deps).join(", ");
125
126 if !missing_deps.is_empty() {
127 abort_call_site!("Not all dependencies were found: [{}]", missing_deps);
128 }
129 }
130
131 if !found_body {
132 abort!(function_body_ident, "{} not found", function_body);
133 }
134
135 quote!(#doc).into()
136}
137
138fn include_always<T: Display>(node: &T) -> Option<String> {
139 Some(format!("{node}\n"))
140}
141
142fn include_if_dependency<T: HasName + Display>(
143 node: &T,
144 dependencies: Option<&HashSet<String>>,
145 dependency_tracker: &mut HashSet<String>,
146) -> Option<String> {
147 node.name().and_then(|name| {
148 let name = name.text();
149
150 if is_dependency(&name, dependencies, dependency_tracker) {
151 Some(format!("{node}\n"))
152 } else {
153 None
154 }
155 })
156}
157
158fn is_type_dependency(
159 ty: &Option<Type>,
160 dependencies: Option<&HashSet<String>>,
161 dependency_tracker: &mut HashSet<String>,
162) -> bool {
163 let Some(ty) = ty else {
164 return false;
165 };
166
167 ty.syntax()
168 .descendants_with_tokens()
169 .any(|token| match token {
170 NodeOrToken::Node(_) => false,
171 NodeOrToken::Token(token) => {
172 is_dependency(token.text(), dependencies, dependency_tracker)
173 }
174 })
175}
176
177fn is_dependency(
178 name: impl AsRef<str>,
179 dependencies: Option<&HashSet<String>>,
180 dependency_tracker: &mut HashSet<String>,
181) -> bool {
182 dependencies.map(|deps| {
183 let name = name.as_ref();
184 let is_dep = deps.contains(name);
185
186 if is_dep {
187 dependency_tracker.insert(name.to_string());
188 }
189
190 is_dep
191 }) != Some(false)
192}
193
194fn extract_function_body(function: &ast::Fn) -> Option<String> {
195 function.body().map(|body| {
196 if function.async_token().is_some() {
197 format!("async {body};\n")
198 } else {
199 remove_indent(
200 body.to_string()
201 .trim()
202 .trim_start_matches('{')
203 .trim_end_matches('}'),
204 ) + "\n"
205 }
206 })
207}
208
209fn remove_indent(text: &str) -> String {
210 let min_indent = text.lines().filter_map(indent_size).min().unwrap_or(0);
211
212 text.lines()
213 .map(|line| {
214 if line.len() > min_indent {
215 &line[min_indent..]
216 } else {
217 ""
218 }
219 })
220 .join("\n")
221 .trim_matches('\n')
222 .to_string()
223}
224
225fn indent_size(text: &str) -> Option<usize> {
226 if text.trim().is_empty() {
227 None
228 } else {
229 text.find(|c: char| c != ' ' && c != '\t')
230 }
231}
232
233fn parse_file(file_expr: &LitStr) -> SourceFile {
234 let source_code = read_file(file_expr);
235 let parse = SourceFile::parse(&source_code);
236 let source = parse.tree();
237
238 if !parse.errors().is_empty() {
239 abort!(file_expr, "Errors in source file");
240 }
241
242 source
243}
244
245fn read_file(file_expr: &LitStr) -> String {
246 let file = file_expr.value();
247
248 let dir = env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|e| abort_call_site!(e));
249 let path = Path::new(&dir).join(file);
250 fs::read_to_string(path).unwrap_or_else(|e| abort!(file_expr, e))
251}
252
253fn hide_in_doc(item: impl Display) -> String {
254 format!("# {}\n", item.to_string().lines().format("")) + "#"
262}