async_openai_alt/
embedding.rs1use crate::{
2 config::Config,
3 error::OpenAIError,
4 types::{
5 CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse,
6 EncodingFormat,
7 },
8 Client,
9};
10
11pub struct Embeddings<'c, C: Config> {
16 client: &'c Client<C>,
17}
18
19impl<'c, C: Config> Embeddings<'c, C> {
20 pub fn new(client: &'c Client<C>) -> Self {
21 Self { client }
22 }
23
24 pub async fn create(
26 &self,
27 request: CreateEmbeddingRequest,
28 ) -> Result<CreateEmbeddingResponse, OpenAIError> {
29 if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
30 return Err(OpenAIError::InvalidArgument(
31 "When encoding_format is base64, use Embeddings::create_base64".into(),
32 ));
33 }
34 self.client.post("/embeddings", request).await
35 }
36
37 pub async fn create_base64(
41 &self,
42 request: CreateEmbeddingRequest,
43 ) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
44 if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
45 return Err(OpenAIError::InvalidArgument(
46 "When encoding_format is not base64, use Embeddings::create".into(),
47 ));
48 }
49
50 self.client.post("/embeddings", request).await
51 }
52}
53
54#[cfg(test)]
55mod tests {
56 use crate::error::OpenAIError;
57 use crate::types::{CreateEmbeddingResponse, Embedding, EncodingFormat};
58 use crate::{types::CreateEmbeddingRequestArgs, Client};
59
60 #[tokio::test]
61 async fn test_embedding_string() {
62 let client = Client::new();
63
64 let request = CreateEmbeddingRequestArgs::default()
65 .model("text-embedding-ada-002")
66 .input("The food was delicious and the waiter...")
67 .build()
68 .unwrap();
69
70 let response = client.embeddings().create(request).await;
71
72 assert!(response.is_ok());
73 }
74
75 #[tokio::test]
76 async fn test_embedding_string_array() {
77 let client = Client::new();
78
79 let request = CreateEmbeddingRequestArgs::default()
80 .model("text-embedding-ada-002")
81 .input(["The food was delicious", "The waiter was good"])
82 .build()
83 .unwrap();
84
85 let response = client.embeddings().create(request).await;
86
87 assert!(response.is_ok());
88 }
89
90 #[tokio::test]
91 async fn test_embedding_integer_array() {
92 let client = Client::new();
93
94 let request = CreateEmbeddingRequestArgs::default()
95 .model("text-embedding-ada-002")
96 .input([1, 2, 3])
97 .build()
98 .unwrap();
99
100 let response = client.embeddings().create(request).await;
101
102 assert!(response.is_ok());
103 }
104
105 #[tokio::test]
106 async fn test_embedding_array_of_integer_array_matrix() {
107 let client = Client::new();
108
109 let request = CreateEmbeddingRequestArgs::default()
110 .model("text-embedding-ada-002")
111 .input([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
112 .build()
113 .unwrap();
114
115 let response = client.embeddings().create(request).await;
116
117 assert!(response.is_ok());
118 }
119
120 #[tokio::test]
121 async fn test_embedding_array_of_integer_array() {
122 let client = Client::new();
123
124 let request = CreateEmbeddingRequestArgs::default()
125 .model("text-embedding-ada-002")
126 .input([vec![1, 2, 3], vec![4, 5, 6, 7], vec![7, 8, 10, 11, 100257]])
127 .build()
128 .unwrap();
129
130 let response = client.embeddings().create(request).await;
131
132 assert!(response.is_ok());
133 }
134
135 #[tokio::test]
136 async fn test_embedding_with_reduced_dimensions() {
137 let client = Client::new();
138 let dimensions = 256u32;
139 let request = CreateEmbeddingRequestArgs::default()
140 .model("text-embedding-3-small")
141 .input("The food was delicious and the waiter...")
142 .dimensions(dimensions)
143 .build()
144 .unwrap();
145
146 let response = client.embeddings().create(request).await;
147
148 assert!(response.is_ok());
149
150 let CreateEmbeddingResponse { mut data, .. } = response.unwrap();
151 assert_eq!(data.len(), 1);
152 let Embedding { embedding, .. } = data.pop().unwrap();
153 assert_eq!(embedding.len(), dimensions as usize);
154 }
155
156 #[tokio::test]
157 async fn test_cannot_use_base64_encoding_with_normal_create_request() {
158 let client = Client::new();
159
160 const MODEL: &str = "text-embedding-ada-002";
161 const INPUT: &str = "You shall not pass.";
162
163 let b64_request = CreateEmbeddingRequestArgs::default()
164 .model(MODEL)
165 .input(INPUT)
166 .encoding_format(EncodingFormat::Base64)
167 .build()
168 .unwrap();
169 let b64_response = client.embeddings().create(b64_request).await;
170 assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
171 }
172
173 #[tokio::test]
174 async fn test_embedding_create_base64() {
175 let client = Client::new();
176
177 const MODEL: &str = "text-embedding-ada-002";
178 const INPUT: &str = "CoLoop will eat the other qual research tools...";
179
180 let b64_request = CreateEmbeddingRequestArgs::default()
181 .model(MODEL)
182 .input(INPUT)
183 .encoding_format(EncodingFormat::Base64)
184 .build()
185 .unwrap();
186 let b64_response = client
187 .embeddings()
188 .create_base64(b64_request)
189 .await
190 .unwrap();
191 let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding;
192 let b64_embedding: Vec<f32> = b64_embedding.into();
193
194 let request = CreateEmbeddingRequestArgs::default()
195 .model(MODEL)
196 .input(INPUT)
197 .build()
198 .unwrap();
199 let response = client.embeddings().create(request).await.unwrap();
200 let embedding = response.data.into_iter().next().unwrap().embedding;
201
202 assert_eq!(b64_embedding.len(), embedding.len());
203 for (b64, normal) in b64_embedding.iter().zip(embedding.iter()) {
204 assert!((b64 - normal).abs() < 1e-6);
205 }
206 }
207}