1use arborium_tree_sitter::{InputEdit, Language, Parser, Point, Tree};
2
3use crate::nodes::NodeKind;
4
5pub struct MarkdownTree {
8 parser: Parser,
9 tree: Option<Tree>,
10 source: String,
11}
12
13impl MarkdownTree {
14 pub fn new(source: &str) -> Self {
16 let language = Language::new(arborium_markdown::language());
17 let mut parser = Parser::new();
18 parser
19 .set_language(&language)
20 .expect("failed to set markdown language");
21
22 let tree = parser.parse(source, None);
23
24 Self {
25 parser,
26 tree,
27 source: source.to_string(),
28 }
29 }
30
31 pub fn source(&self) -> &str {
33 &self.source
34 }
35
36 pub fn tree(&self) -> Option<&Tree> {
38 self.tree.as_ref()
39 }
40
41 pub fn set_source(&mut self, source: &str) {
43 self.source = source.to_string();
44 self.tree = self.parser.parse(source, None);
45 }
46
47 pub fn edit(
55 &mut self,
56 start_byte: usize,
57 old_end_byte: usize,
58 new_text: &str,
59 start_point: Point,
60 old_end_point: Point,
61 ) {
62 let new_end_byte = start_byte + new_text.len();
64 self.source.replace_range(start_byte..old_end_byte, new_text);
65
66 let new_end_point = byte_offset_to_point(&self.source, new_end_byte);
68
69 if let Some(tree) = &mut self.tree {
71 tree.edit(&InputEdit {
72 start_byte,
73 old_end_byte,
74 new_end_byte,
75 start_position: start_point,
76 old_end_position: old_end_point,
77 new_end_position: new_end_point,
78 });
79 }
80
81 self.tree = self.parser.parse(&self.source, self.tree.as_ref());
83 }
84
85 pub fn sexp(&self) -> Option<String> {
87 self.tree.as_ref().map(|t| t.root_node().to_sexp())
88 }
89
90 pub fn walk_blocks<F>(&self, mut visitor: F)
92 where
93 F: FnMut(BlockInfo),
94 {
95 let Some(tree) = &self.tree else { return };
96 let root = tree.root_node();
97 walk_blocks_recursive(&root, &self.source, &mut visitor);
98 }
99
100 pub fn block_at_byte(&self, byte_offset: usize) -> Option<BlockInfo> {
102 let tree = self.tree.as_ref()?;
103 let root = tree.root_node();
104
105 let node = root.named_descendant_for_byte_range(byte_offset, byte_offset)?;
107
108 let mut current = node;
110 loop {
111 let kind = NodeKind::from_ts_kind(current.kind());
112 if kind.is_block() {
113 return Some(block_info_from_node(¤t, &self.source));
114 }
115 match current.parent() {
116 Some(parent) if parent.kind() != "document" => current = parent,
117 _ => return Some(block_info_from_node(¤t, &self.source)),
118 }
119 }
120 }
121
122 pub fn node_at_byte(&self, byte_offset: usize) -> Option<NodeInfo> {
124 let tree = self.tree.as_ref()?;
125 let root = tree.root_node();
126 let node = root.descendant_for_byte_range(byte_offset, byte_offset)?;
127 let kind = refine_node_kind(&node);
128 Some(NodeInfo {
129 kind,
130 start_byte: node.start_byte(),
131 end_byte: node.end_byte(),
132 start_point: node.start_position(),
133 end_point: node.end_position(),
134 })
135 }
136}
137
138#[derive(Debug, Clone)]
140pub struct BlockInfo {
141 pub kind: NodeKind,
142 pub start_byte: usize,
143 pub end_byte: usize,
144 pub start_point: Point,
145 pub end_point: Point,
146 pub text: String,
148}
149
150#[derive(Debug, Clone)]
152pub struct NodeInfo {
153 pub kind: NodeKind,
154 pub start_byte: usize,
155 pub end_byte: usize,
156 pub start_point: Point,
157 pub end_point: Point,
158}
159
160fn is_block_node(kind: &str) -> bool {
161 NodeKind::from_ts_kind(kind).is_block()
162}
163
164fn refine_node_kind(node: &arborium_tree_sitter::Node) -> NodeKind {
166 let mut kind = NodeKind::from_ts_kind(node.kind());
167
168 if matches!(kind, NodeKind::Heading { .. }) {
170 let level = detect_heading_level(node);
171 kind = NodeKind::Heading { level };
172 }
173
174 if matches!(kind, NodeKind::BulletList) {
176 let ordered = node
177 .children(&mut node.walk())
178 .find(|c| c.kind() == "list_item")
179 .map(|item| {
180 item.children(&mut item.walk())
181 .any(|c| c.kind() == "list_marker_dot" || c.kind() == "list_marker_parenthesis")
182 })
183 .unwrap_or(false);
184 kind = if ordered {
185 NodeKind::OrderedList
186 } else {
187 NodeKind::BulletList
188 };
189 }
190
191 kind
192}
193
194fn block_info_from_node(node: &arborium_tree_sitter::Node, source: &str) -> BlockInfo {
195 let kind = refine_node_kind(node);
196
197 let start_byte = node.start_byte();
198 let end_byte = node.end_byte();
199 let text = source[start_byte..end_byte].to_string();
200
201 BlockInfo {
202 kind,
203 start_byte,
204 end_byte,
205 start_point: node.start_position(),
206 end_point: node.end_position(),
207 text,
208 }
209}
210
211fn detect_heading_level(node: &arborium_tree_sitter::Node) -> u8 {
212 if node.kind() == "setext_heading" {
213 let has_h1 = node
214 .children(&mut node.walk())
215 .any(|c| c.kind() == "setext_h1_underline");
216 return if has_h1 { 1 } else { 2 };
217 }
218 for i in 0..node.child_count() {
219 if let Some(child) = node.child(i as u32) {
220 match child.kind() {
221 "atx_h1_marker" => return 1,
222 "atx_h2_marker" => return 2,
223 "atx_h3_marker" => return 3,
224 "atx_h4_marker" => return 4,
225 "atx_h5_marker" => return 5,
226 "atx_h6_marker" => return 6,
227 _ => {}
228 }
229 }
230 }
231 1
232}
233
234fn walk_blocks_recursive<F>(
235 node: &arborium_tree_sitter::Node,
236 source: &str,
237 visitor: &mut F,
238) where
239 F: FnMut(BlockInfo),
240{
241 for i in 0..node.named_child_count() {
242 if let Some(child) = node.named_child(i as u32) {
243 let kind_str = child.kind();
244 if is_block_node(kind_str) {
245 visitor(block_info_from_node(&child, source));
246 let kind = NodeKind::from_ts_kind(kind_str);
248 if kind.is_container() {
249 walk_blocks_recursive(&child, source, visitor);
250 }
251 }
252 }
253 }
254}
255
256fn byte_offset_to_point(source: &str, byte_offset: usize) -> Point {
258 let offset = byte_offset.min(source.len());
259 let slice = &source[..offset];
260 let row = slice.matches('\n').count();
261 let last_newline = slice.rfind('\n').map(|i| i + 1).unwrap_or(0);
262 let column = offset - last_newline;
263 Point { row, column }
264}
265
266pub fn code_block_language<'a>(
268 node: &arborium_tree_sitter::Node,
269 source: &'a str,
270) -> Option<&'a str> {
271 for i in 0..node.child_count() {
272 if let Some(child) = node.child(i as u32) {
273 if child.kind() == "info_string" {
274 let text = &source[child.start_byte()..child.end_byte()];
275 let lang = text.trim();
276 if !lang.is_empty() {
277 return Some(lang);
278 }
279 }
280 }
281 }
282 None
283}
284
285pub fn code_block_content<'a>(
287 node: &arborium_tree_sitter::Node,
288 source: &'a str,
289) -> Option<&'a str> {
290 for i in 0..node.child_count() {
291 if let Some(child) = node.child(i as u32) {
292 if child.kind() == "code_fence_content" {
293 return Some(&source[child.start_byte()..child.end_byte()]);
294 }
295 }
296 }
297 None
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn parse_basic_markdown() {
306 let md = "# Hello\n\nThis is a paragraph.\n";
307 let tree = MarkdownTree::new(md);
308 assert!(tree.tree().is_some());
309
310 let sexp = tree.sexp().unwrap();
311 assert!(sexp.contains("atx_heading"));
312 assert!(sexp.contains("paragraph"));
313 }
314
315 #[test]
316 fn walk_blocks_finds_all() {
317 let md = "# Title\n\nParagraph text.\n\n- item 1\n- item 2\n\n```rust\nfn main() {}\n```\n";
318 let tree = MarkdownTree::new(md);
319
320 let mut blocks = Vec::new();
321 tree.walk_blocks(|info| blocks.push(info));
322
323 let kinds: Vec<_> = blocks.iter().map(|b| b.kind).collect();
324 assert!(kinds.contains(&NodeKind::Heading { level: 1 }));
325 assert!(kinds.contains(&NodeKind::Paragraph));
326 assert!(kinds.contains(&NodeKind::BulletList));
327 assert!(kinds.contains(&NodeKind::FencedCodeBlock));
328 }
329
330 #[test]
331 fn heading_levels() {
332 let md = "# H1\n\n## H2\n\n### H3\n";
333 let tree = MarkdownTree::new(md);
334
335 let mut headings = Vec::new();
336 tree.walk_blocks(|info| {
337 if let NodeKind::Heading { level } = info.kind {
338 headings.push(level);
339 }
340 });
341 assert_eq!(headings, vec![1, 2, 3]);
342 }
343
344 #[test]
345 fn ordered_vs_unordered_list() {
346 let md = "- bullet\n- list\n\n1. ordered\n2. list\n";
347 let tree = MarkdownTree::new(md);
348
349 let mut lists = Vec::new();
350 tree.walk_blocks(|info| {
351 match info.kind {
352 NodeKind::BulletList => lists.push(false),
353 NodeKind::OrderedList => lists.push(true),
354 _ => {}
355 }
356 });
357 assert_eq!(lists, vec![false, true]);
358 }
359
360 #[test]
361 fn fenced_code_block_language() {
362 let md = "```rust\nfn main() {}\n```\n";
363 let tree = MarkdownTree::new(md);
364
365 let t = tree.tree().unwrap();
366 let root = t.root_node();
367
368 let mut found_lang = None;
369 for i in 0..root.named_child_count() {
370 let child = root.named_child(i as u32).unwrap();
371 let code_node = if child.kind() == "fenced_code_block" {
372 Some(child)
373 } else {
374 find_child_by_kind(&child, "fenced_code_block")
375 };
376 if let Some(code) = code_node {
377 found_lang = code_block_language(&code, md).map(|s| s.to_string());
378 }
379 }
380 assert_eq!(found_lang.as_deref(), Some("rust"));
381 }
382
383 #[test]
384 fn incremental_edit() {
385 let mut tree = MarkdownTree::new("# Hello\n\nWorld\n");
386
387 tree.edit(
388 9,
389 14,
390 "Rust",
391 Point { row: 2, column: 0 },
392 Point { row: 2, column: 5 },
393 );
394
395 assert_eq!(tree.source(), "# Hello\n\nRust\n");
396 assert!(tree.tree().is_some());
397 let sexp = tree.sexp().unwrap();
398 assert!(sexp.contains("atx_heading"));
399 assert!(sexp.contains("paragraph"));
400 }
401
402 #[test]
403 fn block_at_byte_offset() {
404 let md = "# Title\n\nSome paragraph.\n";
405 let tree = MarkdownTree::new(md);
406
407 let block = tree.block_at_byte(0).unwrap();
408 assert!(matches!(block.kind, NodeKind::Heading { level: 1 }));
409
410 let block = tree.block_at_byte(10).unwrap();
411 assert_eq!(block.kind, NodeKind::Paragraph);
412 }
413
414 #[test]
415 fn empty_document() {
416 let tree = MarkdownTree::new("");
417 assert!(tree.tree().is_some());
418 let mut blocks = Vec::new();
419 tree.walk_blocks(|info| blocks.push(info));
420 assert!(blocks.is_empty());
421 }
422
423 #[test]
424 fn node_at_byte_uses_node_kind() {
425 let md = "# Hello\n";
426 let tree = MarkdownTree::new(md);
427 let node = tree.node_at_byte(2).unwrap();
428 assert!(!matches!(node.kind, NodeKind::Unknown));
430 }
431
432 fn find_child_by_kind<'a>(
433 node: &arborium_tree_sitter::Node<'a>,
434 kind: &str,
435 ) -> Option<arborium_tree_sitter::Node<'a>> {
436 for i in 0..node.named_child_count() {
437 if let Some(child) = node.named_child(i as u32) {
438 if child.kind() == kind {
439 return Some(child);
440 }
441 if let Some(found) = find_child_by_kind(&child, kind) {
442 return Some(found);
443 }
444 }
445 }
446 None
447 }
448}