code_blocks/
lib.rs

1use std::ops::Range;
2
3use anyhow::{anyhow, Context, Result};
4use tree_sitter::{Node, Query, QueryCursor, Tree, TreeCursor};
5
6#[derive(PartialEq, Eq, Clone, Hash, Debug)]
7pub struct BlockTree<'tree> {
8    pub block: Block<'tree>,
9    pub children: Vec<BlockTree<'tree>>,
10}
11
12#[derive(PartialEq, Eq, Clone, Hash, Debug)]
13pub struct Block<'tree> {
14    nodes: Vec<Node<'tree>>,
15}
16
17impl<'tree> Block<'tree> {
18    pub fn new(nodes: Vec<Node<'tree>>) -> Result<Self> {
19        if nodes.is_empty() {
20            Err(anyhow!("Can't create block empty nodes vec"))
21        } else {
22            Ok(Self { nodes })
23        }
24    }
25
26    pub fn head(&self) -> &Node<'tree> {
27        self.nodes.first().expect("empty nodes vec")
28    }
29
30    pub fn tail(&self) -> &Node<'tree> {
31        self.nodes.last().expect("empty nodes vec")
32    }
33
34    pub fn head_tail(&self) -> (&Node<'tree>, &Node<'tree>) {
35        (self.head(), self.tail())
36    }
37
38    pub fn byte_range(&self) -> Range<usize> {
39        self.head().start_byte()..self.tail().end_byte()
40    }
41}
42
43pub fn get_query_subtrees<'tree>(
44    queries: &[Query],
45    tree: &'tree Tree,
46    text: &str,
47) -> Vec<BlockTree<'tree>> {
48    let mut blocks = get_blocks(queries, tree, text);
49
50    build_block_tree(&mut blocks, &mut tree.walk())
51}
52
53pub fn move_block<'tree>(
54    src_block: Block<'tree>,
55    dst_block: Block<'tree>,
56    text: &str,
57    assert_move_legal_fn: Option<impl Fn(&Block, &Block) -> Result<()>>,
58    force: bool,
59) -> Result<(String, usize, usize)> {
60    if !force {
61        if let Some(move_is_legal) = assert_move_legal_fn {
62            move_is_legal(&src_block, &dst_block).context("Illegal move operation")?;
63        }
64    }
65
66    let (src_head, src_tail) = src_block.head_tail();
67    let (dst_head, dst_tail) = dst_block.head_tail();
68    let src_block_range = src_block.byte_range();
69
70    if src_head.start_position() == dst_head.start_position() {
71        return Ok((
72            text.to_string(),
73            src_head.start_byte(),
74            dst_head.start_byte(),
75        ));
76    }
77
78    let mut new_text = text.to_string();
79
80    let src_text = &text[src_block_range.clone()];
81
82    let spaces = [
83        src_head
84            .prev_sibling()
85            .map(|s| &text[s.end_byte()..src_head.start_byte()]),
86        src_tail
87            .next_sibling()
88            .map(|s| &text[src_tail.end_byte()..s.start_byte()]),
89        dst_head
90            .prev_sibling()
91            .map(|s| &text[s.end_byte()..dst_head.start_byte()]),
92        dst_tail
93            .next_sibling()
94            .map(|s| &text[dst_tail.end_byte()..s.start_byte()]),
95    ];
96
97    let max_space = spaces
98        .into_iter()
99        .flatten()
100        .max_by(|s1, s2| s1.len().cmp(&s2.len()))
101        .unwrap_or_default();
102
103    let src_range = match (src_head.prev_sibling(), src_tail.next_sibling()) {
104        (Some(p), Some(n)) => {
105            let p_space = p.end_byte()..src_tail.end_byte();
106            let n_space = src_head.start_byte()..n.start_byte();
107
108            if p_space.len() >= n_space.len() {
109                p_space
110            } else {
111                n_space
112            }
113        }
114        (None, Some(n)) => src_head.start_byte()..n.start_byte(),
115        (Some(p), None) => p.end_byte()..src_tail.end_byte(),
116        (None, None) => src_block_range,
117    };
118
119    let src_range_len = src_range.end - src_range.start;
120
121    // move src to be below dst
122    // move down
123    let (new_src_start, new_dst_start) = if src_head.end_byte() < dst_head.end_byte() {
124        new_text.insert_str(dst_tail.end_byte(), src_text);
125        new_text.insert_str(dst_tail.end_byte(), max_space);
126        new_text.replace_range(src_range, "");
127
128        if src_head.start_byte() < dst_head.start_byte() {
129            (
130                dst_tail.end_byte() + max_space.len() - src_range_len,
131                dst_head.start_byte() - src_range_len,
132            )
133        } else {
134            (
135                dst_tail.end_byte() + max_space.len() - src_range_len,
136                dst_head.start_byte(),
137            )
138        }
139    }
140    // move up
141    else {
142        new_text.replace_range(src_range, "");
143        new_text.insert_str(dst_tail.end_byte(), src_text);
144        new_text.insert_str(dst_tail.end_byte(), max_space);
145
146        (dst_tail.end_byte() + max_space.len(), dst_head.start_byte())
147    };
148
149    Ok((new_text, new_src_start, new_dst_start))
150}
151
152fn get_blocks<'tree>(queries: &[Query], tree: &'tree Tree, text: &str) -> Vec<Block<'tree>> {
153    let mut blocks = vec![];
154
155    for query in queries {
156        let mut query_cursor = QueryCursor::new();
157        let captures = query_cursor.captures(query, tree.root_node(), text.as_bytes());
158
159        for (q_match, index) in captures {
160            if index != 0 {
161                continue;
162            }
163
164            let mut block = vec![];
165            for capture in q_match.captures {
166                block.push((capture.index, capture.node));
167            }
168
169            block.sort_by(|(i1, _), (i2, _)| i1.cmp(i2));
170
171            let nodes = block.into_iter().map(|(_, n)| n).collect::<Vec<_>>();
172
173            blocks.push(Block { nodes });
174        }
175    }
176
177    blocks
178}
179
180fn build_block_tree<'tree>(
181    blocks: &mut Vec<Block<'tree>>,
182    cursor: &mut TreeCursor<'tree>,
183) -> Vec<BlockTree<'tree>> {
184    let node = cursor.node();
185    let mut trees = vec![];
186
187    if cursor.goto_first_child() {
188        let mut children = build_block_tree(blocks, cursor);
189
190        if let Some(index) = blocks.iter().position(|b| b.tail() == &node) {
191            let block = blocks.remove(index);
192            trees.push(BlockTree { block, children });
193        } else {
194            trees.append(&mut children);
195        }
196
197        cursor.goto_parent();
198    }
199
200    if cursor.goto_next_sibling() {
201        trees.append(&mut build_block_tree(blocks, cursor));
202    }
203
204    trees
205}