Skip to main content

sh_layer3/
text_splitters.rs

1//! # Text Splitters
2//!
3//! 文本分割器:将长文本分割为小块。
4
5use crate::retriever_engine::{Chunk, ChunkPosition, ChunkingStrategy, Document};
6
7/// 递归字符文本分割器
8///
9/// 按分隔符层级分割文本。
10pub struct RecursiveCharacterTextSplitter {
11    /// 分块大小
12    chunk_size: usize,
13    /// 重叠大小
14    chunk_overlap: usize,
15    /// 分隔符优先级(按顺序尝试)
16    separators: Vec<String>,
17    /// 是否保持分隔符
18    keep_separator: bool,
19}
20
21impl RecursiveCharacterTextSplitter {
22    pub fn new(chunk_size: usize, chunk_overlap: usize) -> Self {
23        Self {
24            chunk_size,
25            chunk_overlap,
26            separators: vec![
27                "\n\n".to_string(), // 段落
28                "\n".to_string(),   // 行
29                " ".to_string(),    // 词
30                "".to_string(),     // 字符
31            ],
32            keep_separator: true,
33        }
34    }
35
36    pub fn with_separators(mut self, separators: Vec<String>) -> Self {
37        self.separators = separators;
38        self
39    }
40}
41
42impl Default for RecursiveCharacterTextSplitter {
43    fn default() -> Self {
44        Self::new(1000, 200)
45    }
46}
47
48impl ChunkingStrategy for RecursiveCharacterTextSplitter {
49    fn chunk(&self, document: &Document) -> Vec<Chunk> {
50        let content = &document.content;
51        self.split_text(content, document)
52    }
53}
54
55impl RecursiveCharacterTextSplitter {
56    fn split_text(&self, text: &str, document: &Document) -> Vec<Chunk> {
57        if text.len() <= self.chunk_size {
58            return vec![self.create_chunk(text, 0, 1, document)];
59        }
60
61        // 尝试按分隔符分割
62        for separator in &self.separators {
63            if separator.is_empty() {
64                // 按字符分割
65                return self.split_by_characters(text, document);
66            }
67
68            if text.contains(separator) {
69                return self.split_by_separator(text, separator, document);
70            }
71        }
72
73        self.split_by_characters(text, document)
74    }
75
76    fn split_by_separator(&self, text: &str, separator: &str, document: &Document) -> Vec<Chunk> {
77        let parts: Vec<&str> = text.split(separator).collect();
78        let mut chunks = Vec::new();
79        let mut current_chunk = String::new();
80        let mut start = 0;
81        let mut index = 0;
82
83        for part in parts {
84            let part_len = part.len();
85            let sep_len = if self.keep_separator {
86                separator.len()
87            } else {
88                0
89            };
90
91            if current_chunk.len() + part_len + sep_len > self.chunk_size
92                && !current_chunk.is_empty()
93            {
94                chunks.push(self.create_chunk(&current_chunk, start, index, document));
95                start += current_chunk.len().saturating_sub(self.chunk_overlap);
96                current_chunk = String::new();
97                index += 1;
98            }
99
100            current_chunk.push_str(part);
101            if self.keep_separator && !current_chunk.is_empty() {
102                current_chunk.push_str(separator);
103            }
104        }
105
106        if !current_chunk.is_empty() {
107            chunks.push(self.create_chunk(&current_chunk, start, index, document));
108        }
109
110        let total = chunks.len();
111        for chunk in &mut chunks {
112            chunk.position.total = total;
113        }
114
115        chunks
116    }
117
118    fn split_by_characters(&self, text: &str, document: &Document) -> Vec<Chunk> {
119        let mut chunks = Vec::new();
120        let mut start = 0;
121        let mut index = 0;
122
123        while start < text.len() {
124            let end = (start + self.chunk_size).min(text.len());
125            chunks.push(self.create_chunk(&text[start..end], start, index, document));
126            start = end.saturating_sub(self.chunk_overlap);
127            index += 1;
128        }
129
130        let total = chunks.len();
131        for chunk in &mut chunks {
132            chunk.position.total = total;
133        }
134
135        chunks
136    }
137
138    fn create_chunk(
139        &self,
140        content: &str,
141        start: usize,
142        index: usize,
143        document: &Document,
144    ) -> Chunk {
145        Chunk {
146            id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
147            doc_id: document.id.clone().unwrap_or_default(),
148            content: content.to_string(),
149            position: ChunkPosition {
150                start,
151                end: start + content.len(),
152                index,
153                total: 0, // 将在最后更新
154            },
155            metadata: document.metadata.clone(),
156        }
157    }
158}
159
160/// Markdown 文本分割器
161#[allow(dead_code)]
162pub struct MarkdownTextSplitter {
163    chunk_size: usize,
164    #[allow(dead_code)]
165    chunk_overlap: usize,
166}
167
168impl MarkdownTextSplitter {
169    pub fn new(chunk_size: usize, chunk_overlap: usize) -> Self {
170        Self {
171            chunk_size,
172            chunk_overlap,
173        }
174    }
175}
176
177impl Default for MarkdownTextSplitter {
178    fn default() -> Self {
179        Self::new(1000, 200)
180    }
181}
182
183impl ChunkingStrategy for MarkdownTextSplitter {
184    fn chunk(&self, document: &Document) -> Vec<Chunk> {
185        // 按 Markdown 标题分割
186        let content = &document.content;
187        let mut chunks = Vec::new();
188        let mut start = 0;
189        let mut index = 0;
190
191        // 按 ## 标题分割
192        let lines: Vec<&str> = content.lines().collect();
193        let mut current_chunk = String::new();
194        let mut chunk_start = 0;
195
196        for line in lines {
197            if line.starts_with("#") && current_chunk.len() > self.chunk_size / 2 {
198                // 新标题,保存当前块
199                if !current_chunk.trim().is_empty() {
200                    chunks.push(Chunk {
201                        id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
202                        doc_id: document.id.clone().unwrap_or_default(),
203                        content: current_chunk.trim().to_string(),
204                        position: ChunkPosition {
205                            start: chunk_start,
206                            end: chunk_start + current_chunk.len(),
207                            index,
208                            total: 0,
209                        },
210                        metadata: document.metadata.clone(),
211                    });
212                    index += 1;
213                }
214                current_chunk = String::new();
215                chunk_start = start;
216            }
217            current_chunk.push_str(line);
218            current_chunk.push('\n');
219            start += line.len() + 1;
220        }
221
222        if !current_chunk.trim().is_empty() {
223            chunks.push(Chunk {
224                id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
225                doc_id: document.id.clone().unwrap_or_default(),
226                content: current_chunk.trim().to_string(),
227                position: ChunkPosition {
228                    start: chunk_start,
229                    end: start,
230                    index,
231                    total: 0,
232                },
233                metadata: document.metadata.clone(),
234            });
235        }
236
237        let total = chunks.len();
238        for chunk in &mut chunks {
239            chunk.position.total = total;
240        }
241
242        chunks
243    }
244}
245
246/// 代码文本分割器
247#[allow(dead_code)]
248pub struct CodeTextSplitter {
249    chunk_size: usize,
250    #[allow(dead_code)]
251    chunk_overlap: usize,
252    #[allow(dead_code)]
253    language: String,
254}
255
256impl CodeTextSplitter {
257    pub fn new(chunk_size: usize, chunk_overlap: usize, language: impl Into<String>) -> Self {
258        Self {
259            chunk_size,
260            chunk_overlap,
261            language: language.into(),
262        }
263    }
264}
265
266impl ChunkingStrategy for CodeTextSplitter {
267    fn chunk(&self, document: &Document) -> Vec<Chunk> {
268        // 按 函数/类 分割
269        let content = &document.content;
270        let mut chunks = Vec::new();
271
272        // 简化实现:按函数定义分割
273        let lines: Vec<&str> = content.lines().collect();
274        let mut current_chunk = String::new();
275        let mut start = 0;
276        let mut index = 0;
277        let mut chunk_start = 0;
278
279        for line in lines {
280            // 检测函数/类定义
281            let is_definition = line.trim().starts_with("fn ")
282                || line.trim().starts_with("pub fn ")
283                || line.trim().starts_with("async fn ")
284                || line.trim().starts_with("class ")
285                || line.trim().starts_with("def ")
286                || line.trim().starts_with("public ")
287                || line.trim().starts_with("function ");
288
289            if is_definition
290                && current_chunk.len() > self.chunk_size / 2
291                && !current_chunk.trim().is_empty()
292            {
293                chunks.push(Chunk {
294                    id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
295                    doc_id: document.id.clone().unwrap_or_default(),
296                    content: current_chunk.trim().to_string(),
297                    position: ChunkPosition {
298                        start: chunk_start,
299                        end: chunk_start + current_chunk.len(),
300                        index,
301                        total: 0,
302                    },
303                    metadata: document.metadata.clone(),
304                });
305                index += 1;
306                current_chunk = String::new();
307                chunk_start = start;
308            }
309            current_chunk.push_str(line);
310            current_chunk.push('\n');
311            start += line.len() + 1;
312        }
313
314        if !current_chunk.trim().is_empty() {
315            chunks.push(Chunk {
316                id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
317                doc_id: document.id.clone().unwrap_or_default(),
318                content: current_chunk.trim().to_string(),
319                position: ChunkPosition {
320                    start: chunk_start,
321                    end: start,
322                    index,
323                    total: 0,
324                },
325                metadata: document.metadata.clone(),
326            });
327        }
328
329        let total = chunks.len();
330        for chunk in &mut chunks {
331            chunk.position.total = total;
332        }
333
334        chunks
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn test_recursive_splitter_default() {
344        let splitter = RecursiveCharacterTextSplitter::default();
345        assert_eq!(splitter.chunk_size, 1000);
346        assert_eq!(splitter.chunk_overlap, 200);
347    }
348
349    #[test]
350    fn test_markdown_splitter() {
351        let splitter = MarkdownTextSplitter::default();
352        let doc = Document::new("# Title\n\nContent\n\n## Section\n\nMore content");
353        let chunks = splitter.chunk(&doc);
354        assert!(!chunks.is_empty());
355    }
356
357    #[test]
358    fn test_code_splitter() {
359        let splitter = CodeTextSplitter::new(500, 100, "rust");
360        let doc = Document::new("fn foo() {}\n\nfn bar() {}");
361        let chunks = splitter.chunk(&doc);
362        assert!(!chunks.is_empty());
363    }
364}