include_doc_macro/
lib.rs

1use std::{collections::HashSet, env, fmt::Display, fs, path::Path};
2
3use itertools::Itertools;
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use proc_macro_error::{abort, abort_call_site, proc_macro_error};
7use quote::quote;
8use ra_ap_syntax::{
9    ast::{self, HasModuleItem, HasName, Type},
10    AstNode, NodeOrToken, SourceFile,
11};
12use syn::{
13    bracketed,
14    parse::{Parse, ParseStream},
15    parse_macro_input,
16    token::Comma,
17    Ident, LitStr, Token,
18};
19
20#[proc_macro]
21#[proc_macro_error]
22pub fn source_file(input: TokenStream) -> TokenStream {
23    let file: LitStr = parse_macro_input!(input);
24
25    doc_function_body(file, Ident::new("main", Span::call_site()), None)
26}
27
28#[proc_macro]
29#[proc_macro_error]
30pub fn function_body(input: TokenStream) -> TokenStream {
31    let args: FunctionBodyArgs = parse_macro_input!(input);
32    let function_body = args.function_body.to_string();
33    let mut dependencies = HashSet::new();
34
35    dependencies.extend(args.dependencies.iter().map(Ident::to_string));
36
37    if dependencies.contains(&function_body) {
38        abort_call_site!("Function body can't be in dependencies");
39    }
40
41    doc_function_body(args.file, args.function_body, Some(&dependencies))
42}
43
44struct FunctionBodyArgs {
45    file: LitStr,
46    function_body: Ident,
47    dependencies: Vec<Ident>,
48}
49
50impl Parse for FunctionBodyArgs {
51    fn parse(input: ParseStream) -> syn::Result<Self> {
52        let file = input.parse()?;
53        input.parse::<Comma>()?;
54        let function_body = input.parse()?;
55        input.parse::<Comma>()?;
56        let dependencies;
57        bracketed!(dependencies in input);
58        let dependencies = dependencies
59            .parse_terminated(Ident::parse, Token![,])?
60            .into_iter()
61            .collect();
62
63        Ok(Self {
64            file,
65            function_body,
66            dependencies,
67        })
68    }
69}
70
71fn doc_function_body(
72    file: LitStr,
73    function_body_ident: Ident,
74    deps: Option<&HashSet<String>>,
75) -> TokenStream {
76    let source = parse_file(&file);
77
78    let mut found_body = false;
79    let function_body = function_body_ident.to_string();
80    let mut track_deps = HashSet::new();
81
82    let parts = source.items().filter_map(|item| match item {
83        ast::Item::Use(use_item) => Some(hide_in_doc(use_item)),
84        ast::Item::Fn(function) => function.name().and_then(|name| {
85            let name = name.text();
86
87            if name.as_str() == function_body {
88                found_body = true;
89                extract_function_body(&function)
90            } else if is_dependency(&name, deps, &mut track_deps) {
91                include_always(&function)
92            } else {
93                None
94            }
95        }),
96        ast::Item::Const(item) => include_if_dependency(&item, deps, &mut track_deps),
97        ast::Item::Enum(item) => include_if_dependency(&item, deps, &mut track_deps),
98        ast::Item::ExternBlock(item) => include_always(&item),
99        ast::Item::ExternCrate(item) => include_always(&item),
100        ast::Item::Impl(item) => {
101            if is_type_dependency(&item.self_ty(), deps, &mut track_deps)
102                || is_type_dependency(&item.trait_(), deps, &mut track_deps)
103            {
104                include_always(&item)
105            } else {
106                None
107            }
108        }
109        ast::Item::MacroCall(item) => include_always(&item),
110        ast::Item::MacroRules(item) => include_if_dependency(&item, deps, &mut track_deps),
111        ast::Item::MacroDef(item) => include_if_dependency(&item, deps, &mut track_deps),
112        ast::Item::Module(item) => include_if_dependency(&item, deps, &mut track_deps),
113        ast::Item::Static(item) => include_if_dependency(&item, deps, &mut track_deps),
114        ast::Item::Struct(item) => include_if_dependency(&item, deps, &mut track_deps),
115        ast::Item::Trait(item) => include_if_dependency(&item, deps, &mut track_deps),
116        ast::Item::TypeAlias(item) => include_if_dependency(&item, deps, &mut track_deps),
117        ast::Item::Union(item) => include_if_dependency(&item, deps, &mut track_deps),
118        ast::Item::TraitAlias(item) => include_if_dependency(&item, deps, &mut track_deps),
119    });
120
121    let doc = parts.collect::<Vec<String>>().join("\n");
122
123    if let Some(deps) = deps {
124        let missing_deps = deps.difference(&track_deps).join(", ");
125
126        if !missing_deps.is_empty() {
127            abort_call_site!("Not all dependencies were found: [{}]", missing_deps);
128        }
129    }
130
131    if !found_body {
132        abort!(function_body_ident, "{} not found", function_body);
133    }
134
135    quote!(#doc).into()
136}
137
138fn include_always<T: Display>(node: &T) -> Option<String> {
139    Some(format!("{node}\n"))
140}
141
142fn include_if_dependency<T: HasName + Display>(
143    node: &T,
144    dependencies: Option<&HashSet<String>>,
145    dependency_tracker: &mut HashSet<String>,
146) -> Option<String> {
147    node.name().and_then(|name| {
148        let name = name.text();
149
150        if is_dependency(&name, dependencies, dependency_tracker) {
151            Some(format!("{node}\n"))
152        } else {
153            None
154        }
155    })
156}
157
158fn is_type_dependency(
159    ty: &Option<Type>,
160    dependencies: Option<&HashSet<String>>,
161    dependency_tracker: &mut HashSet<String>,
162) -> bool {
163    let Some(ty) = ty else {
164        return false;
165    };
166
167    ty.syntax()
168        .descendants_with_tokens()
169        .any(|token| match token {
170            NodeOrToken::Node(_) => false,
171            NodeOrToken::Token(token) => {
172                is_dependency(token.text(), dependencies, dependency_tracker)
173            }
174        })
175}
176
177fn is_dependency(
178    name: impl AsRef<str>,
179    dependencies: Option<&HashSet<String>>,
180    dependency_tracker: &mut HashSet<String>,
181) -> bool {
182    dependencies.map(|deps| {
183        let name = name.as_ref();
184        let is_dep = deps.contains(name);
185
186        if is_dep {
187            dependency_tracker.insert(name.to_string());
188        }
189
190        is_dep
191    }) != Some(false)
192}
193
194fn extract_function_body(function: &ast::Fn) -> Option<String> {
195    function.body().map(|body| {
196        if function.async_token().is_some() {
197            format!("async {body};\n")
198        } else {
199            remove_indent(
200                body.to_string()
201                    .trim()
202                    .trim_start_matches('{')
203                    .trim_end_matches('}'),
204            ) + "\n"
205        }
206    })
207}
208
209fn remove_indent(text: &str) -> String {
210    let min_indent = text.lines().filter_map(indent_size).min().unwrap_or(0);
211
212    text.lines()
213        .map(|line| {
214            if line.len() > min_indent {
215                &line[min_indent..]
216            } else {
217                ""
218            }
219        })
220        .join("\n")
221        .trim_matches('\n')
222        .to_string()
223}
224
225fn indent_size(text: &str) -> Option<usize> {
226    if text.trim().is_empty() {
227        None
228    } else {
229        text.find(|c: char| c != ' ' && c != '\t')
230    }
231}
232
233fn parse_file(file_expr: &LitStr) -> SourceFile {
234    let source_code = read_file(file_expr);
235    let parse = SourceFile::parse(&source_code);
236    let source = parse.tree();
237
238    if !parse.errors().is_empty() {
239        abort!(file_expr, "Errors in source file");
240    }
241
242    source
243}
244
245fn read_file(file_expr: &LitStr) -> String {
246    let file = file_expr.value();
247
248    let dir = env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|e| abort_call_site!(e));
249    let path = Path::new(&dir).join(file);
250    fs::read_to_string(path).unwrap_or_else(|e| abort!(file_expr, e))
251}
252
253fn hide_in_doc(item: impl Display) -> String {
254    // We need the extra `"\n#"` as otherwise rustdoc won't include attributes after
255    // hidden items. e.g.
256    //
257    // ```
258    // # use blah
259    // #[attribute_will_also_be_hidden]
260    // ```
261    format!("# {}\n", item.to_string().lines().format("")) + "#"
262}