quickmark_core/rules/
md043.rs

1use serde::Deserialize;
2use std::rc::Rc;
3
4use tree_sitter::Node;
5
6use crate::{
7    linter::{range_from_tree_sitter, Context, RuleLinter, RuleViolation},
8    rules::{Rule, RuleType},
9};
10
11// MD043-specific configuration types
12#[derive(Debug, PartialEq, Clone, Deserialize, Default)]
13pub struct MD043RequiredHeadingsTable {
14    #[serde(default)]
15    pub headings: Vec<String>,
16    #[serde(default)]
17    pub match_case: bool,
18}
19
20#[derive(Debug, Clone)]
21struct HeadingInfo {
22    content: String,
23    level: u8,
24    range: tree_sitter::Range,
25}
26
27pub(crate) struct MD043Linter {
28    context: Rc<Context>,
29    violations: Vec<RuleViolation>,
30    headings: Vec<HeadingInfo>,
31}
32
33impl MD043Linter {
34    pub fn new(context: Rc<Context>) -> Self {
35        Self {
36            context,
37            violations: Vec::new(),
38            headings: Vec::new(),
39        }
40    }
41
42    fn extract_heading_content(&self, node: &Node) -> String {
43        let source = self.context.get_document_content();
44        let start_byte = node.start_byte();
45        let end_byte = node.end_byte();
46        let full_text = &source[start_byte..end_byte];
47
48        match node.kind() {
49            "atx_heading" => {
50                // Remove leading #s and trailing #s if present
51                let text = full_text
52                    .trim_start_matches('#')
53                    .trim()
54                    .trim_end_matches('#')
55                    .trim();
56                text.to_string()
57            }
58            "setext_heading" => {
59                // For setext, take first line (before underline)
60                if let Some(line) = full_text.lines().next() {
61                    line.trim().to_string()
62                } else {
63                    String::new()
64                }
65            }
66            _ => String::new(),
67        }
68    }
69
70    fn extract_heading_level(&self, node: &Node) -> u8 {
71        match node.kind() {
72            "atx_heading" => {
73                for i in 0..node.child_count() {
74                    let child = node.child(i).unwrap();
75                    if child.kind().starts_with("atx_h") && child.kind().ends_with("_marker") {
76                        return child.kind().chars().nth(5).unwrap().to_digit(10).unwrap() as u8;
77                    }
78                }
79                1 // fallback
80            }
81            "setext_heading" => {
82                for i in 0..node.child_count() {
83                    let child = node.child(i).unwrap();
84                    if child.kind() == "setext_h1_underline" {
85                        return 1;
86                    } else if child.kind() == "setext_h2_underline" {
87                        return 2;
88                    }
89                }
90                1 // fallback
91            }
92            _ => 1,
93        }
94    }
95
96    fn format_heading(&self, content: &str, level: u8) -> String {
97        format!("{} {}", "#".repeat(level as usize), content)
98    }
99
100    fn compare_headings(&self, expected: &str, actual: &str) -> bool {
101        let config = &self.context.config.linters.settings.required_headings;
102        if config.match_case {
103            expected == actual
104        } else {
105            expected.to_lowercase() == actual.to_lowercase()
106        }
107    }
108
109    fn check_required_headings(&mut self) {
110        let config = &self.context.config.linters.settings.required_headings;
111
112        if config.headings.is_empty() {
113            return; // Nothing to check
114        }
115
116        let mut required_index = 0;
117        let mut match_any = false;
118        let mut has_error = false;
119        let any_headings = !self.headings.is_empty();
120
121        for heading in &self.headings {
122            if has_error {
123                break;
124            }
125
126            let actual = self.format_heading(&heading.content, heading.level);
127
128            if required_index >= config.headings.len() {
129                // No more required headings, but we have more actual headings
130                break;
131            }
132
133            let expected = &config.headings[required_index];
134
135            match expected.as_str() {
136                "*" => {
137                    // Zero or more unspecified headings
138                    if required_index + 1 < config.headings.len() {
139                        let next_expected = &config.headings[required_index + 1];
140                        if self.compare_headings(next_expected, &actual) {
141                            required_index += 2; // Skip "*" and match the next
142                            match_any = false;
143                        } else {
144                            match_any = true;
145                        }
146                    } else {
147                        match_any = true;
148                    }
149                }
150                "+" => {
151                    // One or more unspecified headings
152                    match_any = true;
153                    required_index += 1;
154                }
155                "?" => {
156                    // Exactly one unspecified heading
157                    required_index += 1;
158                }
159                _ => {
160                    // Specific heading required
161                    if self.compare_headings(expected, &actual) {
162                        required_index += 1;
163                        match_any = false;
164                    } else if match_any {
165                        // We're in a "match any" state, so continue without advancing
166                        continue;
167                    } else {
168                        // Expected specific heading but got something else
169                        self.violations.push(RuleViolation::new(
170                            &MD043,
171                            format!("Expected: {expected}; Actual: {actual}"),
172                            self.context.file_path.clone(),
173                            range_from_tree_sitter(&heading.range),
174                        ));
175                        has_error = true;
176                    }
177                }
178            }
179        }
180
181        // Check if there are unmatched required headings at the end
182        let extra_headings = config.headings.len() - required_index;
183        if !has_error
184            && ((extra_headings > 1)
185                || ((extra_headings == 1) && (config.headings[required_index] != "*")))
186            && (any_headings || !config.headings.iter().all(|h| h == "*"))
187        {
188            // Report missing heading at end of file
189            let last_line = self.context.get_document_content().lines().count();
190            let missing_heading = &config.headings[required_index];
191
192            // Create a range for the end of file
193            let end_range = tree_sitter::Range {
194                start_byte: self.context.get_document_content().len(),
195                end_byte: self.context.get_document_content().len(),
196                start_point: tree_sitter::Point {
197                    row: last_line,
198                    column: 0,
199                },
200                end_point: tree_sitter::Point {
201                    row: last_line,
202                    column: 0,
203                },
204            };
205
206            self.violations.push(RuleViolation::new(
207                &MD043,
208                format!("Missing heading: {missing_heading}"),
209                self.context.file_path.clone(),
210                range_from_tree_sitter(&end_range),
211            ));
212        }
213    }
214}
215
216impl RuleLinter for MD043Linter {
217    fn feed(&mut self, node: &Node) {
218        if node.kind() == "atx_heading" || node.kind() == "setext_heading" {
219            let content = self.extract_heading_content(node);
220            let level = self.extract_heading_level(node);
221
222            self.headings.push(HeadingInfo {
223                content,
224                level,
225                range: node.range(),
226            });
227        }
228    }
229
230    fn finalize(&mut self) -> Vec<RuleViolation> {
231        self.check_required_headings();
232        std::mem::take(&mut self.violations)
233    }
234}
235
236pub const MD043: Rule = Rule {
237    id: "MD043",
238    alias: "required-headings",
239    tags: &["headings"],
240    description: "Required heading structure",
241    rule_type: RuleType::Document,
242    required_nodes: &["atx_heading", "setext_heading"],
243    new_linter: |context| Box::new(MD043Linter::new(context)),
244};
245
246#[cfg(test)]
247mod test {
248    use std::path::PathBuf;
249
250    use crate::config::{LintersSettingsTable, MD043RequiredHeadingsTable, RuleSeverity};
251    use crate::linter::MultiRuleLinter;
252    use crate::test_utils::test_helpers::test_config_with_settings;
253
254    fn test_config(headings: Vec<String>, match_case: bool) -> crate::config::QuickmarkConfig {
255        test_config_with_settings(
256            vec![("required-headings", RuleSeverity::Error)],
257            LintersSettingsTable {
258                required_headings: MD043RequiredHeadingsTable {
259                    headings,
260                    match_case,
261                },
262                ..Default::default()
263            },
264        )
265    }
266
267    #[test]
268    fn test_no_required_headings() {
269        let config = test_config(vec![], false);
270        let input = "# Title\n\n## Section\n\nContent";
271
272        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
273        let violations = linter.analyze();
274        assert_eq!(violations.len(), 0);
275    }
276
277    #[test]
278    fn test_exact_match() {
279        let config = test_config(
280            vec![
281                "# Title".to_string(),
282                "## Section".to_string(),
283                "### Details".to_string(),
284            ],
285            false,
286        );
287        let input = "# Title\n\n## Section\n\n### Details\n\nContent";
288
289        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
290        let violations = linter.analyze();
291        assert_eq!(violations.len(), 0);
292    }
293
294    #[test]
295    fn test_missing_heading() {
296        let config = test_config(
297            vec![
298                "# Title".to_string(),
299                "## Section".to_string(),
300                "### Details".to_string(),
301            ],
302            false,
303        );
304        let input = "# Title\n\n### Details\n\nContent";
305
306        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
307        let violations = linter.analyze();
308        assert_eq!(violations.len(), 1);
309        assert!(violations[0].message().contains("Expected: ## Section"));
310    }
311
312    #[test]
313    fn test_wrong_heading() {
314        let config = test_config(vec!["# Title".to_string(), "## Section".to_string()], false);
315        let input = "# Title\n\n## Wrong Section\n\nContent";
316
317        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
318        let violations = linter.analyze();
319        assert_eq!(violations.len(), 1);
320        assert!(violations[0].message().contains("Expected: ## Section"));
321        assert!(violations[0].message().contains("Actual: ## Wrong Section"));
322    }
323
324    #[test]
325    fn test_case_insensitive_match() {
326        let config = test_config(vec!["# Title".to_string(), "## Section".to_string()], false);
327        let input = "# TITLE\n\n## section\n\nContent";
328
329        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
330        let violations = linter.analyze();
331        assert_eq!(violations.len(), 0);
332    }
333
334    #[test]
335    fn test_case_sensitive_match() {
336        let config = test_config(vec!["# Title".to_string(), "## Section".to_string()], true);
337        let input = "# TITLE\n\n## section\n\nContent";
338
339        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
340        let violations = linter.analyze();
341        assert_eq!(violations.len(), 1); // Only reports the first mismatch
342        assert!(violations[0].message().contains("Expected: # Title"));
343        assert!(violations[0].message().contains("Actual: # TITLE"));
344    }
345
346    #[test]
347    fn test_zero_or_more_wildcard() {
348        let config = test_config(
349            vec![
350                "# Title".to_string(),
351                "*".to_string(),
352                "## Important".to_string(),
353            ],
354            false,
355        );
356        let input = "# Title\n\n## Random\n\n### Sub\n\n## Important\n\nContent";
357
358        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
359        let violations = linter.analyze();
360        assert_eq!(violations.len(), 0);
361    }
362
363    #[test]
364    fn test_one_or_more_wildcard() {
365        let config = test_config(
366            vec![
367                "# Title".to_string(),
368                "+".to_string(),
369                "## Important".to_string(),
370            ],
371            false,
372        );
373        let input = "# Title\n\n## Random\n\n### Sub\n\n## Important\n\nContent";
374
375        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
376        let violations = linter.analyze();
377        assert_eq!(violations.len(), 0);
378    }
379
380    #[test]
381    fn test_question_mark_wildcard() {
382        let config = test_config(vec!["?".to_string(), "## Section".to_string()], false);
383        let input = "# Any Title\n\n## Section\n\nContent";
384
385        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
386        let violations = linter.analyze();
387        assert_eq!(violations.len(), 0);
388    }
389
390    #[test]
391    fn test_missing_heading_at_end() {
392        let config = test_config(
393            vec![
394                "# Title".to_string(),
395                "## Section".to_string(),
396                "### Details".to_string(),
397            ],
398            false,
399        );
400        let input = "# Title\n\n## Section\n\nContent";
401
402        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
403        let violations = linter.analyze();
404        assert_eq!(violations.len(), 1);
405        assert!(violations[0]
406            .message()
407            .contains("Missing heading: ### Details"));
408    }
409
410    #[test]
411    fn test_setext_headings() {
412        let config = test_config(vec!["# Title".to_string(), "## Section".to_string()], false);
413        let input = "Title\n=====\n\nSection\n-------\n\nContent";
414
415        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
416        let violations = linter.analyze();
417        assert_eq!(violations.len(), 0);
418    }
419
420    #[test]
421    fn test_mixed_heading_styles() {
422        let config = test_config(vec!["# Title".to_string(), "## Section".to_string()], false);
423        let input = "Title\n=====\n\n## Section\n\nContent";
424
425        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
426        let violations = linter.analyze();
427        assert_eq!(violations.len(), 0);
428    }
429
430    #[test]
431    fn test_closed_atx_headings() {
432        let config = test_config(vec!["# Title".to_string(), "## Section".to_string()], false);
433        let input = "# Title #\n\n## Section ##\n\nContent";
434
435        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
436        let violations = linter.analyze();
437        assert_eq!(violations.len(), 0);
438    }
439}