sqlite_graphrag/
tokenizer.rs1use crate::errors::AppError;
13
14const WORDS_TO_TOKENS_NUMERATOR: usize = 3;
20const WORDS_TO_TOKENS_DENOMINATOR: usize = 2;
21
22pub fn count_passage_tokens(text: &str) -> Result<usize, AppError> {
25 Ok(approx_tokens(&format!(
26 "{}{}",
27 crate::constants::PASSAGE_PREFIX,
28 text
29 )))
30}
31
32pub fn passage_token_offsets(text: &str) -> Result<Vec<(usize, usize)>, AppError> {
37 let mut offsets = Vec::new();
38 let mut start = None;
39 for (i, c) in text.char_indices() {
40 if c.is_whitespace() {
41 if let Some(s) = start.take() {
42 if i > s {
43 offsets.push((s, i));
44 }
45 }
46 } else if start.is_none() {
47 start = Some(i);
48 }
49 }
50 if let Some(s) = start {
51 if text.len() > s {
52 offsets.push((s, text.len()));
53 }
54 }
55 Ok(offsets)
56}
57
58pub fn get_model_max_length() -> usize {
63 crate::constants::EMBEDDING_MAX_TOKENS
64}
65
66fn approx_tokens(text: &str) -> usize {
67 let words = text.split_whitespace().count();
68 let num = words.saturating_mul(WORDS_TO_TOKENS_NUMERATOR);
70 let (tokens, rem) = (
71 num / WORDS_TO_TOKENS_DENOMINATOR,
72 num % WORDS_TO_TOKENS_DENOMINATOR,
73 );
74 if rem == 0 {
75 tokens
76 } else {
77 tokens + 1
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 #[test]
86 fn empty_string_has_zero_tokens() {
87 assert_eq!(approx_tokens(""), 0);
88 assert_eq!(approx_tokens(" \n\t "), 0);
89 }
90
91 #[test]
92 fn single_word_rounds_up() {
93 assert_eq!(approx_tokens("hello"), 2);
95 }
96
97 #[test]
98 fn four_words_rounds_to_six() {
99 assert_eq!(approx_tokens("the quick brown fox"), 6);
101 }
102
103 #[test]
104 fn passage_offsets_skip_whitespace() {
105 let offsets = passage_token_offsets("hello world foo").unwrap();
106 assert_eq!(offsets, vec![(0, 5), (6, 11), (12, 15)]);
107 }
108
109 #[test]
110 fn passage_offsets_handle_leading_and_trailing_whitespace() {
111 let offsets = passage_token_offsets(" hello ").unwrap();
112 assert_eq!(offsets, vec![(2, 7)]);
113 }
114
115 #[test]
116 fn count_passage_tokens_matches_approx_tokens() {
117 assert_eq!(count_passage_tokens("rust sqlite graphrag").unwrap(), 6);
118 }
119
120 #[test]
121 fn count_passage_tokens_includes_prefix_for_short_inputs() {
122 assert_eq!(count_passage_tokens("teste fix real 4").unwrap(), 8);
123 }
124
125 #[test]
126 fn count_passage_tokens_matches_embedding_when_text_already_has_prefix() {
127 assert_eq!(
128 count_passage_tokens("passage: teste fix real 5").unwrap(),
129 9
130 );
131 }
132}