async_openai_wasm/
embedding.rs1use crate::{
2 Client, RequestOptions,
3 config::Config,
4 error::OpenAIError,
5 types::embeddings::{
6 CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse,
7 },
8};
9
10#[cfg(not(feature = "byot"))]
11use crate::types::embeddings::EncodingFormat;
12
13pub struct Embeddings<'c, C: Config> {
18 client: &'c Client<C>,
19 pub(crate) request_options: RequestOptions,
20}
21
22impl<'c, C: Config> Embeddings<'c, C> {
23 pub fn new(client: &'c Client<C>) -> Self {
24 Self {
25 client,
26 request_options: RequestOptions::new(),
27 }
28 }
29
30 #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
34 pub async fn create(
35 &self,
36 request: CreateEmbeddingRequest,
37 ) -> Result<CreateEmbeddingResponse, OpenAIError> {
38 #[cfg(not(feature = "byot"))]
39 {
40 if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
41 return Err(OpenAIError::InvalidArgument(
42 "When encoding_format is base64, use Embeddings::create_base64".into(),
43 ));
44 }
45 }
46 self.client
47 .post("/embeddings", request, &self.request_options)
48 .await
49 }
50
51 #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
57 pub async fn create_base64(
58 &self,
59 request: CreateEmbeddingRequest,
60 ) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
61 #[cfg(not(feature = "byot"))]
62 {
63 if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
64 return Err(OpenAIError::InvalidArgument(
65 "When encoding_format is not base64, use Embeddings::create".into(),
66 ));
67 }
68 }
69 self.client
70 .post("/embeddings", request, &self.request_options)
71 .await
72 }
73}
74
75#[cfg(all(test, feature = "embedding"))]
76mod tests {
77 use crate::types::embeddings::{CreateEmbeddingResponse, Embedding, EncodingFormat};
78 use crate::{Client, types::embeddings::CreateEmbeddingRequestArgs};
79
80 #[tokio::test]
81 async fn test_embedding_string() {
82 let client = Client::new();
83
84 let request = CreateEmbeddingRequestArgs::default()
85 .model("text-embedding-3-small")
86 .input("The food was delicious and the waiter...")
87 .build()
88 .unwrap();
89
90 let response = client.embeddings().create(request).await;
91
92 assert!(response.is_ok());
93 }
94
95 #[tokio::test]
96 async fn test_embedding_string_array() {
97 let client = Client::new();
98
99 let request = CreateEmbeddingRequestArgs::default()
100 .model("text-embedding-3-small")
101 .input(["The food was delicious", "The waiter was good"])
102 .build()
103 .unwrap();
104
105 let response = client.embeddings().create(request).await;
106
107 assert!(response.is_ok());
108 }
109
110 #[tokio::test]
111 async fn test_embedding_integer_array() {
112 let client = Client::new();
113
114 let request = CreateEmbeddingRequestArgs::default()
115 .model("text-embedding-3-small")
116 .input([1, 2, 3])
117 .build()
118 .unwrap();
119
120 let response = client.embeddings().create(request).await;
121
122 assert!(response.is_ok());
123 }
124
125 #[tokio::test]
126 async fn test_embedding_array_of_integer_array_matrix() {
127 let client = Client::new();
128
129 let request = CreateEmbeddingRequestArgs::default()
130 .model("text-embedding-3-small")
131 .input([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
132 .build()
133 .unwrap();
134
135 let response = client.embeddings().create(request).await;
136
137 assert!(response.is_ok());
138 }
139
140 #[tokio::test]
141 async fn test_embedding_array_of_integer_array() {
142 let client = Client::new();
143
144 let request = CreateEmbeddingRequestArgs::default()
145 .model("text-embedding-3-small")
146 .input([vec![1, 2, 3], vec![4, 5, 6, 7], vec![7, 8, 10, 11, 100257]])
147 .build()
148 .unwrap();
149
150 let response = client.embeddings().create(request).await;
151
152 assert!(response.is_ok());
153 }
154
155 #[tokio::test]
156 async fn test_embedding_with_reduced_dimensions() {
157 let client = Client::new();
158 let dimensions = 256u32;
159 let request = CreateEmbeddingRequestArgs::default()
160 .model("text-embedding-3-small")
161 .input("The food was delicious and the waiter...")
162 .dimensions(dimensions)
163 .build()
164 .unwrap();
165
166 let response = client.embeddings().create(request).await;
167
168 assert!(response.is_ok());
169
170 let CreateEmbeddingResponse { mut data, .. } = response.unwrap();
171 assert_eq!(data.len(), 1);
172 let Embedding { embedding, .. } = data.pop().unwrap();
173 assert_eq!(embedding.len(), dimensions as usize);
174 }
175
176 #[tokio::test]
177 async fn test_cannot_use_base64_encoding_with_normal_create_request() {
178 use crate::error::OpenAIError;
179 let client = Client::new();
180
181 const MODEL: &str = "text-embedding-3-small";
182 const INPUT: &str = "You shall not pass.";
183
184 let b64_request = CreateEmbeddingRequestArgs::default()
185 .model(MODEL)
186 .input(INPUT)
187 .encoding_format(EncodingFormat::Base64)
188 .build()
189 .unwrap();
190 let b64_response = client.embeddings().create(b64_request).await;
191 assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
192 }
193
194 #[tokio::test]
195 async fn test_embedding_create_base64() {
196 let client = Client::new();
197
198 const MODEL: &str = "text-embedding-3-small";
199 const INPUT: &str = "a head full of dreams";
200
201 let b64_request = CreateEmbeddingRequestArgs::default()
202 .model(MODEL)
203 .input(INPUT)
204 .encoding_format(EncodingFormat::Base64)
205 .build()
206 .unwrap();
207 let b64_response = client
208 .embeddings()
209 .create_base64(b64_request)
210 .await
211 .unwrap();
212 let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding;
213 let b64_embedding: Vec<f32> = b64_embedding.into();
214
215 let request = CreateEmbeddingRequestArgs::default()
216 .model(MODEL)
217 .input(INPUT)
218 .build()
219 .unwrap();
220 let response = client.embeddings().create(request).await.unwrap();
221 let embedding = response.data.into_iter().next().unwrap().embedding;
222
223 assert_eq!(b64_embedding.len(), embedding.len());
224 }
225}