Skip to main content

sqlite_graphrag/
chunking.rs

1// src/chunking.rs
2// Token-based chunking for E5 model (512 token limit)
3
4use crate::constants::{CHUNK_OVERLAP_TOKENS, CHUNK_SIZE_TOKENS, EMBEDDING_DIM};
5
6const CHARS_PER_TOKEN: usize = 4;
7pub const CHUNK_SIZE_CHARS: usize = CHUNK_SIZE_TOKENS * CHARS_PER_TOKEN;
8pub const CHUNK_OVERLAP_CHARS: usize = CHUNK_OVERLAP_TOKENS * CHARS_PER_TOKEN;
9
10#[derive(Debug, Clone)]
11pub struct Chunk {
12    pub text: String,
13    pub start_offset: usize,
14    pub end_offset: usize,
15    pub token_count_approx: usize,
16}
17
18pub fn needs_chunking(body: &str) -> bool {
19    body.len() > CHUNK_SIZE_CHARS
20}
21
22pub fn split_into_chunks(body: &str) -> Vec<Chunk> {
23    if !needs_chunking(body) {
24        return vec![Chunk {
25            token_count_approx: body.len() / CHARS_PER_TOKEN,
26            text: body.to_string(),
27            start_offset: 0,
28            end_offset: body.len(),
29        }];
30    }
31
32    let mut chunks = Vec::new();
33    let mut start = 0usize;
34
35    while start < body.len() {
36        let desired_end = (start + CHUNK_SIZE_CHARS).min(body.len());
37        let end = if desired_end < body.len() {
38            find_split_boundary(body, start, desired_end)
39        } else {
40            desired_end
41        };
42
43        let text = body[start..end].to_string();
44        let token_count_approx = text.len() / CHARS_PER_TOKEN;
45        chunks.push(Chunk {
46            text,
47            start_offset: start,
48            end_offset: end,
49            token_count_approx,
50        });
51
52        if end >= body.len() {
53            break;
54        }
55        start = end.saturating_sub(CHUNK_OVERLAP_CHARS);
56    }
57
58    chunks
59}
60
61fn find_split_boundary(body: &str, start: usize, desired_end: usize) -> usize {
62    let slice = &body[start..desired_end];
63    if let Some(pos) = slice.rfind("\n\n") {
64        return start + pos + 2;
65    }
66    if let Some(pos) = slice.rfind(". ") {
67        return start + pos + 2;
68    }
69    if let Some(pos) = slice.rfind(' ') {
70        return start + pos + 1;
71    }
72    desired_end
73}
74
75pub fn aggregate_embeddings(chunk_embeddings: &[Vec<f32>]) -> Vec<f32> {
76    if chunk_embeddings.is_empty() {
77        return vec![0.0f32; EMBEDDING_DIM];
78    }
79    if chunk_embeddings.len() == 1 {
80        return chunk_embeddings[0].clone();
81    }
82
83    let dim = chunk_embeddings[0].len();
84    let mut mean = vec![0.0f32; dim];
85    for emb in chunk_embeddings {
86        for (i, v) in emb.iter().enumerate() {
87            mean[i] += v;
88        }
89    }
90    let n = chunk_embeddings.len() as f32;
91    for v in &mut mean {
92        *v /= n;
93    }
94
95    let norm: f32 = mean.iter().map(|x| x * x).sum::<f32>().sqrt();
96    if norm > 1e-9 {
97        for v in &mut mean {
98            *v /= norm;
99        }
100    }
101    mean
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[test]
109    fn test_short_body_no_chunking() {
110        let body = "short text";
111        assert!(!needs_chunking(body));
112        let chunks = split_into_chunks(body);
113        assert_eq!(chunks.len(), 1);
114        assert_eq!(chunks[0].text, body);
115    }
116
117    #[test]
118    fn test_long_body_produces_multiple_chunks() {
119        let body = "word ".repeat(1000);
120        assert!(needs_chunking(&body));
121        let chunks = split_into_chunks(&body);
122        assert!(chunks.len() > 1);
123    }
124
125    #[test]
126    fn test_aggregate_embeddings_normalizes() {
127        let embs = vec![vec![1.0f32, 0.0], vec![0.0f32, 1.0]];
128        let agg = aggregate_embeddings(&embs);
129        let norm: f32 = agg.iter().map(|x| x * x).sum::<f32>().sqrt();
130        assert!((norm - 1.0).abs() < 1e-5);
131    }
132}