quickmark_core/rules/
md001.rs

1use std::rc::Rc;
2
3use tree_sitter::Node;
4
5use crate::{
6    linter::{range_from_tree_sitter, RuleViolation},
7    rules::{Context, Rule, RuleLinter, RuleType},
8};
9
10pub(crate) struct MD001Linter {
11    context: Rc<Context>,
12    current_heading_level: u8,
13    violations: Vec<RuleViolation>,
14}
15
16impl MD001Linter {
17    pub fn new(context: Rc<Context>) -> Self {
18        Self {
19            context,
20            current_heading_level: 0,
21            violations: Vec::new(),
22        }
23    }
24}
25
26fn extract_heading_level(node: &Node) -> u8 {
27    let mut cursor = node.walk();
28    match node.kind() {
29        "atx_heading" => node
30            .children(&mut cursor)
31            .find_map(|child| {
32                let kind = child.kind();
33                if kind.starts_with("atx_h") && kind.ends_with("_marker") {
34                    // "atx_h3_marker" -> 3
35                    kind.get(5..6)?.parse::<u8>().ok()
36                } else {
37                    None
38                }
39            })
40            .unwrap_or(1),
41        "setext_heading" => node
42            .children(&mut cursor)
43            .find_map(|child| match child.kind() {
44                "setext_h1_underline" => Some(1),
45                "setext_h2_underline" => Some(2),
46                _ => None,
47            })
48            .unwrap_or(1),
49        _ => 1,
50    }
51}
52
53impl RuleLinter for MD001Linter {
54    fn feed(&mut self, node: &Node) {
55        if node.kind() == "atx_heading" || node.kind() == "setext_heading" {
56            let level = extract_heading_level(node);
57
58            if self.current_heading_level > 0
59                && (level as i8 - self.current_heading_level as i8) > 1
60            {
61                self.violations.push(RuleViolation::new(
62                    &MD001,
63                    format!(
64                        "{} [Expected: h{}; Actual: h{}]",
65                        MD001.description,
66                        self.current_heading_level + 1,
67                        level
68                    ),
69                    self.context.file_path.clone(),
70                    range_from_tree_sitter(&node.range()),
71                ));
72            }
73            self.current_heading_level = level;
74        }
75    }
76
77    fn finalize(&mut self) -> Vec<RuleViolation> {
78        std::mem::take(&mut self.violations)
79    }
80}
81
82pub const MD001: Rule = Rule {
83    id: "MD001",
84    alias: "heading-increment",
85    tags: &["headings"],
86    description: "Heading levels should only increment by one level at a time",
87    rule_type: RuleType::Token,
88    required_nodes: &["atx_heading", "setext_heading"],
89    new_linter: |context| Box::new(MD001Linter::new(context)),
90};
91
92#[cfg(test)]
93mod test {
94    use std::path::PathBuf;
95
96    use crate::config::RuleSeverity;
97    use crate::linter::MultiRuleLinter;
98    use crate::test_utils::test_helpers::test_config_with_rules;
99
100    fn test_config() -> crate::config::QuickmarkConfig {
101        test_config_with_rules(vec![
102            ("heading-increment", RuleSeverity::Error),
103            ("heading-style", RuleSeverity::Off),
104        ])
105    }
106
107    #[test]
108    fn test_atx_positive() {
109        let input = "# Heading level 1
110some text
111`some code`
112## Heading level 2
113some other text
114###### Heading level 6
115foobar
116#### Heading level 4
117### Heading level 3
118";
119
120        let config = test_config();
121        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
122        let violations = linter.analyze();
123        assert_eq!(1, violations.len());
124        let range1 = &violations[0].location().range;
125        assert_eq!(5, range1.start.line);
126        assert_eq!(0, range1.start.character);
127        assert_eq!(6, range1.end.line);
128        assert_eq!(0, range1.end.character);
129    }
130
131    #[test]
132    fn test_atx_negative() {
133        let input = "# Heading level 1
134some text
135`some code`
136## Heading level 2
137some other text
138### Heading level 3
139foobar
140#### Heading level 4
141##### Heading level 5
142###### Heading level 6
143";
144
145        let config = test_config();
146        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
147        let violations = linter.analyze();
148        assert_eq!(0, violations.len());
149    }
150
151    #[test]
152    fn test_atx_negative_starts_not_with_level_1() {
153        let input = "## Heading level 2
154some text
155`some code`
156### Heading level 3
157some other text
158#### Heading level 4
159foobar
160##### Heading level 5
161###### Heading level 6
162# level 1
163";
164
165        let config = test_config();
166        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
167        let violations = linter.analyze();
168        assert_eq!(0, violations.len());
169    }
170
171    #[test]
172    fn test_setext_positive() {
173        let input = "
174Heading level 1
175===============
176some text
177`some code`
178### Heading level 3
179some other text
180         ";
181
182        let config = test_config();
183        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
184        let violations = linter.analyze();
185        // Should trigger a violation: setext h1 -> atx h3 (skips h2)
186        assert_eq!(1, violations.len());
187        let range = &violations[0].location().range;
188        // The violation should be on the h3 heading
189        assert_eq!(5, range.start.line);
190        assert_eq!(0, range.start.character);
191    }
192
193    #[test]
194    fn test_setext_negative() {
195        let input = "
196Heading level 1
197===============
198some text
199Heading level 2
200---------------
201some other text
202";
203
204        let config = test_config();
205        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
206        let violations = linter.analyze();
207        // Should be no violations: setext h1 -> setext h2
208        assert_eq!(0, violations.len());
209    }
210}