1use crate::Result;
2use rayon::prelude::*;
3use std::sync::Arc;
4use tiktoken_rs::{get_bpe_from_model, CoreBPE};
5
6const DEFAULT_MODEL: &str = "gpt-4o";
7
8pub struct TokenCounter {
9 bpe: Arc<CoreBPE>,
10 model_name: String,
11}
12
13impl Default for TokenCounter {
14 fn default() -> Self {
15 Self::from_model(DEFAULT_MODEL).expect("Failed to create default tokenizer")
16 }
17}
18
19impl TokenCounter {
20 pub fn new() -> Result<Self> {
21 Self::from_model(DEFAULT_MODEL)
22 }
23
24 pub fn from_model(model_name: &str) -> Result<Self> {
25 let bpe = get_bpe_from_model(model_name)
26 .map_err(|e| anyhow::anyhow!("Failed to get BPE for model {}: {}", model_name, e))?;
27
28 Ok(Self {
29 bpe: Arc::new(bpe),
30 model_name: model_name.to_string(),
31 })
32 }
33
34 pub fn model_name(&self) -> &str {
35 &self.model_name
36 }
37
38 pub fn count_tokens(&self, content: &str) -> usize {
39 self.bpe.encode_with_special_tokens(content).len()
40 }
41
42 pub fn count_tokens_parallel<'a, I>(&self, contents: I) -> Vec<usize>
43 where
44 I: ParallelIterator<Item = &'a str>,
45 {
46 let bpe = Arc::clone(&self.bpe);
47 contents
48 .map(|content| bpe.encode_with_special_tokens(content).len())
49 .collect()
50 }
51
52 pub fn analyze_batch<'a, I>(&self, contents: I) -> (usize, f64)
53 where
54 I: ParallelIterator<Item = &'a str>,
55 {
56 let bpe = Arc::clone(&self.bpe);
57 let total_tokens: usize = contents
58 .map(|content| bpe.encode_with_special_tokens(content).len())
59 .sum();
60
61 let estimated_cost = (total_tokens as f64) * 0.003;
63
64 (total_tokens, estimated_cost)
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71
72 #[test]
73 fn test_default_tokenizer() {
74 let tokenizer = TokenCounter::default();
75 let result = tokenizer.count_tokens("Hello, world!");
76 assert!(result > 0);
77 }
78
79 #[test]
80 fn test_custom_model_tokenizer() -> Result<()> {
81 let tokenizer = TokenCounter::from_model("gpt-3.5-turbo")?;
82 assert_eq!(tokenizer.model_name(), "gpt-3.5-turbo");
83 Ok(())
84 }
85
86 #[test]
87 fn test_parallel_tokenization() {
88 let tokenizer = TokenCounter::default();
89 let texts = vec!["Hello", "World", "Test"];
90 let counts = tokenizer.count_tokens_parallel(texts.par_iter().map(|&s| s));
91 assert_eq!(counts.len(), 3);
92 assert!(counts.iter().all(|&x| x > 0));
93 }
94
95 #[test]
96 fn test_batch_analysis() {
97 let tokenizer = TokenCounter::default();
98 let texts = vec!["Hello", "World", "Test"];
99 let (total_tokens, cost) = tokenizer.analyze_batch(texts.par_iter().map(|&s| s));
100 assert!(total_tokens > 0);
101 assert!(cost > 0.0);
102 }
103
104 #[test]
105 fn test_consistency() {
106 let tokenizer = TokenCounter::default();
107 let text = "Hello, world!";
108
109 let single_count = tokenizer.count_tokens(text);
110 let parallel_count = tokenizer.count_tokens_parallel(vec![text].par_iter().map(|&s| s));
111 let (batch_count, _) = tokenizer.analyze_batch(vec![text].par_iter().map(|&s| s));
112
113 assert_eq!(single_count, parallel_count[0]);
114 assert_eq!(single_count, batch_count);
115 }
116}