context_creator/core/
token.rs1use anyhow::Result;
4use rayon::prelude::*;
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex};
7use tiktoken_rs::{cl100k_base, CoreBPE};
8
9pub struct TokenCounter {
11 encoder: Arc<CoreBPE>,
13 cache: Arc<Mutex<HashMap<u64, usize>>>,
15}
16
17impl TokenCounter {
18 pub fn new() -> Result<Self> {
20 let encoder = cl100k_base()?;
21 Ok(TokenCounter {
22 encoder: Arc::new(encoder),
23 cache: Arc::new(Mutex::new(HashMap::new())),
24 })
25 }
26
27 pub fn count_tokens(&self, text: &str) -> Result<usize> {
29 let hash = calculate_hash(text);
31
32 if let Ok(cache) = self.cache.lock() {
34 if let Some(&count) = cache.get(&hash) {
35 return Ok(count);
36 }
37 }
38
39 let tokens = self.encoder.encode_with_special_tokens(text);
41 let count = tokens.len();
42
43 if let Ok(mut cache) = self.cache.lock() {
45 cache.insert(hash, count);
46 }
47
48 Ok(count)
49 }
50
51 pub fn count_tokens_parallel(&self, texts: &[String]) -> Result<Vec<usize>> {
53 texts
54 .par_iter()
55 .map(|text| self.count_tokens(text))
56 .collect()
57 }
58
59 pub fn count_file_tokens(&self, content: &str, path: &str) -> Result<FileTokenCount> {
61 let content_tokens = self.count_tokens(content)?;
62
63 let header = format!("## {path}\n\n```\n");
65 let footer = "\n```\n\n";
66 let header_tokens = self.count_tokens(&header)?;
67 let footer_tokens = self.count_tokens(footer)?;
68
69 Ok(FileTokenCount {
70 content_tokens,
71 overhead_tokens: header_tokens + footer_tokens,
72 total_tokens: content_tokens + header_tokens + footer_tokens,
73 })
74 }
75
76 pub fn estimate_total_tokens(&self, files: &[(String, String)]) -> Result<TotalTokenEstimate> {
78 let mut total_content = 0;
79 let mut total_overhead = 0;
80 let mut file_counts = Vec::new();
81
82 for (path, content) in files {
83 let count = self.count_file_tokens(content, path)?;
84 total_content += count.content_tokens;
85 total_overhead += count.overhead_tokens;
86 file_counts.push((path.clone(), count));
87 }
88
89 Ok(TotalTokenEstimate {
90 total_tokens: total_content + total_overhead,
91 content_tokens: total_content,
92 overhead_tokens: total_overhead,
93 file_counts,
94 })
95 }
96}
97
98impl Default for TokenCounter {
99 fn default() -> Self {
100 Self::new().expect("Failed to create token counter")
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct FileTokenCount {
107 pub content_tokens: usize,
109 pub overhead_tokens: usize,
111 pub total_tokens: usize,
113}
114
115#[derive(Debug)]
117pub struct TotalTokenEstimate {
118 pub total_tokens: usize,
120 pub content_tokens: usize,
122 pub overhead_tokens: usize,
124 pub file_counts: Vec<(String, FileTokenCount)>,
126}
127
128fn calculate_hash(text: &str) -> u64 {
130 use std::collections::hash_map::DefaultHasher;
131 use std::hash::{Hash, Hasher};
132
133 let mut hasher = DefaultHasher::new();
134 text.hash(&mut hasher);
135 hasher.finish()
136}
137
138pub fn would_exceed_limit(current_tokens: usize, file_tokens: usize, max_tokens: usize) -> bool {
140 current_tokens + file_tokens > max_tokens
141}
142
143pub fn remaining_tokens(current_tokens: usize, max_tokens: usize) -> usize {
145 max_tokens.saturating_sub(current_tokens)
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn test_token_counting() {
154 let counter = TokenCounter::new().unwrap();
155
156 let count = counter.count_tokens("Hello, world!").unwrap();
158 assert!(count > 0);
159
160 let count = counter.count_tokens("").unwrap();
162 assert_eq!(count, 0);
163
164 let text = "This is a test text for caching";
166 let count1 = counter.count_tokens(text).unwrap();
167 let count2 = counter.count_tokens(text).unwrap();
168 assert_eq!(count1, count2);
169 }
170
171 #[test]
172 fn test_file_token_counting() {
173 let counter = TokenCounter::new().unwrap();
174
175 let content = "fn main() {\n println!(\"Hello, world!\");\n}";
176 let path = "src/main.rs";
177
178 let count = counter.count_file_tokens(content, path).unwrap();
179 assert!(count.content_tokens > 0);
180 assert!(count.overhead_tokens > 0);
181 assert_eq!(
182 count.total_tokens,
183 count.content_tokens + count.overhead_tokens
184 );
185 }
186
187 #[test]
188 fn test_parallel_counting() {
189 let counter = TokenCounter::new().unwrap();
190
191 let texts = vec![
192 "First text".to_string(),
193 "Second text".to_string(),
194 "Third text".to_string(),
195 ];
196
197 let counts = counter.count_tokens_parallel(&texts).unwrap();
198 assert_eq!(counts.len(), 3);
199 assert!(counts.iter().all(|&c| c > 0));
200 }
201
202 #[test]
203 fn test_token_limit_checks() {
204 assert!(would_exceed_limit(900, 200, 1000));
205 assert!(!would_exceed_limit(800, 200, 1000));
206
207 assert_eq!(remaining_tokens(300, 1000), 700);
208 assert_eq!(remaining_tokens(1100, 1000), 0);
209 }
210
211 #[test]
212 fn test_total_estimation() {
213 let counter = TokenCounter::new().unwrap();
214
215 let files = vec![
216 ("file1.rs".to_string(), "content1".to_string()),
217 ("file2.rs".to_string(), "content2".to_string()),
218 ];
219
220 let estimate = counter.estimate_total_tokens(&files).unwrap();
221 assert!(estimate.total_tokens > 0);
222 assert_eq!(estimate.file_counts.len(), 2);
223 }
224}