use serde::{Deserialize, Serialize};
use tiktoken_rs::tokenizer::Tokenizer;
use tiktoken_rs::CoreBPE;
pub trait EmbeddingModel {
fn metadata(&self) -> EmbeddingModelMetadata;
}
pub struct EmbeddingModelMetadata {
pub dimensions: usize,
pub max_tokens: usize,
pub tokenizer: Box<dyn TokenizerWrapper>,
}
pub trait TokenizerWrapper {
fn tokenize(&self, text: &str) -> Option<Vec<String>>;
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum OpenAIEmbeddingModel {
#[serde(rename = "text-embedding-ada-002")]
TextEmbeddingAda002,
#[serde(rename = "text-embedding-3-small")]
TextEmbedding3Small,
#[serde(rename = "text-embedding-3-large")]
TextEmbedding3Large,
}
impl EmbeddingModel for OpenAIEmbeddingModel {
fn metadata(&self) -> EmbeddingModelMetadata {
match self {
OpenAIEmbeddingModel::TextEmbeddingAda002 => EmbeddingModelMetadata {
dimensions: 1536,
max_tokens: 8192,
tokenizer: Box::new(OpenAITokenizer::new(Tokenizer::Cl100kBase)),
},
OpenAIEmbeddingModel::TextEmbedding3Small => EmbeddingModelMetadata {
dimensions: 1536,
max_tokens: 8192,
tokenizer: Box::new(OpenAITokenizer::new(Tokenizer::Cl100kBase)),
},
OpenAIEmbeddingModel::TextEmbedding3Large => EmbeddingModelMetadata {
dimensions: 3072,
max_tokens: 8192,
tokenizer: Box::new(OpenAITokenizer::new(Tokenizer::Cl100kBase)),
},
}
}
}
struct OpenAITokenizer {
bpe: CoreBPE,
}
impl OpenAITokenizer {
pub fn new(model: Tokenizer) -> Self {
OpenAITokenizer {
bpe: tiktoken_rs::get_bpe_from_tokenizer(model).unwrap(),
}
}
}
impl TokenizerWrapper for OpenAITokenizer {
fn tokenize(&self, text: &str) -> Option<Vec<String>> {
if let Ok(tokens) = self.bpe.split_by_token(text, true) {
Some(tokens)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn openai_ada002_metadata() {
let metadata: EmbeddingModelMetadata = OpenAIEmbeddingModel::TextEmbeddingAda002.metadata();
assert_eq!(metadata.dimensions, 1536);
assert_eq!(metadata.max_tokens, 8192);
}
#[test]
fn openai_3_small_metadata() {
let metadata: EmbeddingModelMetadata = OpenAIEmbeddingModel::TextEmbedding3Small.metadata();
assert_eq!(metadata.dimensions, 1536);
assert_eq!(metadata.max_tokens, 8192);
}
#[test]
fn openai_3_large_metadata() {
let metadata: EmbeddingModelMetadata = OpenAIEmbeddingModel::TextEmbedding3Large.metadata();
assert_eq!(metadata.dimensions, 3072);
assert_eq!(metadata.max_tokens, 8192);
}
}