sqlite_graphrag/
chunking.rs1use crate::constants::{CHUNK_OVERLAP_TOKENS, CHUNK_SIZE_TOKENS, EMBEDDING_DIM};
5
6const CHARS_PER_TOKEN: usize = 2;
10pub const CHUNK_SIZE_CHARS: usize = CHUNK_SIZE_TOKENS * CHARS_PER_TOKEN;
11pub const CHUNK_OVERLAP_CHARS: usize = CHUNK_OVERLAP_TOKENS * CHARS_PER_TOKEN;
12
13#[derive(Debug, Clone)]
14pub struct Chunk {
15 pub start_offset: usize,
16 pub end_offset: usize,
17 pub token_count_approx: usize,
18}
19
20pub fn needs_chunking(body: &str) -> bool {
21 body.len() > CHUNK_SIZE_CHARS
22}
23
24pub fn split_into_chunks(body: &str) -> Vec<Chunk> {
25 if !needs_chunking(body) {
26 return vec![Chunk {
27 token_count_approx: body.chars().count() / CHARS_PER_TOKEN,
28 start_offset: 0,
29 end_offset: body.len(),
30 }];
31 }
32
33 let mut chunks = Vec::new();
34 let mut start = 0usize;
35
36 while start < body.len() {
37 start = next_char_boundary(body, start);
38 let desired_end = previous_char_boundary(body, (start + CHUNK_SIZE_CHARS).min(body.len()));
39 let end = if desired_end < body.len() {
40 find_split_boundary(body, start, desired_end)
41 } else {
42 desired_end
43 };
44
45 let end = if end <= start {
46 let fallback = previous_char_boundary(body, (start + CHUNK_SIZE_CHARS).min(body.len()));
47 if fallback > start {
48 fallback
49 } else {
50 body.len()
51 }
52 } else {
53 end
54 };
55
56 let token_count_approx = body[start..end].chars().count() / CHARS_PER_TOKEN;
57 chunks.push(Chunk {
58 start_offset: start,
59 end_offset: end,
60 token_count_approx,
61 });
62
63 if end >= body.len() {
64 break;
65 }
66
67 let next_start = next_char_boundary(body, end.saturating_sub(CHUNK_OVERLAP_CHARS));
68 start = if next_start >= end { end } else { next_start };
69 }
70
71 chunks
72}
73
74pub fn chunk_text<'a>(body: &'a str, chunk: &Chunk) -> &'a str {
75 &body[chunk.start_offset..chunk.end_offset]
76}
77
78fn find_split_boundary(body: &str, start: usize, desired_end: usize) -> usize {
79 let slice = &body[start..desired_end];
80 if let Some(pos) = slice.rfind("\n\n") {
81 return start + pos + 2;
82 }
83 if let Some(pos) = slice.rfind(". ") {
84 return start + pos + 2;
85 }
86 if let Some(pos) = slice.rfind(' ') {
87 return start + pos + 1;
88 }
89 desired_end
90}
91
92fn previous_char_boundary(body: &str, mut idx: usize) -> usize {
93 idx = idx.min(body.len());
94 while idx > 0 && !body.is_char_boundary(idx) {
95 idx -= 1;
96 }
97 idx
98}
99
100fn next_char_boundary(body: &str, mut idx: usize) -> usize {
101 idx = idx.min(body.len());
102 while idx < body.len() && !body.is_char_boundary(idx) {
103 idx += 1;
104 }
105 idx
106}
107
108pub fn aggregate_embeddings(chunk_embeddings: &[Vec<f32>]) -> Vec<f32> {
109 if chunk_embeddings.is_empty() {
110 return vec![0.0f32; EMBEDDING_DIM];
111 }
112 if chunk_embeddings.len() == 1 {
113 return chunk_embeddings[0].clone();
114 }
115
116 let dim = chunk_embeddings[0].len();
117 let mut mean = vec![0.0f32; dim];
118 for emb in chunk_embeddings {
119 for (i, v) in emb.iter().enumerate() {
120 mean[i] += v;
121 }
122 }
123 let n = chunk_embeddings.len() as f32;
124 for v in &mut mean {
125 *v /= n;
126 }
127
128 let norm: f32 = mean.iter().map(|x| x * x).sum::<f32>().sqrt();
129 if norm > 1e-9 {
130 for v in &mut mean {
131 *v /= norm;
132 }
133 }
134 mean
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn test_short_body_no_chunking() {
143 let body = "short text";
144 assert!(!needs_chunking(body));
145 let chunks = split_into_chunks(body);
146 assert_eq!(chunks.len(), 1);
147 assert_eq!(chunk_text(body, &chunks[0]), body);
148 }
149
150 #[test]
151 fn test_long_body_produces_multiple_chunks() {
152 let body = "word ".repeat(1000);
153 assert!(needs_chunking(&body));
154 let chunks = split_into_chunks(&body);
155 assert!(chunks.len() > 1);
156 assert!(chunks.iter().all(|c| !chunk_text(&body, c).is_empty()));
157 }
158
159 #[test]
160 fn test_multibyte_body_preserves_progress_and_boundaries() {
161 let body = "ação útil ".repeat(1000);
162 let chunks = split_into_chunks(&body);
163 assert!(chunks.len() > 1);
164 for chunk in &chunks {
165 assert!(!chunk_text(&body, chunk).is_empty());
166 assert!(body.is_char_boundary(chunk.start_offset));
167 assert!(body.is_char_boundary(chunk.end_offset));
168 assert!(chunk.end_offset > chunk.start_offset);
169 }
170 for pair in chunks.windows(2) {
171 assert!(pair[1].start_offset >= pair[0].start_offset);
172 assert!(pair[1].end_offset > pair[0].start_offset);
173 }
174 }
175
176 #[test]
177 fn test_aggregate_embeddings_normalizes() {
178 let embs = vec![vec![1.0f32, 0.0], vec![0.0f32, 1.0]];
179 let agg = aggregate_embeddings(&embs);
180 let norm: f32 = agg.iter().map(|x| x * x).sum::<f32>().sqrt();
181 assert!((norm - 1.0).abs() < 1e-5);
182 }
183}