context_creator/core/
token.rs

1//! Token counting functionality using tiktoken-rs
2
3use anyhow::Result;
4use rayon::prelude::*;
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex};
7use tiktoken_rs::{cl100k_base, CoreBPE};
8
9/// Token counter with caching support
10pub struct TokenCounter {
11    /// The tiktoken encoder
12    encoder: Arc<CoreBPE>,
13    /// Cache of token counts for content hashes
14    cache: Arc<Mutex<HashMap<u64, usize>>>,
15}
16
17impl TokenCounter {
18    /// Create a new token counter with cl100k_base encoding (GPT-4)
19    pub fn new() -> Result<Self> {
20        let encoder = cl100k_base()?;
21        Ok(TokenCounter {
22            encoder: Arc::new(encoder),
23            cache: Arc::new(Mutex::new(HashMap::new())),
24        })
25    }
26
27    /// Count tokens in a single text
28    pub fn count_tokens(&self, text: &str) -> Result<usize> {
29        // Calculate hash for caching
30        let hash = calculate_hash(text);
31
32        // Check cache first
33        if let Ok(cache) = self.cache.lock() {
34            if let Some(&count) = cache.get(&hash) {
35                return Ok(count);
36            }
37        }
38
39        // Count tokens
40        let tokens = self.encoder.encode_with_special_tokens(text);
41        let count = tokens.len();
42
43        // Store in cache
44        if let Ok(mut cache) = self.cache.lock() {
45            cache.insert(hash, count);
46        }
47
48        Ok(count)
49    }
50
51    /// Count tokens in multiple texts in parallel
52    pub fn count_tokens_parallel(&self, texts: &[String]) -> Result<Vec<usize>> {
53        texts
54            .par_iter()
55            .map(|text| self.count_tokens(text))
56            .collect()
57    }
58
59    /// Count tokens for a file's content with metadata
60    pub fn count_file_tokens(&self, content: &str, path: &str) -> Result<FileTokenCount> {
61        let content_tokens = self.count_tokens(content)?;
62
63        // Count tokens in the file path/header that will be included in markdown
64        let header = format!("## {path}\n\n```\n");
65        let footer = "\n```\n\n";
66        let header_tokens = self.count_tokens(&header)?;
67        let footer_tokens = self.count_tokens(footer)?;
68
69        Ok(FileTokenCount {
70            content_tokens,
71            overhead_tokens: header_tokens + footer_tokens,
72            total_tokens: content_tokens + header_tokens + footer_tokens,
73        })
74    }
75
76    /// Estimate tokens for multiple files
77    pub fn estimate_total_tokens(&self, files: &[(String, String)]) -> Result<TotalTokenEstimate> {
78        let mut total_content = 0;
79        let mut total_overhead = 0;
80        let mut file_counts = Vec::new();
81
82        for (path, content) in files {
83            let count = self.count_file_tokens(content, path)?;
84            total_content += count.content_tokens;
85            total_overhead += count.overhead_tokens;
86            file_counts.push((path.clone(), count));
87        }
88
89        Ok(TotalTokenEstimate {
90            total_tokens: total_content + total_overhead,
91            content_tokens: total_content,
92            overhead_tokens: total_overhead,
93            file_counts,
94        })
95    }
96}
97
98impl Default for TokenCounter {
99    fn default() -> Self {
100        Self::new().expect("Failed to create token counter")
101    }
102}
103
104/// Token count for a single file
105#[derive(Debug, Clone)]
106pub struct FileTokenCount {
107    /// Tokens in the file content
108    pub content_tokens: usize,
109    /// Tokens in markdown formatting overhead
110    pub overhead_tokens: usize,
111    /// Total tokens (content + overhead)
112    pub total_tokens: usize,
113}
114
115/// Total token estimate for multiple files
116#[derive(Debug)]
117pub struct TotalTokenEstimate {
118    /// Total tokens across all files
119    pub total_tokens: usize,
120    /// Total content tokens
121    pub content_tokens: usize,
122    /// Total overhead tokens
123    pub overhead_tokens: usize,
124    /// Individual file counts
125    pub file_counts: Vec<(String, FileTokenCount)>,
126}
127
128/// Calculate a hash for content caching
129fn calculate_hash(text: &str) -> u64 {
130    use std::collections::hash_map::DefaultHasher;
131    use std::hash::{Hash, Hasher};
132
133    let mut hasher = DefaultHasher::new();
134    text.hash(&mut hasher);
135    hasher.finish()
136}
137
138/// Check if adding a file would exceed token limit
139pub fn would_exceed_limit(current_tokens: usize, file_tokens: usize, max_tokens: usize) -> bool {
140    current_tokens + file_tokens > max_tokens
141}
142
143/// Calculate remaining token budget
144pub fn remaining_tokens(current_tokens: usize, max_tokens: usize) -> usize {
145    max_tokens.saturating_sub(current_tokens)
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn test_token_counting() {
154        let counter = TokenCounter::new().unwrap();
155
156        // Test simple text
157        let count = counter.count_tokens("Hello, world!").unwrap();
158        assert!(count > 0);
159
160        // Test empty text
161        let count = counter.count_tokens("").unwrap();
162        assert_eq!(count, 0);
163
164        // Test caching
165        let text = "This is a test text for caching";
166        let count1 = counter.count_tokens(text).unwrap();
167        let count2 = counter.count_tokens(text).unwrap();
168        assert_eq!(count1, count2);
169    }
170
171    #[test]
172    fn test_file_token_counting() {
173        let counter = TokenCounter::new().unwrap();
174
175        let content = "fn main() {\n    println!(\"Hello, world!\");\n}";
176        let path = "src/main.rs";
177
178        let count = counter.count_file_tokens(content, path).unwrap();
179        assert!(count.content_tokens > 0);
180        assert!(count.overhead_tokens > 0);
181        assert_eq!(
182            count.total_tokens,
183            count.content_tokens + count.overhead_tokens
184        );
185    }
186
187    #[test]
188    fn test_parallel_counting() {
189        let counter = TokenCounter::new().unwrap();
190
191        let texts = vec![
192            "First text".to_string(),
193            "Second text".to_string(),
194            "Third text".to_string(),
195        ];
196
197        let counts = counter.count_tokens_parallel(&texts).unwrap();
198        assert_eq!(counts.len(), 3);
199        assert!(counts.iter().all(|&c| c > 0));
200    }
201
202    #[test]
203    fn test_token_limit_checks() {
204        assert!(would_exceed_limit(900, 200, 1000));
205        assert!(!would_exceed_limit(800, 200, 1000));
206
207        assert_eq!(remaining_tokens(300, 1000), 700);
208        assert_eq!(remaining_tokens(1100, 1000), 0);
209    }
210
211    #[test]
212    fn test_total_estimation() {
213        let counter = TokenCounter::new().unwrap();
214
215        let files = vec![
216            ("file1.rs".to_string(), "content1".to_string()),
217            ("file2.rs".to_string(), "content2".to_string()),
218        ];
219
220        let estimate = counter.estimate_total_tokens(&files).unwrap();
221        assert!(estimate.total_tokens > 0);
222        assert_eq!(estimate.file_counts.len(), 2);
223    }
224}