git_rune/
tokenizer.rs

1use crate::Result;
2use rayon::prelude::*;
3use std::sync::Arc;
4use tiktoken_rs::{get_bpe_from_model, CoreBPE};
5
6const DEFAULT_MODEL: &str = "gpt-4o";
7
8pub struct TokenCounter {
9    bpe: Arc<CoreBPE>,
10    model_name: String,
11}
12
13impl Default for TokenCounter {
14    fn default() -> Self {
15        Self::from_model(DEFAULT_MODEL).expect("Failed to create default tokenizer")
16    }
17}
18
19impl TokenCounter {
20    pub fn new() -> Result<Self> {
21        Self::from_model(DEFAULT_MODEL)
22    }
23
24    pub fn from_model(model_name: &str) -> Result<Self> {
25        let bpe = get_bpe_from_model(model_name)
26            .map_err(|e| anyhow::anyhow!("Failed to get BPE for model {}: {}", model_name, e))?;
27
28        Ok(Self {
29            bpe: Arc::new(bpe),
30            model_name: model_name.to_string(),
31        })
32    }
33
34    pub fn model_name(&self) -> &str {
35        &self.model_name
36    }
37
38    pub fn count_tokens(&self, content: &str) -> usize {
39        self.bpe.encode_with_special_tokens(content).len()
40    }
41
42    pub fn count_tokens_parallel<'a, I>(&self, contents: I) -> Vec<usize>
43    where
44        I: ParallelIterator<Item = &'a str>,
45    {
46        let bpe = Arc::clone(&self.bpe);
47        contents
48            .map(|content| bpe.encode_with_special_tokens(content).len())
49            .collect()
50    }
51
52    pub fn analyze_batch<'a, I>(&self, contents: I) -> (usize, f64)
53    where
54        I: ParallelIterator<Item = &'a str>,
55    {
56        let bpe = Arc::clone(&self.bpe);
57        let total_tokens: usize = contents
58            .map(|content| bpe.encode_with_special_tokens(content).len())
59            .sum();
60
61        // claude 3.5 sonnet pricing ($3/mtok)
62        let estimated_cost = (total_tokens as f64) * 0.003;
63
64        (total_tokens, estimated_cost)
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71
72    #[test]
73    fn test_default_tokenizer() {
74        let tokenizer = TokenCounter::default();
75        let result = tokenizer.count_tokens("Hello, world!");
76        assert!(result > 0);
77    }
78
79    #[test]
80    fn test_custom_model_tokenizer() -> Result<()> {
81        let tokenizer = TokenCounter::from_model("gpt-3.5-turbo")?;
82        assert_eq!(tokenizer.model_name(), "gpt-3.5-turbo");
83        Ok(())
84    }
85
86    #[test]
87    fn test_parallel_tokenization() {
88        let tokenizer = TokenCounter::default();
89        let texts = vec!["Hello", "World", "Test"];
90        let counts = tokenizer.count_tokens_parallel(texts.par_iter().map(|&s| s));
91        assert_eq!(counts.len(), 3);
92        assert!(counts.iter().all(|&x| x > 0));
93    }
94
95    #[test]
96    fn test_batch_analysis() {
97        let tokenizer = TokenCounter::default();
98        let texts = vec!["Hello", "World", "Test"];
99        let (total_tokens, cost) = tokenizer.analyze_batch(texts.par_iter().map(|&s| s));
100        assert!(total_tokens > 0);
101        assert!(cost > 0.0);
102    }
103
104    #[test]
105    fn test_consistency() {
106        let tokenizer = TokenCounter::default();
107        let text = "Hello, world!";
108
109        let single_count = tokenizer.count_tokens(text);
110        let parallel_count = tokenizer.count_tokens_parallel(vec![text].par_iter().map(|&s| s));
111        let (batch_count, _) = tokenizer.analyze_batch(vec![text].par_iter().map(|&s| s));
112
113        assert_eq!(single_count, parallel_count[0]);
114        assert_eq!(single_count, batch_count);
115    }
116}