async_openai_alt/
embedding.rs

1use crate::{
2    config::Config,
3    error::OpenAIError,
4    types::{
5        CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse,
6        EncodingFormat,
7    },
8    Client,
9};
10
11/// Get a vector representation of a given input that can be easily
12/// consumed by machine learning models and algorithms.
13///
14/// Related guide: [Embeddings](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)
15pub struct Embeddings<'c, C: Config> {
16    client: &'c Client<C>,
17}
18
19impl<'c, C: Config> Embeddings<'c, C> {
20    pub fn new(client: &'c Client<C>) -> Self {
21        Self { client }
22    }
23
24    /// Creates an embedding vector representing the input text.
25    pub async fn create(
26        &self,
27        request: CreateEmbeddingRequest,
28    ) -> Result<CreateEmbeddingResponse, OpenAIError> {
29        if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
30            return Err(OpenAIError::InvalidArgument(
31                "When encoding_format is base64, use Embeddings::create_base64".into(),
32            ));
33        }
34        self.client.post("/embeddings", request).await
35    }
36
37    /// Creates an embedding vector representing the input text.
38    ///
39    /// The response will contain the embedding in base64 format.
40    pub async fn create_base64(
41        &self,
42        request: CreateEmbeddingRequest,
43    ) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
44        if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
45            return Err(OpenAIError::InvalidArgument(
46                "When encoding_format is not base64, use Embeddings::create".into(),
47            ));
48        }
49
50        self.client.post("/embeddings", request).await
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use crate::error::OpenAIError;
57    use crate::types::{CreateEmbeddingResponse, Embedding, EncodingFormat};
58    use crate::{types::CreateEmbeddingRequestArgs, Client};
59
60    #[tokio::test]
61    async fn test_embedding_string() {
62        let client = Client::new();
63
64        let request = CreateEmbeddingRequestArgs::default()
65            .model("text-embedding-ada-002")
66            .input("The food was delicious and the waiter...")
67            .build()
68            .unwrap();
69
70        let response = client.embeddings().create(request).await;
71
72        assert!(response.is_ok());
73    }
74
75    #[tokio::test]
76    async fn test_embedding_string_array() {
77        let client = Client::new();
78
79        let request = CreateEmbeddingRequestArgs::default()
80            .model("text-embedding-ada-002")
81            .input(["The food was delicious", "The waiter was good"])
82            .build()
83            .unwrap();
84
85        let response = client.embeddings().create(request).await;
86
87        assert!(response.is_ok());
88    }
89
90    #[tokio::test]
91    async fn test_embedding_integer_array() {
92        let client = Client::new();
93
94        let request = CreateEmbeddingRequestArgs::default()
95            .model("text-embedding-ada-002")
96            .input([1, 2, 3])
97            .build()
98            .unwrap();
99
100        let response = client.embeddings().create(request).await;
101
102        assert!(response.is_ok());
103    }
104
105    #[tokio::test]
106    async fn test_embedding_array_of_integer_array_matrix() {
107        let client = Client::new();
108
109        let request = CreateEmbeddingRequestArgs::default()
110            .model("text-embedding-ada-002")
111            .input([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
112            .build()
113            .unwrap();
114
115        let response = client.embeddings().create(request).await;
116
117        assert!(response.is_ok());
118    }
119
120    #[tokio::test]
121    async fn test_embedding_array_of_integer_array() {
122        let client = Client::new();
123
124        let request = CreateEmbeddingRequestArgs::default()
125            .model("text-embedding-ada-002")
126            .input([vec![1, 2, 3], vec![4, 5, 6, 7], vec![7, 8, 10, 11, 100257]])
127            .build()
128            .unwrap();
129
130        let response = client.embeddings().create(request).await;
131
132        assert!(response.is_ok());
133    }
134
135    #[tokio::test]
136    async fn test_embedding_with_reduced_dimensions() {
137        let client = Client::new();
138        let dimensions = 256u32;
139        let request = CreateEmbeddingRequestArgs::default()
140            .model("text-embedding-3-small")
141            .input("The food was delicious and the waiter...")
142            .dimensions(dimensions)
143            .build()
144            .unwrap();
145
146        let response = client.embeddings().create(request).await;
147
148        assert!(response.is_ok());
149
150        let CreateEmbeddingResponse { mut data, .. } = response.unwrap();
151        assert_eq!(data.len(), 1);
152        let Embedding { embedding, .. } = data.pop().unwrap();
153        assert_eq!(embedding.len(), dimensions as usize);
154    }
155
156    #[tokio::test]
157    async fn test_cannot_use_base64_encoding_with_normal_create_request() {
158        let client = Client::new();
159
160        const MODEL: &str = "text-embedding-ada-002";
161        const INPUT: &str = "You shall not pass.";
162
163        let b64_request = CreateEmbeddingRequestArgs::default()
164            .model(MODEL)
165            .input(INPUT)
166            .encoding_format(EncodingFormat::Base64)
167            .build()
168            .unwrap();
169        let b64_response = client.embeddings().create(b64_request).await;
170        assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
171    }
172
173    #[tokio::test]
174    async fn test_embedding_create_base64() {
175        let client = Client::new();
176
177        const MODEL: &str = "text-embedding-ada-002";
178        const INPUT: &str = "CoLoop will eat the other qual research tools...";
179
180        let b64_request = CreateEmbeddingRequestArgs::default()
181            .model(MODEL)
182            .input(INPUT)
183            .encoding_format(EncodingFormat::Base64)
184            .build()
185            .unwrap();
186        let b64_response = client
187            .embeddings()
188            .create_base64(b64_request)
189            .await
190            .unwrap();
191        let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding;
192        let b64_embedding: Vec<f32> = b64_embedding.into();
193
194        let request = CreateEmbeddingRequestArgs::default()
195            .model(MODEL)
196            .input(INPUT)
197            .build()
198            .unwrap();
199        let response = client.embeddings().create(request).await.unwrap();
200        let embedding = response.data.into_iter().next().unwrap().embedding;
201
202        assert_eq!(b64_embedding.len(), embedding.len());
203        for (b64, normal) in b64_embedding.iter().zip(embedding.iter()) {
204            assert!((b64 - normal).abs() < 1e-6);
205        }
206    }
207}