1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
use crate::chunk::Chunk;
use crate::error::Result;
use crate::sizer::Sizer;

use std::str;
use tree_sitter::{Language, Node, Parser};

/// Default maximum size of a chunk.
const DEFAULT_MAX_SIZE: usize = 512;

/// A struct for splitting code into chunks.
pub struct Splitter<T: Sizer> {
    /// Language of the code.
    language: Language,
    /// Sizer for counting the size of code chunks.
    sizer: T,
    /// Maximum size of a code chunk.
    max_size: usize,
}

impl<T> Splitter<T>
where
    T: Sizer,
{
    /// Create a new `Splitter` that counts the size of code chunks with the given sizer.
    ///
    /// # Example: split by characters
    /// ```
    /// use code_splitter::{CharCounter, Splitter};
    ///
    /// let lang = tree_sitter_md::language();
    /// let splitter = Splitter::new(lang, CharCounter).unwrap();
    /// let chunks = splitter.split(b"hello, world!").unwrap();
    /// ```
    ///
    /// # Example: split by words
    /// ```
    /// use code_splitter::{Splitter, WordCounter};
    ///
    /// let lang = tree_sitter_md::language();
    /// let splitter = Splitter::new(lang, WordCounter).unwrap();
    /// let chunks = splitter.split(b"hello, world!").unwrap();
    /// ```
    ///
    /// # Example: split by tokens with huggingface tokenizer
    /// ```
    /// # #[cfg(feature = "tokenizers")]
    /// # {
    /// use code_splitter::Splitter;
    /// use tokenizers::Tokenizer;
    ///
    /// let lang = tree_sitter_md::language();
    /// let tokenizer = Tokenizer::from_pretrained("bert-base-cased", None).unwrap();
    /// let splitter = Splitter::new(lang, tokenizer).unwrap();
    /// let chunks = splitter.split(b"hello, world!").unwrap();
    /// # }
    /// ```
    ///
    /// # Example: split by tokens with tiktoken core BPE
    /// ```
    /// # #[cfg(feature = "tiktoken-rs")]
    /// # {
    /// use code_splitter::Splitter;
    /// use tiktoken_rs::cl100k_base;
    ///
    /// let lang = tree_sitter_md::language();
    /// let bpe = cl100k_base().unwrap();
    /// let splitter = Splitter::new(lang, bpe).unwrap();
    /// let chunks = splitter.split(b"hello, world!").unwrap();
    /// # }
    /// ```
    pub fn new(language: Language, sizer: T) -> Result<Self> {
        // Ensure tree-sitter-<language> crate can be loaded
        Parser::new().set_language(&language)?;

        Ok(Self {
            language,
            sizer,
            max_size: DEFAULT_MAX_SIZE,
        })
    }

    /// Set the maximum size of a chunk. The default is 512.
    ///
    /// # Example: set the maximum size to 256
    /// ```
    /// use code_splitter::{CharCounter, Splitter};
    ///
    /// let lang = tree_sitter_md::language();
    /// let splitter = Splitter::new(lang, CharCounter)
    ///   .unwrap()
    ///   .with_max_size(256);
    /// let chunks = splitter.split(b"hello, world!").unwrap();
    /// ```
    pub fn with_max_size(mut self, max_size: usize) -> Self {
        self.max_size = max_size;
        self
    }

    /// Split the code into chunks with no larger than `max_size`.
    pub fn split(&self, code: &[u8]) -> Result<Vec<Chunk>> {
        if code.is_empty() {
            return Ok(vec![]);
        }

        let mut parser = Parser::new();
        parser
            .set_language(&self.language)
            .expect("Error loading tree-sitter language");
        let tree = parser.parse(code, None).ok_or("Error parsing code")?;
        let root_node = tree.root_node();

        let chunks = self.split_node(&root_node, 0, code)?;

        Ok(chunks)
    }

    fn split_node(&self, node: &Node, depth: usize, code: &[u8]) -> Result<Vec<Chunk>> {
        let text = node.utf8_text(code)?;
        let chunk_size = self.sizer.size(text)?;

        if chunk_size == 0 {
            return Ok(vec![]);
        }

        if chunk_size <= self.max_size {
            return Ok(vec![Chunk {
                subtree: format!("{}: {}", format_node(node, depth), chunk_size),
                range: node.range(),
                size: chunk_size,
            }]);
        }

        let chunks = node
            // Traverse the children in depth-first order
            .children(&mut node.walk())
            .map(|child| self.split_node(&child, depth + 1, code))
            .collect::<Result<Vec<_>>>()?
            .into_iter()
            // Join the tail and head of neighboring chunks if possible
            .try_fold(Vec::new(), |mut acc, mut next| -> Result<Vec<Chunk>> {
                if let Some(tail) = acc.pop() {
                    if let Some(head) = next.first_mut() {
                        let joined_size = self.joined_size(&tail, head, code)?;
                        if joined_size <= self.max_size {
                            // Concatenate the tail and head names
                            head.subtree = format!("{}\n{}", tail.subtree, head.subtree);
                            head.range.start_byte = tail.range.start_byte;
                            head.range.start_point = tail.range.start_point;
                            head.size = joined_size;
                        } else {
                            acc.push(tail);
                        }
                    } else {
                        // Push the tail back if next is empty
                        acc.push(tail);
                    }
                }
                acc.append(&mut next);
                Ok(acc)
            })?;

        Ok(chunks)
    }

    fn joined_size(&self, chunk: &Chunk, next: &Chunk, code: &[u8]) -> Result<usize> {
        let joined_bytes = &code[chunk.range.start_byte..next.range.end_byte];
        let joined_text = str::from_utf8(joined_bytes)?;
        self.sizer.size(joined_text)
    }
}

fn format_node(node: &Node, depth: usize) -> String {
    format!(
        "{indent}{branch} {kind:<32} [{start}..{end}]",
        indent = "│  ".repeat(depth.saturating_sub(1)),
        branch = if depth > 0 { "├─" } else { "" },
        kind = node.kind(),
        start = node.start_position().row,
        end = node.end_position().row
    )
}