async_openai/
embedding.rs

1use crate::{
2    config::Config,
3    error::OpenAIError,
4    types::{CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse},
5    Client,
6};
7
8#[cfg(not(feature = "byot"))]
9use crate::types::EncodingFormat;
10
11/// Get a vector representation of a given input that can be easily
12/// consumed by machine learning models and algorithms.
13///
14/// Related guide: [Embeddings](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)
15pub struct Embeddings<'c, C: Config> {
16    client: &'c Client<C>,
17}
18
19impl<'c, C: Config> Embeddings<'c, C> {
20    pub fn new(client: &'c Client<C>) -> Self {
21        Self { client }
22    }
23
24    /// Creates an embedding vector representing the input text.
25    ///
26    /// byot: In serialized `request` you must ensure "encoding_format" is not "base64"
27    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
28    pub async fn create(
29        &self,
30        request: CreateEmbeddingRequest,
31    ) -> Result<CreateEmbeddingResponse, OpenAIError> {
32        #[cfg(not(feature = "byot"))]
33        {
34            if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
35                return Err(OpenAIError::InvalidArgument(
36                    "When encoding_format is base64, use Embeddings::create_base64".into(),
37                ));
38            }
39        }
40        self.client.post("/embeddings", request).await
41    }
42
43    /// Creates an embedding vector representing the input text.
44    ///
45    /// The response will contain the embedding in base64 format.
46    ///
47    /// byot: In serialized `request` you must ensure "encoding_format" is "base64"
48    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
49    pub async fn create_base64(
50        &self,
51        request: CreateEmbeddingRequest,
52    ) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
53        #[cfg(not(feature = "byot"))]
54        {
55            if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
56                return Err(OpenAIError::InvalidArgument(
57                    "When encoding_format is not base64, use Embeddings::create".into(),
58                ));
59            }
60        }
61        self.client.post("/embeddings", request).await
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use crate::types::{CreateEmbeddingResponse, Embedding, EncodingFormat};
68    use crate::{types::CreateEmbeddingRequestArgs, Client};
69
70    #[tokio::test]
71    async fn test_embedding_string() {
72        let client = Client::new();
73
74        let request = CreateEmbeddingRequestArgs::default()
75            .model("text-embedding-ada-002")
76            .input("The food was delicious and the waiter...")
77            .build()
78            .unwrap();
79
80        let response = client.embeddings().create(request).await;
81
82        assert!(response.is_ok());
83    }
84
85    #[tokio::test]
86    async fn test_embedding_string_array() {
87        let client = Client::new();
88
89        let request = CreateEmbeddingRequestArgs::default()
90            .model("text-embedding-ada-002")
91            .input(["The food was delicious", "The waiter was good"])
92            .build()
93            .unwrap();
94
95        let response = client.embeddings().create(request).await;
96
97        assert!(response.is_ok());
98    }
99
100    #[tokio::test]
101    async fn test_embedding_integer_array() {
102        let client = Client::new();
103
104        let request = CreateEmbeddingRequestArgs::default()
105            .model("text-embedding-ada-002")
106            .input([1, 2, 3])
107            .build()
108            .unwrap();
109
110        let response = client.embeddings().create(request).await;
111
112        assert!(response.is_ok());
113    }
114
115    #[tokio::test]
116    async fn test_embedding_array_of_integer_array_matrix() {
117        let client = Client::new();
118
119        let request = CreateEmbeddingRequestArgs::default()
120            .model("text-embedding-ada-002")
121            .input([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
122            .build()
123            .unwrap();
124
125        let response = client.embeddings().create(request).await;
126
127        assert!(response.is_ok());
128    }
129
130    #[tokio::test]
131    async fn test_embedding_array_of_integer_array() {
132        let client = Client::new();
133
134        let request = CreateEmbeddingRequestArgs::default()
135            .model("text-embedding-ada-002")
136            .input([vec![1, 2, 3], vec![4, 5, 6, 7], vec![7, 8, 10, 11, 100257]])
137            .build()
138            .unwrap();
139
140        let response = client.embeddings().create(request).await;
141
142        assert!(response.is_ok());
143    }
144
145    #[tokio::test]
146    async fn test_embedding_with_reduced_dimensions() {
147        let client = Client::new();
148        let dimensions = 256u32;
149        let request = CreateEmbeddingRequestArgs::default()
150            .model("text-embedding-3-small")
151            .input("The food was delicious and the waiter...")
152            .dimensions(dimensions)
153            .build()
154            .unwrap();
155
156        let response = client.embeddings().create(request).await;
157
158        assert!(response.is_ok());
159
160        let CreateEmbeddingResponse { mut data, .. } = response.unwrap();
161        assert_eq!(data.len(), 1);
162        let Embedding { embedding, .. } = data.pop().unwrap();
163        assert_eq!(embedding.len(), dimensions as usize);
164    }
165
166    #[tokio::test]
167    #[cfg(not(feature = "byot"))]
168    async fn test_cannot_use_base64_encoding_with_normal_create_request() {
169        let client = Client::new();
170
171        const MODEL: &str = "text-embedding-ada-002";
172        const INPUT: &str = "You shall not pass.";
173
174        let b64_request = CreateEmbeddingRequestArgs::default()
175            .model(MODEL)
176            .input(INPUT)
177            .encoding_format(EncodingFormat::Base64)
178            .build()
179            .unwrap();
180        let b64_response = client.embeddings().create(b64_request).await;
181        assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
182    }
183
184    #[tokio::test]
185    async fn test_embedding_create_base64() {
186        let client = Client::new();
187
188        const MODEL: &str = "text-embedding-ada-002";
189        const INPUT: &str = "CoLoop will eat the other qual research tools...";
190
191        let b64_request = CreateEmbeddingRequestArgs::default()
192            .model(MODEL)
193            .input(INPUT)
194            .encoding_format(EncodingFormat::Base64)
195            .build()
196            .unwrap();
197        let b64_response = client
198            .embeddings()
199            .create_base64(b64_request)
200            .await
201            .unwrap();
202        let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding;
203        let b64_embedding: Vec<f32> = b64_embedding.into();
204
205        let request = CreateEmbeddingRequestArgs::default()
206            .model(MODEL)
207            .input(INPUT)
208            .build()
209            .unwrap();
210        let response = client.embeddings().create(request).await.unwrap();
211        let embedding = response.data.into_iter().next().unwrap().embedding;
212
213        assert_eq!(b64_embedding.len(), embedding.len());
214        for (b64, normal) in b64_embedding.iter().zip(embedding.iter()) {
215            assert!((b64 - normal).abs() < 1e-6);
216        }
217    }
218}