use std::{cmp::Ordering, ops::Range};
use thiserror::Error;
use tree_sitter::{Language, LanguageError, Parser, TreeCursor, MIN_COMPATIBLE_LANGUAGE_VERSION};
use crate::{
splitter::{SemanticLevel, Splitter},
trim::Trim,
ChunkConfig, ChunkSizer,
};
use super::ChunkCharIndex;
#[derive(Error, Debug)]
#[error(transparent)]
pub struct CodeSplitterError(#[from] CodeSplitterErrorRepr);
#[derive(Error, Debug)]
enum CodeSplitterErrorRepr {
#[error(
"Language version {0:?} is too old. Expected at least version {min_version}",
min_version=MIN_COMPATIBLE_LANGUAGE_VERSION,
)]
LanguageError(LanguageError),
}
#[derive(Debug)]
pub struct CodeSplitter<Sizer>
where
Sizer: ChunkSizer,
{
chunk_config: ChunkConfig<Sizer>,
language: Language,
}
impl<Sizer> CodeSplitter<Sizer>
where
Sizer: ChunkSizer,
{
pub fn new(
language: impl Into<Language>,
chunk_config: impl Into<ChunkConfig<Sizer>>,
) -> Result<Self, CodeSplitterError> {
let mut parser = Parser::new();
let language = language.into();
parser
.set_language(&language)
.map_err(CodeSplitterErrorRepr::LanguageError)?;
Ok(Self {
chunk_config: chunk_config.into(),
language,
})
}
pub fn chunks<'splitter, 'text: 'splitter>(
&'splitter self,
text: &'text str,
) -> impl Iterator<Item = &'text str> + 'splitter {
Splitter::<_>::chunks(self, text)
}
pub fn chunk_indices<'splitter, 'text: 'splitter>(
&'splitter self,
text: &'text str,
) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
Splitter::<_>::chunk_indices(self, text)
}
pub fn chunk_char_indices<'splitter, 'text: 'splitter>(
&'splitter self,
text: &'text str,
) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter {
Splitter::<_>::chunk_char_indices(self, text)
}
}
impl<Sizer> Splitter<Sizer> for CodeSplitter<Sizer>
where
Sizer: ChunkSizer,
{
type Level = Depth;
const TRIM: Trim = Trim::PreserveIndentation;
fn chunk_config(&self) -> &ChunkConfig<Sizer> {
&self.chunk_config
}
fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
let mut parser = Parser::new();
parser
.set_language(&self.language)
.expect("Error loading language");
let tree = parser.parse(text, None).expect("Error parsing source code");
CursorOffsets::new(tree.walk()).collect()
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct Depth(usize);
impl PartialOrd for Depth {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Depth {
fn cmp(&self, other: &Self) -> Ordering {
other.0.cmp(&self.0)
}
}
struct CursorOffsets<'cursor> {
cursor: TreeCursor<'cursor>,
}
impl<'cursor> CursorOffsets<'cursor> {
fn new(cursor: TreeCursor<'cursor>) -> Self {
Self { cursor }
}
}
impl Iterator for CursorOffsets<'_> {
type Item = (Depth, Range<usize>);
fn next(&mut self) -> Option<Self::Item> {
if self.cursor.goto_first_child() {
return Some((
Depth(self.cursor.depth() as usize),
self.cursor.node().byte_range(),
));
}
loop {
if self.cursor.goto_next_sibling() {
return Some((
Depth(self.cursor.depth() as usize),
self.cursor.node().byte_range(),
));
} else if self.cursor.goto_parent() {
continue;
}
return None;
}
}
}
impl SemanticLevel for Depth {}
#[cfg(test)]
mod tests {
use tree_sitter::{Node, Tree};
use super::*;
#[test]
fn rust_splitter() {
let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 16).unwrap();
let text = "fn main() {\n let x = 5;\n}";
let chunks = splitter.chunks(text).collect::<Vec<_>>();
assert_eq!(chunks, vec!["fn main()", "{\n let x = 5;", "}"]);
}
#[test]
fn rust_splitter_indices() {
let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 16).unwrap();
let text = "fn main() {\n let x = 5;\n}";
let chunks = splitter.chunk_indices(text).collect::<Vec<_>>();
assert_eq!(
chunks,
vec![(0, "fn main()"), (10, "{\n let x = 5;"), (27, "}")]
);
}
#[test]
fn rust_splitter_char_indices() {
let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 16).unwrap();
let text = "fn main() {\n let x = 5;\n}";
let chunks = splitter.chunk_char_indices(text).collect::<Vec<_>>();
assert_eq!(
chunks,
vec![
ChunkCharIndex {
chunk: "fn main()",
byte_offset: 0,
char_offset: 0
},
ChunkCharIndex {
chunk: "{\n let x = 5;",
byte_offset: 10,
char_offset: 10
},
ChunkCharIndex {
chunk: "}",
byte_offset: 27,
char_offset: 27
}
]
);
}
#[test]
fn depth_partialord() {
assert_eq!(Depth(0).partial_cmp(&Depth(1)), Some(Ordering::Greater));
assert_eq!(Depth(1).partial_cmp(&Depth(2)), Some(Ordering::Greater));
assert_eq!(Depth(1).partial_cmp(&Depth(1)), Some(Ordering::Equal));
assert_eq!(Depth(2).partial_cmp(&Depth(1)), Some(Ordering::Less));
}
#[test]
fn depth_ord() {
assert_eq!(Depth(0).cmp(&Depth(1)), Ordering::Greater);
assert_eq!(Depth(1).cmp(&Depth(2)), Ordering::Greater);
assert_eq!(Depth(1).cmp(&Depth(1)), Ordering::Equal);
assert_eq!(Depth(2).cmp(&Depth(1)), Ordering::Less);
}
#[test]
fn depth_sorting() {
let mut depths = vec![Depth(0), Depth(1), Depth(2)];
depths.sort();
assert_eq!(depths, [Depth(2), Depth(1), Depth(0)]);
}
#[test]
fn optimized_code_offsets() {
let mut parser = Parser::new();
parser
.set_language(&tree_sitter_rust::LANGUAGE.into())
.expect("Error loading Rust grammar");
let source_code = "fn test() {
let x = 1;
}";
let tree = parser
.parse(source_code, None)
.expect("Error parsing source code");
let offsets = CursorOffsets::new(tree.walk()).collect::<Vec<_>>();
assert_eq!(offsets, naive_offsets(&tree));
}
#[test]
fn multiple_top_siblings() {
let mut parser = Parser::new();
parser
.set_language(&tree_sitter_rust::LANGUAGE.into())
.expect("Error loading Rust grammar");
let source_code = "
fn fn1() {}
fn fn2() {}
fn fn3() {}
fn fn4() {}";
let tree = parser
.parse(source_code, None)
.expect("Error parsing source code");
let offsets = CursorOffsets::new(tree.walk()).collect::<Vec<_>>();
assert_eq!(offsets, naive_offsets(&tree));
}
fn naive_offsets(tree: &Tree) -> Vec<(Depth, Range<usize>)> {
let root_node = tree.root_node();
let mut offsets = vec![];
recursive_naive_offsets(&mut offsets, root_node, 0);
offsets
}
fn recursive_naive_offsets(
collection: &mut Vec<(Depth, Range<usize>)>,
node: Node<'_>,
depth: usize,
) {
if depth > 0 {
collection.push((Depth(depth), node.byte_range()));
}
for child in node.children(&mut node.walk()) {
recursive_naive_offsets(collection, child, depth + 1);
}
}
}