1use std::ops::Range;
2
3use anyhow::{anyhow, Context, Result};
4use tree_sitter::{Node, Query, QueryCursor, Tree, TreeCursor};
5
6#[derive(PartialEq, Eq, Clone, Hash, Debug)]
7pub struct BlockTree<'tree> {
8 pub block: Block<'tree>,
9 pub children: Vec<BlockTree<'tree>>,
10}
11
12#[derive(PartialEq, Eq, Clone, Hash, Debug)]
13pub struct Block<'tree> {
14 nodes: Vec<Node<'tree>>,
15}
16
17impl<'tree> Block<'tree> {
18 pub fn new(nodes: Vec<Node<'tree>>) -> Result<Self> {
19 if nodes.is_empty() {
20 Err(anyhow!("Can't create block empty nodes vec"))
21 } else {
22 Ok(Self { nodes })
23 }
24 }
25
26 pub fn head(&self) -> &Node<'tree> {
27 self.nodes.first().expect("empty nodes vec")
28 }
29
30 pub fn tail(&self) -> &Node<'tree> {
31 self.nodes.last().expect("empty nodes vec")
32 }
33
34 pub fn head_tail(&self) -> (&Node<'tree>, &Node<'tree>) {
35 (self.head(), self.tail())
36 }
37
38 pub fn byte_range(&self) -> Range<usize> {
39 self.head().start_byte()..self.tail().end_byte()
40 }
41}
42
43pub fn get_query_subtrees<'tree>(
44 queries: &[Query],
45 tree: &'tree Tree,
46 text: &str,
47) -> Vec<BlockTree<'tree>> {
48 let mut blocks = get_blocks(queries, tree, text);
49
50 build_block_tree(&mut blocks, &mut tree.walk())
51}
52
53pub fn move_block<'tree>(
54 src_block: Block<'tree>,
55 dst_block: Block<'tree>,
56 text: &str,
57 assert_move_legal_fn: Option<impl Fn(&Block, &Block) -> Result<()>>,
58 force: bool,
59) -> Result<(String, usize, usize)> {
60 if !force {
61 if let Some(move_is_legal) = assert_move_legal_fn {
62 move_is_legal(&src_block, &dst_block).context("Illegal move operation")?;
63 }
64 }
65
66 let (src_head, src_tail) = src_block.head_tail();
67 let (dst_head, dst_tail) = dst_block.head_tail();
68 let src_block_range = src_block.byte_range();
69
70 if src_head.start_position() == dst_head.start_position() {
71 return Ok((
72 text.to_string(),
73 src_head.start_byte(),
74 dst_head.start_byte(),
75 ));
76 }
77
78 let mut new_text = text.to_string();
79
80 let src_text = &text[src_block_range.clone()];
81
82 let spaces = [
83 src_head
84 .prev_sibling()
85 .map(|s| &text[s.end_byte()..src_head.start_byte()]),
86 src_tail
87 .next_sibling()
88 .map(|s| &text[src_tail.end_byte()..s.start_byte()]),
89 dst_head
90 .prev_sibling()
91 .map(|s| &text[s.end_byte()..dst_head.start_byte()]),
92 dst_tail
93 .next_sibling()
94 .map(|s| &text[dst_tail.end_byte()..s.start_byte()]),
95 ];
96
97 let max_space = spaces
98 .into_iter()
99 .flatten()
100 .max_by(|s1, s2| s1.len().cmp(&s2.len()))
101 .unwrap_or_default();
102
103 let src_range = match (src_head.prev_sibling(), src_tail.next_sibling()) {
104 (Some(p), Some(n)) => {
105 let p_space = p.end_byte()..src_tail.end_byte();
106 let n_space = src_head.start_byte()..n.start_byte();
107
108 if p_space.len() >= n_space.len() {
109 p_space
110 } else {
111 n_space
112 }
113 }
114 (None, Some(n)) => src_head.start_byte()..n.start_byte(),
115 (Some(p), None) => p.end_byte()..src_tail.end_byte(),
116 (None, None) => src_block_range,
117 };
118
119 let src_range_len = src_range.end - src_range.start;
120
121 let (new_src_start, new_dst_start) = if src_head.end_byte() < dst_head.end_byte() {
124 new_text.insert_str(dst_tail.end_byte(), src_text);
125 new_text.insert_str(dst_tail.end_byte(), max_space);
126 new_text.replace_range(src_range, "");
127
128 if src_head.start_byte() < dst_head.start_byte() {
129 (
130 dst_tail.end_byte() + max_space.len() - src_range_len,
131 dst_head.start_byte() - src_range_len,
132 )
133 } else {
134 (
135 dst_tail.end_byte() + max_space.len() - src_range_len,
136 dst_head.start_byte(),
137 )
138 }
139 }
140 else {
142 new_text.replace_range(src_range, "");
143 new_text.insert_str(dst_tail.end_byte(), src_text);
144 new_text.insert_str(dst_tail.end_byte(), max_space);
145
146 (dst_tail.end_byte() + max_space.len(), dst_head.start_byte())
147 };
148
149 Ok((new_text, new_src_start, new_dst_start))
150}
151
152fn get_blocks<'tree>(queries: &[Query], tree: &'tree Tree, text: &str) -> Vec<Block<'tree>> {
153 let mut blocks = vec![];
154
155 for query in queries {
156 let mut query_cursor = QueryCursor::new();
157 let captures = query_cursor.captures(query, tree.root_node(), text.as_bytes());
158
159 for (q_match, index) in captures {
160 if index != 0 {
161 continue;
162 }
163
164 let mut block = vec![];
165 for capture in q_match.captures {
166 block.push((capture.index, capture.node));
167 }
168
169 block.sort_by(|(i1, _), (i2, _)| i1.cmp(i2));
170
171 let nodes = block.into_iter().map(|(_, n)| n).collect::<Vec<_>>();
172
173 blocks.push(Block { nodes });
174 }
175 }
176
177 blocks
178}
179
180fn build_block_tree<'tree>(
181 blocks: &mut Vec<Block<'tree>>,
182 cursor: &mut TreeCursor<'tree>,
183) -> Vec<BlockTree<'tree>> {
184 let node = cursor.node();
185 let mut trees = vec![];
186
187 if cursor.goto_first_child() {
188 let mut children = build_block_tree(blocks, cursor);
189
190 if let Some(index) = blocks.iter().position(|b| b.tail() == &node) {
191 let block = blocks.remove(index);
192 trees.push(BlockTree { block, children });
193 } else {
194 trees.append(&mut children);
195 }
196
197 cursor.goto_parent();
198 }
199
200 if cursor.goto_next_sibling() {
201 trees.append(&mut build_block_tree(blocks, cursor));
202 }
203
204 trees
205}