async_openai/
embedding.rs1use crate::{
2 config::Config,
3 error::OpenAIError,
4 types::embeddings::{
5 CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse,
6 },
7 Client,
8};
9
10#[cfg(not(feature = "byot"))]
11use crate::types::embeddings::EncodingFormat;
12
13pub struct Embeddings<'c, C: Config> {
18 client: &'c Client<C>,
19}
20
21impl<'c, C: Config> Embeddings<'c, C> {
22 pub fn new(client: &'c Client<C>) -> Self {
23 Self { client }
24 }
25
26 #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
30 pub async fn create(
31 &self,
32 request: CreateEmbeddingRequest,
33 ) -> Result<CreateEmbeddingResponse, OpenAIError> {
34 #[cfg(not(feature = "byot"))]
35 {
36 if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
37 return Err(OpenAIError::InvalidArgument(
38 "When encoding_format is base64, use Embeddings::create_base64".into(),
39 ));
40 }
41 }
42 self.client.post("/embeddings", request).await
43 }
44
45 #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
51 pub async fn create_base64(
52 &self,
53 request: CreateEmbeddingRequest,
54 ) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
55 #[cfg(not(feature = "byot"))]
56 {
57 if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
58 return Err(OpenAIError::InvalidArgument(
59 "When encoding_format is not base64, use Embeddings::create".into(),
60 ));
61 }
62 }
63 self.client.post("/embeddings", request).await
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use crate::error::OpenAIError;
70 use crate::types::embeddings::{CreateEmbeddingResponse, Embedding, EncodingFormat};
71 use crate::{types::embeddings::CreateEmbeddingRequestArgs, Client};
72
73 #[tokio::test]
74 async fn test_embedding_string() {
75 let client = Client::new();
76
77 let request = CreateEmbeddingRequestArgs::default()
78 .model("text-embedding-ada-002")
79 .input("The food was delicious and the waiter...")
80 .build()
81 .unwrap();
82
83 let response = client.embeddings().create(request).await;
84
85 assert!(response.is_ok());
86 }
87
88 #[tokio::test]
89 async fn test_embedding_string_array() {
90 let client = Client::new();
91
92 let request = CreateEmbeddingRequestArgs::default()
93 .model("text-embedding-ada-002")
94 .input(["The food was delicious", "The waiter was good"])
95 .build()
96 .unwrap();
97
98 let response = client.embeddings().create(request).await;
99
100 assert!(response.is_ok());
101 }
102
103 #[tokio::test]
104 async fn test_embedding_integer_array() {
105 let client = Client::new();
106
107 let request = CreateEmbeddingRequestArgs::default()
108 .model("text-embedding-ada-002")
109 .input([1, 2, 3])
110 .build()
111 .unwrap();
112
113 let response = client.embeddings().create(request).await;
114
115 assert!(response.is_ok());
116 }
117
118 #[tokio::test]
119 async fn test_embedding_array_of_integer_array_matrix() {
120 let client = Client::new();
121
122 let request = CreateEmbeddingRequestArgs::default()
123 .model("text-embedding-ada-002")
124 .input([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
125 .build()
126 .unwrap();
127
128 let response = client.embeddings().create(request).await;
129
130 assert!(response.is_ok());
131 }
132
133 #[tokio::test]
134 async fn test_embedding_array_of_integer_array() {
135 let client = Client::new();
136
137 let request = CreateEmbeddingRequestArgs::default()
138 .model("text-embedding-ada-002")
139 .input([vec![1, 2, 3], vec![4, 5, 6, 7], vec![7, 8, 10, 11, 100257]])
140 .build()
141 .unwrap();
142
143 let response = client.embeddings().create(request).await;
144
145 assert!(response.is_ok());
146 }
147
148 #[tokio::test]
149 async fn test_embedding_with_reduced_dimensions() {
150 let client = Client::new();
151 let dimensions = 256u32;
152 let request = CreateEmbeddingRequestArgs::default()
153 .model("text-embedding-3-small")
154 .input("The food was delicious and the waiter...")
155 .dimensions(dimensions)
156 .build()
157 .unwrap();
158
159 let response = client.embeddings().create(request).await;
160
161 assert!(response.is_ok());
162
163 let CreateEmbeddingResponse { mut data, .. } = response.unwrap();
164 assert_eq!(data.len(), 1);
165 let Embedding { embedding, .. } = data.pop().unwrap();
166 assert_eq!(embedding.len(), dimensions as usize);
167 }
168
169 #[tokio::test]
170 async fn test_cannot_use_base64_encoding_with_normal_create_request() {
171 let client = Client::new();
172
173 const MODEL: &str = "text-embedding-ada-002";
174 const INPUT: &str = "You shall not pass.";
175
176 let b64_request = CreateEmbeddingRequestArgs::default()
177 .model(MODEL)
178 .input(INPUT)
179 .encoding_format(EncodingFormat::Base64)
180 .build()
181 .unwrap();
182 let b64_response = client.embeddings().create(b64_request).await;
183 assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
184 }
185
186 #[tokio::test]
187 async fn test_embedding_create_base64() {
188 let client = Client::new();
189
190 const MODEL: &str = "text-embedding-ada-002";
191 const INPUT: &str = "a head full of dreams";
192
193 let b64_request = CreateEmbeddingRequestArgs::default()
194 .model(MODEL)
195 .input(INPUT)
196 .encoding_format(EncodingFormat::Base64)
197 .build()
198 .unwrap();
199 let b64_response = client
200 .embeddings()
201 .create_base64(b64_request)
202 .await
203 .unwrap();
204 let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding;
205 let b64_embedding: Vec<f32> = b64_embedding.into();
206
207 let request = CreateEmbeddingRequestArgs::default()
208 .model(MODEL)
209 .input(INPUT)
210 .build()
211 .unwrap();
212 let response = client.embeddings().create(request).await.unwrap();
213 let embedding = response.data.into_iter().next().unwrap().embedding;
214
215 assert_eq!(b64_embedding.len(), embedding.len());
216 for (b64, normal) in b64_embedding.iter().zip(embedding.iter()) {
217 assert!((b64 - normal).abs() < 1e-6);
218 }
219 }
220}