async_openai_wasm/
embedding.rs

1use crate::{
2    Client, RequestOptions,
3    config::Config,
4    error::OpenAIError,
5    types::embeddings::{
6        CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse,
7    },
8};
9
10#[cfg(not(feature = "byot"))]
11use crate::types::embeddings::EncodingFormat;
12
13/// Get a vector representation of a given input that can be easily
14/// consumed by machine learning models and algorithms.
15///
16/// Related guide: [Embeddings](https://platform.openai.com/docs/guides/embeddings)
17pub struct Embeddings<'c, C: Config> {
18    client: &'c Client<C>,
19    pub(crate) request_options: RequestOptions,
20}
21
22impl<'c, C: Config> Embeddings<'c, C> {
23    pub fn new(client: &'c Client<C>) -> Self {
24        Self {
25            client,
26            request_options: RequestOptions::new(),
27        }
28    }
29
30    /// Creates an embedding vector representing the input text.
31    ///
32    /// byot: In serialized `request` you must ensure "encoding_format" is not "base64"
33    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
34    pub async fn create(
35        &self,
36        request: CreateEmbeddingRequest,
37    ) -> Result<CreateEmbeddingResponse, OpenAIError> {
38        #[cfg(not(feature = "byot"))]
39        {
40            if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
41                return Err(OpenAIError::InvalidArgument(
42                    "When encoding_format is base64, use Embeddings::create_base64".into(),
43                ));
44            }
45        }
46        self.client
47            .post("/embeddings", request, &self.request_options)
48            .await
49    }
50
51    /// Creates an embedding vector representing the input text.
52    ///
53    /// The response will contain the embedding in base64 format.
54    ///
55    /// byot: In serialized `request` you must ensure "encoding_format" is "base64"
56    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
57    pub async fn create_base64(
58        &self,
59        request: CreateEmbeddingRequest,
60    ) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
61        #[cfg(not(feature = "byot"))]
62        {
63            if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
64                return Err(OpenAIError::InvalidArgument(
65                    "When encoding_format is not base64, use Embeddings::create".into(),
66                ));
67            }
68        }
69        self.client
70            .post("/embeddings", request, &self.request_options)
71            .await
72    }
73}
74
75#[cfg(all(test, feature = "embedding"))]
76mod tests {
77    use crate::types::embeddings::{CreateEmbeddingResponse, Embedding, EncodingFormat};
78    use crate::{Client, types::embeddings::CreateEmbeddingRequestArgs};
79
80    #[tokio::test]
81    async fn test_embedding_string() {
82        let client = Client::new();
83
84        let request = CreateEmbeddingRequestArgs::default()
85            .model("text-embedding-3-small")
86            .input("The food was delicious and the waiter...")
87            .build()
88            .unwrap();
89
90        let response = client.embeddings().create(request).await;
91
92        assert!(response.is_ok());
93    }
94
95    #[tokio::test]
96    async fn test_embedding_string_array() {
97        let client = Client::new();
98
99        let request = CreateEmbeddingRequestArgs::default()
100            .model("text-embedding-3-small")
101            .input(["The food was delicious", "The waiter was good"])
102            .build()
103            .unwrap();
104
105        let response = client.embeddings().create(request).await;
106
107        assert!(response.is_ok());
108    }
109
110    #[tokio::test]
111    async fn test_embedding_integer_array() {
112        let client = Client::new();
113
114        let request = CreateEmbeddingRequestArgs::default()
115            .model("text-embedding-3-small")
116            .input([1, 2, 3])
117            .build()
118            .unwrap();
119
120        let response = client.embeddings().create(request).await;
121
122        assert!(response.is_ok());
123    }
124
125    #[tokio::test]
126    async fn test_embedding_array_of_integer_array_matrix() {
127        let client = Client::new();
128
129        let request = CreateEmbeddingRequestArgs::default()
130            .model("text-embedding-3-small")
131            .input([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
132            .build()
133            .unwrap();
134
135        let response = client.embeddings().create(request).await;
136
137        assert!(response.is_ok());
138    }
139
140    #[tokio::test]
141    async fn test_embedding_array_of_integer_array() {
142        let client = Client::new();
143
144        let request = CreateEmbeddingRequestArgs::default()
145            .model("text-embedding-3-small")
146            .input([vec![1, 2, 3], vec![4, 5, 6, 7], vec![7, 8, 10, 11, 100257]])
147            .build()
148            .unwrap();
149
150        let response = client.embeddings().create(request).await;
151
152        assert!(response.is_ok());
153    }
154
155    #[tokio::test]
156    async fn test_embedding_with_reduced_dimensions() {
157        let client = Client::new();
158        let dimensions = 256u32;
159        let request = CreateEmbeddingRequestArgs::default()
160            .model("text-embedding-3-small")
161            .input("The food was delicious and the waiter...")
162            .dimensions(dimensions)
163            .build()
164            .unwrap();
165
166        let response = client.embeddings().create(request).await;
167
168        assert!(response.is_ok());
169
170        let CreateEmbeddingResponse { mut data, .. } = response.unwrap();
171        assert_eq!(data.len(), 1);
172        let Embedding { embedding, .. } = data.pop().unwrap();
173        assert_eq!(embedding.len(), dimensions as usize);
174    }
175
176    #[tokio::test]
177    async fn test_cannot_use_base64_encoding_with_normal_create_request() {
178        use crate::error::OpenAIError;
179        let client = Client::new();
180
181        const MODEL: &str = "text-embedding-3-small";
182        const INPUT: &str = "You shall not pass.";
183
184        let b64_request = CreateEmbeddingRequestArgs::default()
185            .model(MODEL)
186            .input(INPUT)
187            .encoding_format(EncodingFormat::Base64)
188            .build()
189            .unwrap();
190        let b64_response = client.embeddings().create(b64_request).await;
191        assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
192    }
193
194    #[tokio::test]
195    async fn test_embedding_create_base64() {
196        let client = Client::new();
197
198        const MODEL: &str = "text-embedding-3-small";
199        const INPUT: &str = "a head full of dreams";
200
201        let b64_request = CreateEmbeddingRequestArgs::default()
202            .model(MODEL)
203            .input(INPUT)
204            .encoding_format(EncodingFormat::Base64)
205            .build()
206            .unwrap();
207        let b64_response = client
208            .embeddings()
209            .create_base64(b64_request)
210            .await
211            .unwrap();
212        let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding;
213        let b64_embedding: Vec<f32> = b64_embedding.into();
214
215        let request = CreateEmbeddingRequestArgs::default()
216            .model(MODEL)
217            .input(INPUT)
218            .build()
219            .unwrap();
220        let response = client.embeddings().create(request).await.unwrap();
221        let embedding = response.data.into_iter().next().unwrap().embedding;
222
223        assert_eq!(b64_embedding.len(), embedding.len());
224    }
225}