1use std::{cell::RefCell, collections::HashMap, fmt::Display, path::PathBuf, rc::Rc};
2use tree_sitter::{Node, Parser};
3use tree_sitter_md::LANGUAGE;
4
5use crate::{
6    config::{QuickmarkConfig, RuleSeverity},
7    rules::{Rule, ALL_RULES},
8    tree_sitter_walker::TreeSitterWalker,
9};
10
11#[derive(Debug, Clone)]
12pub struct CharPosition {
13    pub line: usize,
14    pub character: usize,
15}
16
17#[derive(Debug, Clone)]
18pub struct Range {
19    pub start: CharPosition,
20    pub end: CharPosition,
21}
22#[derive(Debug)]
23pub struct Location {
24    pub file_path: PathBuf,
25    pub range: Range,
26}
27
28#[derive(Debug)]
29pub struct RuleViolation {
30    location: Location,
31    message: String,
32    rule: &'static Rule,
33    pub(crate) severity: RuleSeverity,
34}
35
36impl RuleViolation {
37    pub fn new(rule: &'static Rule, message: String, file_path: PathBuf, range: Range) -> Self {
38        Self {
39            rule,
40            message,
41            location: Location { file_path, range },
42            severity: RuleSeverity::Error, }
44    }
45
46    pub fn location(&self) -> &Location {
47        &self.location
48    }
49
50    pub fn message(&self) -> &str {
51        &self.message
52    }
53
54    pub fn rule(&self) -> &'static Rule {
55        self.rule
56    }
57
58    pub fn severity(&self) -> &RuleSeverity {
59        &self.severity
60    }
61}
62
63pub fn range_from_tree_sitter(ts_range: &tree_sitter::Range) -> Range {
65    Range {
66        start: CharPosition {
67            line: ts_range.start_point.row,
68            character: ts_range.start_point.column,
69        },
70        end: CharPosition {
71            line: ts_range.end_point.row,
72            character: ts_range.end_point.column,
73        },
74    }
75}
76
77impl Display for RuleViolation {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        write!(
80            f,
81            "{}:{}:{} {}/{} {}",
82            self.location().file_path.to_string_lossy(),
83            self.location().range.start.line,
84            self.location().range.start.character,
85            self.rule().id,
86            self.rule().alias,
87            self.message()
88        )
89    }
90}
91
92#[derive(Debug)]
99pub struct Context {
100    pub file_path: PathBuf,
101    pub config: QuickmarkConfig,
102    pub lines: RefCell<Vec<String>>,
104    pub node_cache: RefCell<HashMap<String, Vec<NodeInfo>>>,
106    pub document_content: RefCell<String>,
108}
109
110#[derive(Debug, Clone)]
112pub struct NodeInfo {
113    pub line_start: usize,
114    pub line_end: usize,
115    pub kind: String,
116}
117
118impl Context {
119    pub fn new(
120        file_path: PathBuf,
121        config: QuickmarkConfig,
122        source: &str,
123        root_node: &Node,
124    ) -> Self {
125        let mut lines: Vec<String> = source.lines().map(String::from).collect();
128
129        if source.ends_with('\n') {
131            lines.push(String::new());
132        }
133        let node_cache = Self::build_node_cache(root_node);
134
135        Self {
136            file_path,
137            config,
138            lines: RefCell::new(lines),
139            node_cache: RefCell::new(node_cache),
140            document_content: RefCell::new(source.to_string()),
141        }
142    }
143
144    pub fn get_document_content(&self) -> std::cell::Ref<'_, String> {
147        self.document_content.borrow()
148    }
149
150    fn build_node_cache(root_node: &Node) -> HashMap<String, Vec<NodeInfo>> {
152        let mut cache = HashMap::new();
153        Self::collect_nodes_recursive(root_node, &mut cache);
154        cache
155    }
156
157    fn collect_nodes_recursive(node: &Node, cache: &mut HashMap<String, Vec<NodeInfo>>) {
158        let kind = node.kind();
159        let kind_string = kind.to_string();
160        let node_info = NodeInfo {
161            line_start: node.start_position().row,
162            line_end: node.end_position().row,
163            kind: kind_string.clone(),
164        };
165
166        cache
168            .entry(kind_string)
169            .or_default()
170            .push(node_info.clone());
171
172        if kind.contains("heading") {
174            cache
175                .entry("*heading*".to_string())
176                .or_default()
177                .push(node_info);
178        }
179
180        for i in 0..node.child_count() {
182            if let Some(child) = node.child(i) {
183                Self::collect_nodes_recursive(&child, cache);
184            }
185        }
186    }
187
188    pub fn get_nodes(&self, node_types: &[&str]) -> Vec<NodeInfo> {
190        let cache = self.node_cache.borrow();
191        let mut result = Vec::new();
192        for node_type in node_types {
193            if let Some(nodes) = cache.get(*node_type) {
194                result.extend(nodes.iter().cloned());
195            }
196        }
197        result
198    }
199
200    pub fn get_node_type_for_line(&self, line_number: usize) -> String {
202        let cache = self.node_cache.borrow();
203        let mut best_match: Option<&NodeInfo> = None;
205        let mut smallest_range = usize::MAX;
206
207        for nodes in cache.values() {
208            for node in nodes {
209                if line_number >= node.line_start && line_number <= node.line_end {
210                    let range_size = node.line_end - node.line_start;
211                    if range_size < smallest_range {
212                        smallest_range = range_size;
213                        best_match = Some(node);
214                    }
215                }
216            }
217        }
218
219        best_match
220            .map(|n| n.kind.clone())
221            .unwrap_or_else(|| "text".to_string())
222    }
223}
224
225pub trait RuleLinter {
256    fn feed(&mut self, node: &Node);
262
263    fn finalize(&mut self) -> Vec<RuleViolation>;
267}
268pub struct MultiRuleLinter {
273    linters: Vec<Box<dyn RuleLinter>>,
274    tree: Option<tree_sitter::Tree>,
275    config: QuickmarkConfig,
276}
277
278impl MultiRuleLinter {
279    pub fn new_for_document(file_path: PathBuf, config: QuickmarkConfig, document: &str) -> Self {
289        let active_rules: Vec<_> = ALL_RULES
291            .iter()
292            .filter(|r| {
293                config
294                    .linters
295                    .severity
296                    .get(r.alias)
297                    .map(|severity| *severity != RuleSeverity::Off)
298                    .unwrap_or(false)
299            })
300            .collect();
301
302        if active_rules.is_empty() {
304            return Self {
305                linters: Vec::new(),
306                tree: None,
307                config,
308            };
309        }
310
311        let mut parser = Parser::new();
313        parser
314            .set_language(&LANGUAGE.into())
315            .expect("Error loading Markdown grammar");
316        let tree = parser.parse(document, None).expect("Parse failed");
317
318        let context = Rc::new(Context::new(
320            file_path,
321            config.clone(),
322            document,
323            &tree.root_node(),
324        ));
325
326        let linters = active_rules
328            .iter()
329            .map(|r| ((r.new_linter)(context.clone())))
330            .collect();
331
332        Self {
333            linters,
334            tree: Some(tree),
335            config,
336        }
337    }
338
339    pub fn analyze(&mut self) -> Vec<RuleViolation> {
344        if self.linters.is_empty() {
346            return Vec::new();
347        }
348
349        let tree = match &self.tree {
351            Some(tree) => tree,
352            None => return Vec::new(),
353        };
354
355        let walker = TreeSitterWalker::new(tree);
356
357        walker.walk(|node| {
359            for linter in &mut self.linters {
360                linter.feed(&node);
361            }
362        });
363
364        let mut violations = Vec::new();
366        for linter in &mut self.linters {
367            let mut linter_violations = linter.finalize();
368            for violation in &mut linter_violations {
370                let severity = self
371                    .config
372                    .linters
373                    .severity
374                    .get(violation.rule().alias)
375                    .cloned()
376                    .unwrap_or(RuleSeverity::Error);
377                violation.severity = severity;
378            }
379            violations.extend(linter_violations);
380        }
381
382        violations
383    }
384}
385
386#[cfg(test)]
387mod test {
388    use std::{collections::HashMap, path::PathBuf};
389
390    use crate::{
391        config::{self, QuickmarkConfig, RuleSeverity},
392        rules::{md001::MD001, md003::MD003, md013::MD013},
393    };
394
395    use super::MultiRuleLinter;
396
397    #[test]
398    fn test_multiple_violations() {
399        let severity: HashMap<_, _> = vec![
400            (MD001.alias.to_string(), RuleSeverity::Error),
401            (MD003.alias.to_string(), RuleSeverity::Error),
402            (MD013.alias.to_string(), RuleSeverity::Error),
403        ]
404        .into_iter()
405        .collect();
406
407        let config = QuickmarkConfig {
408            linters: config::LintersTable {
409                severity,
410                settings: config::LintersSettingsTable {
411                    heading_style: config::MD003HeadingStyleTable {
412                        style: config::HeadingStyle::ATX,
413                    },
414                    ..Default::default()
415                },
416            },
417        };
418
419        let input = "
423# First heading
424Second heading
425==============
426#### Fourth level
427";
428
429        let mut linter = MultiRuleLinter::new_for_document(PathBuf::from("test.md"), config, input);
430        let violations = linter.analyze();
431        assert_eq!(
432            2,
433            violations.len(),
434            "Should find both MD001 and MD003 violations"
435        );
436        assert_eq!(MD001.id, violations[0].rule().id);
437        assert_eq!(4, violations[0].location().range.start.line);
438        assert_eq!(MD003.id, violations[1].rule().id);
439        assert_eq!(2, violations[1].location().range.start.line);
440    }
441}