Skip to main content

oxideshield_core/
tokenizer.rs

1//! Tokenizer integration using tiktoken-rs
2
3use crate::{Error, Result};
4use tiktoken_rs::{cl100k_base, o200k_base, p50k_base, CoreBPE};
5use tracing::debug;
6
7/// Supported tokenizer models
8#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
9pub enum TokenizerModel {
10    /// GPT-4, GPT-3.5-turbo, text-embedding-ada-002
11    #[default]
12    Cl100kBase,
13    /// GPT-4o, GPT-4o-mini
14    O200kBase,
15    /// GPT-3 models (text-davinci-003, etc.)
16    P50kBase,
17}
18
19impl std::str::FromStr for TokenizerModel {
20    type Err = String;
21
22    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
23        match s.to_lowercase().as_str() {
24            "cl100k_base" | "cl100k" | "gpt-4" | "gpt-3.5-turbo" => Ok(TokenizerModel::Cl100kBase),
25            "o200k_base" | "o200k" | "gpt-4o" => Ok(TokenizerModel::O200kBase),
26            "p50k_base" | "p50k" | "gpt-3" => Ok(TokenizerModel::P50kBase),
27            _ => Err(format!("Unknown tokenizer model: {}", s)),
28        }
29    }
30}
31
32/// Tokenizer for counting and encoding tokens
33pub struct Tokenizer {
34    bpe: CoreBPE,
35    model: TokenizerModel,
36}
37
38impl std::fmt::Debug for Tokenizer {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct("Tokenizer")
41            .field("model", &self.model)
42            .finish()
43    }
44}
45
46impl Tokenizer {
47    /// Create a new tokenizer with the specified model
48    pub fn new(model: TokenizerModel) -> Result<Self> {
49        let bpe = match model {
50            TokenizerModel::Cl100kBase => {
51                cl100k_base().map_err(|e| Error::Tokenizer(e.to_string()))?
52            }
53            TokenizerModel::O200kBase => {
54                o200k_base().map_err(|e| Error::Tokenizer(e.to_string()))?
55            }
56            TokenizerModel::P50kBase => p50k_base().map_err(|e| Error::Tokenizer(e.to_string()))?,
57        };
58
59        debug!("Created tokenizer with model {:?}", model);
60
61        Ok(Self { bpe, model })
62    }
63
64    /// Create a tokenizer with the default model (cl100k_base)
65    pub fn default_model() -> Result<Self> {
66        Self::new(TokenizerModel::default())
67    }
68
69    /// Get the tokenizer model
70    pub fn model(&self) -> TokenizerModel {
71        self.model
72    }
73
74    /// Count the number of tokens in the text
75    pub fn count_tokens(&self, text: &str) -> usize {
76        self.bpe.encode_ordinary(text).len()
77    }
78
79    /// Encode text into token IDs
80    pub fn encode(&self, text: &str) -> Vec<u32> {
81        self.bpe.encode_ordinary(text)
82    }
83
84    /// Decode token IDs back to text
85    pub fn decode(&self, tokens: &[u32]) -> Result<String> {
86        self.bpe
87            .decode(tokens.to_vec())
88            .map_err(|e| Error::Tokenizer(e.to_string()))
89    }
90
91    /// Truncate text to a maximum number of tokens
92    pub fn truncate(&self, text: &str, max_tokens: usize) -> Result<String> {
93        let tokens = self.encode(text);
94        if tokens.len() <= max_tokens {
95            return Ok(text.to_string());
96        }
97
98        let truncated = &tokens[..max_tokens];
99        self.decode(truncated)
100    }
101
102    /// Check if text exceeds a token limit
103    pub fn exceeds_limit(&self, text: &str, limit: usize) -> bool {
104        self.count_tokens(text) > limit
105    }
106}
107
108impl Default for Tokenizer {
109    fn default() -> Self {
110        Self::new(TokenizerModel::default()).expect("Failed to create default tokenizer")
111    }
112}
113
114/// Token statistics for text analysis
115#[derive(Debug, Clone, Default)]
116pub struct TokenStats {
117    /// Total token count
118    pub total_tokens: usize,
119    /// Average tokens per line
120    pub avg_tokens_per_line: f64,
121    /// Maximum tokens in a single line
122    pub max_line_tokens: usize,
123    /// Number of lines
124    pub line_count: usize,
125}
126
127impl Tokenizer {
128    /// Calculate token statistics for text
129    pub fn stats(&self, text: &str) -> TokenStats {
130        let lines: Vec<&str> = text.lines().collect();
131        let line_count = lines.len();
132
133        if line_count == 0 {
134            return TokenStats::default();
135        }
136
137        let mut total_tokens = 0;
138        let mut max_line_tokens = 0;
139
140        for line in &lines {
141            let tokens = self.count_tokens(line);
142            total_tokens += tokens;
143            max_line_tokens = max_line_tokens.max(tokens);
144        }
145
146        TokenStats {
147            total_tokens,
148            avg_tokens_per_line: total_tokens as f64 / line_count as f64,
149            max_line_tokens,
150            line_count,
151        }
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn test_token_counting() {
161        let tokenizer = Tokenizer::default();
162        let count = tokenizer.count_tokens("Hello, world!");
163        assert!(count > 0);
164        assert!(count < 10); // Should be around 3-4 tokens
165    }
166
167    #[test]
168    fn test_encode_decode() {
169        let tokenizer = Tokenizer::default();
170        let text = "Hello, world!";
171        let tokens = tokenizer.encode(text);
172        let decoded = tokenizer.decode(&tokens).unwrap();
173        assert_eq!(decoded, text);
174    }
175
176    #[test]
177    fn test_truncate() {
178        let tokenizer = Tokenizer::default();
179        let text = "This is a longer text that should be truncated to fewer tokens.";
180        let truncated = tokenizer.truncate(text, 5).unwrap();
181        assert!(tokenizer.count_tokens(&truncated) <= 5);
182    }
183
184    #[test]
185    fn test_exceeds_limit() {
186        let tokenizer = Tokenizer::default();
187        assert!(!tokenizer.exceeds_limit("short", 100));
188        assert!(tokenizer.exceeds_limit("This is a longer text.", 2));
189    }
190
191    #[test]
192    fn test_stats() {
193        let tokenizer = Tokenizer::default();
194        let text = "Line one\nLine two\nLine three";
195        let stats = tokenizer.stats(text);
196
197        assert_eq!(stats.line_count, 3);
198        assert!(stats.total_tokens > 0);
199    }
200}