async_openai/
embedding.rs

1use crate::{
2    config::Config,
3    error::OpenAIError,
4    types::embeddings::{
5        CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse,
6    },
7    Client,
8};
9
10#[cfg(not(feature = "byot"))]
11use crate::types::embeddings::EncodingFormat;
12
13/// Get a vector representation of a given input that can be easily
14/// consumed by machine learning models and algorithms.
15///
16/// Related guide: [Embeddings](https://platform.openai.com/docs/guides/embeddings)
17pub struct Embeddings<'c, C: Config> {
18    client: &'c Client<C>,
19}
20
21impl<'c, C: Config> Embeddings<'c, C> {
22    pub fn new(client: &'c Client<C>) -> Self {
23        Self { client }
24    }
25
26    /// Creates an embedding vector representing the input text.
27    ///
28    /// byot: In serialized `request` you must ensure "encoding_format" is not "base64"
29    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
30    pub async fn create(
31        &self,
32        request: CreateEmbeddingRequest,
33    ) -> Result<CreateEmbeddingResponse, OpenAIError> {
34        #[cfg(not(feature = "byot"))]
35        {
36            if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
37                return Err(OpenAIError::InvalidArgument(
38                    "When encoding_format is base64, use Embeddings::create_base64".into(),
39                ));
40            }
41        }
42        self.client.post("/embeddings", request).await
43    }
44
45    /// Creates an embedding vector representing the input text.
46    ///
47    /// The response will contain the embedding in base64 format.
48    ///
49    /// byot: In serialized `request` you must ensure "encoding_format" is "base64"
50    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
51    pub async fn create_base64(
52        &self,
53        request: CreateEmbeddingRequest,
54    ) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
55        #[cfg(not(feature = "byot"))]
56        {
57            if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
58                return Err(OpenAIError::InvalidArgument(
59                    "When encoding_format is not base64, use Embeddings::create".into(),
60                ));
61            }
62        }
63        self.client.post("/embeddings", request).await
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use crate::error::OpenAIError;
70    use crate::types::embeddings::{CreateEmbeddingResponse, Embedding, EncodingFormat};
71    use crate::{types::embeddings::CreateEmbeddingRequestArgs, Client};
72
73    #[tokio::test]
74    async fn test_embedding_string() {
75        let client = Client::new();
76
77        let request = CreateEmbeddingRequestArgs::default()
78            .model("text-embedding-ada-002")
79            .input("The food was delicious and the waiter...")
80            .build()
81            .unwrap();
82
83        let response = client.embeddings().create(request).await;
84
85        assert!(response.is_ok());
86    }
87
88    #[tokio::test]
89    async fn test_embedding_string_array() {
90        let client = Client::new();
91
92        let request = CreateEmbeddingRequestArgs::default()
93            .model("text-embedding-ada-002")
94            .input(["The food was delicious", "The waiter was good"])
95            .build()
96            .unwrap();
97
98        let response = client.embeddings().create(request).await;
99
100        assert!(response.is_ok());
101    }
102
103    #[tokio::test]
104    async fn test_embedding_integer_array() {
105        let client = Client::new();
106
107        let request = CreateEmbeddingRequestArgs::default()
108            .model("text-embedding-ada-002")
109            .input([1, 2, 3])
110            .build()
111            .unwrap();
112
113        let response = client.embeddings().create(request).await;
114
115        assert!(response.is_ok());
116    }
117
118    #[tokio::test]
119    async fn test_embedding_array_of_integer_array_matrix() {
120        let client = Client::new();
121
122        let request = CreateEmbeddingRequestArgs::default()
123            .model("text-embedding-ada-002")
124            .input([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
125            .build()
126            .unwrap();
127
128        let response = client.embeddings().create(request).await;
129
130        assert!(response.is_ok());
131    }
132
133    #[tokio::test]
134    async fn test_embedding_array_of_integer_array() {
135        let client = Client::new();
136
137        let request = CreateEmbeddingRequestArgs::default()
138            .model("text-embedding-ada-002")
139            .input([vec![1, 2, 3], vec![4, 5, 6, 7], vec![7, 8, 10, 11, 100257]])
140            .build()
141            .unwrap();
142
143        let response = client.embeddings().create(request).await;
144
145        assert!(response.is_ok());
146    }
147
148    #[tokio::test]
149    async fn test_embedding_with_reduced_dimensions() {
150        let client = Client::new();
151        let dimensions = 256u32;
152        let request = CreateEmbeddingRequestArgs::default()
153            .model("text-embedding-3-small")
154            .input("The food was delicious and the waiter...")
155            .dimensions(dimensions)
156            .build()
157            .unwrap();
158
159        let response = client.embeddings().create(request).await;
160
161        assert!(response.is_ok());
162
163        let CreateEmbeddingResponse { mut data, .. } = response.unwrap();
164        assert_eq!(data.len(), 1);
165        let Embedding { embedding, .. } = data.pop().unwrap();
166        assert_eq!(embedding.len(), dimensions as usize);
167    }
168
169    #[tokio::test]
170    async fn test_cannot_use_base64_encoding_with_normal_create_request() {
171        let client = Client::new();
172
173        const MODEL: &str = "text-embedding-ada-002";
174        const INPUT: &str = "You shall not pass.";
175
176        let b64_request = CreateEmbeddingRequestArgs::default()
177            .model(MODEL)
178            .input(INPUT)
179            .encoding_format(EncodingFormat::Base64)
180            .build()
181            .unwrap();
182        let b64_response = client.embeddings().create(b64_request).await;
183        assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
184    }
185
186    #[tokio::test]
187    async fn test_embedding_create_base64() {
188        let client = Client::new();
189
190        const MODEL: &str = "text-embedding-ada-002";
191        const INPUT: &str = "a head full of dreams";
192
193        let b64_request = CreateEmbeddingRequestArgs::default()
194            .model(MODEL)
195            .input(INPUT)
196            .encoding_format(EncodingFormat::Base64)
197            .build()
198            .unwrap();
199        let b64_response = client
200            .embeddings()
201            .create_base64(b64_request)
202            .await
203            .unwrap();
204        let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding;
205        let b64_embedding: Vec<f32> = b64_embedding.into();
206
207        let request = CreateEmbeddingRequestArgs::default()
208            .model(MODEL)
209            .input(INPUT)
210            .build()
211            .unwrap();
212        let response = client.embeddings().create(request).await.unwrap();
213        let embedding = response.data.into_iter().next().unwrap().embedding;
214
215        assert_eq!(b64_embedding.len(), embedding.len());
216        for (b64, normal) in b64_embedding.iter().zip(embedding.iter()) {
217            assert!((b64 - normal).abs() < 1e-6);
218        }
219    }
220}