1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
use crate::chunk::Chunk;
use crate::error::Result;
use crate::sizer::Sizer;
use std::str;
use tree_sitter::{Language, Node, Parser};
/// Default maximum size of a chunk.
const DEFAULT_MAX_SIZE: usize = 512;
/// A struct for splitting code into chunks.
pub struct Splitter<T: Sizer> {
/// Language of the code.
language: Language,
/// Sizer for counting the size of code chunks.
sizer: T,
/// Maximum size of a code chunk.
max_size: usize,
}
impl<T> Splitter<T>
where
T: Sizer,
{
/// Create a new `Splitter` that counts the size of code chunks with the given sizer.
///
/// # Example: split by characters
/// ```
/// use code_splitter::{CharCounter, Splitter};
///
/// let lang = tree_sitter_md::language();
/// let splitter = Splitter::new(lang, CharCounter).unwrap();
/// let chunks = splitter.split(b"hello, world!").unwrap();
/// ```
///
/// # Example: split by words
/// ```
/// use code_splitter::{Splitter, WordCounter};
///
/// let lang = tree_sitter_md::language();
/// let splitter = Splitter::new(lang, WordCounter).unwrap();
/// let chunks = splitter.split(b"hello, world!").unwrap();
/// ```
///
/// # Example: split by tokens with huggingface tokenizer
/// ```
/// # #[cfg(feature = "tokenizers")]
/// # {
/// use code_splitter::Splitter;
/// use tokenizers::Tokenizer;
///
/// let lang = tree_sitter_md::language();
/// let tokenizer = Tokenizer::from_pretrained("bert-base-cased", None).unwrap();
/// let splitter = Splitter::new(lang, tokenizer).unwrap();
/// let chunks = splitter.split(b"hello, world!").unwrap();
/// # }
/// ```
///
/// # Example: split by tokens with tiktoken core BPE
/// ```
/// # #[cfg(feature = "tiktoken-rs")]
/// # {
/// use code_splitter::Splitter;
/// use tiktoken_rs::cl100k_base;
///
/// let lang = tree_sitter_md::language();
/// let bpe = cl100k_base().unwrap();
/// let splitter = Splitter::new(lang, bpe).unwrap();
/// let chunks = splitter.split(b"hello, world!").unwrap();
/// # }
/// ```
pub fn new(language: Language, sizer: T) -> Result<Self> {
// Ensure tree-sitter-<language> crate can be loaded
Parser::new().set_language(&language)?;
Ok(Self {
language,
sizer,
max_size: DEFAULT_MAX_SIZE,
})
}
/// Set the maximum size of a chunk. The default is 512.
///
/// # Example: set the maximum size to 256
/// ```
/// use code_splitter::{CharCounter, Splitter};
///
/// let lang = tree_sitter_md::language();
/// let splitter = Splitter::new(lang, CharCounter)
/// .unwrap()
/// .with_max_size(256);
/// let chunks = splitter.split(b"hello, world!").unwrap();
/// ```
pub fn with_max_size(mut self, max_size: usize) -> Self {
self.max_size = max_size;
self
}
/// Split the code into chunks with no larger than `max_size`.
pub fn split(&self, code: &[u8]) -> Result<Vec<Chunk>> {
if code.is_empty() {
return Ok(vec![]);
}
let mut parser = Parser::new();
parser
.set_language(&self.language)
.expect("Error loading tree-sitter language");
let tree = parser.parse(code, None).ok_or("Error parsing code")?;
let root_node = tree.root_node();
let chunks = self.split_node(&root_node, 0, code)?;
Ok(chunks)
}
fn split_node(&self, node: &Node, depth: usize, code: &[u8]) -> Result<Vec<Chunk>> {
let text = node.utf8_text(code)?;
let chunk_size = self.sizer.size(text)?;
if chunk_size == 0 {
return Ok(vec![]);
}
if chunk_size <= self.max_size {
return Ok(vec![Chunk {
subtree: format!("{}: {}", format_node(node, depth), chunk_size),
range: node.range(),
size: chunk_size,
}]);
}
let chunks = node
// Traverse the children in depth-first order
.children(&mut node.walk())
.map(|child| self.split_node(&child, depth + 1, code))
.collect::<Result<Vec<_>>>()?
.into_iter()
// Join the tail and head of neighboring chunks if possible
.try_fold(Vec::new(), |mut acc, mut next| -> Result<Vec<Chunk>> {
if let Some(tail) = acc.pop() {
if let Some(head) = next.first_mut() {
let joined_size = self.joined_size(&tail, head, code)?;
if joined_size <= self.max_size {
// Concatenate the tail and head names
head.subtree = format!("{}\n{}", tail.subtree, head.subtree);
head.range.start_byte = tail.range.start_byte;
head.range.start_point = tail.range.start_point;
head.size = joined_size;
} else {
acc.push(tail);
}
} else {
// Push the tail back if next is empty
acc.push(tail);
}
}
acc.append(&mut next);
Ok(acc)
})?;
Ok(chunks)
}
fn joined_size(&self, chunk: &Chunk, next: &Chunk, code: &[u8]) -> Result<usize> {
let joined_bytes = &code[chunk.range.start_byte..next.range.end_byte];
let joined_text = str::from_utf8(joined_bytes)?;
self.sizer.size(joined_text)
}
}
fn format_node(node: &Node, depth: usize) -> String {
format!(
"{indent}{branch} {kind:<32} [{start}..{end}]",
indent = "│ ".repeat(depth.saturating_sub(1)),
branch = if depth > 0 { "├─" } else { "" },
kind = node.kind(),
start = node.start_position().row,
end = node.end_position().row
)
}