use anyhow::{ensure, Result};
use code_blocks::{get_query_subtrees, Block, BlockTree};
use tree_sitter::{Point, Query};
#[path = "../test_utils.rs"]
mod test_utils;
use test_utils::build_tree;
const PYTHON_QUERY_STRINGS: [&str; 3] = [
"(class_definition) @item",
"(function_definition) @item",
"(decorated_definition) @item",
];
fn python_queries() -> [Query; 3] {
PYTHON_QUERY_STRINGS.map(|q| Query::new(tree_sitter_python::language(), q).unwrap())
}
#[test]
fn test_get_query_subtrees() {
let text = r#"
@decor1
@decor2
class A:
...
class C:
"""class docstring"""
def __init__(self):
pass
@staticmethod()
def foo():
"""method docstring"""
def bar():
...
def func():
...
a = lambda: 5
"#;
let tree = build_tree(text, tree_sitter_python::language());
let subtrees = get_query_subtrees(&python_queries(), &tree, text);
fn get_tree_blocks(subtree: &BlockTree, blocks: &mut Vec<String>, text: &str) {
blocks.push(text[subtree.block.byte_range()].to_string());
for child in &subtree.children {
get_tree_blocks(child, blocks, text);
}
}
let mut blocks = vec![];
for t in &subtrees {
get_tree_blocks(t, &mut blocks, text);
}
insta::assert_yaml_snapshot!(blocks);
}
fn copy_item_below<'tree>(
ident: &str,
text: &str,
trees: &Vec<BlockTree<'tree>>,
) -> Option<Block<'tree>> {
let pos = text
.lines()
.enumerate()
.find_map(|(row, line)| line.find(ident).map(|col| Point::new(row + 1, col - 1)))?;
for tree in trees {
if tree.block.tail().start_position() == pos {
return Some(tree.block.clone());
}
if let Some(node) = copy_item_below(ident, text, &tree.children) {
return Some(node);
}
}
None
}
macro_rules! check {
(check: $check_fn:expr, force: $force:literal, $text:literal) => {
let text = $text;
let force = false;
let tree = build_tree($text, tree_sitter_python::language());
let items = get_query_subtrees(&python_queries(), &tree, $text);
let src_block = copy_item_below("Vsrc", $text, &items).unwrap();
let dst_item = copy_item_below("Vdst", $text, &items);
let fail_item = copy_item_below("Vfail", $text, &items);
let snapshot = if let Some(dst_item) = dst_item {
let (new_text, mut new_src_start, mut new_dst_start) =
code_blocks::move_block(src_block, dst_item, text, Some(check_fn), force).unwrap();
let mut new_lines = vec![];
let mut added_src = false;
let mut added_dst = false;
for line in new_text.lines() {
new_lines.push(line.to_string());
if new_src_start > line.len() {
new_src_start -= line.len() + 1;
} else if !added_src {
new_lines.push(" ".repeat(new_src_start) + "^ Source");
added_src = true;
}
if new_dst_start > line.len() {
new_dst_start -= line.len() + 1;
} else if !added_dst {
new_lines.push(" ".repeat(new_dst_start) + "^ Dest");
added_dst = true;
}
}
let new_text = new_lines.join("\n");
format!("input:\n{}\n---\noutput:\n{}", text, new_text)
} else if let Some(fail_item) = fail_item {
let result = code_blocks::move_block(src_block, fail_item, $text, $check_fn, $force);
assert!(result.is_err());
format!("{}\n\n{:?}", $text, result.err().unwrap())
} else {
panic!("no dst/fail item in input");
};
insta::assert_display_snapshot!(snapshot);
};
}
fn check_fn(s: &Block, d: &Block) -> Result<()> {
ensure!(
s.head().parent() == d.head().parent(),
"Blocks have different parents"
);
Ok(())
}
#[test]
fn test_move_block() {
check!(
check: Some(check_fn),
force: false,
r#"
#Vsrc
@decor1
@decor2
class A:
...
#Vdst
class C:
"""class docstring"""
def __init__(self):
pass
@staticmethod()
def foo():
"""method docstring"""
def bar():
...
def func():
...
"#
);
check!(
check: Some(check_fn),
force: false,
r#"
#Vsrc
@decor1
@decor2
class A:
...
class C:
"""class docstring"""
#Vfail
def __init__(self):
pass
@staticmethod()
def foo():
"""method docstring"""
def bar():
...
def func():
...
"#
);
check!(
check: Some(check_fn),
force: false,
r#"
@decor1
@decor2
class A:
...
class C:
"""class docstring"""
#Vsrc
def __init__(self):
pass
#Vdst
@staticmethod()
def foo():
"""method docstring"""
def bar():
...
def func():
...
"#
);
check!(
check: Some(check_fn),
force: false,
r#"
@decor1
@decor2
class A:
...
class C:
"""class docstring"""
def __init__(self):
pass
#Vsrc
@staticmethod()
def foo():
"""method docstring"""
def bar():
...
#Vfail
def func():
...
"#
);
check!(
check: Some(check_fn),
force: false,
r#"
@decor1
@decor2
class A:
...
class C:
"""class docstring"""
def __init__(self):
pass
#Vsrc
@staticmethod()
def foo():
"""method docstring"""
#Vdst
def bar():
...
def func():
...
"#
);
}