Skip to main content

tiktoken_wasm/
lib.rs

1//! WebAssembly bindings for the tiktoken BPE tokenizer.
2//!
3//! Provides browser-compatible wrappers around the core `tiktoken` crate,
4//! enabling high-performance token encoding, decoding, counting, and
5//! cost estimation directly in JavaScript/TypeScript applications.
6//!
7//! All encoding instances are cached globally via `OnceLock`, so repeated
8//! calls to `getEncoding()` with the same name return the same underlying data.
9
10use wasm_bindgen::prelude::{wasm_bindgen, JsError};
11
12/// WASM wrapper around a tiktoken encoding instance.
13///
14/// Created via [`get_encoding`] or [`encoding_for_model`].
15/// Call `.free()` when done to release WASM memory.
16#[wasm_bindgen]
17pub struct Encoding {
18    /// encoding name (e.g. "cl100k_base") — always a static string
19    name: &'static str,
20    /// reference to the globally cached CoreBpe instance
21    bpe: &'static tiktoken::CoreBpe,
22}
23
24#[wasm_bindgen]
25impl Encoding {
26    /// Encode text into token ids (returns `Uint32Array` in JS).
27    ///
28    /// Special tokens like `<|endoftext|>` are treated as ordinary text.
29    /// Use `encodeWithSpecialTokens()` to recognize them.
30    pub fn encode(&self, text: &str) -> Vec<u32> {
31        self.bpe.encode(text)
32    }
33
34    /// Encode text into token ids, recognizing special tokens.
35    ///
36    /// Special tokens (e.g. `<|endoftext|>`) are encoded as their designated ids
37    /// instead of being split into sub-word pieces.
38    #[wasm_bindgen(js_name = encodeWithSpecialTokens)]
39    pub fn encode_with_special_tokens(&self, text: &str) -> Vec<u32> {
40        self.bpe.encode_with_special_tokens(text)
41    }
42
43    /// Decode token ids back to a UTF-8 string.
44    ///
45    /// Uses lossy UTF-8 conversion — invalid byte sequences are replaced with U+FFFD.
46    pub fn decode(&self, tokens: &[u32]) -> String {
47        let bytes = self.bpe.decode(tokens);
48        String::from_utf8_lossy(&bytes).into_owned()
49    }
50
51    /// Count tokens without building the full token id array.
52    ///
53    /// Faster than `encode(text).length` for cases where you only need the count.
54    pub fn count(&self, text: &str) -> usize {
55        self.bpe.count(text)
56    }
57
58    /// Count tokens, recognizing special tokens.
59    ///
60    /// Like `count()` but special tokens (e.g. `<|endoftext|>`) are counted
61    /// as single tokens instead of being split into sub-word pieces.
62    #[wasm_bindgen(js_name = countWithSpecialTokens)]
63    pub fn count_with_special_tokens(&self, text: &str) -> usize {
64        self.bpe.count_with_special_tokens(text)
65    }
66
67    /// Get the number of regular (non-special) tokens in the vocabulary.
68    #[wasm_bindgen(js_name = vocabSize, getter)]
69    pub fn vocab_size(&self) -> usize {
70        self.bpe.vocab_size()
71    }
72
73    /// Get the number of special tokens in the vocabulary.
74    #[wasm_bindgen(js_name = numSpecialTokens, getter)]
75    pub fn num_special_tokens(&self) -> usize {
76        self.bpe.num_special_tokens()
77    }
78
79    /// Get the encoding name (e.g. `"cl100k_base"`).
80    #[wasm_bindgen(getter)]
81    pub fn name(&self) -> String {
82        self.name.to_string()
83    }
84}
85
86/// List all available encoding names.
87///
88/// Returns an array of strings: `["cl100k_base", "o200k_base", ...]`
89#[wasm_bindgen(js_name = listEncodings)]
90pub fn list_encodings() -> Vec<String> {
91    tiktoken::list_encodings()
92        .iter()
93        .map(|s| s.to_string())
94        .collect()
95}
96
97/// Get an encoding by name.
98///
99/// Supported encodings:
100/// - `"cl100k_base"` — GPT-4, GPT-3.5-turbo
101/// - `"o200k_base"` — GPT-4o, GPT-4.1, o1, o3
102/// - `"p50k_base"` — text-davinci-002/003
103/// - `"p50k_edit"` — text-davinci-edit
104/// - `"r50k_base"` — GPT-3 (davinci, curie, etc.)
105/// - `"llama3"` — Meta Llama 3/4
106/// - `"deepseek_v3"` — DeepSeek V3/R1
107/// - `"qwen2"` — Qwen 2/2.5/3
108/// - `"mistral_v3"` — Mistral/Codestral/Pixtral
109///
110/// Throws `Error` for unknown encoding names.
111#[wasm_bindgen(js_name = getEncoding)]
112pub fn get_encoding(name: &str) -> Result<Encoding, JsError> {
113    // look up the static name from tiktoken's canonical list (single source of truth)
114    let static_name = tiktoken::list_encodings()
115        .iter()
116        .find(|&&n| n == name)
117        .ok_or_else(|| JsError::new(&format!("unknown encoding: {name}")))?;
118    let bpe = tiktoken::get_encoding(name)
119        .ok_or_else(|| JsError::new(&format!("unknown encoding: {name}")))?;
120    Ok(Encoding {
121        name: static_name,
122        bpe,
123    })
124}
125
126/// Get an encoding for a model name (e.g. `"gpt-4o"`, `"o3-mini"`, `"llama-4"`, `"deepseek-r1"`).
127///
128/// Supports models from OpenAI, Meta, DeepSeek, Qwen, and Mistral.
129/// Automatically resolves the model name to the correct encoding.
130/// Throws `Error` for unknown model names.
131#[wasm_bindgen(js_name = encodingForModel)]
132pub fn encoding_for_model(model: &str) -> Result<Encoding, JsError> {
133    let name = tiktoken::model_to_encoding(model)
134        .ok_or_else(|| JsError::new(&format!("unknown model: {model}")))?;
135    let bpe = tiktoken::get_encoding(name)
136        .ok_or_else(|| JsError::new(&format!("unknown encoding: {name}")))?;
137    Ok(Encoding { name, bpe })
138}
139
140/// Map a model name to its encoding name without loading the encoding.
141///
142/// Returns the encoding name string (e.g. `"o200k_base"`) or `null` for unknown models.
143#[wasm_bindgen(js_name = modelToEncoding)]
144pub fn model_to_encoding(model: &str) -> Option<String> {
145    tiktoken::model_to_encoding(model).map(|s| s.to_string())
146}
147
148/// Estimate cost in USD for a given model, input token count, and output token count.
149///
150/// Supports OpenAI, Anthropic Claude, Google Gemini, Meta Llama, DeepSeek, Qwen, and Mistral models.
151/// Throws `Error` for unknown model ids.
152#[wasm_bindgen(js_name = estimateCost)]
153pub fn estimate_cost(
154    model_id: &str,
155    input_tokens: u32,
156    output_tokens: u32,
157) -> Result<f64, JsError> {
158    tiktoken::pricing::estimate_cost(model_id, input_tokens as u64, output_tokens as u64)
159        .ok_or_else(|| JsError::new(&format!("unknown model: {model_id}")))
160}
161
162/// Get model pricing and metadata.
163///
164/// Returns a typed object with: `id`, `provider`, `inputPer1m`, `outputPer1m`,
165/// `cachedInputPer1m`, `contextWindow`, `maxOutput`.
166///
167/// Throws `Error` for unknown model ids.
168#[wasm_bindgen(js_name = getModelInfo)]
169pub fn get_model_info(model_id: &str) -> Result<ModelInfo, JsError> {
170    let model = tiktoken::pricing::get_model(model_id)
171        .ok_or_else(|| JsError::new(&format!("unknown model: {model_id}")))?;
172    Ok(convert_model(model))
173}
174
175/// List all supported models with pricing info.
176///
177/// Returns an array of `ModelInfo` objects.
178#[wasm_bindgen(js_name = allModels)]
179pub fn all_models() -> Vec<ModelInfo> {
180    tiktoken::pricing::all_models()
181        .iter()
182        .map(convert_model)
183        .collect()
184}
185
186/// List models filtered by provider name.
187///
188/// Provider names: `"OpenAI"`, `"Anthropic"`, `"Google"`, `"Meta"`, `"DeepSeek"`, `"Alibaba"`, `"Mistral"`.
189/// Returns an empty array for unknown providers.
190#[wasm_bindgen(js_name = modelsByProvider)]
191pub fn models_by_provider(provider: &str) -> Vec<ModelInfo> {
192    let Some(provider) = parse_provider(provider) else {
193        return Vec::new();
194    };
195
196    tiktoken::pricing::models_by_provider(provider)
197        .iter()
198        .map(|m| convert_model(m))
199        .collect()
200}
201
202fn convert_model(m: &tiktoken::pricing::Model) -> ModelInfo {
203    ModelInfo {
204        id: m.id,
205        provider: m.provider.to_string(),
206        input_per_1m: m.pricing.input_per_1m,
207        output_per_1m: m.pricing.output_per_1m,
208        cached_input_per_1m: m.pricing.cached_input_per_1m,
209        context_window: m.context_window,
210        max_output: m.max_output,
211    }
212}
213
214fn parse_provider(s: &str) -> Option<tiktoken::pricing::Provider> {
215    match s {
216        "OpenAI" => Some(tiktoken::pricing::Provider::OpenAI),
217        "Anthropic" => Some(tiktoken::pricing::Provider::Anthropic),
218        "Google" => Some(tiktoken::pricing::Provider::Google),
219        "Meta" => Some(tiktoken::pricing::Provider::Meta),
220        "DeepSeek" => Some(tiktoken::pricing::Provider::DeepSeek),
221        "Alibaba" => Some(tiktoken::pricing::Provider::Alibaba),
222        "Mistral" => Some(tiktoken::pricing::Provider::Mistral),
223        _ => None,
224    }
225}
226
227/// Model pricing and metadata.
228#[wasm_bindgen]
229#[derive(Clone)]
230pub struct ModelInfo {
231    id: &'static str,
232    provider: String,
233    input_per_1m: f64,
234    output_per_1m: f64,
235    cached_input_per_1m: Option<f64>,
236    context_window: u32,
237    max_output: u32,
238}
239
240#[wasm_bindgen]
241impl ModelInfo {
242    #[wasm_bindgen(getter)]
243    pub fn id(&self) -> String {
244        self.id.to_string()
245    }
246    #[wasm_bindgen(getter)]
247    pub fn provider(&self) -> String {
248        self.provider.clone()
249    }
250    #[wasm_bindgen(getter, js_name = inputPer1m)]
251    pub fn input_per_1m(&self) -> f64 {
252        self.input_per_1m
253    }
254    #[wasm_bindgen(getter, js_name = outputPer1m)]
255    pub fn output_per_1m(&self) -> f64 {
256        self.output_per_1m
257    }
258    #[wasm_bindgen(getter, js_name = cachedInputPer1m)]
259    pub fn cached_input_per_1m(&self) -> Option<f64> {
260        self.cached_input_per_1m
261    }
262    #[wasm_bindgen(getter, js_name = contextWindow)]
263    pub fn context_window(&self) -> u32 {
264        self.context_window
265    }
266    #[wasm_bindgen(getter, js_name = maxOutput)]
267    pub fn max_output(&self) -> u32 {
268        self.max_output
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn all_encodings_roundtrip() {
278        for &name in tiktoken::list_encodings() {
279            let enc = get_encoding(name).unwrap();
280            let text = "hello world 你好 🚀";
281            let tokens = enc.encode(text);
282            let decoded = enc.decode(&tokens);
283            assert_eq!(decoded, text, "roundtrip failed for {name}");
284        }
285    }
286
287    #[test]
288    fn encoding_for_known_models() {
289        let models = [
290            "gpt-4o", "gpt-4", "gpt-3.5-turbo", "llama-4", "deepseek-r1", "qwen3", "mistral-large",
291        ];
292        for model in models {
293            let enc = encoding_for_model(model);
294            assert!(enc.is_ok(), "encoding_for_model failed for {model}");
295        }
296    }
297
298    #[test]
299    fn list_encodings_count() {
300        let names = list_encodings();
301        assert_eq!(names.len(), 9);
302    }
303
304    #[test]
305    fn all_models_count() {
306        let models = all_models();
307        assert_eq!(models.len(), tiktoken::pricing::all_models().len());
308    }
309
310    #[test]
311    fn models_by_valid_provider() {
312        let openai = models_by_provider("OpenAI");
313        assert!(!openai.is_empty());
314        for m in &openai {
315            assert_eq!(m.provider, "OpenAI");
316        }
317    }
318
319    #[test]
320    fn models_by_invalid_provider() {
321        let unknown = models_by_provider("NonExistent");
322        assert!(unknown.is_empty());
323    }
324
325    #[test]
326    fn estimate_cost_known_model() {
327        let cost = estimate_cost("gpt-4o", 1000, 1000).unwrap();
328        assert!(cost > 0.0);
329    }
330
331    #[test]
332    fn estimate_cost_unknown_model() {
333        assert!(estimate_cost("fake-model", 1000, 1000).is_err());
334    }
335
336    #[test]
337    fn get_model_info_known() {
338        let info = get_model_info("gpt-4o").unwrap();
339        assert_eq!(info.id(), "gpt-4o");
340        assert_eq!(info.provider(), "OpenAI");
341        assert!(info.context_window() > 0);
342    }
343
344    #[test]
345    fn get_model_info_unknown() {
346        assert!(get_model_info("fake-model").is_err());
347    }
348
349    #[test]
350    fn unknown_encoding_error() {
351        assert!(get_encoding("nonexistent").is_err());
352    }
353
354    #[test]
355    fn unknown_model_encoding_error() {
356        assert!(encoding_for_model("nonexistent-model-xyz").is_err());
357    }
358
359    #[test]
360    fn model_to_encoding_known() {
361        let name = model_to_encoding("gpt-4o");
362        assert_eq!(name.as_deref(), Some("o200k_base"));
363    }
364
365    #[test]
366    fn model_to_encoding_unknown() {
367        assert!(model_to_encoding("fake-model").is_none());
368    }
369
370    #[test]
371    fn parse_provider_all_variants() {
372        assert!(parse_provider("OpenAI").is_some());
373        assert!(parse_provider("Anthropic").is_some());
374        assert!(parse_provider("Google").is_some());
375        assert!(parse_provider("Meta").is_some());
376        assert!(parse_provider("DeepSeek").is_some());
377        assert!(parse_provider("Alibaba").is_some());
378        assert!(parse_provider("Mistral").is_some());
379        assert!(parse_provider("Unknown").is_none());
380    }
381}