rag_toolchain/common/
embedding_shared.rs1use serde::{Deserialize, Serialize};
2use tiktoken_rs::tokenizer::Tokenizer;
3use tiktoken_rs::CoreBPE;
4
5pub trait EmbeddingModel {
11 fn metadata(&self) -> EmbeddingModelMetadata;
12}
13
14pub struct EmbeddingModelMetadata {
17 pub dimensions: usize,
19 pub max_tokens: usize,
21 pub tokenizer: Box<dyn TokenizerWrapper>,
23}
24
25pub trait TokenizerWrapper {
29 fn tokenize(&self, text: &str) -> Option<Vec<String>>;
31}
32#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
38#[serde(rename_all = "snake_case")]
39pub enum OpenAIEmbeddingModel {
40 #[serde(rename = "text-embedding-ada-002")]
41 TextEmbeddingAda002,
42 #[serde(rename = "text-embedding-3-small")]
43 TextEmbedding3Small,
44 #[serde(rename = "text-embedding-3-large")]
45 TextEmbedding3Large,
46}
47
48impl EmbeddingModel for OpenAIEmbeddingModel {
53 fn metadata(&self) -> EmbeddingModelMetadata {
54 match self {
55 OpenAIEmbeddingModel::TextEmbeddingAda002 => EmbeddingModelMetadata {
56 dimensions: 1536,
57 max_tokens: 8192,
58 tokenizer: Box::new(OpenAITokenizer::new(Tokenizer::Cl100kBase)),
59 },
60 OpenAIEmbeddingModel::TextEmbedding3Small => EmbeddingModelMetadata {
61 dimensions: 1536,
62 max_tokens: 8192,
63 tokenizer: Box::new(OpenAITokenizer::new(Tokenizer::Cl100kBase)),
64 },
65 OpenAIEmbeddingModel::TextEmbedding3Large => EmbeddingModelMetadata {
66 dimensions: 3072,
67 max_tokens: 8192,
68 tokenizer: Box::new(OpenAITokenizer::new(Tokenizer::Cl100kBase)),
69 },
70 }
71 }
72}
73
74struct OpenAITokenizer {
78 bpe: CoreBPE,
79}
80
81impl OpenAITokenizer {
84 pub fn new(model: Tokenizer) -> Self {
85 OpenAITokenizer {
86 bpe: tiktoken_rs::get_bpe_from_tokenizer(model).unwrap(),
87 }
88 }
89}
90
91impl TokenizerWrapper for OpenAITokenizer {
93 fn tokenize(&self, text: &str) -> Option<Vec<String>> {
94 if let Ok(tokens) = self.bpe.split_by_token(text, true) {
95 Some(tokens)
96 } else {
97 None
98 }
99 }
100}
101#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn openai_ada002_metadata() {
109 let metadata: EmbeddingModelMetadata = OpenAIEmbeddingModel::TextEmbeddingAda002.metadata();
110 assert_eq!(metadata.dimensions, 1536);
111 assert_eq!(metadata.max_tokens, 8192);
112 }
113
114 #[test]
115 fn openai_3_small_metadata() {
116 let metadata: EmbeddingModelMetadata = OpenAIEmbeddingModel::TextEmbedding3Small.metadata();
117 assert_eq!(metadata.dimensions, 1536);
118 assert_eq!(metadata.max_tokens, 8192);
119 }
120
121 #[test]
122 fn openai_3_large_metadata() {
123 let metadata: EmbeddingModelMetadata = OpenAIEmbeddingModel::TextEmbedding3Large.metadata();
124 assert_eq!(metadata.dimensions, 3072);
125 assert_eq!(metadata.max_tokens, 8192);
126 }
127}