use crate::core::Encoding;
use crate::encodings::get_encoding;
use crate::errors::{Result, TiktokenError};
use std::collections::HashMap;
use std::sync::OnceLock;
static MODEL_PREFIX_REGISTRY: OnceLock<HashMap<&'static str, &'static str>> = OnceLock::new();
static MODEL_REGISTRY: OnceLock<HashMap<&'static str, &'static str>> = OnceLock::new();
fn init_prefix_registry() -> HashMap<&'static str, &'static str> {
let mut registry = HashMap::new();
registry.insert("o1-", "o200k_base");
registry.insert("o3-", "o200k_base");
registry.insert("o4-mini-", "o200k_base");
registry.insert("gpt-5-", "o200k_base");
registry.insert("gpt-4.5-", "o200k_base");
registry.insert("gpt-4.1-", "o200k_base");
registry.insert("chatgpt-4o-", "o200k_base");
registry.insert("gpt-4o-", "o200k_base");
registry.insert("gpt-4-", "cl100k_base");
registry.insert("gpt-3.5-turbo-", "cl100k_base");
registry.insert("gpt-35-turbo-", "cl100k_base"); registry.insert("gpt-oss-", "o200k_harmony");
registry.insert("ft:gpt-4o", "o200k_base");
registry.insert("ft:gpt-4", "cl100k_base");
registry.insert("ft:gpt-3.5-turbo", "cl100k_base");
registry.insert("ft:davinci-002", "cl100k_base");
registry.insert("ft:babbage-002", "cl100k_base");
registry
}
fn init_model_registry() -> HashMap<&'static str, &'static str> {
let mut registry = HashMap::new();
registry.insert("o1", "o200k_base");
registry.insert("o3", "o200k_base");
registry.insert("o4-mini", "o200k_base");
registry.insert("gpt-5", "o200k_base");
registry.insert("gpt-4.1", "o200k_base");
registry.insert("gpt-4o", "o200k_base");
registry.insert("gpt-4", "cl100k_base");
registry.insert("gpt-3.5-turbo", "cl100k_base");
registry.insert("gpt-3.5", "cl100k_base");
registry.insert("gpt-35-turbo", "cl100k_base");
registry.insert("davinci-002", "cl100k_base");
registry.insert("babbage-002", "cl100k_base");
registry.insert("text-embedding-ada-002", "cl100k_base");
registry.insert("text-embedding-3-small", "cl100k_base");
registry.insert("text-embedding-3-large", "cl100k_base");
registry.insert("text-davinci-003", "p50k_base");
registry.insert("text-davinci-002", "p50k_base");
registry.insert("text-davinci-001", "r50k_base");
registry.insert("text-curie-001", "r50k_base");
registry.insert("text-babbage-001", "r50k_base");
registry.insert("text-ada-001", "r50k_base");
registry.insert("davinci", "r50k_base");
registry.insert("curie", "r50k_base");
registry.insert("babbage", "r50k_base");
registry.insert("ada", "r50k_base");
registry.insert("code-davinci-002", "p50k_base");
registry.insert("code-davinci-001", "p50k_base");
registry.insert("code-cushman-002", "p50k_base");
registry.insert("code-cushman-001", "p50k_base");
registry.insert("davinci-codex", "p50k_base");
registry.insert("cushman-codex", "p50k_base");
registry.insert("text-davinci-edit-001", "p50k_edit");
registry.insert("code-davinci-edit-001", "p50k_edit");
registry.insert("text-similarity-davinci-001", "r50k_base");
registry.insert("text-similarity-curie-001", "r50k_base");
registry.insert("text-similarity-babbage-001", "r50k_base");
registry.insert("text-similarity-ada-001", "r50k_base");
registry.insert("text-search-davinci-doc-001", "r50k_base");
registry.insert("text-search-curie-doc-001", "r50k_base");
registry.insert("text-search-babbage-doc-001", "r50k_base");
registry.insert("text-search-ada-doc-001", "r50k_base");
registry.insert("code-search-babbage-code-001", "r50k_base");
registry.insert("code-search-ada-code-001", "r50k_base");
registry.insert("gpt2", "gpt2");
registry.insert("gpt-2", "gpt2");
registry
}
pub fn encoding_name_for_model(model_name: &str) -> Result<String> {
let model_registry = MODEL_REGISTRY.get_or_init(init_model_registry);
let prefix_registry = MODEL_PREFIX_REGISTRY.get_or_init(init_prefix_registry);
if let Some(&encoding_name) = model_registry.get(model_name) {
return Ok(encoding_name.to_string());
}
for (&prefix, &encoding_name) in prefix_registry.iter() {
if model_name.starts_with(prefix) {
return Ok(encoding_name.to_string());
}
}
Err(TiktokenError::UnknownModel(model_name.to_string()))
}
pub fn encoding_for_model(model_name: &str) -> Result<Encoding> {
let encoding_name = encoding_name_for_model(model_name)?;
get_encoding(&encoding_name)
}
pub fn list_supported_models() -> Vec<String> {
let model_registry = MODEL_REGISTRY.get_or_init(init_model_registry);
model_registry.keys().map(|&s| s.to_string()).collect()
}
pub fn is_model_supported(model_name: &str) -> bool {
encoding_name_for_model(model_name).is_ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_model_mapping() {
assert_eq!(encoding_name_for_model("gpt-4").unwrap(), "cl100k_base");
assert_eq!(encoding_name_for_model("gpt-3.5-turbo").unwrap(), "cl100k_base");
assert_eq!(encoding_name_for_model("gpt-4o").unwrap(), "o200k_base");
assert_eq!(encoding_name_for_model("text-davinci-003").unwrap(), "p50k_base");
}
#[test]
fn test_prefix_model_mapping() {
assert_eq!(encoding_name_for_model("gpt-4-0314").unwrap(), "cl100k_base");
assert_eq!(encoding_name_for_model("gpt-4o-2024-05-13").unwrap(), "o200k_base");
assert_eq!(encoding_name_for_model("gpt-3.5-turbo-0301").unwrap(), "cl100k_base");
}
#[test]
fn test_unknown_model() {
assert!(encoding_name_for_model("unknown-model").is_err());
}
#[test]
fn test_encoding_for_model() {
let encoding = encoding_for_model("gpt-4").unwrap();
assert_eq!(encoding.name, "cl100k_base");
}
}