async_openai/
embedding.rs

1use crate::{
2    config::Config,
3    error::OpenAIError,
4    types::{CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse},
5    Client,
6};
7
8#[cfg(not(feature = "byot"))]
9use crate::types::EncodingFormat;
10
11/// Get a vector representation of a given input that can be easily
12/// consumed by machine learning models and algorithms.
13///
14/// Related guide: [Embeddings](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)
15pub struct Embeddings<'c, C: Config> {
16    client: &'c Client<C>,
17}
18
19impl<'c, C: Config> Embeddings<'c, C> {
20    pub fn new(client: &'c Client<C>) -> Self {
21        Self { client }
22    }
23
24    /// Creates an embedding vector representing the input text.
25    ///
26    /// byot: In serialized `request` you must ensure "encoding_format" is not "base64"
27    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
28    pub async fn create(
29        &self,
30        request: CreateEmbeddingRequest,
31    ) -> Result<CreateEmbeddingResponse, OpenAIError> {
32        #[cfg(not(feature = "byot"))]
33        {
34            if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
35                return Err(OpenAIError::InvalidArgument(
36                    "When encoding_format is base64, use Embeddings::create_base64".into(),
37                ));
38            }
39        }
40        self.client.post("/embeddings", request).await
41    }
42
43    /// Creates an embedding vector representing the input text.
44    ///
45    /// The response will contain the embedding in base64 format.
46    ///
47    /// byot: In serialized `request` you must ensure "encoding_format" is "base64"
48    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
49    pub async fn create_base64(
50        &self,
51        request: CreateEmbeddingRequest,
52    ) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
53        #[cfg(not(feature = "byot"))]
54        {
55            if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
56                return Err(OpenAIError::InvalidArgument(
57                    "When encoding_format is not base64, use Embeddings::create".into(),
58                ));
59            }
60        }
61        self.client.post("/embeddings", request).await
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use crate::error::OpenAIError;
68    use crate::types::{CreateEmbeddingResponse, Embedding, EncodingFormat};
69    use crate::{types::CreateEmbeddingRequestArgs, Client};
70
71    #[tokio::test]
72    async fn test_embedding_string() {
73        let client = Client::new();
74
75        let request = CreateEmbeddingRequestArgs::default()
76            .model("text-embedding-ada-002")
77            .input("The food was delicious and the waiter...")
78            .build()
79            .unwrap();
80
81        let response = client.embeddings().create(request).await;
82
83        assert!(response.is_ok());
84    }
85
86    #[tokio::test]
87    async fn test_embedding_string_array() {
88        let client = Client::new();
89
90        let request = CreateEmbeddingRequestArgs::default()
91            .model("text-embedding-ada-002")
92            .input(["The food was delicious", "The waiter was good"])
93            .build()
94            .unwrap();
95
96        let response = client.embeddings().create(request).await;
97
98        assert!(response.is_ok());
99    }
100
101    #[tokio::test]
102    async fn test_embedding_integer_array() {
103        let client = Client::new();
104
105        let request = CreateEmbeddingRequestArgs::default()
106            .model("text-embedding-ada-002")
107            .input([1, 2, 3])
108            .build()
109            .unwrap();
110
111        let response = client.embeddings().create(request).await;
112
113        assert!(response.is_ok());
114    }
115
116    #[tokio::test]
117    async fn test_embedding_array_of_integer_array_matrix() {
118        let client = Client::new();
119
120        let request = CreateEmbeddingRequestArgs::default()
121            .model("text-embedding-ada-002")
122            .input([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
123            .build()
124            .unwrap();
125
126        let response = client.embeddings().create(request).await;
127
128        assert!(response.is_ok());
129    }
130
131    #[tokio::test]
132    async fn test_embedding_array_of_integer_array() {
133        let client = Client::new();
134
135        let request = CreateEmbeddingRequestArgs::default()
136            .model("text-embedding-ada-002")
137            .input([vec![1, 2, 3], vec![4, 5, 6, 7], vec![7, 8, 10, 11, 100257]])
138            .build()
139            .unwrap();
140
141        let response = client.embeddings().create(request).await;
142
143        assert!(response.is_ok());
144    }
145
146    #[tokio::test]
147    async fn test_embedding_with_reduced_dimensions() {
148        let client = Client::new();
149        let dimensions = 256u32;
150        let request = CreateEmbeddingRequestArgs::default()
151            .model("text-embedding-3-small")
152            .input("The food was delicious and the waiter...")
153            .dimensions(dimensions)
154            .build()
155            .unwrap();
156
157        let response = client.embeddings().create(request).await;
158
159        assert!(response.is_ok());
160
161        let CreateEmbeddingResponse { mut data, .. } = response.unwrap();
162        assert_eq!(data.len(), 1);
163        let Embedding { embedding, .. } = data.pop().unwrap();
164        assert_eq!(embedding.len(), dimensions as usize);
165    }
166
167    #[tokio::test]
168    #[cfg(not(feature = "byot"))]
169    async fn test_cannot_use_base64_encoding_with_normal_create_request() {
170        let client = Client::new();
171
172        const MODEL: &str = "text-embedding-ada-002";
173        const INPUT: &str = "You shall not pass.";
174
175        let b64_request = CreateEmbeddingRequestArgs::default()
176            .model(MODEL)
177            .input(INPUT)
178            .encoding_format(EncodingFormat::Base64)
179            .build()
180            .unwrap();
181        let b64_response = client.embeddings().create(b64_request).await;
182        assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
183    }
184
185    #[tokio::test]
186    async fn test_embedding_create_base64() {
187        let client = Client::new();
188
189        const MODEL: &str = "text-embedding-ada-002";
190        const INPUT: &str = "CoLoop will eat the other qual research tools...";
191
192        let b64_request = CreateEmbeddingRequestArgs::default()
193            .model(MODEL)
194            .input(INPUT)
195            .encoding_format(EncodingFormat::Base64)
196            .build()
197            .unwrap();
198        let b64_response = client
199            .embeddings()
200            .create_base64(b64_request)
201            .await
202            .unwrap();
203        let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding;
204        let b64_embedding: Vec<f32> = b64_embedding.into();
205
206        let request = CreateEmbeddingRequestArgs::default()
207            .model(MODEL)
208            .input(INPUT)
209            .build()
210            .unwrap();
211        let response = client.embeddings().create(request).await.unwrap();
212        let embedding = response.data.into_iter().next().unwrap().embedding;
213
214        assert_eq!(b64_embedding.len(), embedding.len());
215        for (b64, normal) in b64_embedding.iter().zip(embedding.iter()) {
216            assert!((b64 - normal).abs() < 1e-6);
217        }
218    }
219}