rag_toolchain/common/
embedding_shared.rs

1use serde::{Deserialize, Serialize};
2use tiktoken_rs::tokenizer::Tokenizer;
3use tiktoken_rs::CoreBPE;
4
5// ---------------------- Embedding Models ----------------------
6/// # [`EmbeddingModel`]
7/// This trait is used for methods to understand the requirements
8/// set out by which embedding model is being used such as embedding
9/// dimensions and max tokens
10pub trait EmbeddingModel {
11    fn metadata(&self) -> EmbeddingModelMetadata;
12}
13
14/// # [`EmbeddingModelMetadata`]
15/// Struct to contain all of the relevant metadata for an embedding model
16pub struct EmbeddingModelMetadata {
17    /// The dimension of the vectors produced by the embedding model
18    pub dimensions: usize,
19    /// The maximum amount of tokens that can be sent to the embedding model
20    pub max_tokens: usize,
21    /// The tokenizer that the embedding model uses
22    pub tokenizer: Box<dyn TokenizerWrapper>,
23}
24
25/// # [`TokenizerWrapper`]
26/// We wrap the tokenizer for a specific embedding model to allow
27/// for a common interface for tokenization.
28pub trait TokenizerWrapper {
29    // This should potentially go back to a Result
30    fn tokenize(&self, text: &str) -> Option<Vec<String>>;
31}
32// -------------------------------------------------------------
33
34// ------------------ OpenAI Embedding Models ------------------
35/// # [`OpenAIEmbeddingModel`]
36/// Top level enum to hold all OpenAI embedding model variants.
37#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
38#[serde(rename_all = "snake_case")]
39pub enum OpenAIEmbeddingModel {
40    #[serde(rename = "text-embedding-ada-002")]
41    TextEmbeddingAda002,
42    #[serde(rename = "text-embedding-3-small")]
43    TextEmbedding3Small,
44    #[serde(rename = "text-embedding-3-large")]
45    TextEmbedding3Large,
46}
47
48/// Implementation of the EmbeddingModel trait for OpenAIEmbeddingModel
49/// This just sets out the requirements for each of the OpenAI models.
50/// This can then be used by things such as stores to understand what size
51/// vectors it has to store.
52impl EmbeddingModel for OpenAIEmbeddingModel {
53    fn metadata(&self) -> EmbeddingModelMetadata {
54        match self {
55            OpenAIEmbeddingModel::TextEmbeddingAda002 => EmbeddingModelMetadata {
56                dimensions: 1536,
57                max_tokens: 8192,
58                tokenizer: Box::new(OpenAITokenizer::new(Tokenizer::Cl100kBase)),
59            },
60            OpenAIEmbeddingModel::TextEmbedding3Small => EmbeddingModelMetadata {
61                dimensions: 1536,
62                max_tokens: 8192,
63                tokenizer: Box::new(OpenAITokenizer::new(Tokenizer::Cl100kBase)),
64            },
65            OpenAIEmbeddingModel::TextEmbedding3Large => EmbeddingModelMetadata {
66                dimensions: 3072,
67                max_tokens: 8192,
68                tokenizer: Box::new(OpenAITokenizer::new(Tokenizer::Cl100kBase)),
69            },
70        }
71    }
72}
73
74/// We use the tiktoken_rs library to handle the tokenization for OpenAI models.
75/// So this is the struct will implement [`TokenizerWrapper`] which we can then
76/// use in the rest of the library to tokenize text.
77struct OpenAITokenizer {
78    bpe: CoreBPE,
79}
80
81/// Added new function to hide the unwrap
82// The panic here should be fine as this shouldn't fail as we use an enum variant.
83impl OpenAITokenizer {
84    pub fn new(model: Tokenizer) -> Self {
85        OpenAITokenizer {
86            bpe: tiktoken_rs::get_bpe_from_tokenizer(model).unwrap(),
87        }
88    }
89}
90
91// Implement the tokenize function for OpenAITokenizer
92impl TokenizerWrapper for OpenAITokenizer {
93    fn tokenize(&self, text: &str) -> Option<Vec<String>> {
94        if let Ok(tokens) = self.bpe.split_by_token(text, true) {
95            Some(tokens)
96        } else {
97            None
98        }
99    }
100}
101// ------------------ OpenAI Embedding Models ------------------
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn openai_ada002_metadata() {
109        let metadata: EmbeddingModelMetadata = OpenAIEmbeddingModel::TextEmbeddingAda002.metadata();
110        assert_eq!(metadata.dimensions, 1536);
111        assert_eq!(metadata.max_tokens, 8192);
112    }
113
114    #[test]
115    fn openai_3_small_metadata() {
116        let metadata: EmbeddingModelMetadata = OpenAIEmbeddingModel::TextEmbedding3Small.metadata();
117        assert_eq!(metadata.dimensions, 1536);
118        assert_eq!(metadata.max_tokens, 8192);
119    }
120
121    #[test]
122    fn openai_3_large_metadata() {
123        let metadata: EmbeddingModelMetadata = OpenAIEmbeddingModel::TextEmbedding3Large.metadata();
124        assert_eq!(metadata.dimensions, 3072);
125        assert_eq!(metadata.max_tokens, 8192);
126    }
127}