Skip to main content

mdbook_rust/
lib.rs

1use std::{cmp::min, collections::VecDeque, fmt::Display};
2
3use anyhow::{bail, Result};
4use itertools::Itertools;
5use ra_ap_syntax::{
6    AstNode, AstToken, Edition, NodeOrToken, SourceFile, SyntaxKind, SyntaxNode, SyntaxToken, ast::{self, HasModuleItem, HasName, Item}
7};
8
9pub fn write_module(source_text: &str) -> Result<Option<String>> {
10    let source = parse_module(source_text)?;
11
12    for item in source.items() {
13        if let Item::Fn(function) = item {
14            if is_named(&function, "body") {
15                if let Some(new_content) = write_function(function)? {
16                    return Ok(Some(new_content));
17                }
18            }
19        }
20    }
21
22    Ok(None)
23}
24
25fn write_function(function: ast::Fn) -> Result<Option<String>> {
26    if let Some(stmts) = function.body().and_then(|body| body.stmt_list()) {
27        let mut stmts: VecDeque<_> = stmts.syntax().children_with_tokens().collect();
28
29        expect_kind(SyntaxKind::L_CURLY, stmts.pop_front())?;
30        expect_kind(SyntaxKind::R_CURLY, stmts.pop_back())?;
31
32        let body_text = stmts.iter().map(|s| s.to_string()).collect::<String>();
33        let ws_prefixes = body_text.lines().filter_map(whitespace_prefix);
34        let longest_prefix = longest_prefix(ws_prefixes);
35
36        if stmts
37            .front()
38            .and_then(|node| node.as_token())
39            .is_some_and(|token| ast::Whitespace::can_cast(token.kind()))
40        {
41            stmts.pop_front();
42        }
43
44        Ok(Some(write_body(stmts, longest_prefix)))
45    } else {
46        Ok(None)
47    }
48}
49
50fn write_body(
51    stmts: impl IntoIterator<Item = NodeOrToken<SyntaxNode, SyntaxToken>>,
52    longest_prefix: &str,
53) -> String {
54    let mut whitespace = String::new();
55    let mut in_code_block = false;
56    let mut output = String::new();
57
58    for node in stmts {
59        write_node_or_token(
60            &mut output,
61            &mut in_code_block,
62            &mut whitespace,
63            node,
64            longest_prefix,
65        );
66    }
67
68    if in_code_block {
69        output.push_str("\n```");
70    }
71
72    output.push('\n');
73
74    output
75}
76
77fn write_node_or_token(
78    output: &mut String,
79    in_code_block: &mut bool,
80    whitespace: &mut String,
81    node: NodeOrToken<SyntaxNode, SyntaxToken>,
82    longest_prefix: &str,
83) {
84    match &node {
85        NodeOrToken::Node(node) => {
86            let mut children = node.children_with_tokens();
87
88            // `Fn` nodes will have comments associated with them, rather than the parent.
89            // We want to include these comments as markdown.
90            for child in children.by_ref() {
91                if child.kind() == SyntaxKind::COMMENT || child.kind() == SyntaxKind::WHITESPACE {
92                    write_node_or_token(output, in_code_block, whitespace, child, longest_prefix);
93                } else {
94                    output.push_str(ensure_in_code_block(in_code_block, whitespace));
95                    output.push_str(&write_lines(child, longest_prefix));
96                    break;
97                }
98            }
99
100            for child in children {
101                output.push_str(&write_lines(child, longest_prefix));
102            }
103
104            whitespace.clear();
105        }
106        NodeOrToken::Token(token) => {
107            write_token(output, in_code_block, whitespace, token, longest_prefix);
108        }
109    }
110}
111
112fn write_token(
113    output: &mut String,
114    in_code_block: &mut bool,
115    whitespace: &mut String,
116    token: &SyntaxToken,
117    longest_prefix: &str,
118) {
119    if let Some(comment) = ast::Comment::cast(token.clone()) {
120        if comment.is_doc() {
121            output.push_str(ensure_in_code_block(in_code_block, &*whitespace));
122            output.push_str(&write_lines(comment, longest_prefix));
123        } else {
124            output.push_str(ensure_in_markdown(in_code_block, &*whitespace));
125            output.push_str(&write_comment(comment, longest_prefix));
126        }
127
128        whitespace.clear();
129    } else if ast::Whitespace::can_cast(token.kind()) {
130        *whitespace = "\n".repeat(token.to_string().chars().filter(|c| *c == '\n').count())
131    } else {
132        output.push_str(&*whitespace);
133        output.push_str(&write_lines(token, longest_prefix));
134        whitespace.clear();
135    }
136}
137
138fn write_lines(text: impl Display, prefix: &str) -> String {
139    text.to_string()
140        .split('\n')
141        .map(|line| line.strip_prefix(prefix).unwrap_or(line))
142        .join("\n")
143}
144
145fn write_comment(comment: ast::Comment, prefix: &str) -> String {
146    let comment_suffix = &comment.text()[comment.prefix().len()..];
147    let comment_text = match comment.kind().shape {
148        ast::CommentShape::Line => comment_suffix,
149        ast::CommentShape::Block => comment_suffix.strip_suffix("*/").unwrap_or(comment_suffix),
150    };
151
152    let mut lines = comment_text.split('\n');
153    let mut output = String::new();
154
155    if let Some(first_line) = lines.next() {
156        output.push_str(first_line.strip_prefix(' ').unwrap_or(first_line));
157    }
158
159    for line in lines {
160        output.push('\n');
161        output.push_str(line.strip_prefix(prefix).unwrap_or(line))
162    }
163
164    output
165}
166
167fn parse_module(source_text: &str) -> Result<SourceFile> {
168    let parsed = SourceFile::parse(source_text, Edition::Edition2024);
169    let errors = parsed.errors();
170
171    if !errors.is_empty() {
172        bail!(errors.iter().join("\n"))
173    }
174
175    Ok(parsed.tree())
176}
177
178fn is_named(item: &impl HasName, name: &str) -> bool {
179    item.name().is_some_and(|n| n.text().as_ref() == name)
180}
181
182fn longest_prefix<'a>(mut prefixes: impl Iterator<Item = &'a str>) -> &'a str {
183    if let Some(mut longest_prefix) = prefixes.next() {
184        for prefix in prefixes {
185            // We can use `split_at` with `find_position` as our strings
186            // only contain single byte chars (' ' or '\t').
187            longest_prefix = longest_prefix
188                .split_at(
189                    longest_prefix
190                        .chars()
191                        .zip(prefix.chars())
192                        .find_position(|(x, y)| x != y)
193                        .map(|(position, _ch)| position)
194                        .unwrap_or_else(|| min(longest_prefix.len(), prefix.len())),
195                )
196                .0;
197        }
198
199        longest_prefix
200    } else {
201        ""
202    }
203}
204
205fn ensure_in_markdown<'a>(in_code_block: &mut bool, whitespace: &'a str) -> &'a str {
206    let text = if *in_code_block {
207        "\n```\n\n"
208    } else {
209        whitespace
210    };
211
212    *in_code_block = false;
213    text
214}
215
216fn ensure_in_code_block<'a>(in_code_block: &mut bool, whitespace: &'a str) -> &'a str {
217    let text = if *in_code_block {
218        whitespace
219    } else {
220        "\n\n```rust,ignore\n"
221    };
222
223    *in_code_block = true;
224    text
225}
226
227fn whitespace_prefix(line: &str) -> Option<&str> {
228    let non_ws = |c| c != ' ' && c != '\t';
229    line.split_once(non_ws).map(|(prefix, _)| prefix)
230}
231
232fn expect_kind(
233    expected: SyntaxKind,
234    actual: Option<NodeOrToken<SyntaxNode, SyntaxToken>>,
235) -> Result<()> {
236    let actual_kind = actual
237        .and_then(|last| last.into_token())
238        .map(|token| token.kind());
239
240    if Some(expected) == actual_kind {
241        Ok(())
242    } else {
243        bail!("Unexpected token")
244    }
245}