async_openai/
embedding.rs

1use crate::{
2    config::Config,
3    error::OpenAIError,
4    types::embeddings::{
5        CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse,
6    },
7    Client, RequestOptions,
8};
9
10#[cfg(not(feature = "byot"))]
11use crate::types::embeddings::EncodingFormat;
12
13/// Get a vector representation of a given input that can be easily
14/// consumed by machine learning models and algorithms.
15///
16/// Related guide: [Embeddings](https://platform.openai.com/docs/guides/embeddings)
17pub struct Embeddings<'c, C: Config> {
18    client: &'c Client<C>,
19    pub(crate) request_options: RequestOptions,
20}
21
22impl<'c, C: Config> Embeddings<'c, C> {
23    pub fn new(client: &'c Client<C>) -> Self {
24        Self {
25            client,
26            request_options: RequestOptions::new(),
27        }
28    }
29
30    /// Creates an embedding vector representing the input text.
31    ///
32    /// byot: In serialized `request` you must ensure "encoding_format" is not "base64"
33    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
34    pub async fn create(
35        &self,
36        request: CreateEmbeddingRequest,
37    ) -> Result<CreateEmbeddingResponse, OpenAIError> {
38        #[cfg(not(feature = "byot"))]
39        {
40            if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
41                return Err(OpenAIError::InvalidArgument(
42                    "When encoding_format is base64, use Embeddings::create_base64".into(),
43                ));
44            }
45        }
46        self.client
47            .post("/embeddings", request, &self.request_options)
48            .await
49    }
50
51    /// Creates an embedding vector representing the input text.
52    ///
53    /// The response will contain the embedding in base64 format.
54    ///
55    /// byot: In serialized `request` you must ensure "encoding_format" is "base64"
56    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
57    pub async fn create_base64(
58        &self,
59        request: CreateEmbeddingRequest,
60    ) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
61        #[cfg(not(feature = "byot"))]
62        {
63            if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
64                return Err(OpenAIError::InvalidArgument(
65                    "When encoding_format is not base64, use Embeddings::create".into(),
66                ));
67            }
68        }
69        self.client
70            .post("/embeddings", request, &self.request_options)
71            .await
72    }
73}
74
75#[cfg(all(test, feature = "embedding"))]
76mod tests {
77    use crate::error::OpenAIError;
78    use crate::types::embeddings::{CreateEmbeddingResponse, Embedding, EncodingFormat};
79    use crate::{types::embeddings::CreateEmbeddingRequestArgs, Client};
80
81    #[tokio::test]
82    async fn test_embedding_string() {
83        let client = Client::new();
84
85        let request = CreateEmbeddingRequestArgs::default()
86            .model("text-embedding-3-small")
87            .input("The food was delicious and the waiter...")
88            .build()
89            .unwrap();
90
91        let response = client.embeddings().create(request).await;
92
93        assert!(response.is_ok());
94    }
95
96    #[tokio::test]
97    async fn test_embedding_string_array() {
98        let client = Client::new();
99
100        let request = CreateEmbeddingRequestArgs::default()
101            .model("text-embedding-3-small")
102            .input(["The food was delicious", "The waiter was good"])
103            .build()
104            .unwrap();
105
106        let response = client.embeddings().create(request).await;
107
108        assert!(response.is_ok());
109    }
110
111    #[tokio::test]
112    async fn test_embedding_integer_array() {
113        let client = Client::new();
114
115        let request = CreateEmbeddingRequestArgs::default()
116            .model("text-embedding-3-small")
117            .input([1, 2, 3])
118            .build()
119            .unwrap();
120
121        let response = client.embeddings().create(request).await;
122
123        assert!(response.is_ok());
124    }
125
126    #[tokio::test]
127    async fn test_embedding_array_of_integer_array_matrix() {
128        let client = Client::new();
129
130        let request = CreateEmbeddingRequestArgs::default()
131            .model("text-embedding-3-small")
132            .input([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
133            .build()
134            .unwrap();
135
136        let response = client.embeddings().create(request).await;
137
138        assert!(response.is_ok());
139    }
140
141    #[tokio::test]
142    async fn test_embedding_array_of_integer_array() {
143        let client = Client::new();
144
145        let request = CreateEmbeddingRequestArgs::default()
146            .model("text-embedding-3-small")
147            .input([vec![1, 2, 3], vec![4, 5, 6, 7], vec![7, 8, 10, 11, 100257]])
148            .build()
149            .unwrap();
150
151        let response = client.embeddings().create(request).await;
152
153        assert!(response.is_ok());
154    }
155
156    #[tokio::test]
157    async fn test_embedding_with_reduced_dimensions() {
158        let client = Client::new();
159        let dimensions = 256u32;
160        let request = CreateEmbeddingRequestArgs::default()
161            .model("text-embedding-3-small")
162            .input("The food was delicious and the waiter...")
163            .dimensions(dimensions)
164            .build()
165            .unwrap();
166
167        let response = client.embeddings().create(request).await;
168
169        assert!(response.is_ok());
170
171        let CreateEmbeddingResponse { mut data, .. } = response.unwrap();
172        assert_eq!(data.len(), 1);
173        let Embedding { embedding, .. } = data.pop().unwrap();
174        assert_eq!(embedding.len(), dimensions as usize);
175    }
176
177    #[tokio::test]
178    async fn test_cannot_use_base64_encoding_with_normal_create_request() {
179        let client = Client::new();
180
181        const MODEL: &str = "text-embedding-3-small";
182        const INPUT: &str = "You shall not pass.";
183
184        let b64_request = CreateEmbeddingRequestArgs::default()
185            .model(MODEL)
186            .input(INPUT)
187            .encoding_format(EncodingFormat::Base64)
188            .build()
189            .unwrap();
190        let b64_response = client.embeddings().create(b64_request).await;
191        assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
192    }
193
194    #[tokio::test]
195    async fn test_embedding_create_base64() {
196        let client = Client::new();
197
198        const MODEL: &str = "text-embedding-3-small";
199        const INPUT: &str = "a head full of dreams";
200
201        let b64_request = CreateEmbeddingRequestArgs::default()
202            .model(MODEL)
203            .input(INPUT)
204            .encoding_format(EncodingFormat::Base64)
205            .build()
206            .unwrap();
207        let b64_response = client
208            .embeddings()
209            .create_base64(b64_request)
210            .await
211            .unwrap();
212        let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding;
213        let b64_embedding: Vec<f32> = b64_embedding.into();
214
215        let request = CreateEmbeddingRequestArgs::default()
216            .model(MODEL)
217            .input(INPUT)
218            .build()
219            .unwrap();
220        let response = client.embeddings().create(request).await.unwrap();
221        let embedding = response.data.into_iter().next().unwrap().embedding;
222
223        assert_eq!(b64_embedding.len(), embedding.len());
224    }
225}