1mod context_splitter;
8pub mod entity_splitter;
9mod line_spliter;
10
11#[cfg(test)]
12#[path = "./splitter/test_java.rs"]
13mod test_java;
14#[cfg(test)]
15#[path = "./splitter/test_python.rs"]
16mod test_python;
17#[cfg(test)]
18#[path = "./splitter/test_rust.rs"]
19mod test_rust;
20#[cfg(test)]
21#[path = "./splitter/test_solidity.rs"]
22mod test_solidity;
23#[cfg(test)]
24#[path = "./splitter/test_ts.rs"]
25mod test_ts;
26
27use crate::{
28 lang::{
29 Lang,
30 LangConfig,
31 },
32 Chunk,
33 Entity,
34 EntityType,
35 SplitOptions,
36};
37use anyhow::Result;
38use std::{
39 collections::{
40 BTreeMap,
41 HashMap,
42 },
43 ops::Range,
44};
45use tree_sitter::{
46 Node,
47 Parser,
48 Query,
49 QueryCursor,
50 Tree,
51};
52
53#[derive(Debug, Clone, PartialEq)]
55pub struct CodeEntity {
56 pub parent_name: Option<String>,
58 pub name: String,
60 pub interface_names: Vec<String>,
62 pub comment_line_range: Option<Range<usize>>,
64 pub body_line_range: Range<usize>,
66 pub entity_type: EntityType,
68 pub comment_byte_range: Option<Range<usize>>,
70 pub body_byte_range: Range<usize>,
72 pub parent_line_range: Option<Range<usize>>,
74}
75
76#[derive(Debug, Clone, PartialEq, Default)]
77pub struct CodeChunk {
78 pub line_range: Range<usize>,
79 pub entities: Vec<CodeEntity>,
82}
83
84#[derive(Debug, Clone)]
85pub struct EntityNode {
86 pub byte_range: Range<usize>,
88 pub line_range: Range<usize>,
90}
91
92fn parse_capture_for_entity<'a>(
93 lang_config: &LangConfig,
94 code: &'a str,
95 tree: &'a Tree,
96) -> Result<Vec<(HashMap<String, EntityNode>, Vec<Node<'a>>)>> {
97 let query = Query::new(&(lang_config.grammar)(), lang_config.query)?;
98 let mut query_cursor = QueryCursor::new();
99 let matches = query_cursor.matches(&query, tree.root_node(), code.as_bytes());
100 let mut entity_captures_map: BTreeMap<usize, (HashMap<String, EntityNode>, Vec<Node>)> =
103 BTreeMap::new();
104 for m in matches {
105 let mut captures: HashMap<String, EntityNode> = HashMap::new();
106 let mut parent_captures: HashMap<String, EntityNode> = HashMap::new();
107 let mut nodes = vec![];
108 let mut definition_start = 0;
109 for c in m.captures {
110 let capture_name = query.capture_names()[c.index as usize];
111 if capture_name.contains("class") || capture_name.contains("interface") {
116 parent_captures.insert(
117 capture_name.to_string(),
118 EntityNode {
119 byte_range: c.node.byte_range(),
120 line_range: c.node.start_position().row..c.node.end_position().row,
121 },
122 );
123 continue;
124 }
125 if let Some(existing_node) = captures.get_mut(capture_name) {
128 existing_node.byte_range = Range {
129 start: existing_node
130 .byte_range
131 .start
132 .min(c.node.byte_range().start),
133 end: existing_node.byte_range.end.max(c.node.byte_range().end),
134 };
135 existing_node.line_range = Range {
136 start: existing_node
137 .line_range
138 .start
139 .min(c.node.start_position().row),
140 end: existing_node.line_range.end.max(c.node.end_position().row),
141 };
142 } else {
143 captures.insert(
144 capture_name.to_string(),
145 EntityNode {
146 byte_range: c.node.byte_range(),
147 line_range: c.node.start_position().row..c.node.end_position().row,
148 },
149 );
150 }
151
152 if capture_name.ends_with(".definition") {
154 definition_start = c.node.byte_range().start;
155 }
156
157 if capture_name.ends_with(".name") {
159 parent_captures.iter().for_each(|(k, v)| {
161 captures.insert(k.clone(), v.clone());
162 });
163 entity_captures_map.insert(definition_start, (captures.clone(), nodes));
164 captures = HashMap::new();
166 nodes = vec![];
167 } else {
168 nodes.push(c.node);
169 }
170 }
171 }
172 Ok(entity_captures_map
173 .iter()
174 .map(|(_start, (captures, nodes))| (captures.clone(), nodes.clone()))
175 .collect::<Vec<(HashMap<String, EntityNode>, Vec<Node>)>>())
176}
177
178pub fn split(filename: &str, code: &str, options: &SplitOptions) -> Result<Vec<Chunk>> {
205 let Some(lang_config) = Lang::from_filename(filename) else {
206 return Err(anyhow::anyhow!("Unsupported language"));
207 };
208 let lines = code.lines().collect::<Vec<&str>>();
209 let mut parser = Parser::new();
210 parser.set_language(&(lang_config.grammar)())?;
211 let tree = parser
212 .parse(code, None)
213 .ok_or(anyhow::anyhow!("Failed to parse code"))?;
214 if lang_config.query.is_empty() {
215 return line_spliter::split_tree_node(
216 &lines,
217 &tree.root_node(),
218 options.chunk_line_limit,
219 options.chunk_line_limit / 2,
220 );
221 }
222 let captures = parse_capture_for_entity(&lang_config, code, &tree)?;
223 if captures.is_empty() {
224 return line_spliter::split_tree_node(
225 &lines,
226 &tree.root_node(),
227 options.chunk_line_limit,
228 options.chunk_line_limit / 2,
229 );
230 }
231 let entities = captures
232 .iter()
233 .filter_map(|(captures, nodes)| {
234 match context_splitter::convert_node_to_code_entity(captures, code) {
235 Ok(entity) => Some((entity, nodes.to_vec())),
236 Err(_e) => None,
237 }
238 })
239 .collect::<Vec<(CodeEntity, Vec<Node>)>>();
240 let chunks = context_splitter::merge_code_entities(code, &entities, options)?;
241 Ok(chunks
242 .iter()
243 .map(|code_chunk| {
244 let entities = code_chunk
245 .entities
246 .iter()
247 .map(|entity| {
248 let chunk_line_range = Range {
249 start: code_chunk
250 .line_range
251 .start
252 .max(entity.body_line_range.start),
253 end: code_chunk.line_range.end.min(entity.body_line_range.end),
254 };
255 Entity {
256 name: entity.name.clone(),
257 entity_type: entity.entity_type.clone(),
258 parent: entity.parent_name.clone(),
259 completed_line_range: entity.body_line_range.clone(),
260 chunk_line_range,
261 parent_line_range: entity.parent_line_range.clone(),
262 }
263 })
264 .collect::<Vec<Entity>>();
265 let chunk = Chunk {
266 line_range: code_chunk.line_range.clone(),
267 entities,
268 };
269 chunk
270 })
271 .collect::<Vec<Chunk>>())
272}
273
274#[cfg(test)]
275fn run_test_case(
276 filename: &str,
277 code: &str,
278 capture_names: Vec<(usize, &str)>,
279 line_ranges: Vec<Range<usize>>,
280) {
281 let lang_config = Lang::from_filename(filename).unwrap();
282 let mut parser = Parser::new();
283 parser.set_language(&(lang_config.grammar)()).unwrap();
284 let tree = parser
285 .parse(code, None)
286 .ok_or(anyhow::anyhow!("Failed to parse code"))
287 .unwrap();
288 let captures = parse_capture_for_entity(&lang_config, code, &tree).unwrap();
289 println!("captures: {:?}", captures);
290 for (i, (index, capture_name)) in capture_names.iter().enumerate() {
291 let capture = captures[*index].0.get(*capture_name).unwrap();
292 let line_range = line_ranges[i].clone();
293 assert_eq!(
294 capture.line_range, line_range,
295 "capture_name: {}",
296 capture_name
297 );
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use rstest::*;
305
306 #[rstest]
307 fn test_rust_split_demo() {
308 let code = r#"
309fn main() {
310 println!("Hello, world!");
311}
312
313struct Test {
314 a: i32,
315 b: i32,
316}
317
318impl Test {
319 fn test() {
320 for i in 0..10 {
321 println!("i: {}", i);
322 }
323 for i in 0..10 {
324 println!("i: {}", i);
325 }
326 for i in 0..10 {
327 println!("i: {}", i);
328 }
329 for i in 0..10 {
330 println!("i: {}", i);
331 }
332 for i in 0..10 {
333 println!("i: {}", i);
334 }
335 for i in 0..10 {
336 println!("i: {}", i);
337 }
338 for i in 0..10 {
339 println!("i: {}", i);
340 }
341 for i in 0..10 {
342 println!("i: {}", i);
343 }
344 println!("Hello, world!");
345 }
346
347
348 fn test_rust_split_2() {
349 println!("test_rust_split_2");
350 }
351}
352"#;
353 let options = SplitOptions {
354 chunk_line_limit: 5,
355 };
356 let result = split("test.rs", code, &options);
357 assert_eq!(result.is_ok(), true);
358 let chunks = result.unwrap();
359 for chunk in &chunks {
360 println!("chunk: {:?}", chunk);
361 }
362 }
363}