rumdl_lib/utils/
ast_utils.rs

1//!
2//! AST parsing utilities and caching for rumdl
3//!
4//! This module provides shared AST parsing and caching functionality to avoid
5//! reparsing the same Markdown content multiple times across different rules.
6
7use crate::rule::MarkdownAst;
8use lazy_static::lazy_static;
9use std::collections::HashMap;
10use std::panic;
11use std::sync::{Arc, Mutex};
12
13/// Cache for parsed AST nodes
14#[derive(Debug)]
15pub struct AstCache {
16    cache: HashMap<u64, Arc<MarkdownAst>>,
17    usage_stats: HashMap<u64, u64>,
18}
19
20impl Default for AstCache {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl AstCache {
27    pub fn new() -> Self {
28        Self {
29            cache: HashMap::new(),
30            usage_stats: HashMap::new(),
31        }
32    }
33
34    /// Get or parse AST for the given content
35    pub fn get_or_parse(&mut self, content: &str) -> Arc<MarkdownAst> {
36        let content_hash = crate::utils::fast_hash(content);
37
38        if let Some(ast) = self.cache.get(&content_hash) {
39            *self.usage_stats.entry(content_hash).or_insert(0) += 1;
40            return ast.clone();
41        }
42
43        // Parse the AST
44        let ast = Arc::new(parse_markdown_ast(content));
45        self.cache.insert(content_hash, ast.clone());
46        *self.usage_stats.entry(content_hash).or_insert(0) += 1;
47
48        ast
49    }
50
51    /// Get cache statistics
52    pub fn get_stats(&self) -> HashMap<u64, u64> {
53        self.usage_stats.clone()
54    }
55
56    /// Clear the cache
57    pub fn clear(&mut self) {
58        self.cache.clear();
59        self.usage_stats.clear();
60    }
61
62    /// Get cache size
63    pub fn len(&self) -> usize {
64        self.cache.len()
65    }
66
67    /// Check if cache is empty
68    pub fn is_empty(&self) -> bool {
69        self.cache.is_empty()
70    }
71}
72
73lazy_static! {
74    /// Global AST cache instance
75    static ref GLOBAL_AST_CACHE: Arc<Mutex<AstCache>> = Arc::new(Mutex::new(AstCache::new()));
76}
77
78/// Get or parse AST from the global cache
79pub fn get_cached_ast(content: &str) -> Arc<MarkdownAst> {
80    let mut cache = GLOBAL_AST_CACHE.lock().unwrap();
81    cache.get_or_parse(content)
82}
83
84/// Get AST cache statistics
85pub fn get_ast_cache_stats() -> HashMap<u64, u64> {
86    let cache = GLOBAL_AST_CACHE.lock().unwrap();
87    cache.get_stats()
88}
89
90/// Clear the global AST cache
91pub fn clear_ast_cache() {
92    let mut cache = GLOBAL_AST_CACHE.lock().unwrap();
93    cache.clear();
94}
95
96/// Parse Markdown content into an AST
97pub fn parse_markdown_ast(content: &str) -> MarkdownAst {
98    // Check for problematic patterns that cause the markdown crate to panic
99    if content_has_problematic_lists(content) {
100        log::debug!("Detected problematic list patterns, skipping AST parsing");
101        return MarkdownAst::Root(markdown::mdast::Root {
102            children: vec![],
103            position: None,
104        });
105    }
106
107    // Try to parse AST with GFM extensions enabled, but handle panics from the markdown crate
108    match panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
109        let mut parse_options = markdown::ParseOptions::gfm();
110        parse_options.constructs.frontmatter = true; // Also enable frontmatter parsing
111        markdown::to_mdast(content, &parse_options)
112    })) {
113        Ok(Ok(ast)) => {
114            // Successfully parsed AST
115            ast
116        }
117        Ok(Err(err)) => {
118            // Parsing failed with an error
119            log::debug!("Failed to parse markdown AST in ast_utils: {err:?}");
120            MarkdownAst::Root(markdown::mdast::Root {
121                children: vec![],
122                position: None,
123            })
124        }
125        Err(_) => {
126            // Parsing panicked
127            log::debug!("Markdown AST parsing panicked in ast_utils, falling back to empty AST");
128            MarkdownAst::Root(markdown::mdast::Root {
129                children: vec![],
130                position: None,
131            })
132        }
133    }
134}
135
136/// Check if content contains patterns that cause the markdown crate to panic
137fn content_has_problematic_lists(content: &str) -> bool {
138    let lines: Vec<&str> = content.lines().collect();
139    let mut in_code_block = false;
140
141    // Track code blocks to avoid false positives
142    for i in 0..lines.len() {
143        let line = lines[i].trim_start();
144
145        // Toggle code block state
146        if line.starts_with("```") || line.starts_with("~~~") {
147            in_code_block = !in_code_block;
148            continue;
149        }
150
151        // Skip lines inside code blocks
152        if in_code_block {
153            continue;
154        }
155
156        // Check for adjacent lines with different list markers
157        if i + 1 < lines.len() {
158            let line1 = lines[i].trim_start();
159            let line2 = lines[i + 1].trim_start();
160
161            // Skip if next line is a code block start
162            if line2.starts_with("```") || line2.starts_with("~~~") {
163                continue;
164            }
165
166            // Check if both lines are list items with different markers
167            let is_list1 = line1.starts_with("* ") || line1.starts_with("+ ") || line1.starts_with("- ");
168            let is_list2 = line2.starts_with("* ") || line2.starts_with("+ ") || line2.starts_with("- ");
169
170            if is_list1 && is_list2 {
171                let marker1 = line1.chars().next().unwrap_or(' ');
172                let marker2 = line2.chars().next().unwrap_or(' ');
173
174                // If different markers, this could cause a panic
175                if marker1 != marker2 {
176                    return true;
177                }
178            }
179        }
180    }
181
182    // The first loop above should handle all cases
183    // No need for a second loop that would duplicate the check
184
185    false
186}
187
188/// Check if AST contains specific node types
189pub fn ast_contains_node_type(ast: &MarkdownAst, node_type: &str) -> bool {
190    match ast {
191        MarkdownAst::Root(root) => root
192            .children
193            .iter()
194            .any(|child| ast_contains_node_type(child, node_type)),
195        MarkdownAst::Heading(_) if node_type == "heading" => true,
196        MarkdownAst::List(_) if node_type == "list" => true,
197        MarkdownAst::Link(_) if node_type == "link" => true,
198        MarkdownAst::Image(_) if node_type == "image" => true,
199        MarkdownAst::Code(_) if node_type == "code" => true,
200        MarkdownAst::InlineCode(_) if node_type == "inline_code" => true,
201        MarkdownAst::Emphasis(_) if node_type == "emphasis" => true,
202        MarkdownAst::Strong(_) if node_type == "strong" => true,
203        MarkdownAst::Html(_) if node_type == "html" => true,
204        MarkdownAst::Blockquote(_) if node_type == "blockquote" => true,
205        MarkdownAst::Table(_) if node_type == "table" => true,
206        _ => {
207            // Check children recursively
208            if let Some(children) = ast.children() {
209                children.iter().any(|child| ast_contains_node_type(child, node_type))
210            } else {
211                false
212            }
213        }
214    }
215}
216
217/// Extract all nodes of a specific type from the AST
218pub fn extract_nodes_by_type<'a>(ast: &'a MarkdownAst, node_type: &str) -> Vec<&'a MarkdownAst> {
219    let mut nodes = Vec::new();
220    extract_nodes_by_type_recursive(ast, node_type, &mut nodes);
221    nodes
222}
223
224fn extract_nodes_by_type_recursive<'a>(ast: &'a MarkdownAst, node_type: &str, nodes: &mut Vec<&'a MarkdownAst>) {
225    match ast {
226        MarkdownAst::Heading(_) if node_type == "heading" => nodes.push(ast),
227        MarkdownAst::List(_) if node_type == "list" => nodes.push(ast),
228        MarkdownAst::Link(_) if node_type == "link" => nodes.push(ast),
229        MarkdownAst::Image(_) if node_type == "image" => nodes.push(ast),
230        MarkdownAst::Code(_) if node_type == "code" => nodes.push(ast),
231        MarkdownAst::InlineCode(_) if node_type == "inline_code" => nodes.push(ast),
232        MarkdownAst::Emphasis(_) if node_type == "emphasis" => nodes.push(ast),
233        MarkdownAst::Strong(_) if node_type == "strong" => nodes.push(ast),
234        MarkdownAst::Html(_) if node_type == "html" => nodes.push(ast),
235        MarkdownAst::Blockquote(_) if node_type == "blockquote" => nodes.push(ast),
236        MarkdownAst::Table(_) if node_type == "table" => nodes.push(ast),
237        _ => {}
238    }
239
240    // Check children recursively
241    if let Some(children) = ast.children() {
242        for child in children {
243            extract_nodes_by_type_recursive(child, node_type, nodes);
244        }
245    }
246}
247
248/// Utility function to get text content from AST nodes
249pub fn get_text_content(ast: &MarkdownAst) -> String {
250    match ast {
251        MarkdownAst::Text(text) => text.value.clone(),
252        MarkdownAst::InlineCode(code) => code.value.clone(),
253        MarkdownAst::Code(code) => code.value.clone(),
254        _ => {
255            if let Some(children) = ast.children() {
256                children.iter().map(get_text_content).collect::<Vec<_>>().join("")
257            } else {
258                String::new()
259            }
260        }
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_ast_cache() {
270        let mut cache = AstCache::new();
271        let content = "# Hello World\n\nThis is a test.";
272
273        let ast1 = cache.get_or_parse(content);
274        let ast2 = cache.get_or_parse(content);
275
276        // Should return the same Arc (cached)
277        assert!(Arc::ptr_eq(&ast1, &ast2));
278        assert_eq!(cache.len(), 1);
279
280        // Test usage stats
281        let stats = cache.get_stats();
282        let content_hash = crate::utils::fast_hash(content);
283        assert_eq!(stats.get(&content_hash), Some(&2));
284    }
285
286    #[test]
287    fn test_ast_cache_multiple_documents() {
288        let mut cache = AstCache::new();
289        let content1 = "# Document 1";
290        let content2 = "# Document 2";
291        let content3 = "# Document 3";
292
293        let _ast1 = cache.get_or_parse(content1);
294        let _ast2 = cache.get_or_parse(content2);
295        let _ast3 = cache.get_or_parse(content3);
296        assert_eq!(cache.len(), 3);
297
298        // Access first document again
299        let _ast1_again = cache.get_or_parse(content1);
300        assert_eq!(cache.len(), 3); // Still 3 documents
301
302        let stats = cache.get_stats();
303        let hash1 = crate::utils::fast_hash(content1);
304        assert_eq!(stats.get(&hash1), Some(&2)); // Accessed twice
305    }
306
307    #[test]
308    fn test_ast_cache_clear() {
309        let mut cache = AstCache::new();
310        cache.get_or_parse("# Test");
311        cache.get_or_parse("## Another");
312
313        assert_eq!(cache.len(), 2);
314        assert!(!cache.is_empty());
315
316        cache.clear();
317        assert_eq!(cache.len(), 0);
318        assert!(cache.is_empty());
319        assert!(cache.get_stats().is_empty());
320    }
321
322    #[test]
323    fn test_parse_markdown_ast() {
324        let content = "# Hello World\n\nThis is a test.";
325        let ast = parse_markdown_ast(content);
326
327        assert!(matches!(ast, MarkdownAst::Root(_)));
328    }
329
330    #[test]
331    fn test_problematic_list_detection() {
332        // Mixed list markers that would cause panic
333        let problematic = "* Item 1\n- Item 2\n+ Item 3";
334        assert!(content_has_problematic_lists(problematic));
335
336        // Consistent markers should be fine
337        let ok_content = "* Item 1\n* Item 2\n* Item 3";
338        assert!(!content_has_problematic_lists(ok_content));
339
340        // Different marker types separated by content
341        let separated = "* Item 1\n\nSome text\n\n- Item 2";
342        assert!(!content_has_problematic_lists(separated));
343
344        // Edge case: markers with different indentation
345        let indented = "* Item 1\n  - Subitem";
346        assert!(content_has_problematic_lists(indented));
347    }
348
349    #[test]
350    fn test_parse_malformed_markdown() {
351        // Test various malformed markdown that might cause issues
352        let test_cases = vec![
353            "",                           // Empty
354            "\n\n\n",                     // Only newlines
355            "```",                        // Unclosed code block
356            "```\ncode\n```extra```",     // Multiple code blocks
357            "[link]()",                   // Empty link URL
358            "![]()",                      // Empty image
359            "|table|without|header|",     // Malformed table
360            "> > > deeply nested quotes", // Deep nesting
361            "# \n## \n### ",              // Empty headings
362            "*unclosed emphasis",         // Unclosed emphasis
363            "**unclosed strong",          // Unclosed strong
364            "[unclosed link",             // Unclosed link
365            "![unclosed image",           // Unclosed image
366            "---\ntitle: test",           // Unclosed front matter
367        ];
368
369        for content in test_cases {
370            let ast = parse_markdown_ast(content);
371            // Should always return a valid AST, even if empty
372            assert!(matches!(ast, MarkdownAst::Root(_)));
373        }
374    }
375
376    #[test]
377    fn test_ast_with_mixed_list_markers() {
378        // This should trigger the problematic list detection
379        let content = "* First\n- Second\n+ Third";
380        let ast = parse_markdown_ast(content);
381
382        // Should return empty AST due to problematic pattern
383        if let MarkdownAst::Root(root) = ast {
384            assert!(root.children.is_empty());
385        } else {
386            panic!("Expected Root AST node");
387        }
388    }
389
390    #[test]
391    fn test_ast_contains_node_type() {
392        let content = "# Hello World\n\nThis is a [link](http://example.com).";
393        let ast = parse_markdown_ast(content);
394
395        assert!(ast_contains_node_type(&ast, "heading"));
396        assert!(ast_contains_node_type(&ast, "link"));
397        assert!(!ast_contains_node_type(&ast, "table"));
398
399        // Test with empty AST
400        let empty_ast = MarkdownAst::Root(markdown::mdast::Root {
401            children: vec![],
402            position: None,
403        });
404        assert!(!ast_contains_node_type(&empty_ast, "heading"));
405    }
406
407    #[test]
408    fn test_ast_contains_nested_nodes() {
409        let content = "> # Heading in blockquote\n> \n> With a [link](url)";
410        let ast = parse_markdown_ast(content);
411
412        assert!(ast_contains_node_type(&ast, "blockquote"));
413        assert!(ast_contains_node_type(&ast, "heading"));
414        assert!(ast_contains_node_type(&ast, "link"));
415    }
416
417    #[test]
418    fn test_extract_nodes_by_type() {
419        let content = "# Heading 1\n\n## Heading 2\n\nSome text.";
420        let ast = parse_markdown_ast(content);
421
422        let headings = extract_nodes_by_type(&ast, "heading");
423        assert_eq!(headings.len(), 2);
424
425        // Test extracting non-existent type
426        let tables = extract_nodes_by_type(&ast, "table");
427        assert_eq!(tables.len(), 0);
428    }
429
430    #[test]
431    fn test_extract_multiple_node_types() {
432        let content = "# Heading\n\n*emphasis* and **strong** and `code`\n\n[link](url) and ![image](img.png)";
433        let ast = parse_markdown_ast(content);
434
435        assert_eq!(extract_nodes_by_type(&ast, "heading").len(), 1);
436        assert_eq!(extract_nodes_by_type(&ast, "emphasis").len(), 1);
437        assert_eq!(extract_nodes_by_type(&ast, "strong").len(), 1);
438        assert_eq!(extract_nodes_by_type(&ast, "inline_code").len(), 1);
439        assert_eq!(extract_nodes_by_type(&ast, "link").len(), 1);
440        assert_eq!(extract_nodes_by_type(&ast, "image").len(), 1);
441    }
442
443    #[test]
444    fn test_get_text_content() {
445        let content = "Hello world";
446        let ast = MarkdownAst::Text(markdown::mdast::Text {
447            value: content.to_string(),
448            position: None,
449        });
450
451        assert_eq!(get_text_content(&ast), content);
452
453        // Test inline code
454        let code_ast = MarkdownAst::InlineCode(markdown::mdast::InlineCode {
455            value: "code".to_string(),
456            position: None,
457        });
458        assert_eq!(get_text_content(&code_ast), "code");
459
460        // Test code block
461        let block_ast = MarkdownAst::Code(markdown::mdast::Code {
462            value: "fn main() {}".to_string(),
463            lang: None,
464            meta: None,
465            position: None,
466        });
467        assert_eq!(get_text_content(&block_ast), "fn main() {}");
468    }
469
470    #[test]
471    fn test_get_text_content_nested() {
472        // Create a paragraph with mixed content
473        let paragraph = MarkdownAst::Paragraph(markdown::mdast::Paragraph {
474            children: vec![
475                MarkdownAst::Text(markdown::mdast::Text {
476                    value: "Hello ".to_string(),
477                    position: None,
478                }),
479                MarkdownAst::Strong(markdown::mdast::Strong {
480                    children: vec![MarkdownAst::Text(markdown::mdast::Text {
481                        value: "world".to_string(),
482                        position: None,
483                    })],
484                    position: None,
485                }),
486                MarkdownAst::Text(markdown::mdast::Text {
487                    value: "!".to_string(),
488                    position: None,
489                }),
490            ],
491            position: None,
492        });
493
494        assert_eq!(get_text_content(&paragraph), "Hello world!");
495    }
496
497    #[test]
498    fn test_global_cache_functions() {
499        // Clear cache first to ensure clean state
500        clear_ast_cache();
501
502        let content = "# Global cache test";
503        let ast1 = get_cached_ast(content);
504        let ast2 = get_cached_ast(content);
505
506        // Should be the same instance
507        assert!(Arc::ptr_eq(&ast1, &ast2));
508
509        // Check stats
510        let stats = get_ast_cache_stats();
511        assert!(!stats.is_empty());
512
513        // Clear and verify
514        clear_ast_cache();
515        let stats_after = get_ast_cache_stats();
516        assert!(stats_after.is_empty());
517    }
518
519    #[test]
520    fn test_unicode_content() {
521        let unicode_content = "# 你好世界\n\n这是一个测试。\n\n## Ñoño\n\n🚀 Emoji content!";
522        let ast = parse_markdown_ast(unicode_content);
523
524        assert!(matches!(ast, MarkdownAst::Root(_)));
525        assert!(ast_contains_node_type(&ast, "heading"));
526
527        // Extract headings and verify count
528        let headings = extract_nodes_by_type(&ast, "heading");
529        assert_eq!(headings.len(), 2);
530    }
531
532    #[test]
533    fn test_very_large_document() {
534        // Generate a large document
535        let mut content = String::new();
536        for i in 0..1000 {
537            content.push_str(&format!("# Heading {i}\n\nParagraph {i}\n\n"));
538        }
539
540        let ast = parse_markdown_ast(&content);
541        assert!(matches!(ast, MarkdownAst::Root(_)));
542
543        // Should have 1000 headings
544        let headings = extract_nodes_by_type(&ast, "heading");
545        assert_eq!(headings.len(), 1000);
546    }
547
548    #[test]
549    fn test_deeply_nested_structure() {
550        let content = "> > > > > Deeply nested blockquote\n> > > > > > Even deeper";
551        let ast = parse_markdown_ast(content);
552
553        assert!(matches!(ast, MarkdownAst::Root(_)));
554        assert!(ast_contains_node_type(&ast, "blockquote"));
555    }
556
557    #[test]
558    fn test_all_node_types() {
559        let comprehensive_content = r#"# Heading
560
561> Blockquote
562
563- List item
564
565| Table | Header |
566|-------|--------|
567| Cell  | Cell   |
568
569```rust
570code block
571```
572
573*emphasis* **strong** `inline code`
574
575[link](url) ![image](img.png)
576
577<div>HTML</div>
578
579---
580"#;
581
582        let ast = parse_markdown_ast(comprehensive_content);
583
584        // Test all node type detections
585        assert!(ast_contains_node_type(&ast, "heading"));
586        assert!(ast_contains_node_type(&ast, "blockquote"));
587        assert!(ast_contains_node_type(&ast, "list"));
588        // Tables are now supported with GFM extension enabled
589        assert!(ast_contains_node_type(&ast, "table"));
590        assert!(ast_contains_node_type(&ast, "code"));
591        assert!(ast_contains_node_type(&ast, "emphasis"));
592        assert!(ast_contains_node_type(&ast, "strong"));
593        assert!(ast_contains_node_type(&ast, "inline_code"));
594        assert!(ast_contains_node_type(&ast, "link"));
595        assert!(ast_contains_node_type(&ast, "image"));
596        assert!(ast_contains_node_type(&ast, "html"));
597    }
598
599    #[test]
600    fn test_gfm_table_parsing() {
601        // Test that GFM tables are properly parsed
602        let table_content = r#"| Column 1 | Column 2 |
603|----------|----------|
604| Cell 1   | Cell 2   |
605| Cell 3   | Cell 4   |"#;
606
607        let ast = parse_markdown_ast(table_content);
608        assert!(ast_contains_node_type(&ast, "table"));
609
610        let tables = extract_nodes_by_type(&ast, "table");
611        assert_eq!(tables.len(), 1);
612
613        // Test more complex table with alignment
614        let complex_table = r#"| Left | Center | Right |
615|:-----|:------:|------:|
616| L    |   C    |     R |
617| Left |  Mid   | Right |"#;
618
619        let ast2 = parse_markdown_ast(complex_table);
620        assert!(ast_contains_node_type(&ast2, "table"));
621    }
622
623    #[test]
624    fn test_edge_case_empty_nodes() {
625        // Test with nodes that have empty content
626        let empty_text = MarkdownAst::Text(markdown::mdast::Text {
627            value: String::new(),
628            position: None,
629        });
630        assert_eq!(get_text_content(&empty_text), "");
631
632        // Test with node that has no children method
633        let thematic_break = MarkdownAst::ThematicBreak(markdown::mdast::ThematicBreak { position: None });
634        assert_eq!(get_text_content(&thematic_break), "");
635    }
636}