async_openai/
embedding.rs1use crate::{
2 config::Config,
3 error::OpenAIError,
4 types::embeddings::{
5 CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse,
6 },
7 Client, RequestOptions,
8};
9
10#[cfg(not(feature = "byot"))]
11use crate::types::embeddings::EncodingFormat;
12
13pub struct Embeddings<'c, C: Config> {
18 client: &'c Client<C>,
19 pub(crate) request_options: RequestOptions,
20}
21
22impl<'c, C: Config> Embeddings<'c, C> {
23 pub fn new(client: &'c Client<C>) -> Self {
24 Self {
25 client,
26 request_options: RequestOptions::new(),
27 }
28 }
29
30 #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
34 pub async fn create(
35 &self,
36 request: CreateEmbeddingRequest,
37 ) -> Result<CreateEmbeddingResponse, OpenAIError> {
38 #[cfg(not(feature = "byot"))]
39 {
40 if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
41 return Err(OpenAIError::InvalidArgument(
42 "When encoding_format is base64, use Embeddings::create_base64".into(),
43 ));
44 }
45 }
46 self.client
47 .post("/embeddings", request, &self.request_options)
48 .await
49 }
50
51 #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
57 pub async fn create_base64(
58 &self,
59 request: CreateEmbeddingRequest,
60 ) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
61 #[cfg(not(feature = "byot"))]
62 {
63 if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
64 return Err(OpenAIError::InvalidArgument(
65 "When encoding_format is not base64, use Embeddings::create".into(),
66 ));
67 }
68 }
69 self.client
70 .post("/embeddings", request, &self.request_options)
71 .await
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use crate::error::OpenAIError;
78 use crate::types::embeddings::{CreateEmbeddingResponse, Embedding, EncodingFormat};
79 use crate::{types::embeddings::CreateEmbeddingRequestArgs, Client};
80
81 #[tokio::test]
82 async fn test_embedding_string() {
83 let client = Client::new();
84
85 let request = CreateEmbeddingRequestArgs::default()
86 .model("text-embedding-ada-002")
87 .input("The food was delicious and the waiter...")
88 .build()
89 .unwrap();
90
91 let response = client.embeddings().create(request).await;
92
93 assert!(response.is_ok());
94 }
95
96 #[tokio::test]
97 async fn test_embedding_string_array() {
98 let client = Client::new();
99
100 let request = CreateEmbeddingRequestArgs::default()
101 .model("text-embedding-ada-002")
102 .input(["The food was delicious", "The waiter was good"])
103 .build()
104 .unwrap();
105
106 let response = client.embeddings().create(request).await;
107
108 assert!(response.is_ok());
109 }
110
111 #[tokio::test]
112 async fn test_embedding_integer_array() {
113 let client = Client::new();
114
115 let request = CreateEmbeddingRequestArgs::default()
116 .model("text-embedding-ada-002")
117 .input([1, 2, 3])
118 .build()
119 .unwrap();
120
121 let response = client.embeddings().create(request).await;
122
123 assert!(response.is_ok());
124 }
125
126 #[tokio::test]
127 async fn test_embedding_array_of_integer_array_matrix() {
128 let client = Client::new();
129
130 let request = CreateEmbeddingRequestArgs::default()
131 .model("text-embedding-ada-002")
132 .input([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
133 .build()
134 .unwrap();
135
136 let response = client.embeddings().create(request).await;
137
138 assert!(response.is_ok());
139 }
140
141 #[tokio::test]
142 async fn test_embedding_array_of_integer_array() {
143 let client = Client::new();
144
145 let request = CreateEmbeddingRequestArgs::default()
146 .model("text-embedding-ada-002")
147 .input([vec![1, 2, 3], vec![4, 5, 6, 7], vec![7, 8, 10, 11, 100257]])
148 .build()
149 .unwrap();
150
151 let response = client.embeddings().create(request).await;
152
153 assert!(response.is_ok());
154 }
155
156 #[tokio::test]
157 async fn test_embedding_with_reduced_dimensions() {
158 let client = Client::new();
159 let dimensions = 256u32;
160 let request = CreateEmbeddingRequestArgs::default()
161 .model("text-embedding-3-small")
162 .input("The food was delicious and the waiter...")
163 .dimensions(dimensions)
164 .build()
165 .unwrap();
166
167 let response = client.embeddings().create(request).await;
168
169 assert!(response.is_ok());
170
171 let CreateEmbeddingResponse { mut data, .. } = response.unwrap();
172 assert_eq!(data.len(), 1);
173 let Embedding { embedding, .. } = data.pop().unwrap();
174 assert_eq!(embedding.len(), dimensions as usize);
175 }
176
177 #[tokio::test]
178 async fn test_cannot_use_base64_encoding_with_normal_create_request() {
179 let client = Client::new();
180
181 const MODEL: &str = "text-embedding-ada-002";
182 const INPUT: &str = "You shall not pass.";
183
184 let b64_request = CreateEmbeddingRequestArgs::default()
185 .model(MODEL)
186 .input(INPUT)
187 .encoding_format(EncodingFormat::Base64)
188 .build()
189 .unwrap();
190 let b64_response = client.embeddings().create(b64_request).await;
191 assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
192 }
193
194 #[tokio::test]
195 async fn test_embedding_create_base64() {
196 let client = Client::new();
197
198 const MODEL: &str = "text-embedding-ada-002";
199 const INPUT: &str = "a head full of dreams";
200
201 let b64_request = CreateEmbeddingRequestArgs::default()
202 .model(MODEL)
203 .input(INPUT)
204 .encoding_format(EncodingFormat::Base64)
205 .build()
206 .unwrap();
207 let b64_response = client
208 .embeddings()
209 .create_base64(b64_request)
210 .await
211 .unwrap();
212 let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding;
213 let b64_embedding: Vec<f32> = b64_embedding.into();
214
215 let request = CreateEmbeddingRequestArgs::default()
216 .model(MODEL)
217 .input(INPUT)
218 .build()
219 .unwrap();
220 let response = client.embeddings().create(request).await.unwrap();
221 let embedding = response.data.into_iter().next().unwrap().embedding;
222
223 assert_eq!(b64_embedding.len(), embedding.len());
224 for (b64, normal) in b64_embedding.iter().zip(embedding.iter()) {
225 assert!((b64 - normal).abs() < 1e-6);
226 }
227 }
228}