Skip to main content

kode_markdown/
parse.rs

1use arborium_tree_sitter::{InputEdit, Language, Parser, Point, Tree};
2
3use crate::nodes::NodeKind;
4
5/// Wraps a tree-sitter parser configured for markdown.
6/// Supports incremental re-parsing after edits.
7pub struct MarkdownTree {
8    parser: Parser,
9    tree: Option<Tree>,
10    source: String,
11}
12
13impl MarkdownTree {
14    /// Create a new markdown parser and parse the given source.
15    pub fn new(source: &str) -> Self {
16        let language = Language::new(arborium_markdown::language());
17        let mut parser = Parser::new();
18        parser
19            .set_language(&language)
20            .expect("failed to set markdown language");
21
22        let tree = parser.parse(source, None);
23
24        Self {
25            parser,
26            tree,
27            source: source.to_string(),
28        }
29    }
30
31    /// Get the current source text.
32    pub fn source(&self) -> &str {
33        &self.source
34    }
35
36    /// Get the current parse tree, if available.
37    pub fn tree(&self) -> Option<&Tree> {
38        self.tree.as_ref()
39    }
40
41    /// Replace the entire source and re-parse from scratch.
42    pub fn set_source(&mut self, source: &str) {
43        self.source = source.to_string();
44        self.tree = self.parser.parse(source, None);
45    }
46
47    /// Apply an edit and incrementally re-parse.
48    ///
49    /// `start_byte` / `old_end_byte` / `new_end_byte` describe the edit in byte offsets.
50    /// Points describe the same edit in row/column coordinates.
51    ///
52    /// # Panics
53    /// Panics if `start_byte` or `old_end_byte` are not on UTF-8 char boundaries.
54    pub fn edit(
55        &mut self,
56        start_byte: usize,
57        old_end_byte: usize,
58        new_text: &str,
59        start_point: Point,
60        old_end_point: Point,
61    ) {
62        // Apply the edit to our source string
63        let new_end_byte = start_byte + new_text.len();
64        self.source.replace_range(start_byte..old_end_byte, new_text);
65
66        // Calculate new end point
67        let new_end_point = byte_offset_to_point(&self.source, new_end_byte);
68
69        // Tell tree-sitter about the edit
70        if let Some(tree) = &mut self.tree {
71            tree.edit(&InputEdit {
72                start_byte,
73                old_end_byte,
74                new_end_byte,
75                start_position: start_point,
76                old_end_position: old_end_point,
77                new_end_position: new_end_point,
78            });
79        }
80
81        // Re-parse incrementally
82        self.tree = self.parser.parse(&self.source, self.tree.as_ref());
83    }
84
85    /// Get the S-expression representation of the parse tree (for debugging).
86    pub fn sexp(&self) -> Option<String> {
87        self.tree.as_ref().map(|t| t.root_node().to_sexp())
88    }
89
90    /// Walk the top-level blocks of the document, calling the visitor for each.
91    pub fn walk_blocks<F>(&self, mut visitor: F)
92    where
93        F: FnMut(BlockInfo),
94    {
95        let Some(tree) = &self.tree else { return };
96        let root = tree.root_node();
97        walk_blocks_recursive(&root, &self.source, &mut visitor);
98    }
99
100    /// Find the block node at the given byte offset.
101    pub fn block_at_byte(&self, byte_offset: usize) -> Option<BlockInfo> {
102        let tree = self.tree.as_ref()?;
103        let root = tree.root_node();
104
105        // Find the deepest named node at this offset
106        let node = root.named_descendant_for_byte_range(byte_offset, byte_offset)?;
107
108        // Walk up to find the nearest block-level node
109        let mut current = node;
110        loop {
111            let kind = NodeKind::from_ts_kind(current.kind());
112            if kind.is_block() {
113                return Some(block_info_from_node(&current, &self.source));
114            }
115            match current.parent() {
116                Some(parent) if parent.kind() != "document" => current = parent,
117                _ => return Some(block_info_from_node(&current, &self.source)),
118            }
119        }
120    }
121
122    /// Find the innermost node at the given byte offset.
123    pub fn node_at_byte(&self, byte_offset: usize) -> Option<NodeInfo> {
124        let tree = self.tree.as_ref()?;
125        let root = tree.root_node();
126        let node = root.descendant_for_byte_range(byte_offset, byte_offset)?;
127        let kind = refine_node_kind(&node);
128        Some(NodeInfo {
129            kind,
130            start_byte: node.start_byte(),
131            end_byte: node.end_byte(),
132            start_point: node.start_position(),
133            end_point: node.end_position(),
134        })
135    }
136}
137
138/// Information about a block-level element.
139#[derive(Debug, Clone)]
140pub struct BlockInfo {
141    pub kind: NodeKind,
142    pub start_byte: usize,
143    pub end_byte: usize,
144    pub start_point: Point,
145    pub end_point: Point,
146    /// The raw text of this block.
147    pub text: String,
148}
149
150/// Information about any node in the tree.
151#[derive(Debug, Clone)]
152pub struct NodeInfo {
153    pub kind: NodeKind,
154    pub start_byte: usize,
155    pub end_byte: usize,
156    pub start_point: Point,
157    pub end_point: Point,
158}
159
160fn is_block_node(kind: &str) -> bool {
161    NodeKind::from_ts_kind(kind).is_block()
162}
163
164/// Refine a NodeKind from tree-sitter, resolving heading levels and list types.
165fn refine_node_kind(node: &arborium_tree_sitter::Node) -> NodeKind {
166    let mut kind = NodeKind::from_ts_kind(node.kind());
167
168    // Refine heading level
169    if matches!(kind, NodeKind::Heading { .. }) {
170        let level = detect_heading_level(node);
171        kind = NodeKind::Heading { level };
172    }
173
174    // Refine list type (bullet vs ordered)
175    if matches!(kind, NodeKind::BulletList) {
176        let ordered = node
177            .children(&mut node.walk())
178            .find(|c| c.kind() == "list_item")
179            .map(|item| {
180                item.children(&mut item.walk())
181                    .any(|c| c.kind() == "list_marker_dot" || c.kind() == "list_marker_parenthesis")
182            })
183            .unwrap_or(false);
184        kind = if ordered {
185            NodeKind::OrderedList
186        } else {
187            NodeKind::BulletList
188        };
189    }
190
191    kind
192}
193
194fn block_info_from_node(node: &arborium_tree_sitter::Node, source: &str) -> BlockInfo {
195    let kind = refine_node_kind(node);
196
197    let start_byte = node.start_byte();
198    let end_byte = node.end_byte();
199    let text = source[start_byte..end_byte].to_string();
200
201    BlockInfo {
202        kind,
203        start_byte,
204        end_byte,
205        start_point: node.start_position(),
206        end_point: node.end_position(),
207        text,
208    }
209}
210
211fn detect_heading_level(node: &arborium_tree_sitter::Node) -> u8 {
212    if node.kind() == "setext_heading" {
213        let has_h1 = node
214            .children(&mut node.walk())
215            .any(|c| c.kind() == "setext_h1_underline");
216        return if has_h1 { 1 } else { 2 };
217    }
218    for i in 0..node.child_count() {
219        if let Some(child) = node.child(i as u32) {
220            match child.kind() {
221                "atx_h1_marker" => return 1,
222                "atx_h2_marker" => return 2,
223                "atx_h3_marker" => return 3,
224                "atx_h4_marker" => return 4,
225                "atx_h5_marker" => return 5,
226                "atx_h6_marker" => return 6,
227                _ => {}
228            }
229        }
230    }
231    1
232}
233
234fn walk_blocks_recursive<F>(
235    node: &arborium_tree_sitter::Node,
236    source: &str,
237    visitor: &mut F,
238) where
239    F: FnMut(BlockInfo),
240{
241    for i in 0..node.named_child_count() {
242        if let Some(child) = node.named_child(i as u32) {
243            let kind_str = child.kind();
244            if is_block_node(kind_str) {
245                visitor(block_info_from_node(&child, source));
246                // Recurse into containers
247                let kind = NodeKind::from_ts_kind(kind_str);
248                if kind.is_container() {
249                    walk_blocks_recursive(&child, source, visitor);
250                }
251            }
252        }
253    }
254}
255
256/// Convert a byte offset in a string to a tree-sitter Point (row, column in bytes).
257fn byte_offset_to_point(source: &str, byte_offset: usize) -> Point {
258    let offset = byte_offset.min(source.len());
259    let slice = &source[..offset];
260    let row = slice.matches('\n').count();
261    let last_newline = slice.rfind('\n').map(|i| i + 1).unwrap_or(0);
262    let column = offset - last_newline;
263    Point { row, column }
264}
265
266/// Extract the info string (language) from a fenced code block node.
267pub fn code_block_language<'a>(
268    node: &arborium_tree_sitter::Node,
269    source: &'a str,
270) -> Option<&'a str> {
271    for i in 0..node.child_count() {
272        if let Some(child) = node.child(i as u32) {
273            if child.kind() == "info_string" {
274                let text = &source[child.start_byte()..child.end_byte()];
275                let lang = text.trim();
276                if !lang.is_empty() {
277                    return Some(lang);
278                }
279            }
280        }
281    }
282    None
283}
284
285/// Extract the content of a fenced code block (without fences).
286pub fn code_block_content<'a>(
287    node: &arborium_tree_sitter::Node,
288    source: &'a str,
289) -> Option<&'a str> {
290    for i in 0..node.child_count() {
291        if let Some(child) = node.child(i as u32) {
292            if child.kind() == "code_fence_content" {
293                return Some(&source[child.start_byte()..child.end_byte()]);
294            }
295        }
296    }
297    None
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn parse_basic_markdown() {
306        let md = "# Hello\n\nThis is a paragraph.\n";
307        let tree = MarkdownTree::new(md);
308        assert!(tree.tree().is_some());
309
310        let sexp = tree.sexp().unwrap();
311        assert!(sexp.contains("atx_heading"));
312        assert!(sexp.contains("paragraph"));
313    }
314
315    #[test]
316    fn walk_blocks_finds_all() {
317        let md = "# Title\n\nParagraph text.\n\n- item 1\n- item 2\n\n```rust\nfn main() {}\n```\n";
318        let tree = MarkdownTree::new(md);
319
320        let mut blocks = Vec::new();
321        tree.walk_blocks(|info| blocks.push(info));
322
323        let kinds: Vec<_> = blocks.iter().map(|b| b.kind).collect();
324        assert!(kinds.contains(&NodeKind::Heading { level: 1 }));
325        assert!(kinds.contains(&NodeKind::Paragraph));
326        assert!(kinds.contains(&NodeKind::BulletList));
327        assert!(kinds.contains(&NodeKind::FencedCodeBlock));
328    }
329
330    #[test]
331    fn heading_levels() {
332        let md = "# H1\n\n## H2\n\n### H3\n";
333        let tree = MarkdownTree::new(md);
334
335        let mut headings = Vec::new();
336        tree.walk_blocks(|info| {
337            if let NodeKind::Heading { level } = info.kind {
338                headings.push(level);
339            }
340        });
341        assert_eq!(headings, vec![1, 2, 3]);
342    }
343
344    #[test]
345    fn ordered_vs_unordered_list() {
346        let md = "- bullet\n- list\n\n1. ordered\n2. list\n";
347        let tree = MarkdownTree::new(md);
348
349        let mut lists = Vec::new();
350        tree.walk_blocks(|info| {
351            match info.kind {
352                NodeKind::BulletList => lists.push(false),
353                NodeKind::OrderedList => lists.push(true),
354                _ => {}
355            }
356        });
357        assert_eq!(lists, vec![false, true]);
358    }
359
360    #[test]
361    fn fenced_code_block_language() {
362        let md = "```rust\nfn main() {}\n```\n";
363        let tree = MarkdownTree::new(md);
364
365        let t = tree.tree().unwrap();
366        let root = t.root_node();
367
368        let mut found_lang = None;
369        for i in 0..root.named_child_count() {
370            let child = root.named_child(i as u32).unwrap();
371            let code_node = if child.kind() == "fenced_code_block" {
372                Some(child)
373            } else {
374                find_child_by_kind(&child, "fenced_code_block")
375            };
376            if let Some(code) = code_node {
377                found_lang = code_block_language(&code, md).map(|s| s.to_string());
378            }
379        }
380        assert_eq!(found_lang.as_deref(), Some("rust"));
381    }
382
383    #[test]
384    fn incremental_edit() {
385        let mut tree = MarkdownTree::new("# Hello\n\nWorld\n");
386
387        tree.edit(
388            9,
389            14,
390            "Rust",
391            Point { row: 2, column: 0 },
392            Point { row: 2, column: 5 },
393        );
394
395        assert_eq!(tree.source(), "# Hello\n\nRust\n");
396        assert!(tree.tree().is_some());
397        let sexp = tree.sexp().unwrap();
398        assert!(sexp.contains("atx_heading"));
399        assert!(sexp.contains("paragraph"));
400    }
401
402    #[test]
403    fn block_at_byte_offset() {
404        let md = "# Title\n\nSome paragraph.\n";
405        let tree = MarkdownTree::new(md);
406
407        let block = tree.block_at_byte(0).unwrap();
408        assert!(matches!(block.kind, NodeKind::Heading { level: 1 }));
409
410        let block = tree.block_at_byte(10).unwrap();
411        assert_eq!(block.kind, NodeKind::Paragraph);
412    }
413
414    #[test]
415    fn empty_document() {
416        let tree = MarkdownTree::new("");
417        assert!(tree.tree().is_some());
418        let mut blocks = Vec::new();
419        tree.walk_blocks(|info| blocks.push(info));
420        assert!(blocks.is_empty());
421    }
422
423    #[test]
424    fn node_at_byte_uses_node_kind() {
425        let md = "# Hello\n";
426        let tree = MarkdownTree::new(md);
427        let node = tree.node_at_byte(2).unwrap();
428        // Should return a typed NodeKind, not a raw string
429        assert!(!matches!(node.kind, NodeKind::Unknown));
430    }
431
432    fn find_child_by_kind<'a>(
433        node: &arborium_tree_sitter::Node<'a>,
434        kind: &str,
435    ) -> Option<arborium_tree_sitter::Node<'a>> {
436        for i in 0..node.named_child_count() {
437            if let Some(child) = node.named_child(i as u32) {
438                if child.kind() == kind {
439                    return Some(child);
440                }
441                if let Some(found) = find_child_by_kind(&child, kind) {
442                    return Some(found);
443                }
444            }
445        }
446        None
447    }
448}