include_doc_macro/
lib.rs

1use std::{collections::HashSet, env, fmt::Display, fs, path::Path};
2
3use error::Tokens;
4use itertools::Itertools;
5use proc_macro::TokenStream;
6use proc_macro2::Span;
7use quote::quote;
8use ra_ap_syntax::{
9    ast::{self, HasModuleItem, HasName, Type},
10    AstNode, NodeOrToken, SourceFile,
11};
12use syn::{
13    bracketed,
14    parse::{Parse, ParseStream},
15    parse_macro_input,
16    token::Comma,
17    Error, Ident, LitStr, Result, Token,
18};
19
20mod error;
21
22macro_rules! call_site_error {
23    ($($message:tt)*) => {
24        Err($crate::error::call_site(format!($($message)*)))
25    }
26}
27
28macro_rules! error {
29    ($span:expr, $($message:tt)*) => {
30        Err($crate::Error::new($span, format!($($message)*)))
31    }
32}
33
34#[proc_macro]
35pub fn source_file(input: TokenStream) -> TokenStream {
36    let file: LitStr = parse_macro_input!(input);
37
38    doc_function_body(file, Ident::new("main", Span::call_site()), None).tokens()
39}
40
41#[proc_macro]
42pub fn function_body(input: TokenStream) -> TokenStream {
43    let args: FunctionBodyArgs = parse_macro_input!(input);
44    let function_body = args.function_body.to_string();
45    let mut dependencies = HashSet::new();
46
47    dependencies.extend(args.dependencies.iter().map(Ident::to_string));
48
49    if dependencies.contains(&function_body) {
50        return Error::new(
51            args.function_body.span(),
52            "Function body can't be in dependencies",
53        )
54        .into_compile_error()
55        .into();
56    }
57
58    doc_function_body(args.file, args.function_body, Some(&dependencies)).tokens()
59}
60
61struct FunctionBodyArgs {
62    file: LitStr,
63    function_body: Ident,
64    dependencies: Vec<Ident>,
65}
66
67impl Parse for FunctionBodyArgs {
68    fn parse(input: ParseStream) -> Result<Self> {
69        let file = input.parse()?;
70        input.parse::<Comma>()?;
71        let function_body = input.parse()?;
72        input.parse::<Comma>()?;
73        let dependencies;
74        bracketed!(dependencies in input);
75        let dependencies = dependencies
76            .parse_terminated(Ident::parse, Token![,])?
77            .into_iter()
78            .collect();
79
80        Ok(Self {
81            file,
82            function_body,
83            dependencies,
84        })
85    }
86}
87
88fn doc_function_body(
89    file: LitStr,
90    function_body_ident: Ident,
91    deps: Option<&HashSet<String>>,
92) -> Result<proc_macro2::TokenStream> {
93    let source = parse_file(&file)?;
94
95    let mut found_body = false;
96    let function_body = function_body_ident.to_string();
97    let mut track_deps = HashSet::new();
98
99    let parts = source.items().filter_map(|item| match item {
100        ast::Item::Use(use_item) => Some(hide_in_doc(use_item)),
101        ast::Item::Fn(function) => function.name().and_then(|name| {
102            let name = name.text();
103
104            if name.as_str() == function_body {
105                found_body = true;
106                extract_function_body(&function)
107            } else if is_dependency(&name, deps, &mut track_deps) {
108                include_always(&function)
109            } else {
110                None
111            }
112        }),
113        ast::Item::Const(item) => include_if_dependency(&item, deps, &mut track_deps),
114        ast::Item::Enum(item) => include_if_dependency(&item, deps, &mut track_deps),
115        ast::Item::ExternBlock(item) => include_always(&item),
116        ast::Item::ExternCrate(item) => include_always(&item),
117        ast::Item::Impl(item) => {
118            if is_type_dependency(&item.self_ty(), deps, &mut track_deps)
119                || is_type_dependency(&item.trait_(), deps, &mut track_deps)
120            {
121                include_always(&item)
122            } else {
123                None
124            }
125        }
126        ast::Item::MacroCall(item) => include_always(&item),
127        ast::Item::MacroRules(item) => include_if_dependency(&item, deps, &mut track_deps),
128        ast::Item::MacroDef(item) => include_if_dependency(&item, deps, &mut track_deps),
129        ast::Item::Module(item) => include_if_dependency(&item, deps, &mut track_deps),
130        ast::Item::Static(item) => include_if_dependency(&item, deps, &mut track_deps),
131        ast::Item::Struct(item) => include_if_dependency(&item, deps, &mut track_deps),
132        ast::Item::Trait(item) => include_if_dependency(&item, deps, &mut track_deps),
133        ast::Item::TypeAlias(item) => include_if_dependency(&item, deps, &mut track_deps),
134        ast::Item::Union(item) => include_if_dependency(&item, deps, &mut track_deps),
135        ast::Item::TraitAlias(item) => include_if_dependency(&item, deps, &mut track_deps),
136    });
137
138    let doc = parts.collect::<Vec<String>>().join("\n");
139
140    if let Some(deps) = deps {
141        let missing_deps = deps.difference(&track_deps).join(", ");
142
143        if !missing_deps.is_empty() {
144            call_site_error!("Not all dependencies were found: [{missing_deps}]")?;
145        }
146    }
147
148    if !found_body {
149        error!(function_body_ident.span(), "{function_body} not found")?;
150    }
151
152    Ok(quote!(#doc))
153}
154
155fn include_always<T: Display>(node: &T) -> Option<String> {
156    Some(format!("{node}\n"))
157}
158
159fn include_if_dependency<T: HasName + Display>(
160    node: &T,
161    dependencies: Option<&HashSet<String>>,
162    dependency_tracker: &mut HashSet<String>,
163) -> Option<String> {
164    node.name().and_then(|name| {
165        let name = name.text();
166
167        if is_dependency(&name, dependencies, dependency_tracker) {
168            Some(format!("{node}\n"))
169        } else {
170            None
171        }
172    })
173}
174
175fn is_type_dependency(
176    ty: &Option<Type>,
177    dependencies: Option<&HashSet<String>>,
178    dependency_tracker: &mut HashSet<String>,
179) -> bool {
180    let Some(ty) = ty else {
181        return false;
182    };
183
184    ty.syntax()
185        .descendants_with_tokens()
186        .any(|token| match token {
187            NodeOrToken::Node(_) => false,
188            NodeOrToken::Token(token) => {
189                is_dependency(token.text(), dependencies, dependency_tracker)
190            }
191        })
192}
193
194fn is_dependency(
195    name: impl AsRef<str>,
196    dependencies: Option<&HashSet<String>>,
197    dependency_tracker: &mut HashSet<String>,
198) -> bool {
199    dependencies.map(|deps| {
200        let name = name.as_ref();
201        let is_dep = deps.contains(name);
202
203        if is_dep {
204            dependency_tracker.insert(name.to_string());
205        }
206
207        is_dep
208    }) != Some(false)
209}
210
211fn extract_function_body(function: &ast::Fn) -> Option<String> {
212    function.body().map(|body| {
213        if function.async_token().is_some() {
214            format!("async {body};\n")
215        } else {
216            remove_indent(
217                body.to_string()
218                    .trim()
219                    .trim_start_matches('{')
220                    .trim_end_matches('}'),
221            ) + "\n"
222        }
223    })
224}
225
226fn remove_indent(text: &str) -> String {
227    let min_indent = text.lines().filter_map(indent_size).min().unwrap_or(0);
228
229    text.lines()
230        .map(|line| {
231            if line.len() > min_indent {
232                &line[min_indent..]
233            } else {
234                ""
235            }
236        })
237        .join("\n")
238        .trim_matches('\n')
239        .to_string()
240}
241
242fn indent_size(text: &str) -> Option<usize> {
243    if text.trim().is_empty() {
244        None
245    } else {
246        text.find(|c: char| c != ' ' && c != '\t')
247    }
248}
249
250fn parse_file(file_expr: &LitStr) -> Result<SourceFile> {
251    let source_code = read_file(file_expr)?;
252    let parse = SourceFile::parse(&source_code);
253    let source = parse.tree();
254
255    if !parse.errors().is_empty() {
256        error!(file_expr.span(), "Errors in source file")?;
257    }
258
259    Ok(source)
260}
261
262fn read_file(file_expr: &LitStr) -> Result<String> {
263    let file = file_expr.value();
264
265    let dir = env::var("CARGO_MANIFEST_DIR").map_err(error::call_site)?;
266    let path = Path::new(&dir).join(file);
267    fs::read_to_string(path).map_err(|e| Error::new(file_expr.span(), e))
268}
269
270fn hide_in_doc(item: impl Display) -> String {
271    // We need the extra `"\n#"` as otherwise rustdoc won't include attributes after
272    // hidden items. e.g.
273    //
274    // ```
275    // # use blah
276    // #[attribute_will_also_be_hidden]
277    // ```
278    format!("# {}\n", item.to_string().lines().format("")) + "#"
279}