chat_gpt_lib_rs/api_resources/
embeddings.rs

1//! This module provides functionality for creating embeddings using the
2//! [OpenAI Embeddings API](https://platform.openai.com/docs/api-reference/embeddings).
3//!
4//! The Embeddings API takes in text or tokenized text and returns a vector representation
5//! (embedding) that can be used for tasks like similarity searches, clustering, or classification
6//! in vector databases.
7//!
8//! # Overview
9//!
10//! The core usage involves calling [`create_embeddings`] with a [`CreateEmbeddingsRequest`],
11//! which includes the `model` name (e.g., `"text-embedding-ada-002"`) and the input text(s).
12//!
13//! ```rust,no_run
14//! use chat_gpt_lib_rs::api_resources::embeddings::{create_embeddings, CreateEmbeddingsRequest, EmbeddingsInput};
15//! use chat_gpt_lib_rs::error::OpenAIError;
16//! use chat_gpt_lib_rs::OpenAIClient;
17//!
18//! #[tokio::main]
19//! async fn main() -> Result<(), OpenAIError> {
20//!     let client = OpenAIClient::new(None)?; // Reads API key from OPENAI_API_KEY
21//!
22//!     let request = CreateEmbeddingsRequest {
23//!         model: "text-embedding-ada-002".into(),
24//!         input: EmbeddingsInput::String("Hello world".to_string()),
25//!         user: None,
26//!     };
27//!
28//!     let response = create_embeddings(&client, &request).await?;
29//!     for (i, emb) in response.data.iter().enumerate() {
30//!         println!("Embedding #{}: vector size = {}", i, emb.embedding.len());
31//!     }
32//!     println!("Model used: {:?}", response.model);
33//!     if let Some(usage) = &response.usage {
34//!         println!("Usage => prompt_tokens: {}, total_tokens: {}",
35//!             usage.prompt_tokens, usage.total_tokens);
36//!     }
37//!
38//!     Ok(())
39//! }
40//! ```
41
42use serde::{Deserialize, Serialize};
43
44use crate::api::post_json;
45use crate::config::OpenAIClient;
46use crate::error::OpenAIError;
47
48use super::models::Model;
49
50/// Represents the diverse ways the input can be supplied for embeddings:
51///
52/// - A single string
53/// - Multiple strings
54/// - A single sequence of token IDs
55/// - Multiple sequences of token IDs
56///
57/// This is analogous to how prompt inputs can be specified in the Completions API,
58/// so we mirror that flexibility here.
59#[derive(Debug, Serialize, Deserialize, Clone)]
60#[serde(untagged)]
61pub enum EmbeddingsInput {
62    /// A single string
63    String(String),
64    /// Multiple strings
65    Strings(Vec<String>),
66    /// A single sequence of token IDs
67    Ints(Vec<i64>),
68    /// Multiple sequences of token IDs
69    MultiInts(Vec<Vec<i64>>),
70}
71
72/// A request struct for creating embeddings with the OpenAI API.
73///
74/// For more details, see the [API documentation](https://platform.openai.com/docs/api-reference/embeddings).
75#[derive(Debug, Serialize, Clone)]
76pub struct CreateEmbeddingsRequest {
77    /// **Required.** The ID of the model to use.
78    /// For example: `"text-embedding-ada-002"`.
79    pub model: Model,
80    /// **Required.** The input text or tokens for which you want to generate embeddings.
81    pub input: EmbeddingsInput,
82    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub user: Option<String>,
85}
86
87/// The response returned by the OpenAI Embeddings API.
88///
89/// Contains one or more embeddings (depending on whether multiple inputs were provided),
90/// along with the model used and usage metrics.
91#[derive(Debug, Deserialize)]
92pub struct CreateEmbeddingsResponse {
93    /// An identifier for this embedding request (e.g. "embedding-xxxxxx").
94    pub object: String,
95    /// The list of embeddings returned, each containing an index and the embedding vector.
96    pub data: Vec<EmbeddingData>,
97    /// The model used for creating these embeddings.
98    pub model: Model,
99    /// Optional usage statistics for this request.
100    #[serde(default)]
101    pub usage: Option<EmbeddingsUsage>,
102}
103
104/// The embedding result for a single input item.
105#[derive(Debug, Deserialize)]
106pub struct EmbeddingData {
107    /// The type of object returned, usually "embedding".
108    pub object: String,
109    /// The position/index of this embedding in the input array.
110    pub index: u32,
111    /// The embedding vector itself.
112    pub embedding: Vec<f32>,
113}
114
115/// Usage statistics for an embeddings request, if provided by the API.
116#[derive(Debug, Deserialize)]
117pub struct EmbeddingsUsage {
118    /// Number of tokens present in the prompt(s).
119    pub prompt_tokens: u32,
120    /// Total number of tokens consumed by this request.
121    ///
122    /// For embeddings, this is typically the same as `prompt_tokens`, unless the API
123    /// changes how it reports usage data in the future.
124    pub total_tokens: u32,
125}
126
127/// Creates embeddings using the [OpenAI Embeddings API](https://platform.openai.com/docs/api-reference/embeddings).
128///
129/// # Parameters
130///
131/// * `client` - The [`OpenAIClient`](crate::config::OpenAIClient) to use for the request.
132/// * `request` - A [`CreateEmbeddingsRequest`] specifying the model and input(s).
133///
134/// # Returns
135///
136/// A [`CreateEmbeddingsResponse`] containing one or more embedding vectors.
137///
138/// # Errors
139///
140/// - [`OpenAIError::HTTPError`]: if the request fails at the network layer.
141/// - [`OpenAIError::DeserializeError`]: if the response fails to parse.
142/// - [`OpenAIError::APIError`]: if OpenAI returns an error (e.g., invalid request).
143pub async fn create_embeddings(
144    client: &OpenAIClient,
145    request: &CreateEmbeddingsRequest,
146) -> Result<CreateEmbeddingsResponse, OpenAIError> {
147    // According to the OpenAI docs, the endpoint for embeddings is:
148    // POST /v1/embeddings
149    let endpoint = "embeddings";
150    post_json(client, endpoint, request).await
151}
152
153#[cfg(test)]
154mod tests {
155    /// # Tests for the `embeddings` module
156    ///
157    /// We rely on [`wiremock`](https://crates.io/crates/wiremock) to mock responses from the
158    /// `/v1/embeddings` endpoint. The tests ensure:
159    /// 1. A **success** case where we receive a valid embedding response (`CreateEmbeddingsResponse`).
160    /// 2. A **failure** case returning an OpenAI-style error (mapped to `OpenAIError::APIError`).
161    /// 3. A **deserialization error** case when the JSON is malformed.
162    ///
163    use super::*;
164    use crate::config::OpenAIClient;
165    use crate::error::OpenAIError;
166    use serde_json::json;
167    use wiremock::matchers::{method, path};
168    use wiremock::{Mock, MockServer, ResponseTemplate};
169
170    #[tokio::test]
171    async fn test_create_embeddings_success() {
172        // Start the local mock server
173        let mock_server = MockServer::start().await;
174
175        // Define a successful response JSON
176        let success_body = json!({
177            "object": "list",
178            "data": [
179                {
180                    "object": "embedding",
181                    "index": 0,
182                    "embedding": [0.123, -0.456, 0.789]
183                },
184                {
185                    "object": "embedding",
186                    "index": 1,
187                    "embedding": [0.111, 0.222, 0.333]
188                }
189            ],
190            "model": "text-embedding-ada-002",
191            "usage": {
192                "prompt_tokens": 5,
193                "total_tokens": 5
194            }
195        });
196
197        // Mock a POST to /v1/embeddings
198        Mock::given(method("POST"))
199            .and(path("/embeddings"))
200            .respond_with(ResponseTemplate::new(200).set_body_json(success_body))
201            .mount(&mock_server)
202            .await;
203
204        let client = OpenAIClient::builder()
205            .with_api_key("test-key")
206            .with_base_url(&mock_server.uri())
207            .build()
208            .unwrap();
209
210        let req = CreateEmbeddingsRequest {
211            model: "text-embedding-ada-002".into(),
212            input: EmbeddingsInput::Strings(vec!["Hello".to_string(), "World".to_string()]),
213            user: None,
214        };
215
216        let result = create_embeddings(&client, &req).await;
217        assert!(result.is_ok(), "Expected Ok, got: {:?}", result);
218
219        let resp = result.unwrap();
220        assert_eq!(resp.object, "list");
221        assert_eq!(resp.data.len(), 2);
222        assert_eq!(resp.model, "text-embedding-ada-002".into());
223
224        let first = &resp.data[0];
225        assert_eq!(first.object, "embedding");
226        assert_eq!(first.index, 0);
227        assert_eq!(first.embedding, vec![0.123, -0.456, 0.789]);
228
229        let usage = resp.usage.as_ref().unwrap();
230        assert_eq!(usage.prompt_tokens, 5);
231        assert_eq!(usage.total_tokens, 5);
232    }
233
234    #[tokio::test]
235    async fn test_create_embeddings_api_error() {
236        let mock_server = MockServer::start().await;
237
238        // Simulate a 400 error with an OpenAI-style error body
239        let error_body = json!({
240            "error": {
241                "message": "Invalid model: text-embedding-ada-999",
242                "type": "invalid_request_error",
243                "code": "model_invalid"
244            }
245        });
246
247        Mock::given(method("POST"))
248            .and(path("/embeddings"))
249            .respond_with(ResponseTemplate::new(400).set_body_json(error_body))
250            .mount(&mock_server)
251            .await;
252
253        let client = OpenAIClient::builder()
254            .with_api_key("test-key")
255            .with_base_url(&mock_server.uri())
256            .build()
257            .unwrap();
258
259        let req = CreateEmbeddingsRequest {
260            model: "text-embedding-ada-999".into(),
261            input: EmbeddingsInput::String("test input".to_string()),
262            user: Some("user-123".to_string()),
263        };
264
265        let result = create_embeddings(&client, &req).await;
266        match result {
267            Err(OpenAIError::APIError { message, .. }) => {
268                assert!(message.contains("Invalid model: text-embedding-ada-999"));
269            }
270            other => panic!("Expected APIError, got: {:?}", other),
271        }
272    }
273
274    #[tokio::test]
275    async fn test_create_embeddings_deserialize_error() {
276        let mock_server = MockServer::start().await;
277
278        // Return 200 but invalid or mismatched JSON
279        let malformed_json = r#"{
280            "object": "list",
281            "data": "should be an array of embeddings, not a string",
282            "model": "text-embedding-ada-002"
283        }"#;
284
285        Mock::given(method("POST"))
286            .and(path("/embeddings"))
287            .respond_with(
288                ResponseTemplate::new(200).set_body_raw(malformed_json, "application/json"),
289            )
290            .mount(&mock_server)
291            .await;
292
293        let client = OpenAIClient::builder()
294            .with_api_key("test-key")
295            .with_base_url(&mock_server.uri())
296            .build()
297            .unwrap();
298
299        let req = CreateEmbeddingsRequest {
300            model: "text-embedding-ada-002".into(),
301            input: EmbeddingsInput::String("Hello".to_string()),
302            user: None,
303        };
304
305        let result = create_embeddings(&client, &req).await;
306        match result {
307            Err(OpenAIError::DeserializeError(_)) => {
308                // success
309            }
310            other => panic!("Expected DeserializeError, got {:?}", other),
311        }
312    }
313}