Skip to main content

abu_rag/document/
chunk.rs

1use std::convert::Infallible;
2use super::{Chunk, Document, Id};
3
4pub trait DocumentChunk {
5    type Error: std::error::Error + 'static + Send + Sync;
6    fn chunk(&self, doc: &Document) -> Result<Vec<Chunk>, Self::Error>;
7}
8
9pub struct FixedChunker {
10    pub chunk_size: usize,
11    pub overlap: usize,
12}
13
14impl FixedChunker {
15    pub fn new(chunk_size: usize, overlap: usize) -> Self {
16        Self { chunk_size, overlap }
17    }
18
19    pub fn without_overlap(chunk_size: usize) -> Self {
20        Self::new(chunk_size, 0)
21    }
22}
23
24impl DocumentChunk for FixedChunker {
25    type Error = Infallible;
26    fn chunk(&self, doc: &Document) -> Result<Vec<Chunk>, Self::Error> {
27        let mut chunks = vec![];
28        let mut start = 0;
29        let chars: Vec<char> = doc.text.chars().collect();
30
31        while start < chars.len() {
32            let end = (start + self.chunk_size).min(chars.len());
33            let chunk_text: String = chars[start..end].iter().collect();
34
35            let chunk = Chunk {
36                id: Id::uuid(),
37                document_id: doc.id.clone(),
38                text: chunk_text,
39                start, end 
40            };
41            chunks.push(chunk);
42
43            if end == chars.len() {
44                break;
45            }
46            start += self.chunk_size - self.overlap;
47        }
48
49        Ok(chunks)
50    }
51}
52
53pub struct ParagraphChunker;
54
55impl DocumentChunk for ParagraphChunker {
56    type Error = Infallible;
57    fn chunk(&self, doc: &Document) -> Result<Vec<Chunk>, Self::Error> {
58        let mut chunks = vec![];
59        let mut current_offset = 0;
60
61        for (i, para) in doc.text.split("\n\n").enumerate() {
62            let start = current_offset;
63            let end = start + para.len();
64
65            let chunk = Chunk {
66                id: Id::new(format!("{}_{}", doc.id, i)),
67                document_id: doc.id.clone(),
68                text: para.to_string(),
69                start, end 
70            };
71            chunks.push(chunk);
72        
73            current_offset = end + 2; 
74        }
75
76        Ok(chunks)
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use std::path::PathBuf;
83
84    use super::*;
85
86    #[test]
87    fn test_id_generation() {
88        let id1 = Id::uuid();
89        let id2 = Id::uuid();
90        assert_ne!(id1, id2, "两次生成的 UUID 不应该相同");
91        assert!(!id1.to_string().contains("00000000-0000"), "不应该是全 0 的 UUID");
92    }
93
94    #[test]
95    fn test_fixed_chunker_basic() {
96        let doc = Document {
97            id: Id::new("doc_1"),
98            source: PathBuf::from("dummy.txt"),
99            text: "1234567890".to_string(), // 10个字符
100        };
101        
102        let chunker = FixedChunker::new(4, 2);
103        let chunks = chunker.chunk(&doc).unwrap();
104
105        assert_eq!(chunks.len(), 4);
106        assert_eq!(chunks[0].text, "1234");
107        assert_eq!(chunks[1].text, "3456");
108        assert_eq!(chunks[3].text, "7890");
109    }
110
111    #[test]
112    fn test_fixed_chunker_utf8_chinese() {
113        let doc = Document {
114            id: Id::new("doc_2"),
115            source: PathBuf::from("dummy.txt"),
116            text: "你好世界,这是一个测试".to_string(), // 11 个字符
117        };
118
119        let chunker = FixedChunker::new(5, 2);
120        let chunks = chunker.chunk(&doc).unwrap();
121
122        assert_eq!(chunks.len(), 3);
123        assert_eq!(chunks[0].text, "你好世界,");
124        assert_eq!(chunks[1].text, "界,这是一");
125        assert_eq!(chunks[2].text, "是一个测试");
126    }
127
128    // ---------------------------------------------------------
129    // 3. 测试 ParagraphChunker (按段落切分)
130    // ---------------------------------------------------------
131    #[test]
132    fn test_paragraph_chunker() {
133        let text = "段落一\n\n段落二的内容\n\n段落三";
134        let doc = Document {
135            id: Id::new("doc_3"),
136            source: PathBuf::from("dummy.txt"),
137            text: text.to_string(),
138        };
139
140        let chunker = ParagraphChunker;
141        let chunks = chunker.chunk(&doc).unwrap();
142
143        assert_eq!(chunks.len(), 3);
144        assert_eq!(chunks[0].text, "段落一");
145        assert_eq!(chunks[1].text, "段落二的内容");
146        assert_eq!(chunks[2].text, "段落三");
147
148        assert_eq!(chunks[0].start, 0);
149        assert_eq!(chunks[0].end, 9);
150        assert_eq!(chunks[1].start, 11);
151        assert_eq!(chunks[1].end, 11 + "段落二的内容".len());
152    }
153}