devgen_splitter/
splitter.rs

1//
2// spliter.rs
3// Copyright (C) 2024 imotai <codego.me@gmail.com>
4// Distributed under terms of the MIT license.
5//
6
7mod context_splitter;
8pub mod entity_splitter;
9mod line_spliter;
10
11#[cfg(test)]
12#[path = "./splitter/test_java.rs"]
13mod test_java;
14#[cfg(test)]
15#[path = "./splitter/test_python.rs"]
16mod test_python;
17#[cfg(test)]
18#[path = "./splitter/test_rust.rs"]
19mod test_rust;
20#[cfg(test)]
21#[path = "./splitter/test_solidity.rs"]
22mod test_solidity;
23#[cfg(test)]
24#[path = "./splitter/test_ts.rs"]
25mod test_ts;
26
27use crate::{
28    lang::{
29        Lang,
30        LangConfig,
31    },
32    Chunk,
33    Entity,
34    EntityType,
35    SplitOptions,
36};
37use anyhow::Result;
38use std::{
39    collections::{
40        BTreeMap,
41        HashMap,
42    },
43    ops::Range,
44};
45use tree_sitter::{
46    Node,
47    Parser,
48    Query,
49    QueryCursor,
50    Tree,
51};
52
53/// Represents a code entity with associated metadata
54#[derive(Debug, Clone, PartialEq)]
55pub struct CodeEntity {
56    /// Name of the parent entity (e.g., class name for a method)
57    pub parent_name: Option<String>,
58    /// Name of the entity (e.g., function name, class name)
59    pub name: String,
60    /// Names of interfaces or traits implemented by this entity
61    pub interface_names: Vec<String>,
62    /// Range of lines containing the entity's documentation comments
63    pub comment_line_range: Option<Range<usize>>,
64    /// Range of lines containing the entity's actual code body
65    pub body_line_range: Range<usize>,
66    /// Type of the entity (e.g., Class, Function, Interface, Method)
67    pub entity_type: EntityType,
68    /// byte range of the comment
69    pub comment_byte_range: Option<Range<usize>>,
70    /// byte range of the body
71    pub body_byte_range: Range<usize>,
72    /// line range of the parent
73    pub parent_line_range: Option<Range<usize>>,
74}
75
76#[derive(Debug, Clone, PartialEq, Default)]
77pub struct CodeChunk {
78    pub line_range: Range<usize>,
79    /// description of the chunk
80    /// entities in the chunk
81    pub entities: Vec<CodeEntity>,
82}
83
84#[derive(Debug, Clone)]
85pub struct EntityNode {
86    /// the byte range of the node
87    pub byte_range: Range<usize>,
88    /// the line range of the node. the end is included
89    pub line_range: Range<usize>,
90}
91
92fn parse_capture_for_entity<'a>(
93    lang_config: &LangConfig,
94    code: &'a str,
95    tree: &'a Tree,
96) -> Result<Vec<(HashMap<String, EntityNode>, Vec<Node<'a>>)>> {
97    let query = Query::new(&(lang_config.grammar)(), lang_config.query)?;
98    let mut query_cursor = QueryCursor::new();
99    let matches = query_cursor.matches(&query, tree.root_node(), code.as_bytes());
100    // only the method, function, struct, enum will be pushed to entity_captures_map
101    // Note: if the method and function has the same location, only the method will be captured
102    let mut entity_captures_map: BTreeMap<usize, (HashMap<String, EntityNode>, Vec<Node>)> =
103        BTreeMap::new();
104    for m in matches {
105        let mut captures: HashMap<String, EntityNode> = HashMap::new();
106        let mut parent_captures: HashMap<String, EntityNode> = HashMap::new();
107        let mut nodes = vec![];
108        let mut definition_start = 0;
109        for c in m.captures {
110            let capture_name = query.capture_names()[c.index as usize];
111            // handle the parent capture. current the list of parent capture
112            // 1. class.definition
113            // 2. method.class.name
114            // 3. method.interface.name
115            if capture_name.contains("class") || capture_name.contains("interface") {
116                parent_captures.insert(
117                    capture_name.to_string(),
118                    EntityNode {
119                        byte_range: c.node.byte_range(),
120                        line_range: c.node.start_position().row..c.node.end_position().row,
121                    },
122                );
123                continue;
124            }
125            // handle the multi times for the same capture name
126            // the line comment and block comment will be merged
127            if let Some(existing_node) = captures.get_mut(capture_name) {
128                existing_node.byte_range = Range {
129                    start: existing_node
130                        .byte_range
131                        .start
132                        .min(c.node.byte_range().start),
133                    end: existing_node.byte_range.end.max(c.node.byte_range().end),
134                };
135                existing_node.line_range = Range {
136                    start: existing_node
137                        .line_range
138                        .start
139                        .min(c.node.start_position().row),
140                    end: existing_node.line_range.end.max(c.node.end_position().row),
141                };
142            } else {
143                captures.insert(
144                    capture_name.to_string(),
145                    EntityNode {
146                        byte_range: c.node.byte_range(),
147                        line_range: c.node.start_position().row..c.node.end_position().row,
148                    },
149                );
150            }
151
152            // handle the all the definition
153            if capture_name.ends_with(".definition") {
154                definition_start = c.node.byte_range().start;
155            }
156
157            // handle the name node match
158            if capture_name.ends_with(".name") {
159                // copy the parent capture to the captures
160                parent_captures.iter().for_each(|(k, v)| {
161                    captures.insert(k.clone(), v.clone());
162                });
163                entity_captures_map.insert(definition_start, (captures.clone(), nodes));
164                // reset the captures and nodes
165                captures = HashMap::new();
166                nodes = vec![];
167            } else {
168                nodes.push(c.node);
169            }
170        }
171    }
172    Ok(entity_captures_map
173        .iter()
174        .map(|(_start, (captures, nodes))| (captures.clone(), nodes.clone()))
175        .collect::<Vec<(HashMap<String, EntityNode>, Vec<Node>)>>())
176}
177
178/// Splits the given code into chunks based on the provided options.
179///
180/// # Arguments
181///
182/// * `filename` - The name of the file containing the code.
183/// * `code` - The source code to be split.
184/// * `options` - The options for splitting the code.
185///
186/// # Returns
187///
188/// A `Result` containing a vector of `Chunk`s if successful, or an error if parsing fails.
189///
190/// # Example
191///
192/// ```
193/// use devgen_splitter::{
194///     split,
195///     SplitOptions,
196/// };
197///
198/// let code = "fn main() { println!(\"Hello, world!\"); }";
199/// let options = SplitOptions {
200///     chunk_line_limit: 5,
201/// };
202/// let chunks = split("example.rs", code, &options).unwrap();
203/// ```
204pub fn split(filename: &str, code: &str, options: &SplitOptions) -> Result<Vec<Chunk>> {
205    let Some(lang_config) = Lang::from_filename(filename) else {
206        return Err(anyhow::anyhow!("Unsupported language"));
207    };
208    let lines = code.lines().collect::<Vec<&str>>();
209    let mut parser = Parser::new();
210    parser.set_language(&(lang_config.grammar)())?;
211    let tree = parser
212        .parse(code, None)
213        .ok_or(anyhow::anyhow!("Failed to parse code"))?;
214    if lang_config.query.is_empty() {
215        return line_spliter::split_tree_node(
216            &lines,
217            &tree.root_node(),
218            options.chunk_line_limit,
219            options.chunk_line_limit / 2,
220        );
221    }
222    let captures = parse_capture_for_entity(&lang_config, code, &tree)?;
223    if captures.is_empty() {
224        return line_spliter::split_tree_node(
225            &lines,
226            &tree.root_node(),
227            options.chunk_line_limit,
228            options.chunk_line_limit / 2,
229        );
230    }
231    let entities = captures
232        .iter()
233        .filter_map(|(captures, nodes)| {
234            match context_splitter::convert_node_to_code_entity(captures, code) {
235                Ok(entity) => Some((entity, nodes.to_vec())),
236                Err(_e) => None,
237            }
238        })
239        .collect::<Vec<(CodeEntity, Vec<Node>)>>();
240    let chunks = context_splitter::merge_code_entities(code, &entities, options)?;
241    Ok(chunks
242        .iter()
243        .map(|code_chunk| {
244            let entities = code_chunk
245                .entities
246                .iter()
247                .map(|entity| {
248                    let chunk_line_range = Range {
249                        start: code_chunk
250                            .line_range
251                            .start
252                            .max(entity.body_line_range.start),
253                        end: code_chunk.line_range.end.min(entity.body_line_range.end),
254                    };
255                    Entity {
256                        name: entity.name.clone(),
257                        entity_type: entity.entity_type.clone(),
258                        parent: entity.parent_name.clone(),
259                        completed_line_range: entity.body_line_range.clone(),
260                        chunk_line_range,
261                        parent_line_range: entity.parent_line_range.clone(),
262                    }
263                })
264                .collect::<Vec<Entity>>();
265            let chunk = Chunk {
266                line_range: code_chunk.line_range.clone(),
267                entities,
268            };
269            chunk
270        })
271        .collect::<Vec<Chunk>>())
272}
273
274#[cfg(test)]
275fn run_test_case(
276    filename: &str,
277    code: &str,
278    capture_names: Vec<(usize, &str)>,
279    line_ranges: Vec<Range<usize>>,
280) {
281    let lang_config = Lang::from_filename(filename).unwrap();
282    let mut parser = Parser::new();
283    parser.set_language(&(lang_config.grammar)()).unwrap();
284    let tree = parser
285        .parse(code, None)
286        .ok_or(anyhow::anyhow!("Failed to parse code"))
287        .unwrap();
288    let captures = parse_capture_for_entity(&lang_config, code, &tree).unwrap();
289    println!("captures: {:?}", captures);
290    for (i, (index, capture_name)) in capture_names.iter().enumerate() {
291        let capture = captures[*index].0.get(*capture_name).unwrap();
292        let line_range = line_ranges[i].clone();
293        assert_eq!(
294            capture.line_range, line_range,
295            "capture_name: {}",
296            capture_name
297        );
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use rstest::*;
305
306    #[rstest]
307    fn test_rust_split_demo() {
308        let code = r#"
309fn main() { 
310    println!("Hello, world!");
311}
312
313struct Test {
314    a: i32,
315    b: i32,
316}
317
318impl Test {
319    fn test() {
320        for i in 0..10 {
321            println!("i: {}", i);
322        }
323        for i in 0..10 {
324            println!("i: {}", i);
325        }
326        for i in 0..10 {
327            println!("i: {}", i);
328        }
329        for i in 0..10 {
330            println!("i: {}", i);
331        }
332        for i in 0..10 {
333            println!("i: {}", i);
334        }
335        for i in 0..10 {
336            println!("i: {}", i);
337        }
338        for i in 0..10 {
339            println!("i: {}", i);
340        }
341        for i in 0..10 {
342            println!("i: {}", i);
343        }
344        println!("Hello, world!");
345    }
346
347
348    fn test_rust_split_2() {
349        println!("test_rust_split_2");
350    }
351}
352"#;
353        let options = SplitOptions {
354            chunk_line_limit: 5,
355        };
356        let result = split("test.rs", code, &options);
357        assert_eq!(result.is_ok(), true);
358        let chunks = result.unwrap();
359        for chunk in &chunks {
360            println!("chunk: {:?}", chunk);
361        }
362    }
363}