1use serde::Deserialize;
2use std::rc::Rc;
3
4use tree_sitter::Node;
5
6use crate::{
7 linter::{range_from_tree_sitter, Context, RuleLinter, RuleViolation},
8 rules::{Rule, RuleType},
9};
10
11#[derive(Debug, PartialEq, Clone, Deserialize, Default)]
13pub struct MD043RequiredHeadingsTable {
14 #[serde(default)]
15 pub headings: Vec<String>,
16 #[serde(default)]
17 pub match_case: bool,
18}
19
20#[derive(Debug, Clone)]
21struct HeadingInfo {
22 content: String,
23 level: u8,
24 range: tree_sitter::Range,
25}
26
27pub(crate) struct MD043Linter {
28 context: Rc<Context>,
29 violations: Vec<RuleViolation>,
30 headings: Vec<HeadingInfo>,
31}
32
33impl MD043Linter {
34 pub fn new(context: Rc<Context>) -> Self {
35 Self {
36 context,
37 violations: Vec::new(),
38 headings: Vec::new(),
39 }
40 }
41
42 fn extract_heading_content(&self, node: &Node) -> String {
43 let source = self.context.get_document_content();
44 let start_byte = node.start_byte();
45 let end_byte = node.end_byte();
46 let full_text = &source[start_byte..end_byte];
47
48 match node.kind() {
49 "atx_heading" => {
50 let text = full_text
52 .trim_start_matches('#')
53 .trim()
54 .trim_end_matches('#')
55 .trim();
56 text.to_string()
57 }
58 "setext_heading" => {
59 if let Some(line) = full_text.lines().next() {
61 line.trim().to_string()
62 } else {
63 String::new()
64 }
65 }
66 _ => String::new(),
67 }
68 }
69
70 fn extract_heading_level(&self, node: &Node) -> u8 {
71 match node.kind() {
72 "atx_heading" => {
73 for i in 0..node.child_count() {
74 let child = node.child(i).unwrap();
75 if child.kind().starts_with("atx_h") && child.kind().ends_with("_marker") {
76 return child.kind().chars().nth(5).unwrap().to_digit(10).unwrap() as u8;
77 }
78 }
79 1 }
81 "setext_heading" => {
82 for i in 0..node.child_count() {
83 let child = node.child(i).unwrap();
84 if child.kind() == "setext_h1_underline" {
85 return 1;
86 } else if child.kind() == "setext_h2_underline" {
87 return 2;
88 }
89 }
90 1 }
92 _ => 1,
93 }
94 }
95
96 fn format_heading(&self, content: &str, level: u8) -> String {
97 format!("{} {}", "#".repeat(level as usize), content)
98 }
99
100 fn compare_headings(&self, expected: &str, actual: &str) -> bool {
101 let config = &self.context.config.linters.settings.required_headings;
102 if config.match_case {
103 expected == actual
104 } else {
105 expected.to_lowercase() == actual.to_lowercase()
106 }
107 }
108
109 fn check_required_headings(&mut self) {
110 let config = &self.context.config.linters.settings.required_headings;
111
112 if config.headings.is_empty() {
113 return; }
115
116 let mut required_index = 0;
117 let mut match_any = false;
118 let mut has_error = false;
119 let any_headings = !self.headings.is_empty();
120
121 for heading in &self.headings {
122 if has_error {
123 break;
124 }
125
126 let actual = self.format_heading(&heading.content, heading.level);
127
128 if required_index >= config.headings.len() {
129 break;
131 }
132
133 let expected = &config.headings[required_index];
134
135 match expected.as_str() {
136 "*" => {
137 if required_index + 1 < config.headings.len() {
139 let next_expected = &config.headings[required_index + 1];
140 if self.compare_headings(next_expected, &actual) {
141 required_index += 2; match_any = false;
143 } else {
144 match_any = true;
145 }
146 } else {
147 match_any = true;
148 }
149 }
150 "+" => {
151 match_any = true;
153 required_index += 1;
154 }
155 "?" => {
156 required_index += 1;
158 }
159 _ => {
160 if self.compare_headings(expected, &actual) {
162 required_index += 1;
163 match_any = false;
164 } else if match_any {
165 continue;
167 } else {
168 self.violations.push(RuleViolation::new(
170 &MD043,
171 format!("Expected: {expected}; Actual: {actual}"),
172 self.context.file_path.clone(),
173 range_from_tree_sitter(&heading.range),
174 ));
175 has_error = true;
176 }
177 }
178 }
179 }
180
181 let extra_headings = config.headings.len() - required_index;
183 if !has_error
184 && ((extra_headings > 1)
185 || ((extra_headings == 1) && (config.headings[required_index] != "*")))
186 && (any_headings || !config.headings.iter().all(|h| h == "*"))
187 {
188 let last_line = self.context.get_document_content().lines().count();
190 let missing_heading = &config.headings[required_index];
191
192 let end_range = tree_sitter::Range {
194 start_byte: self.context.get_document_content().len(),
195 end_byte: self.context.get_document_content().len(),
196 start_point: tree_sitter::Point {
197 row: last_line,
198 column: 0,
199 },
200 end_point: tree_sitter::Point {
201 row: last_line,
202 column: 0,
203 },
204 };
205
206 self.violations.push(RuleViolation::new(
207 &MD043,
208 format!("Missing heading: {missing_heading}"),
209 self.context.file_path.clone(),
210 range_from_tree_sitter(&end_range),
211 ));
212 }
213 }
214}
215
216impl RuleLinter for MD043Linter {
217 fn feed(&mut self, node: &Node) {
218 if node.kind() == "atx_heading" || node.kind() == "setext_heading" {
219 let content = self.extract_heading_content(node);
220 let level = self.extract_heading_level(node);
221
222 self.headings.push(HeadingInfo {
223 content,
224 level,
225 range: node.range(),
226 });
227 }
228 }
229
230 fn finalize(&mut self) -> Vec<RuleViolation> {
231 self.check_required_headings();
232 std::mem::take(&mut self.violations)
233 }
234}
235
236pub const MD043: Rule = Rule {
237 id: "MD043",
238 alias: "required-headings",
239 tags: &["headings"],
240 description: "Required heading structure",
241 rule_type: RuleType::Document,
242 required_nodes: &["atx_heading", "setext_heading"],
243 new_linter: |context| Box::new(MD043Linter::new(context)),
244};
245
246#[cfg(test)]
247mod test {
248 use std::path::PathBuf;
249
250 use crate::config::{LintersSettingsTable, MD043RequiredHeadingsTable, RuleSeverity};
251 use crate::linter::MultiRuleLinter;
252 use crate::test_utils::test_helpers::test_config_with_settings;
253
254 fn test_config(headings: Vec<String>, match_case: bool) -> crate::config::QuickmarkConfig {
255 test_config_with_settings(
256 vec![("required-headings", RuleSeverity::Error)],
257 LintersSettingsTable {
258 required_headings: MD043RequiredHeadingsTable {
259 headings,
260 match_case,
261 },
262 ..Default::default()
263 },
264 )
265 }
266
267 #[test]
268 fn test_no_required_headings() {
269 let config = test_config(vec![], false);
270 let input = "# Title\n\n## Section\n\nContent";
271
272 let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
273 let violations = linter.analyze();
274 assert_eq!(violations.len(), 0);
275 }
276
277 #[test]
278 fn test_exact_match() {
279 let config = test_config(
280 vec![
281 "# Title".to_string(),
282 "## Section".to_string(),
283 "### Details".to_string(),
284 ],
285 false,
286 );
287 let input = "# Title\n\n## Section\n\n### Details\n\nContent";
288
289 let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
290 let violations = linter.analyze();
291 assert_eq!(violations.len(), 0);
292 }
293
294 #[test]
295 fn test_missing_heading() {
296 let config = test_config(
297 vec![
298 "# Title".to_string(),
299 "## Section".to_string(),
300 "### Details".to_string(),
301 ],
302 false,
303 );
304 let input = "# Title\n\n### Details\n\nContent";
305
306 let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
307 let violations = linter.analyze();
308 assert_eq!(violations.len(), 1);
309 assert!(violations[0].message().contains("Expected: ## Section"));
310 }
311
312 #[test]
313 fn test_wrong_heading() {
314 let config = test_config(vec!["# Title".to_string(), "## Section".to_string()], false);
315 let input = "# Title\n\n## Wrong Section\n\nContent";
316
317 let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
318 let violations = linter.analyze();
319 assert_eq!(violations.len(), 1);
320 assert!(violations[0].message().contains("Expected: ## Section"));
321 assert!(violations[0].message().contains("Actual: ## Wrong Section"));
322 }
323
324 #[test]
325 fn test_case_insensitive_match() {
326 let config = test_config(vec!["# Title".to_string(), "## Section".to_string()], false);
327 let input = "# TITLE\n\n## section\n\nContent";
328
329 let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
330 let violations = linter.analyze();
331 assert_eq!(violations.len(), 0);
332 }
333
334 #[test]
335 fn test_case_sensitive_match() {
336 let config = test_config(vec!["# Title".to_string(), "## Section".to_string()], true);
337 let input = "# TITLE\n\n## section\n\nContent";
338
339 let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
340 let violations = linter.analyze();
341 assert_eq!(violations.len(), 1); assert!(violations[0].message().contains("Expected: # Title"));
343 assert!(violations[0].message().contains("Actual: # TITLE"));
344 }
345
346 #[test]
347 fn test_zero_or_more_wildcard() {
348 let config = test_config(
349 vec![
350 "# Title".to_string(),
351 "*".to_string(),
352 "## Important".to_string(),
353 ],
354 false,
355 );
356 let input = "# Title\n\n## Random\n\n### Sub\n\n## Important\n\nContent";
357
358 let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
359 let violations = linter.analyze();
360 assert_eq!(violations.len(), 0);
361 }
362
363 #[test]
364 fn test_one_or_more_wildcard() {
365 let config = test_config(
366 vec![
367 "# Title".to_string(),
368 "+".to_string(),
369 "## Important".to_string(),
370 ],
371 false,
372 );
373 let input = "# Title\n\n## Random\n\n### Sub\n\n## Important\n\nContent";
374
375 let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
376 let violations = linter.analyze();
377 assert_eq!(violations.len(), 0);
378 }
379
380 #[test]
381 fn test_question_mark_wildcard() {
382 let config = test_config(vec!["?".to_string(), "## Section".to_string()], false);
383 let input = "# Any Title\n\n## Section\n\nContent";
384
385 let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
386 let violations = linter.analyze();
387 assert_eq!(violations.len(), 0);
388 }
389
390 #[test]
391 fn test_missing_heading_at_end() {
392 let config = test_config(
393 vec![
394 "# Title".to_string(),
395 "## Section".to_string(),
396 "### Details".to_string(),
397 ],
398 false,
399 );
400 let input = "# Title\n\n## Section\n\nContent";
401
402 let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
403 let violations = linter.analyze();
404 assert_eq!(violations.len(), 1);
405 assert!(violations[0]
406 .message()
407 .contains("Missing heading: ### Details"));
408 }
409
410 #[test]
411 fn test_setext_headings() {
412 let config = test_config(vec!["# Title".to_string(), "## Section".to_string()], false);
413 let input = "Title\n=====\n\nSection\n-------\n\nContent";
414
415 let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
416 let violations = linter.analyze();
417 assert_eq!(violations.len(), 0);
418 }
419
420 #[test]
421 fn test_mixed_heading_styles() {
422 let config = test_config(vec!["# Title".to_string(), "## Section".to_string()], false);
423 let input = "Title\n=====\n\n## Section\n\nContent";
424
425 let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
426 let violations = linter.analyze();
427 assert_eq!(violations.len(), 0);
428 }
429
430 #[test]
431 fn test_closed_atx_headings() {
432 let config = test_config(vec!["# Title".to_string(), "## Section".to_string()], false);
433 let input = "# Title #\n\n## Section ##\n\nContent";
434
435 let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
436 let violations = linter.analyze();
437 assert_eq!(violations.len(), 0);
438 }
439}