oxideshield_core/
tokenizer.rs1use crate::{Error, Result};
4use tiktoken_rs::{cl100k_base, o200k_base, p50k_base, CoreBPE};
5use tracing::debug;
6
7#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
9pub enum TokenizerModel {
10 #[default]
12 Cl100kBase,
13 O200kBase,
15 P50kBase,
17}
18
19impl std::str::FromStr for TokenizerModel {
20 type Err = String;
21
22 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
23 match s.to_lowercase().as_str() {
24 "cl100k_base" | "cl100k" | "gpt-4" | "gpt-3.5-turbo" => Ok(TokenizerModel::Cl100kBase),
25 "o200k_base" | "o200k" | "gpt-4o" => Ok(TokenizerModel::O200kBase),
26 "p50k_base" | "p50k" | "gpt-3" => Ok(TokenizerModel::P50kBase),
27 _ => Err(format!("Unknown tokenizer model: {}", s)),
28 }
29 }
30}
31
32pub struct Tokenizer {
34 bpe: CoreBPE,
35 model: TokenizerModel,
36}
37
38impl std::fmt::Debug for Tokenizer {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("Tokenizer")
41 .field("model", &self.model)
42 .finish()
43 }
44}
45
46impl Tokenizer {
47 pub fn new(model: TokenizerModel) -> Result<Self> {
49 let bpe = match model {
50 TokenizerModel::Cl100kBase => {
51 cl100k_base().map_err(|e| Error::Tokenizer(e.to_string()))?
52 }
53 TokenizerModel::O200kBase => {
54 o200k_base().map_err(|e| Error::Tokenizer(e.to_string()))?
55 }
56 TokenizerModel::P50kBase => p50k_base().map_err(|e| Error::Tokenizer(e.to_string()))?,
57 };
58
59 debug!("Created tokenizer with model {:?}", model);
60
61 Ok(Self { bpe, model })
62 }
63
64 pub fn default_model() -> Result<Self> {
66 Self::new(TokenizerModel::default())
67 }
68
69 pub fn model(&self) -> TokenizerModel {
71 self.model
72 }
73
74 pub fn count_tokens(&self, text: &str) -> usize {
76 self.bpe.encode_ordinary(text).len()
77 }
78
79 pub fn encode(&self, text: &str) -> Vec<u32> {
81 self.bpe.encode_ordinary(text)
82 }
83
84 pub fn decode(&self, tokens: &[u32]) -> Result<String> {
86 self.bpe
87 .decode(tokens.to_vec())
88 .map_err(|e| Error::Tokenizer(e.to_string()))
89 }
90
91 pub fn truncate(&self, text: &str, max_tokens: usize) -> Result<String> {
93 let tokens = self.encode(text);
94 if tokens.len() <= max_tokens {
95 return Ok(text.to_string());
96 }
97
98 let truncated = &tokens[..max_tokens];
99 self.decode(truncated)
100 }
101
102 pub fn exceeds_limit(&self, text: &str, limit: usize) -> bool {
104 self.count_tokens(text) > limit
105 }
106}
107
108impl Default for Tokenizer {
109 fn default() -> Self {
110 Self::new(TokenizerModel::default()).expect("Failed to create default tokenizer")
111 }
112}
113
114#[derive(Debug, Clone, Default)]
116pub struct TokenStats {
117 pub total_tokens: usize,
119 pub avg_tokens_per_line: f64,
121 pub max_line_tokens: usize,
123 pub line_count: usize,
125}
126
127impl Tokenizer {
128 pub fn stats(&self, text: &str) -> TokenStats {
130 let lines: Vec<&str> = text.lines().collect();
131 let line_count = lines.len();
132
133 if line_count == 0 {
134 return TokenStats::default();
135 }
136
137 let mut total_tokens = 0;
138 let mut max_line_tokens = 0;
139
140 for line in &lines {
141 let tokens = self.count_tokens(line);
142 total_tokens += tokens;
143 max_line_tokens = max_line_tokens.max(tokens);
144 }
145
146 TokenStats {
147 total_tokens,
148 avg_tokens_per_line: total_tokens as f64 / line_count as f64,
149 max_line_tokens,
150 line_count,
151 }
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn test_token_counting() {
161 let tokenizer = Tokenizer::default();
162 let count = tokenizer.count_tokens("Hello, world!");
163 assert!(count > 0);
164 assert!(count < 10); }
166
167 #[test]
168 fn test_encode_decode() {
169 let tokenizer = Tokenizer::default();
170 let text = "Hello, world!";
171 let tokens = tokenizer.encode(text);
172 let decoded = tokenizer.decode(&tokens).unwrap();
173 assert_eq!(decoded, text);
174 }
175
176 #[test]
177 fn test_truncate() {
178 let tokenizer = Tokenizer::default();
179 let text = "This is a longer text that should be truncated to fewer tokens.";
180 let truncated = tokenizer.truncate(text, 5).unwrap();
181 assert!(tokenizer.count_tokens(&truncated) <= 5);
182 }
183
184 #[test]
185 fn test_exceeds_limit() {
186 let tokenizer = Tokenizer::default();
187 assert!(!tokenizer.exceeds_limit("short", 100));
188 assert!(tokenizer.exceeds_limit("This is a longer text.", 2));
189 }
190
191 #[test]
192 fn test_stats() {
193 let tokenizer = Tokenizer::default();
194 let text = "Line one\nLine two\nLine three";
195 let stats = tokenizer.stats(text);
196
197 assert_eq!(stats.line_count, 3);
198 assert!(stats.total_tokens > 0);
199 }
200}