abu_rag/document/
chunk.rs1use std::convert::Infallible;
2use super::{Chunk, Document, Id};
3
4pub trait DocumentChunk {
5 type Error: std::error::Error + 'static + Send + Sync;
6 fn chunk(&self, doc: &Document) -> Result<Vec<Chunk>, Self::Error>;
7}
8
9pub struct FixedChunker {
10 pub chunk_size: usize,
11 pub overlap: usize,
12}
13
14impl FixedChunker {
15 pub fn new(chunk_size: usize, overlap: usize) -> Self {
16 Self { chunk_size, overlap }
17 }
18
19 pub fn without_overlap(chunk_size: usize) -> Self {
20 Self::new(chunk_size, 0)
21 }
22}
23
24impl DocumentChunk for FixedChunker {
25 type Error = Infallible;
26 fn chunk(&self, doc: &Document) -> Result<Vec<Chunk>, Self::Error> {
27 let mut chunks = vec![];
28 let mut start = 0;
29 let chars: Vec<char> = doc.text.chars().collect();
30
31 while start < chars.len() {
32 let end = (start + self.chunk_size).min(chars.len());
33 let chunk_text: String = chars[start..end].iter().collect();
34
35 let chunk = Chunk {
36 id: Id::uuid(),
37 document_id: doc.id.clone(),
38 text: chunk_text,
39 start, end
40 };
41 chunks.push(chunk);
42
43 if end == chars.len() {
44 break;
45 }
46 start += self.chunk_size - self.overlap;
47 }
48
49 Ok(chunks)
50 }
51}
52
53pub struct ParagraphChunker;
54
55impl DocumentChunk for ParagraphChunker {
56 type Error = Infallible;
57 fn chunk(&self, doc: &Document) -> Result<Vec<Chunk>, Self::Error> {
58 let mut chunks = vec![];
59 let mut current_offset = 0;
60
61 for (i, para) in doc.text.split("\n\n").enumerate() {
62 let start = current_offset;
63 let end = start + para.len();
64
65 let chunk = Chunk {
66 id: Id::new(format!("{}_{}", doc.id, i)),
67 document_id: doc.id.clone(),
68 text: para.to_string(),
69 start, end
70 };
71 chunks.push(chunk);
72
73 current_offset = end + 2;
74 }
75
76 Ok(chunks)
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use std::path::PathBuf;
83
84 use super::*;
85
86 #[test]
87 fn test_id_generation() {
88 let id1 = Id::uuid();
89 let id2 = Id::uuid();
90 assert_ne!(id1, id2, "两次生成的 UUID 不应该相同");
91 assert!(!id1.to_string().contains("00000000-0000"), "不应该是全 0 的 UUID");
92 }
93
94 #[test]
95 fn test_fixed_chunker_basic() {
96 let doc = Document {
97 id: Id::new("doc_1"),
98 source: PathBuf::from("dummy.txt"),
99 text: "1234567890".to_string(), };
101
102 let chunker = FixedChunker::new(4, 2);
103 let chunks = chunker.chunk(&doc).unwrap();
104
105 assert_eq!(chunks.len(), 4);
106 assert_eq!(chunks[0].text, "1234");
107 assert_eq!(chunks[1].text, "3456");
108 assert_eq!(chunks[3].text, "7890");
109 }
110
111 #[test]
112 fn test_fixed_chunker_utf8_chinese() {
113 let doc = Document {
114 id: Id::new("doc_2"),
115 source: PathBuf::from("dummy.txt"),
116 text: "你好世界,这是一个测试".to_string(), };
118
119 let chunker = FixedChunker::new(5, 2);
120 let chunks = chunker.chunk(&doc).unwrap();
121
122 assert_eq!(chunks.len(), 3);
123 assert_eq!(chunks[0].text, "你好世界,");
124 assert_eq!(chunks[1].text, "界,这是一");
125 assert_eq!(chunks[2].text, "是一个测试");
126 }
127
128 #[test]
132 fn test_paragraph_chunker() {
133 let text = "段落一\n\n段落二的内容\n\n段落三";
134 let doc = Document {
135 id: Id::new("doc_3"),
136 source: PathBuf::from("dummy.txt"),
137 text: text.to_string(),
138 };
139
140 let chunker = ParagraphChunker;
141 let chunks = chunker.chunk(&doc).unwrap();
142
143 assert_eq!(chunks.len(), 3);
144 assert_eq!(chunks[0].text, "段落一");
145 assert_eq!(chunks[1].text, "段落二的内容");
146 assert_eq!(chunks[2].text, "段落三");
147
148 assert_eq!(chunks[0].start, 0);
149 assert_eq!(chunks[0].end, 9);
150 assert_eq!(chunks[1].start, 11);
151 assert_eq!(chunks[1].end, 11 + "段落二的内容".len());
152 }
153}