async_openai/
embedding.rs1use crate::{
2 config::Config,
3 error::OpenAIError,
4 types::{CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse},
5 Client,
6};
7
8#[cfg(not(feature = "byot"))]
9use crate::types::EncodingFormat;
10
11pub struct Embeddings<'c, C: Config> {
16 client: &'c Client<C>,
17}
18
19impl<'c, C: Config> Embeddings<'c, C> {
20 pub fn new(client: &'c Client<C>) -> Self {
21 Self { client }
22 }
23
24 #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
28 pub async fn create(
29 &self,
30 request: CreateEmbeddingRequest,
31 ) -> Result<CreateEmbeddingResponse, OpenAIError> {
32 #[cfg(not(feature = "byot"))]
33 {
34 if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
35 return Err(OpenAIError::InvalidArgument(
36 "When encoding_format is base64, use Embeddings::create_base64".into(),
37 ));
38 }
39 }
40 self.client.post("/embeddings", request).await
41 }
42
43 #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
49 pub async fn create_base64(
50 &self,
51 request: CreateEmbeddingRequest,
52 ) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
53 #[cfg(not(feature = "byot"))]
54 {
55 if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
56 return Err(OpenAIError::InvalidArgument(
57 "When encoding_format is not base64, use Embeddings::create".into(),
58 ));
59 }
60 }
61 self.client.post("/embeddings", request).await
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use crate::types::{CreateEmbeddingResponse, Embedding, EncodingFormat};
68 use crate::{types::CreateEmbeddingRequestArgs, Client};
69
70 #[tokio::test]
71 async fn test_embedding_string() {
72 let client = Client::new();
73
74 let request = CreateEmbeddingRequestArgs::default()
75 .model("text-embedding-ada-002")
76 .input("The food was delicious and the waiter...")
77 .build()
78 .unwrap();
79
80 let response = client.embeddings().create(request).await;
81
82 assert!(response.is_ok());
83 }
84
85 #[tokio::test]
86 async fn test_embedding_string_array() {
87 let client = Client::new();
88
89 let request = CreateEmbeddingRequestArgs::default()
90 .model("text-embedding-ada-002")
91 .input(["The food was delicious", "The waiter was good"])
92 .build()
93 .unwrap();
94
95 let response = client.embeddings().create(request).await;
96
97 assert!(response.is_ok());
98 }
99
100 #[tokio::test]
101 async fn test_embedding_integer_array() {
102 let client = Client::new();
103
104 let request = CreateEmbeddingRequestArgs::default()
105 .model("text-embedding-ada-002")
106 .input([1, 2, 3])
107 .build()
108 .unwrap();
109
110 let response = client.embeddings().create(request).await;
111
112 assert!(response.is_ok());
113 }
114
115 #[tokio::test]
116 async fn test_embedding_array_of_integer_array_matrix() {
117 let client = Client::new();
118
119 let request = CreateEmbeddingRequestArgs::default()
120 .model("text-embedding-ada-002")
121 .input([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
122 .build()
123 .unwrap();
124
125 let response = client.embeddings().create(request).await;
126
127 assert!(response.is_ok());
128 }
129
130 #[tokio::test]
131 async fn test_embedding_array_of_integer_array() {
132 let client = Client::new();
133
134 let request = CreateEmbeddingRequestArgs::default()
135 .model("text-embedding-ada-002")
136 .input([vec![1, 2, 3], vec![4, 5, 6, 7], vec![7, 8, 10, 11, 100257]])
137 .build()
138 .unwrap();
139
140 let response = client.embeddings().create(request).await;
141
142 assert!(response.is_ok());
143 }
144
145 #[tokio::test]
146 async fn test_embedding_with_reduced_dimensions() {
147 let client = Client::new();
148 let dimensions = 256u32;
149 let request = CreateEmbeddingRequestArgs::default()
150 .model("text-embedding-3-small")
151 .input("The food was delicious and the waiter...")
152 .dimensions(dimensions)
153 .build()
154 .unwrap();
155
156 let response = client.embeddings().create(request).await;
157
158 assert!(response.is_ok());
159
160 let CreateEmbeddingResponse { mut data, .. } = response.unwrap();
161 assert_eq!(data.len(), 1);
162 let Embedding { embedding, .. } = data.pop().unwrap();
163 assert_eq!(embedding.len(), dimensions as usize);
164 }
165
166 #[tokio::test]
167 #[cfg(not(feature = "byot"))]
168 async fn test_cannot_use_base64_encoding_with_normal_create_request() {
169 let client = Client::new();
170
171 const MODEL: &str = "text-embedding-ada-002";
172 const INPUT: &str = "You shall not pass.";
173
174 let b64_request = CreateEmbeddingRequestArgs::default()
175 .model(MODEL)
176 .input(INPUT)
177 .encoding_format(EncodingFormat::Base64)
178 .build()
179 .unwrap();
180 let b64_response = client.embeddings().create(b64_request).await;
181 assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
182 }
183
184 #[tokio::test]
185 async fn test_embedding_create_base64() {
186 let client = Client::new();
187
188 const MODEL: &str = "text-embedding-ada-002";
189 const INPUT: &str = "CoLoop will eat the other qual research tools...";
190
191 let b64_request = CreateEmbeddingRequestArgs::default()
192 .model(MODEL)
193 .input(INPUT)
194 .encoding_format(EncodingFormat::Base64)
195 .build()
196 .unwrap();
197 let b64_response = client
198 .embeddings()
199 .create_base64(b64_request)
200 .await
201 .unwrap();
202 let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding;
203 let b64_embedding: Vec<f32> = b64_embedding.into();
204
205 let request = CreateEmbeddingRequestArgs::default()
206 .model(MODEL)
207 .input(INPUT)
208 .build()
209 .unwrap();
210 let response = client.embeddings().create(request).await.unwrap();
211 let embedding = response.data.into_iter().next().unwrap().embedding;
212
213 assert_eq!(b64_embedding.len(), embedding.len());
214 for (b64, normal) in b64_embedding.iter().zip(embedding.iter()) {
215 assert!((b64 - normal).abs() < 1e-6);
216 }
217 }
218}