code_digest/core/
token.rs1use anyhow::Result;
4use rayon::prelude::*;
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex};
7use tiktoken_rs::{cl100k_base, CoreBPE};
8
9pub struct TokenCounter {
11 encoder: Arc<CoreBPE>,
13 cache: Arc<Mutex<HashMap<u64, usize>>>,
15}
16
17impl TokenCounter {
18 pub fn new() -> Result<Self> {
20 let encoder = cl100k_base()?;
21 Ok(TokenCounter { encoder: Arc::new(encoder), cache: Arc::new(Mutex::new(HashMap::new())) })
22 }
23
24 pub fn count_tokens(&self, text: &str) -> Result<usize> {
26 let hash = calculate_hash(text);
28
29 if let Ok(cache) = self.cache.lock() {
31 if let Some(&count) = cache.get(&hash) {
32 return Ok(count);
33 }
34 }
35
36 let tokens = self.encoder.encode_with_special_tokens(text);
38 let count = tokens.len();
39
40 if let Ok(mut cache) = self.cache.lock() {
42 cache.insert(hash, count);
43 }
44
45 Ok(count)
46 }
47
48 pub fn count_tokens_parallel(&self, texts: &[String]) -> Result<Vec<usize>> {
50 texts.par_iter().map(|text| self.count_tokens(text)).collect()
51 }
52
53 pub fn count_file_tokens(&self, content: &str, path: &str) -> Result<FileTokenCount> {
55 let content_tokens = self.count_tokens(content)?;
56
57 let header = format!("## {path}\n\n```\n");
59 let footer = "\n```\n\n";
60 let header_tokens = self.count_tokens(&header)?;
61 let footer_tokens = self.count_tokens(footer)?;
62
63 Ok(FileTokenCount {
64 content_tokens,
65 overhead_tokens: header_tokens + footer_tokens,
66 total_tokens: content_tokens + header_tokens + footer_tokens,
67 })
68 }
69
70 pub fn estimate_total_tokens(&self, files: &[(String, String)]) -> Result<TotalTokenEstimate> {
72 let mut total_content = 0;
73 let mut total_overhead = 0;
74 let mut file_counts = Vec::new();
75
76 for (path, content) in files {
77 let count = self.count_file_tokens(content, path)?;
78 total_content += count.content_tokens;
79 total_overhead += count.overhead_tokens;
80 file_counts.push((path.clone(), count));
81 }
82
83 Ok(TotalTokenEstimate {
84 total_tokens: total_content + total_overhead,
85 content_tokens: total_content,
86 overhead_tokens: total_overhead,
87 file_counts,
88 })
89 }
90}
91
92impl Default for TokenCounter {
93 fn default() -> Self {
94 Self::new().expect("Failed to create token counter")
95 }
96}
97
98#[derive(Debug, Clone)]
100pub struct FileTokenCount {
101 pub content_tokens: usize,
103 pub overhead_tokens: usize,
105 pub total_tokens: usize,
107}
108
109#[derive(Debug)]
111pub struct TotalTokenEstimate {
112 pub total_tokens: usize,
114 pub content_tokens: usize,
116 pub overhead_tokens: usize,
118 pub file_counts: Vec<(String, FileTokenCount)>,
120}
121
122fn calculate_hash(text: &str) -> u64 {
124 use std::collections::hash_map::DefaultHasher;
125 use std::hash::{Hash, Hasher};
126
127 let mut hasher = DefaultHasher::new();
128 text.hash(&mut hasher);
129 hasher.finish()
130}
131
132pub fn would_exceed_limit(current_tokens: usize, file_tokens: usize, max_tokens: usize) -> bool {
134 current_tokens + file_tokens > max_tokens
135}
136
137pub fn remaining_tokens(current_tokens: usize, max_tokens: usize) -> usize {
139 max_tokens.saturating_sub(current_tokens)
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[test]
147 fn test_token_counting() {
148 let counter = TokenCounter::new().unwrap();
149
150 let count = counter.count_tokens("Hello, world!").unwrap();
152 assert!(count > 0);
153
154 let count = counter.count_tokens("").unwrap();
156 assert_eq!(count, 0);
157
158 let text = "This is a test text for caching";
160 let count1 = counter.count_tokens(text).unwrap();
161 let count2 = counter.count_tokens(text).unwrap();
162 assert_eq!(count1, count2);
163 }
164
165 #[test]
166 fn test_file_token_counting() {
167 let counter = TokenCounter::new().unwrap();
168
169 let content = "fn main() {\n println!(\"Hello, world!\");\n}";
170 let path = "src/main.rs";
171
172 let count = counter.count_file_tokens(content, path).unwrap();
173 assert!(count.content_tokens > 0);
174 assert!(count.overhead_tokens > 0);
175 assert_eq!(count.total_tokens, count.content_tokens + count.overhead_tokens);
176 }
177
178 #[test]
179 fn test_parallel_counting() {
180 let counter = TokenCounter::new().unwrap();
181
182 let texts =
183 vec!["First text".to_string(), "Second text".to_string(), "Third text".to_string()];
184
185 let counts = counter.count_tokens_parallel(&texts).unwrap();
186 assert_eq!(counts.len(), 3);
187 assert!(counts.iter().all(|&c| c > 0));
188 }
189
190 #[test]
191 fn test_token_limit_checks() {
192 assert!(would_exceed_limit(900, 200, 1000));
193 assert!(!would_exceed_limit(800, 200, 1000));
194
195 assert_eq!(remaining_tokens(300, 1000), 700);
196 assert_eq!(remaining_tokens(1100, 1000), 0);
197 }
198
199 #[test]
200 fn test_total_estimation() {
201 let counter = TokenCounter::new().unwrap();
202
203 let files = vec![
204 ("file1.rs".to_string(), "content1".to_string()),
205 ("file2.rs".to_string(), "content2".to_string()),
206 ];
207
208 let estimate = counter.estimate_total_tokens(&files).unwrap();
209 assert!(estimate.total_tokens > 0);
210 assert_eq!(estimate.file_counts.len(), 2);
211 }
212}