1use std::{cmp::min, collections::VecDeque, fmt::Display};
2
3use anyhow::{bail, Result};
4use itertools::Itertools;
5use ra_ap_syntax::{
6 AstNode, AstToken, Edition, NodeOrToken, SourceFile, SyntaxKind, SyntaxNode, SyntaxToken, ast::{self, HasModuleItem, HasName, Item}
7};
8
9pub fn write_module(source_text: &str) -> Result<Option<String>> {
10 let source = parse_module(source_text)?;
11
12 for item in source.items() {
13 if let Item::Fn(function) = item {
14 if is_named(&function, "body") {
15 if let Some(new_content) = write_function(function)? {
16 return Ok(Some(new_content));
17 }
18 }
19 }
20 }
21
22 Ok(None)
23}
24
25fn write_function(function: ast::Fn) -> Result<Option<String>> {
26 if let Some(stmts) = function.body().and_then(|body| body.stmt_list()) {
27 let mut stmts: VecDeque<_> = stmts.syntax().children_with_tokens().collect();
28
29 expect_kind(SyntaxKind::L_CURLY, stmts.pop_front())?;
30 expect_kind(SyntaxKind::R_CURLY, stmts.pop_back())?;
31
32 let body_text = stmts.iter().map(|s| s.to_string()).collect::<String>();
33 let ws_prefixes = body_text.lines().filter_map(whitespace_prefix);
34 let longest_prefix = longest_prefix(ws_prefixes);
35
36 if stmts
37 .front()
38 .and_then(|node| node.as_token())
39 .is_some_and(|token| ast::Whitespace::can_cast(token.kind()))
40 {
41 stmts.pop_front();
42 }
43
44 Ok(Some(write_body(stmts, longest_prefix)))
45 } else {
46 Ok(None)
47 }
48}
49
50fn write_body(
51 stmts: impl IntoIterator<Item = NodeOrToken<SyntaxNode, SyntaxToken>>,
52 longest_prefix: &str,
53) -> String {
54 let mut whitespace = String::new();
55 let mut in_code_block = false;
56 let mut output = String::new();
57
58 for node in stmts {
59 write_node_or_token(
60 &mut output,
61 &mut in_code_block,
62 &mut whitespace,
63 node,
64 longest_prefix,
65 );
66 }
67
68 if in_code_block {
69 output.push_str("\n```");
70 }
71
72 output.push('\n');
73
74 output
75}
76
77fn write_node_or_token(
78 output: &mut String,
79 in_code_block: &mut bool,
80 whitespace: &mut String,
81 node: NodeOrToken<SyntaxNode, SyntaxToken>,
82 longest_prefix: &str,
83) {
84 match &node {
85 NodeOrToken::Node(node) => {
86 let mut children = node.children_with_tokens();
87
88 for child in children.by_ref() {
91 if child.kind() == SyntaxKind::COMMENT || child.kind() == SyntaxKind::WHITESPACE {
92 write_node_or_token(output, in_code_block, whitespace, child, longest_prefix);
93 } else {
94 output.push_str(ensure_in_code_block(in_code_block, whitespace));
95 output.push_str(&write_lines(child, longest_prefix));
96 break;
97 }
98 }
99
100 for child in children {
101 output.push_str(&write_lines(child, longest_prefix));
102 }
103
104 whitespace.clear();
105 }
106 NodeOrToken::Token(token) => {
107 write_token(output, in_code_block, whitespace, token, longest_prefix);
108 }
109 }
110}
111
112fn write_token(
113 output: &mut String,
114 in_code_block: &mut bool,
115 whitespace: &mut String,
116 token: &SyntaxToken,
117 longest_prefix: &str,
118) {
119 if let Some(comment) = ast::Comment::cast(token.clone()) {
120 if comment.is_doc() {
121 output.push_str(ensure_in_code_block(in_code_block, &*whitespace));
122 output.push_str(&write_lines(comment, longest_prefix));
123 } else {
124 output.push_str(ensure_in_markdown(in_code_block, &*whitespace));
125 output.push_str(&write_comment(comment, longest_prefix));
126 }
127
128 whitespace.clear();
129 } else if ast::Whitespace::can_cast(token.kind()) {
130 *whitespace = "\n".repeat(token.to_string().chars().filter(|c| *c == '\n').count())
131 } else {
132 output.push_str(&*whitespace);
133 output.push_str(&write_lines(token, longest_prefix));
134 whitespace.clear();
135 }
136}
137
138fn write_lines(text: impl Display, prefix: &str) -> String {
139 text.to_string()
140 .split('\n')
141 .map(|line| line.strip_prefix(prefix).unwrap_or(line))
142 .join("\n")
143}
144
145fn write_comment(comment: ast::Comment, prefix: &str) -> String {
146 let comment_suffix = &comment.text()[comment.prefix().len()..];
147 let comment_text = match comment.kind().shape {
148 ast::CommentShape::Line => comment_suffix,
149 ast::CommentShape::Block => comment_suffix.strip_suffix("*/").unwrap_or(comment_suffix),
150 };
151
152 let mut lines = comment_text.split('\n');
153 let mut output = String::new();
154
155 if let Some(first_line) = lines.next() {
156 output.push_str(first_line.strip_prefix(' ').unwrap_or(first_line));
157 }
158
159 for line in lines {
160 output.push('\n');
161 output.push_str(line.strip_prefix(prefix).unwrap_or(line))
162 }
163
164 output
165}
166
167fn parse_module(source_text: &str) -> Result<SourceFile> {
168 let parsed = SourceFile::parse(source_text, Edition::Edition2024);
169 let errors = parsed.errors();
170
171 if !errors.is_empty() {
172 bail!(errors.iter().join("\n"))
173 }
174
175 Ok(parsed.tree())
176}
177
178fn is_named(item: &impl HasName, name: &str) -> bool {
179 item.name().is_some_and(|n| n.text().as_ref() == name)
180}
181
182fn longest_prefix<'a>(mut prefixes: impl Iterator<Item = &'a str>) -> &'a str {
183 if let Some(mut longest_prefix) = prefixes.next() {
184 for prefix in prefixes {
185 longest_prefix = longest_prefix
188 .split_at(
189 longest_prefix
190 .chars()
191 .zip(prefix.chars())
192 .find_position(|(x, y)| x != y)
193 .map(|(position, _ch)| position)
194 .unwrap_or_else(|| min(longest_prefix.len(), prefix.len())),
195 )
196 .0;
197 }
198
199 longest_prefix
200 } else {
201 ""
202 }
203}
204
205fn ensure_in_markdown<'a>(in_code_block: &mut bool, whitespace: &'a str) -> &'a str {
206 let text = if *in_code_block {
207 "\n```\n\n"
208 } else {
209 whitespace
210 };
211
212 *in_code_block = false;
213 text
214}
215
216fn ensure_in_code_block<'a>(in_code_block: &mut bool, whitespace: &'a str) -> &'a str {
217 let text = if *in_code_block {
218 whitespace
219 } else {
220 "\n\n```rust,ignore\n"
221 };
222
223 *in_code_block = true;
224 text
225}
226
227fn whitespace_prefix(line: &str) -> Option<&str> {
228 let non_ws = |c| c != ' ' && c != '\t';
229 line.split_once(non_ws).map(|(prefix, _)| prefix)
230}
231
232fn expect_kind(
233 expected: SyntaxKind,
234 actual: Option<NodeOrToken<SyntaxNode, SyntaxToken>>,
235) -> Result<()> {
236 let actual_kind = actual
237 .and_then(|last| last.into_token())
238 .map(|token| token.kind());
239
240 if Some(expected) == actual_kind {
241 Ok(())
242 } else {
243 bail!("Unexpected token")
244 }
245}