Skip to main content

sqlite_graphrag/
chunking.rs

1// src/chunking.rs
2// Token-based chunking for E5 model (512 token limit)
3
4use crate::constants::{CHUNK_OVERLAP_TOKENS, CHUNK_SIZE_TOKENS, EMBEDDING_DIM};
5
6// Heurística conservadora para reduzir o risco de subestimar o número real de tokens
7// em Markdown, código e texto multilíngue. Valor anterior 4 chars/token permitia
8// chunks grandes demais para alguns documentos reais.
9const CHARS_PER_TOKEN: usize = 2;
10pub const CHUNK_SIZE_CHARS: usize = CHUNK_SIZE_TOKENS * CHARS_PER_TOKEN;
11pub const CHUNK_OVERLAP_CHARS: usize = CHUNK_OVERLAP_TOKENS * CHARS_PER_TOKEN;
12
13#[derive(Debug, Clone)]
14pub struct Chunk {
15    pub start_offset: usize,
16    pub end_offset: usize,
17    pub token_count_approx: usize,
18}
19
20pub fn needs_chunking(body: &str) -> bool {
21    body.len() > CHUNK_SIZE_CHARS
22}
23
24pub fn split_into_chunks(body: &str) -> Vec<Chunk> {
25    if !needs_chunking(body) {
26        return vec![Chunk {
27            token_count_approx: body.chars().count() / CHARS_PER_TOKEN,
28            start_offset: 0,
29            end_offset: body.len(),
30        }];
31    }
32
33    let mut chunks = Vec::new();
34    let mut start = 0usize;
35
36    while start < body.len() {
37        start = next_char_boundary(body, start);
38        let desired_end = previous_char_boundary(body, (start + CHUNK_SIZE_CHARS).min(body.len()));
39        let end = if desired_end < body.len() {
40            find_split_boundary(body, start, desired_end)
41        } else {
42            desired_end
43        };
44
45        let end = if end <= start {
46            let fallback = previous_char_boundary(body, (start + CHUNK_SIZE_CHARS).min(body.len()));
47            if fallback > start {
48                fallback
49            } else {
50                body.len()
51            }
52        } else {
53            end
54        };
55
56        let token_count_approx = body[start..end].chars().count() / CHARS_PER_TOKEN;
57        chunks.push(Chunk {
58            start_offset: start,
59            end_offset: end,
60            token_count_approx,
61        });
62
63        if end >= body.len() {
64            break;
65        }
66
67        let next_start = next_char_boundary(body, end.saturating_sub(CHUNK_OVERLAP_CHARS));
68        start = if next_start >= end { end } else { next_start };
69    }
70
71    chunks
72}
73
74pub fn split_into_chunks_by_token_offsets(
75    body: &str,
76    token_offsets: &[(usize, usize)],
77) -> Vec<Chunk> {
78    if token_offsets.len() <= CHUNK_SIZE_TOKENS {
79        return vec![Chunk {
80            token_count_approx: token_offsets.len(),
81            start_offset: 0,
82            end_offset: body.len(),
83        }];
84    }
85
86    let mut chunks = Vec::new();
87    let mut start_token = 0usize;
88
89    while start_token < token_offsets.len() {
90        let end_token = (start_token + CHUNK_SIZE_TOKENS).min(token_offsets.len());
91
92        chunks.push(Chunk {
93            start_offset: if start_token == 0 {
94                0
95            } else {
96                token_offsets[start_token].0
97            },
98            end_offset: if end_token == token_offsets.len() {
99                body.len()
100            } else {
101                token_offsets[end_token - 1].1
102            },
103            token_count_approx: end_token - start_token,
104        });
105
106        if end_token == token_offsets.len() {
107            break;
108        }
109
110        let next_start = end_token.saturating_sub(CHUNK_OVERLAP_TOKENS);
111        start_token = if next_start <= start_token {
112            end_token
113        } else {
114            next_start
115        };
116    }
117
118    chunks
119}
120
121pub fn chunk_text<'a>(body: &'a str, chunk: &Chunk) -> &'a str {
122    &body[chunk.start_offset..chunk.end_offset]
123}
124
125fn find_split_boundary(body: &str, start: usize, desired_end: usize) -> usize {
126    let slice = &body[start..desired_end];
127    if let Some(pos) = slice.rfind("\n\n") {
128        return start + pos + 2;
129    }
130    if let Some(pos) = slice.rfind(". ") {
131        return start + pos + 2;
132    }
133    if let Some(pos) = slice.rfind(' ') {
134        return start + pos + 1;
135    }
136    desired_end
137}
138
139fn previous_char_boundary(body: &str, mut idx: usize) -> usize {
140    idx = idx.min(body.len());
141    while idx > 0 && !body.is_char_boundary(idx) {
142        idx -= 1;
143    }
144    idx
145}
146
147fn next_char_boundary(body: &str, mut idx: usize) -> usize {
148    idx = idx.min(body.len());
149    while idx < body.len() && !body.is_char_boundary(idx) {
150        idx += 1;
151    }
152    idx
153}
154
155pub fn aggregate_embeddings(chunk_embeddings: &[Vec<f32>]) -> Vec<f32> {
156    if chunk_embeddings.is_empty() {
157        return vec![0.0f32; EMBEDDING_DIM];
158    }
159    if chunk_embeddings.len() == 1 {
160        return chunk_embeddings[0].clone();
161    }
162
163    let dim = chunk_embeddings[0].len();
164    let mut mean = vec![0.0f32; dim];
165    for emb in chunk_embeddings {
166        for (i, v) in emb.iter().enumerate() {
167            mean[i] += v;
168        }
169    }
170    let n = chunk_embeddings.len() as f32;
171    for v in &mut mean {
172        *v /= n;
173    }
174
175    let norm: f32 = mean.iter().map(|x| x * x).sum::<f32>().sqrt();
176    if norm > 1e-9 {
177        for v in &mut mean {
178            *v /= norm;
179        }
180    }
181    mean
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn test_short_body_no_chunking() {
190        let body = "short text";
191        assert!(!needs_chunking(body));
192        let chunks = split_into_chunks(body);
193        assert_eq!(chunks.len(), 1);
194        assert_eq!(chunk_text(body, &chunks[0]), body);
195    }
196
197    #[test]
198    fn test_long_body_produces_multiple_chunks() {
199        let body = "word ".repeat(1000);
200        assert!(needs_chunking(&body));
201        let chunks = split_into_chunks(&body);
202        assert!(chunks.len() > 1);
203        assert!(chunks.iter().all(|c| !chunk_text(&body, c).is_empty()));
204    }
205
206    #[test]
207    fn split_by_token_offsets_respeita_limite_e_overlap() {
208        let body = "ab".repeat(460);
209        let offsets: Vec<(usize, usize)> = (0..460)
210            .map(|i| {
211                let start = i * 2;
212                (start, start + 2)
213            })
214            .collect();
215
216        let chunks = split_into_chunks_by_token_offsets(&body, &offsets);
217        assert_eq!(chunks.len(), 2);
218        assert_eq!(chunks[0].token_count_approx, CHUNK_SIZE_TOKENS);
219        assert_eq!(chunks[1].token_count_approx, 110);
220        assert_eq!(chunks[0].start_offset, 0);
221        assert_eq!(
222            chunks[1].start_offset,
223            offsets[CHUNK_SIZE_TOKENS - CHUNK_OVERLAP_TOKENS].0
224        );
225    }
226
227    #[test]
228    fn split_by_token_offsets_retorna_um_chunk_quando_cabe() {
229        let body = "texto curto";
230        let offsets = vec![(0, 5), (6, 11)];
231        let chunks = split_into_chunks_by_token_offsets(body, &offsets);
232        assert_eq!(chunks.len(), 1);
233        assert_eq!(chunks[0].start_offset, 0);
234        assert_eq!(chunks[0].end_offset, body.len());
235        assert_eq!(chunks[0].token_count_approx, 2);
236    }
237
238    #[test]
239    fn test_multibyte_body_preserves_progress_and_boundaries() {
240        let body = "ação útil ".repeat(1000);
241        let chunks = split_into_chunks(&body);
242        assert!(chunks.len() > 1);
243        for chunk in &chunks {
244            assert!(!chunk_text(&body, chunk).is_empty());
245            assert!(body.is_char_boundary(chunk.start_offset));
246            assert!(body.is_char_boundary(chunk.end_offset));
247            assert!(chunk.end_offset > chunk.start_offset);
248        }
249        for pair in chunks.windows(2) {
250            assert!(pair[1].start_offset >= pair[0].start_offset);
251            assert!(pair[1].end_offset > pair[0].start_offset);
252        }
253    }
254
255    #[test]
256    fn test_aggregate_embeddings_normalizes() {
257        let embs = vec![vec![1.0f32, 0.0], vec![0.0f32, 1.0]];
258        let agg = aggregate_embeddings(&embs);
259        let norm: f32 = agg.iter().map(|x| x * x).sum::<f32>().sqrt();
260        assert!((norm - 1.0).abs() < 1e-5);
261    }
262}