1use std::{cmp::Ordering, ops::Range};
2
3use thiserror::Error;
4use tree_sitter::{Language, LanguageError, Parser, TreeCursor, MIN_COMPATIBLE_LANGUAGE_VERSION};
5
6use crate::{
7 splitter::{SemanticLevel, Splitter},
8 trim::Trim,
9 ChunkConfig, ChunkSizer,
10};
11
12#[derive(Error, Debug)]
16#[error(transparent)]
17#[allow(clippy::module_name_repetitions)]
18pub struct CodeSplitterError(#[from] CodeSplitterErrorRepr);
19
20#[derive(Error, Debug)]
22enum CodeSplitterErrorRepr {
23 #[error(
24 "Language version {0:?} is too old. Expected at least version {min_version}",
25 min_version=MIN_COMPATIBLE_LANGUAGE_VERSION,
26 )]
27 LanguageError(LanguageError),
28}
29
30#[derive(Debug)]
34#[allow(clippy::module_name_repetitions)]
35pub struct CodeSplitter<Sizer>
36where
37 Sizer: ChunkSizer,
38{
39 chunk_config: ChunkConfig<Sizer>,
41 language: Language,
43}
44
45impl<Sizer> CodeSplitter<Sizer>
46where
47 Sizer: ChunkSizer,
48{
49 pub fn new(
63 language: impl Into<Language>,
64 chunk_config: impl Into<ChunkConfig<Sizer>>,
65 ) -> Result<Self, CodeSplitterError> {
66 let mut parser = Parser::new();
68 let language = language.into();
69 parser
70 .set_language(&language)
71 .map_err(CodeSplitterErrorRepr::LanguageError)?;
72 Ok(Self {
73 chunk_config: chunk_config.into(),
74 language,
75 })
76 }
77
78 pub fn chunks<'splitter, 'text: 'splitter>(
109 &'splitter self,
110 text: &'text str,
111 ) -> impl Iterator<Item = &'text str> + 'splitter {
112 Splitter::<_>::chunks(self, text)
113 }
114
115 pub fn chunk_indices<'splitter, 'text: 'splitter>(
129 &'splitter self,
130 text: &'text str,
131 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
132 Splitter::<_>::chunk_indices(self, text)
133 }
134}
135
136impl<Sizer> Splitter<Sizer> for CodeSplitter<Sizer>
137where
138 Sizer: ChunkSizer,
139{
140 type Level = Depth;
141
142 const TRIM: Trim = Trim::PreserveIndentation;
143
144 fn chunk_config(&self) -> &ChunkConfig<Sizer> {
145 &self.chunk_config
146 }
147
148 fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
149 let mut parser = Parser::new();
150 parser
151 .set_language(&self.language)
152 .expect("Error loading language");
154 let tree = parser.parse(text, None).expect("Error parsing source code");
159
160 CursorOffsets::new(tree.walk()).collect()
161 }
162}
163
164#[derive(Clone, Copy, Debug, Eq, PartialEq)]
168pub struct Depth(usize);
169
170impl PartialOrd for Depth {
171 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
172 Some(self.cmp(other))
173 }
174}
175
176impl Ord for Depth {
177 fn cmp(&self, other: &Self) -> Ordering {
178 other.0.cmp(&self.0)
179 }
180}
181
182struct CursorOffsets<'cursor> {
186 cursor: TreeCursor<'cursor>,
187}
188
189impl<'cursor> CursorOffsets<'cursor> {
190 fn new(cursor: TreeCursor<'cursor>) -> Self {
191 Self { cursor }
192 }
193}
194
195impl Iterator for CursorOffsets<'_> {
196 type Item = (Depth, Range<usize>);
197
198 fn next(&mut self) -> Option<Self::Item> {
199 if self.cursor.goto_first_child() {
201 return Some((
202 Depth(self.cursor.depth() as usize),
203 self.cursor.node().byte_range(),
204 ));
205 }
206
207 loop {
208 if self.cursor.goto_next_sibling() {
210 return Some((
211 Depth(self.cursor.depth() as usize),
212 self.cursor.node().byte_range(),
213 ));
214 } else if self.cursor.goto_parent() {
216 continue;
217 }
218
219 return None;
221 }
222 }
223}
224
225impl SemanticLevel for Depth {}
226
227#[cfg(test)]
228mod tests {
229 use tree_sitter::{Node, Tree};
230
231 use super::*;
232
233 #[test]
234 fn rust_splitter() {
235 let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 16).unwrap();
236 let text = "fn main() {\n let x = 5;\n}";
237 let chunks = splitter.chunks(text).collect::<Vec<_>>();
238
239 assert_eq!(chunks, vec!["fn main()", "{\n let x = 5;", "}"]);
240 }
241
242 #[test]
243 fn rust_splitter_indices() {
244 let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 16).unwrap();
245 let text = "fn main() {\n let x = 5;\n}";
246 let chunks = splitter.chunk_indices(text).collect::<Vec<_>>();
247
248 assert_eq!(
249 chunks,
250 vec![(0, "fn main()"), (10, "{\n let x = 5;"), (27, "}")]
251 );
252 }
253
254 #[test]
255 fn depth_partialord() {
256 assert_eq!(Depth(0).partial_cmp(&Depth(1)), Some(Ordering::Greater));
257 assert_eq!(Depth(1).partial_cmp(&Depth(2)), Some(Ordering::Greater));
258 assert_eq!(Depth(1).partial_cmp(&Depth(1)), Some(Ordering::Equal));
259 assert_eq!(Depth(2).partial_cmp(&Depth(1)), Some(Ordering::Less));
260 }
261
262 #[test]
263 fn depth_ord() {
264 assert_eq!(Depth(0).cmp(&Depth(1)), Ordering::Greater);
265 assert_eq!(Depth(1).cmp(&Depth(2)), Ordering::Greater);
266 assert_eq!(Depth(1).cmp(&Depth(1)), Ordering::Equal);
267 assert_eq!(Depth(2).cmp(&Depth(1)), Ordering::Less);
268 }
269
270 #[test]
271 fn depth_sorting() {
272 let mut depths = vec![Depth(0), Depth(1), Depth(2)];
273 depths.sort();
274 assert_eq!(depths, [Depth(2), Depth(1), Depth(0)]);
275 }
276
277 #[test]
279 fn optimized_code_offsets() {
280 let mut parser = Parser::new();
281 parser
282 .set_language(&tree_sitter_rust::LANGUAGE.into())
283 .expect("Error loading Rust grammar");
284 let source_code = "fn test() {
285 let x = 1;
286}";
287 let tree = parser
288 .parse(source_code, None)
289 .expect("Error parsing source code");
290
291 let offsets = CursorOffsets::new(tree.walk()).collect::<Vec<_>>();
292
293 assert_eq!(offsets, naive_offsets(&tree));
294 }
295
296 #[test]
297 fn multiple_top_siblings() {
298 let mut parser = Parser::new();
299 parser
300 .set_language(&tree_sitter_rust::LANGUAGE.into())
301 .expect("Error loading Rust grammar");
302 let source_code = "
303fn fn1() {}
304fn fn2() {}
305fn fn3() {}
306fn fn4() {}";
307 let tree = parser
308 .parse(source_code, None)
309 .expect("Error parsing source code");
310
311 let offsets = CursorOffsets::new(tree.walk()).collect::<Vec<_>>();
312
313 assert_eq!(offsets, naive_offsets(&tree));
314 }
315
316 fn naive_offsets(tree: &Tree) -> Vec<(Depth, Range<usize>)> {
317 let root_node = tree.root_node();
318 let mut offsets = vec![];
319 recursive_naive_offsets(&mut offsets, root_node, 0);
320 offsets
321 }
322
323 fn recursive_naive_offsets(
327 collection: &mut Vec<(Depth, Range<usize>)>,
328 node: Node<'_>,
329 depth: usize,
330 ) {
331 if depth > 0 {
333 collection.push((Depth(depth), node.byte_range()));
334 }
335
336 for child in node.children(&mut node.walk()) {
337 recursive_naive_offsets(collection, child, depth + 1);
338 }
339 }
340}