chat_gpt_lib_rs/api_resources/embeddings.rs
1//! This module provides functionality for creating embeddings using the
2//! [OpenAI Embeddings API](https://platform.openai.com/docs/api-reference/embeddings).
3//!
4//! The Embeddings API takes in text or tokenized text and returns a vector representation
5//! (embedding) that can be used for tasks like similarity searches, clustering, or classification
6//! in vector databases.
7//!
8//! # Overview
9//!
10//! The core usage involves calling [`create_embeddings`] with a [`CreateEmbeddingsRequest`],
11//! which includes the `model` name (e.g., `"text-embedding-ada-002"`) and the input text(s).
12//!
13//! ```rust,no_run
14//! use chat_gpt_lib_rs::api_resources::embeddings::{create_embeddings, CreateEmbeddingsRequest, EmbeddingsInput};
15//! use chat_gpt_lib_rs::error::OpenAIError;
16//! use chat_gpt_lib_rs::OpenAIClient;
17//!
18//! #[tokio::main]
19//! async fn main() -> Result<(), OpenAIError> {
20//! let client = OpenAIClient::new(None)?; // Reads API key from OPENAI_API_KEY
21//!
22//! let request = CreateEmbeddingsRequest {
23//! model: "text-embedding-ada-002".into(),
24//! input: EmbeddingsInput::String("Hello world".to_string()),
25//! user: None,
26//! };
27//!
28//! let response = create_embeddings(&client, &request).await?;
29//! for (i, emb) in response.data.iter().enumerate() {
30//! println!("Embedding #{}: vector size = {}", i, emb.embedding.len());
31//! }
32//! println!("Model used: {:?}", response.model);
33//! if let Some(usage) = &response.usage {
34//! println!("Usage => prompt_tokens: {}, total_tokens: {}",
35//! usage.prompt_tokens, usage.total_tokens);
36//! }
37//!
38//! Ok(())
39//! }
40//! ```
41
42use serde::{Deserialize, Serialize};
43
44use crate::api::post_json;
45use crate::config::OpenAIClient;
46use crate::error::OpenAIError;
47
48use super::models::Model;
49
50/// Represents the diverse ways the input can be supplied for embeddings:
51///
52/// - A single string
53/// - Multiple strings
54/// - A single sequence of token IDs
55/// - Multiple sequences of token IDs
56///
57/// This is analogous to how prompt inputs can be specified in the Completions API,
58/// so we mirror that flexibility here.
59#[derive(Debug, Serialize, Deserialize, Clone)]
60#[serde(untagged)]
61pub enum EmbeddingsInput {
62 /// A single string
63 String(String),
64 /// Multiple strings
65 Strings(Vec<String>),
66 /// A single sequence of token IDs
67 Ints(Vec<i64>),
68 /// Multiple sequences of token IDs
69 MultiInts(Vec<Vec<i64>>),
70}
71
72/// A request struct for creating embeddings with the OpenAI API.
73///
74/// For more details, see the [API documentation](https://platform.openai.com/docs/api-reference/embeddings).
75#[derive(Debug, Serialize, Clone)]
76pub struct CreateEmbeddingsRequest {
77 /// **Required.** The ID of the model to use.
78 /// For example: `"text-embedding-ada-002"`.
79 pub model: Model,
80 /// **Required.** The input text or tokens for which you want to generate embeddings.
81 pub input: EmbeddingsInput,
82 /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
83 #[serde(skip_serializing_if = "Option::is_none")]
84 pub user: Option<String>,
85}
86
87/// The response returned by the OpenAI Embeddings API.
88///
89/// Contains one or more embeddings (depending on whether multiple inputs were provided),
90/// along with the model used and usage metrics.
91#[derive(Debug, Deserialize)]
92pub struct CreateEmbeddingsResponse {
93 /// An identifier for this embedding request (e.g. "embedding-xxxxxx").
94 pub object: String,
95 /// The list of embeddings returned, each containing an index and the embedding vector.
96 pub data: Vec<EmbeddingData>,
97 /// The model used for creating these embeddings.
98 pub model: Model,
99 /// Optional usage statistics for this request.
100 #[serde(default)]
101 pub usage: Option<EmbeddingsUsage>,
102}
103
104/// The embedding result for a single input item.
105#[derive(Debug, Deserialize)]
106pub struct EmbeddingData {
107 /// The type of object returned, usually "embedding".
108 pub object: String,
109 /// The position/index of this embedding in the input array.
110 pub index: u32,
111 /// The embedding vector itself.
112 pub embedding: Vec<f32>,
113}
114
115/// Usage statistics for an embeddings request, if provided by the API.
116#[derive(Debug, Deserialize)]
117pub struct EmbeddingsUsage {
118 /// Number of tokens present in the prompt(s).
119 pub prompt_tokens: u32,
120 /// Total number of tokens consumed by this request.
121 ///
122 /// For embeddings, this is typically the same as `prompt_tokens`, unless the API
123 /// changes how it reports usage data in the future.
124 pub total_tokens: u32,
125}
126
127/// Creates embeddings using the [OpenAI Embeddings API](https://platform.openai.com/docs/api-reference/embeddings).
128///
129/// # Parameters
130///
131/// * `client` - The [`OpenAIClient`](crate::config::OpenAIClient) to use for the request.
132/// * `request` - A [`CreateEmbeddingsRequest`] specifying the model and input(s).
133///
134/// # Returns
135///
136/// A [`CreateEmbeddingsResponse`] containing one or more embedding vectors.
137///
138/// # Errors
139///
140/// - [`OpenAIError::HTTPError`]: if the request fails at the network layer.
141/// - [`OpenAIError::DeserializeError`]: if the response fails to parse.
142/// - [`OpenAIError::APIError`]: if OpenAI returns an error (e.g., invalid request).
143pub async fn create_embeddings(
144 client: &OpenAIClient,
145 request: &CreateEmbeddingsRequest,
146) -> Result<CreateEmbeddingsResponse, OpenAIError> {
147 // According to the OpenAI docs, the endpoint for embeddings is:
148 // POST /v1/embeddings
149 let endpoint = "embeddings";
150 post_json(client, endpoint, request).await
151}
152
153#[cfg(test)]
154mod tests {
155 /// # Tests for the `embeddings` module
156 ///
157 /// We rely on [`wiremock`](https://crates.io/crates/wiremock) to mock responses from the
158 /// `/v1/embeddings` endpoint. The tests ensure:
159 /// 1. A **success** case where we receive a valid embedding response (`CreateEmbeddingsResponse`).
160 /// 2. A **failure** case returning an OpenAI-style error (mapped to `OpenAIError::APIError`).
161 /// 3. A **deserialization error** case when the JSON is malformed.
162 ///
163 use super::*;
164 use crate::config::OpenAIClient;
165 use crate::error::OpenAIError;
166 use serde_json::json;
167 use wiremock::matchers::{method, path};
168 use wiremock::{Mock, MockServer, ResponseTemplate};
169
170 #[tokio::test]
171 async fn test_create_embeddings_success() {
172 // Start the local mock server
173 let mock_server = MockServer::start().await;
174
175 // Define a successful response JSON
176 let success_body = json!({
177 "object": "list",
178 "data": [
179 {
180 "object": "embedding",
181 "index": 0,
182 "embedding": [0.123, -0.456, 0.789]
183 },
184 {
185 "object": "embedding",
186 "index": 1,
187 "embedding": [0.111, 0.222, 0.333]
188 }
189 ],
190 "model": "text-embedding-ada-002",
191 "usage": {
192 "prompt_tokens": 5,
193 "total_tokens": 5
194 }
195 });
196
197 // Mock a POST to /v1/embeddings
198 Mock::given(method("POST"))
199 .and(path("/embeddings"))
200 .respond_with(ResponseTemplate::new(200).set_body_json(success_body))
201 .mount(&mock_server)
202 .await;
203
204 let client = OpenAIClient::builder()
205 .with_api_key("test-key")
206 .with_base_url(&mock_server.uri())
207 .build()
208 .unwrap();
209
210 let req = CreateEmbeddingsRequest {
211 model: "text-embedding-ada-002".into(),
212 input: EmbeddingsInput::Strings(vec!["Hello".to_string(), "World".to_string()]),
213 user: None,
214 };
215
216 let result = create_embeddings(&client, &req).await;
217 assert!(result.is_ok(), "Expected Ok, got: {:?}", result);
218
219 let resp = result.unwrap();
220 assert_eq!(resp.object, "list");
221 assert_eq!(resp.data.len(), 2);
222 assert_eq!(resp.model, "text-embedding-ada-002".into());
223
224 let first = &resp.data[0];
225 assert_eq!(first.object, "embedding");
226 assert_eq!(first.index, 0);
227 assert_eq!(first.embedding, vec![0.123, -0.456, 0.789]);
228
229 let usage = resp.usage.as_ref().unwrap();
230 assert_eq!(usage.prompt_tokens, 5);
231 assert_eq!(usage.total_tokens, 5);
232 }
233
234 #[tokio::test]
235 async fn test_create_embeddings_api_error() {
236 let mock_server = MockServer::start().await;
237
238 // Simulate a 400 error with an OpenAI-style error body
239 let error_body = json!({
240 "error": {
241 "message": "Invalid model: text-embedding-ada-999",
242 "type": "invalid_request_error",
243 "code": "model_invalid"
244 }
245 });
246
247 Mock::given(method("POST"))
248 .and(path("/embeddings"))
249 .respond_with(ResponseTemplate::new(400).set_body_json(error_body))
250 .mount(&mock_server)
251 .await;
252
253 let client = OpenAIClient::builder()
254 .with_api_key("test-key")
255 .with_base_url(&mock_server.uri())
256 .build()
257 .unwrap();
258
259 let req = CreateEmbeddingsRequest {
260 model: "text-embedding-ada-999".into(),
261 input: EmbeddingsInput::String("test input".to_string()),
262 user: Some("user-123".to_string()),
263 };
264
265 let result = create_embeddings(&client, &req).await;
266 match result {
267 Err(OpenAIError::APIError { message, .. }) => {
268 assert!(message.contains("Invalid model: text-embedding-ada-999"));
269 }
270 other => panic!("Expected APIError, got: {:?}", other),
271 }
272 }
273
274 #[tokio::test]
275 async fn test_create_embeddings_deserialize_error() {
276 let mock_server = MockServer::start().await;
277
278 // Return 200 but invalid or mismatched JSON
279 let malformed_json = r#"{
280 "object": "list",
281 "data": "should be an array of embeddings, not a string",
282 "model": "text-embedding-ada-002"
283 }"#;
284
285 Mock::given(method("POST"))
286 .and(path("/embeddings"))
287 .respond_with(
288 ResponseTemplate::new(200).set_body_raw(malformed_json, "application/json"),
289 )
290 .mount(&mock_server)
291 .await;
292
293 let client = OpenAIClient::builder()
294 .with_api_key("test-key")
295 .with_base_url(&mock_server.uri())
296 .build()
297 .unwrap();
298
299 let req = CreateEmbeddingsRequest {
300 model: "text-embedding-ada-002".into(),
301 input: EmbeddingsInput::String("Hello".to_string()),
302 user: None,
303 };
304
305 let result = create_embeddings(&client, &req).await;
306 match result {
307 Err(OpenAIError::DeserializeError(_)) => {
308 // success
309 }
310 other => panic!("Expected DeserializeError, got {:?}", other),
311 }
312 }
313}