async_openai/
embedding.rs1use crate::{
2 config::Config,
3 error::OpenAIError,
4 types::{CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse},
5 Client,
6};
7
8#[cfg(not(feature = "byot"))]
9use crate::types::EncodingFormat;
10
11pub struct Embeddings<'c, C: Config> {
16 client: &'c Client<C>,
17}
18
19impl<'c, C: Config> Embeddings<'c, C> {
20 pub fn new(client: &'c Client<C>) -> Self {
21 Self { client }
22 }
23
24 #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
28 pub async fn create(
29 &self,
30 request: CreateEmbeddingRequest,
31 ) -> Result<CreateEmbeddingResponse, OpenAIError> {
32 #[cfg(not(feature = "byot"))]
33 {
34 if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
35 return Err(OpenAIError::InvalidArgument(
36 "When encoding_format is base64, use Embeddings::create_base64".into(),
37 ));
38 }
39 }
40 self.client.post("/embeddings", request).await
41 }
42
43 #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
49 pub async fn create_base64(
50 &self,
51 request: CreateEmbeddingRequest,
52 ) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
53 #[cfg(not(feature = "byot"))]
54 {
55 if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
56 return Err(OpenAIError::InvalidArgument(
57 "When encoding_format is not base64, use Embeddings::create".into(),
58 ));
59 }
60 }
61 self.client.post("/embeddings", request).await
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use crate::error::OpenAIError;
68 use crate::types::{CreateEmbeddingResponse, Embedding, EncodingFormat};
69 use crate::{types::CreateEmbeddingRequestArgs, Client};
70
71 #[tokio::test]
72 async fn test_embedding_string() {
73 let client = Client::new();
74
75 let request = CreateEmbeddingRequestArgs::default()
76 .model("text-embedding-ada-002")
77 .input("The food was delicious and the waiter...")
78 .build()
79 .unwrap();
80
81 let response = client.embeddings().create(request).await;
82
83 assert!(response.is_ok());
84 }
85
86 #[tokio::test]
87 async fn test_embedding_string_array() {
88 let client = Client::new();
89
90 let request = CreateEmbeddingRequestArgs::default()
91 .model("text-embedding-ada-002")
92 .input(["The food was delicious", "The waiter was good"])
93 .build()
94 .unwrap();
95
96 let response = client.embeddings().create(request).await;
97
98 assert!(response.is_ok());
99 }
100
101 #[tokio::test]
102 async fn test_embedding_integer_array() {
103 let client = Client::new();
104
105 let request = CreateEmbeddingRequestArgs::default()
106 .model("text-embedding-ada-002")
107 .input([1, 2, 3])
108 .build()
109 .unwrap();
110
111 let response = client.embeddings().create(request).await;
112
113 assert!(response.is_ok());
114 }
115
116 #[tokio::test]
117 async fn test_embedding_array_of_integer_array_matrix() {
118 let client = Client::new();
119
120 let request = CreateEmbeddingRequestArgs::default()
121 .model("text-embedding-ada-002")
122 .input([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
123 .build()
124 .unwrap();
125
126 let response = client.embeddings().create(request).await;
127
128 assert!(response.is_ok());
129 }
130
131 #[tokio::test]
132 async fn test_embedding_array_of_integer_array() {
133 let client = Client::new();
134
135 let request = CreateEmbeddingRequestArgs::default()
136 .model("text-embedding-ada-002")
137 .input([vec![1, 2, 3], vec![4, 5, 6, 7], vec![7, 8, 10, 11, 100257]])
138 .build()
139 .unwrap();
140
141 let response = client.embeddings().create(request).await;
142
143 assert!(response.is_ok());
144 }
145
146 #[tokio::test]
147 async fn test_embedding_with_reduced_dimensions() {
148 let client = Client::new();
149 let dimensions = 256u32;
150 let request = CreateEmbeddingRequestArgs::default()
151 .model("text-embedding-3-small")
152 .input("The food was delicious and the waiter...")
153 .dimensions(dimensions)
154 .build()
155 .unwrap();
156
157 let response = client.embeddings().create(request).await;
158
159 assert!(response.is_ok());
160
161 let CreateEmbeddingResponse { mut data, .. } = response.unwrap();
162 assert_eq!(data.len(), 1);
163 let Embedding { embedding, .. } = data.pop().unwrap();
164 assert_eq!(embedding.len(), dimensions as usize);
165 }
166
167 #[tokio::test]
168 #[cfg(not(feature = "byot"))]
169 async fn test_cannot_use_base64_encoding_with_normal_create_request() {
170 let client = Client::new();
171
172 const MODEL: &str = "text-embedding-ada-002";
173 const INPUT: &str = "You shall not pass.";
174
175 let b64_request = CreateEmbeddingRequestArgs::default()
176 .model(MODEL)
177 .input(INPUT)
178 .encoding_format(EncodingFormat::Base64)
179 .build()
180 .unwrap();
181 let b64_response = client.embeddings().create(b64_request).await;
182 assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
183 }
184
185 #[tokio::test]
186 async fn test_embedding_create_base64() {
187 let client = Client::new();
188
189 const MODEL: &str = "text-embedding-ada-002";
190 const INPUT: &str = "CoLoop will eat the other qual research tools...";
191
192 let b64_request = CreateEmbeddingRequestArgs::default()
193 .model(MODEL)
194 .input(INPUT)
195 .encoding_format(EncodingFormat::Base64)
196 .build()
197 .unwrap();
198 let b64_response = client
199 .embeddings()
200 .create_base64(b64_request)
201 .await
202 .unwrap();
203 let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding;
204 let b64_embedding: Vec<f32> = b64_embedding.into();
205
206 let request = CreateEmbeddingRequestArgs::default()
207 .model(MODEL)
208 .input(INPUT)
209 .build()
210 .unwrap();
211 let response = client.embeddings().create(request).await.unwrap();
212 let embedding = response.data.into_iter().next().unwrap().embedding;
213
214 assert_eq!(b64_embedding.len(), embedding.len());
215 for (b64, normal) in b64_embedding.iter().zip(embedding.iter()) {
216 assert!((b64 - normal).abs() < 1e-6);
217 }
218 }
219}