code_digest/core/
token.rs

1//! Token counting functionality using tiktoken-rs
2
3use anyhow::Result;
4use rayon::prelude::*;
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex};
7use tiktoken_rs::{cl100k_base, CoreBPE};
8
9/// Token counter with caching support
10pub struct TokenCounter {
11    /// The tiktoken encoder
12    encoder: Arc<CoreBPE>,
13    /// Cache of token counts for content hashes
14    cache: Arc<Mutex<HashMap<u64, usize>>>,
15}
16
17impl TokenCounter {
18    /// Create a new token counter with cl100k_base encoding (GPT-4)
19    pub fn new() -> Result<Self> {
20        let encoder = cl100k_base()?;
21        Ok(TokenCounter { encoder: Arc::new(encoder), cache: Arc::new(Mutex::new(HashMap::new())) })
22    }
23
24    /// Count tokens in a single text
25    pub fn count_tokens(&self, text: &str) -> Result<usize> {
26        // Calculate hash for caching
27        let hash = calculate_hash(text);
28
29        // Check cache first
30        if let Ok(cache) = self.cache.lock() {
31            if let Some(&count) = cache.get(&hash) {
32                return Ok(count);
33            }
34        }
35
36        // Count tokens
37        let tokens = self.encoder.encode_with_special_tokens(text);
38        let count = tokens.len();
39
40        // Store in cache
41        if let Ok(mut cache) = self.cache.lock() {
42            cache.insert(hash, count);
43        }
44
45        Ok(count)
46    }
47
48    /// Count tokens in multiple texts in parallel
49    pub fn count_tokens_parallel(&self, texts: &[String]) -> Result<Vec<usize>> {
50        texts.par_iter().map(|text| self.count_tokens(text)).collect()
51    }
52
53    /// Count tokens for a file's content with metadata
54    pub fn count_file_tokens(&self, content: &str, path: &str) -> Result<FileTokenCount> {
55        let content_tokens = self.count_tokens(content)?;
56
57        // Count tokens in the file path/header that will be included in markdown
58        let header = format!("## {path}\n\n```\n");
59        let footer = "\n```\n\n";
60        let header_tokens = self.count_tokens(&header)?;
61        let footer_tokens = self.count_tokens(footer)?;
62
63        Ok(FileTokenCount {
64            content_tokens,
65            overhead_tokens: header_tokens + footer_tokens,
66            total_tokens: content_tokens + header_tokens + footer_tokens,
67        })
68    }
69
70    /// Estimate tokens for multiple files
71    pub fn estimate_total_tokens(&self, files: &[(String, String)]) -> Result<TotalTokenEstimate> {
72        let mut total_content = 0;
73        let mut total_overhead = 0;
74        let mut file_counts = Vec::new();
75
76        for (path, content) in files {
77            let count = self.count_file_tokens(content, path)?;
78            total_content += count.content_tokens;
79            total_overhead += count.overhead_tokens;
80            file_counts.push((path.clone(), count));
81        }
82
83        Ok(TotalTokenEstimate {
84            total_tokens: total_content + total_overhead,
85            content_tokens: total_content,
86            overhead_tokens: total_overhead,
87            file_counts,
88        })
89    }
90}
91
92impl Default for TokenCounter {
93    fn default() -> Self {
94        Self::new().expect("Failed to create token counter")
95    }
96}
97
98/// Token count for a single file
99#[derive(Debug, Clone)]
100pub struct FileTokenCount {
101    /// Tokens in the file content
102    pub content_tokens: usize,
103    /// Tokens in markdown formatting overhead
104    pub overhead_tokens: usize,
105    /// Total tokens (content + overhead)
106    pub total_tokens: usize,
107}
108
109/// Total token estimate for multiple files
110#[derive(Debug)]
111pub struct TotalTokenEstimate {
112    /// Total tokens across all files
113    pub total_tokens: usize,
114    /// Total content tokens
115    pub content_tokens: usize,
116    /// Total overhead tokens
117    pub overhead_tokens: usize,
118    /// Individual file counts
119    pub file_counts: Vec<(String, FileTokenCount)>,
120}
121
122/// Calculate a hash for content caching
123fn calculate_hash(text: &str) -> u64 {
124    use std::collections::hash_map::DefaultHasher;
125    use std::hash::{Hash, Hasher};
126
127    let mut hasher = DefaultHasher::new();
128    text.hash(&mut hasher);
129    hasher.finish()
130}
131
132/// Check if adding a file would exceed token limit
133pub fn would_exceed_limit(current_tokens: usize, file_tokens: usize, max_tokens: usize) -> bool {
134    current_tokens + file_tokens > max_tokens
135}
136
137/// Calculate remaining token budget
138pub fn remaining_tokens(current_tokens: usize, max_tokens: usize) -> usize {
139    max_tokens.saturating_sub(current_tokens)
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn test_token_counting() {
148        let counter = TokenCounter::new().unwrap();
149
150        // Test simple text
151        let count = counter.count_tokens("Hello, world!").unwrap();
152        assert!(count > 0);
153
154        // Test empty text
155        let count = counter.count_tokens("").unwrap();
156        assert_eq!(count, 0);
157
158        // Test caching
159        let text = "This is a test text for caching";
160        let count1 = counter.count_tokens(text).unwrap();
161        let count2 = counter.count_tokens(text).unwrap();
162        assert_eq!(count1, count2);
163    }
164
165    #[test]
166    fn test_file_token_counting() {
167        let counter = TokenCounter::new().unwrap();
168
169        let content = "fn main() {\n    println!(\"Hello, world!\");\n}";
170        let path = "src/main.rs";
171
172        let count = counter.count_file_tokens(content, path).unwrap();
173        assert!(count.content_tokens > 0);
174        assert!(count.overhead_tokens > 0);
175        assert_eq!(count.total_tokens, count.content_tokens + count.overhead_tokens);
176    }
177
178    #[test]
179    fn test_parallel_counting() {
180        let counter = TokenCounter::new().unwrap();
181
182        let texts =
183            vec!["First text".to_string(), "Second text".to_string(), "Third text".to_string()];
184
185        let counts = counter.count_tokens_parallel(&texts).unwrap();
186        assert_eq!(counts.len(), 3);
187        assert!(counts.iter().all(|&c| c > 0));
188    }
189
190    #[test]
191    fn test_token_limit_checks() {
192        assert!(would_exceed_limit(900, 200, 1000));
193        assert!(!would_exceed_limit(800, 200, 1000));
194
195        assert_eq!(remaining_tokens(300, 1000), 700);
196        assert_eq!(remaining_tokens(1100, 1000), 0);
197    }
198
199    #[test]
200    fn test_total_estimation() {
201        let counter = TokenCounter::new().unwrap();
202
203        let files = vec![
204            ("file1.rs".to_string(), "content1".to_string()),
205            ("file2.rs".to_string(), "content2".to_string()),
206        ];
207
208        let estimate = counter.estimate_total_tokens(&files).unwrap();
209        assert!(estimate.total_tokens > 0);
210        assert_eq!(estimate.file_counts.len(), 2);
211    }
212}