openai_tools/embedding/request.rs
1//! OpenAI Embeddings API Request Module
2//!
3//! This module provides the functionality to build and send requests to the OpenAI Embeddings API.
4//! It offers a builder pattern for constructing embedding requests, allowing you to convert text
5//! into numerical vector representations that capture semantic meaning.
6//!
7//! # Key Features
8//!
9//! - **Builder Pattern**: Fluent API for constructing embedding requests
10//! - **Single & Batch Input**: Support for single text or multiple texts at once
11//! - **Encoding Formats**: Support for `float` and `base64` output formats
12//! - **Error Handling**: Robust error management and validation
13//!
14//! # Quick Start
15//!
16//! ```rust,no_run
17//! use openai_tools::embedding::request::Embedding;
18//! use openai_tools::common::models::EmbeddingModel;
19//!
20//! #[tokio::main]
21//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
22//! // Initialize the embedding client
23//! let mut embedding = Embedding::new()?;
24//!
25//! // Generate embedding for a single text
26//! let response = embedding
27//! .model(EmbeddingModel::TextEmbedding3Small)
28//! .input_text("Hello, world!")
29//! .embed()
30//! .await?;
31//!
32//! let vector = response.data[0].embedding.as_1d().unwrap();
33//! println!("Embedding dimension: {}", vector.len());
34//! Ok(())
35//! }
36//! ```
37//!
38//! # Batch Processing
39//!
40//! ```rust,no_run
41//! use openai_tools::embedding::request::Embedding;
42//! use openai_tools::common::models::EmbeddingModel;
43//!
44//! #[tokio::main]
45//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
46//! let mut embedding = Embedding::new()?;
47//!
48//! // Embed multiple texts in a single request
49//! let texts = vec!["First text", "Second text", "Third text"];
50//!
51//! let response = embedding
52//! .model(EmbeddingModel::TextEmbedding3Small)
53//! .input_text_array(texts)
54//! .embed()
55//! .await?;
56//!
57//! for data in &response.data {
58//! println!("Index {}: {} dimensions",
59//! data.index,
60//! data.embedding.as_1d().unwrap().len());
61//! }
62//! Ok(())
63//! }
64//! ```
65
66use crate::common::auth::{AuthProvider, OpenAIAuth};
67use crate::common::client::create_http_client;
68use crate::common::errors::{ErrorResponse, OpenAIToolError, Result};
69use crate::common::models::EmbeddingModel;
70use crate::embedding::response::Response;
71use core::str;
72use serde::{Deserialize, Serialize};
73use std::time::Duration;
74
75/// Internal structure for handling input text in embedding requests.
76///
77/// This struct supports two input formats:
78/// - Single text string (`input_text`)
79/// - Array of text strings (`input_text_array`)
80///
81/// The custom `Serialize` implementation ensures proper JSON formatting
82/// based on which input type is provided.
83#[derive(Debug, Clone, Deserialize, Default)]
84struct Input {
85 /// Single input text for embedding
86 #[serde(skip_serializing_if = "String::is_empty")]
87 input_text: String,
88 /// Array of input texts for batch embedding
89 #[serde(skip_serializing_if = "Vec::is_empty")]
90 input_text_array: Vec<String>,
91}
92
93impl Input {
94 /// Creates an Input from a single text string.
95 ///
96 /// # Arguments
97 ///
98 /// * `input_text` - The text to embed
99 ///
100 /// # Returns
101 ///
102 /// A new `Input` instance with the single text set
103 pub fn from_text(input_text: &str) -> Self {
104 Self { input_text: input_text.to_string(), input_text_array: vec![] }
105 }
106
107 /// Creates an Input from an array of text strings.
108 ///
109 /// # Arguments
110 ///
111 /// * `input_text_array` - Vector of texts to embed
112 ///
113 /// # Returns
114 ///
115 /// A new `Input` instance with the text array set
116 pub fn from_text_array(input_text_array: Vec<String>) -> Self {
117 Self { input_text: String::new(), input_text_array }
118 }
119}
120
121/// Custom serialization for Input to match OpenAI API format.
122///
123/// The OpenAI Embeddings API accepts either a single string or an array of strings
124/// for the `input` field. This implementation serializes to the appropriate format
125/// based on which field is populated.
126impl Serialize for Input {
127 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
128 where
129 S: serde::Serializer,
130 {
131 if !self.input_text.is_empty() && self.input_text_array.is_empty() {
132 self.input_text.serialize(serializer)
133 } else if self.input_text.is_empty() && !self.input_text_array.is_empty() {
134 self.input_text_array.serialize(serializer)
135 } else {
136 // Default to empty string if both are empty
137 "".serialize(serializer)
138 }
139 }
140}
141
142/// Request body structure for the OpenAI Embeddings API.
143///
144/// Contains all parameters that can be sent to the API endpoint.
145#[derive(Debug, Clone, Deserialize, Serialize, Default)]
146struct Body {
147 /// The model to use for embedding generation
148 model: EmbeddingModel,
149 /// The input text(s) to embed
150 input: Input,
151 /// The format for the output embeddings ("float" or "base64")
152 encoding_format: Option<String>,
153}
154
155/// Default API path for Embeddings
156const EMBEDDINGS_PATH: &str = "embeddings";
157
158/// Main struct for building and sending embedding requests to the OpenAI API.
159///
160/// This struct provides a builder pattern interface for constructing embedding
161/// requests with various parameters. Use [`Embedding::new()`] to create a new
162/// instance, then chain methods to configure the request before calling [`embed()`].
163///
164/// # Providers
165///
166/// The client supports two providers:
167/// - **OpenAI**: Standard OpenAI API (default)
168/// - **Azure**: Azure OpenAI Service
169///
170/// # Example
171///
172/// ```rust,no_run
173/// use openai_tools::embedding::request::Embedding;
174/// use openai_tools::common::models::EmbeddingModel;
175///
176/// #[tokio::main]
177/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
178/// let mut embedding = Embedding::new()?;
179///
180/// let response = embedding
181/// .model(EmbeddingModel::TextEmbedding3Small)
182/// .input_text("Sample text")
183/// .embed()
184/// .await?;
185///
186/// Ok(())
187/// }
188/// ```
189pub struct Embedding {
190 /// Authentication provider (OpenAI or Azure)
191 auth: AuthProvider,
192 /// Request body containing model and input parameters
193 body: Body,
194 /// Optional request timeout duration
195 timeout: Option<Duration>,
196}
197
198impl Embedding {
199 /// Creates a new Embedding instance for OpenAI API.
200 ///
201 /// Initializes the embedding client by loading the OpenAI API key from
202 /// the environment variable `OPENAI_API_KEY`. Supports `.env` file loading
203 /// via dotenvy.
204 ///
205 /// # Returns
206 ///
207 /// * `Ok(Embedding)` - A new embedding instance ready for configuration
208 /// * `Err(OpenAIToolError)` - If the API key is not found in the environment
209 ///
210 /// # Example
211 ///
212 /// ```rust,no_run
213 /// use openai_tools::embedding::request::Embedding;
214 ///
215 /// let embedding = Embedding::new().expect("API key should be set");
216 /// ```
217 pub fn new() -> Result<Self> {
218 let auth = AuthProvider::openai_from_env()?;
219 let body = Body::default();
220 Ok(Self { auth, body, timeout: None })
221 }
222
223 /// Creates a new Embedding instance with a custom authentication provider
224 ///
225 /// Use this to explicitly configure OpenAI or Azure authentication.
226 ///
227 /// # Arguments
228 ///
229 /// * `auth` - The authentication provider
230 ///
231 /// # Returns
232 ///
233 /// A new Embedding instance with the specified auth provider
234 pub fn with_auth(auth: AuthProvider) -> Self {
235 Self { auth, body: Body::default(), timeout: None }
236 }
237
238 /// Creates a new Embedding instance for Azure OpenAI API
239 ///
240 /// Loads configuration from Azure-specific environment variables.
241 ///
242 /// # Returns
243 ///
244 /// `Result<Embedding>` - Configured for Azure or error if env vars missing
245 pub fn azure() -> Result<Self> {
246 let auth = AuthProvider::azure_from_env()?;
247 Ok(Self { auth, body: Body::default(), timeout: None })
248 }
249
250 /// Creates a new Embedding instance by auto-detecting the provider
251 ///
252 /// Tries Azure first (if AZURE_OPENAI_API_KEY is set), then falls back to OpenAI.
253 pub fn detect_provider() -> Result<Self> {
254 let auth = AuthProvider::from_env()?;
255 Ok(Self { auth, body: Body::default(), timeout: None })
256 }
257
258 /// Creates a new Embedding instance with URL-based provider detection
259 ///
260 /// Analyzes the URL pattern to determine the provider:
261 /// - URLs containing `.openai.azure.com` → Azure
262 /// - All other URLs → OpenAI-compatible
263 ///
264 /// # Arguments
265 ///
266 /// * `base_url` - The complete base URL for API requests
267 /// * `api_key` - The API key or token
268 pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
269 let auth = AuthProvider::from_url_with_key(base_url, api_key);
270 Self { auth, body: Body::default(), timeout: None }
271 }
272
273 /// Creates a new Embedding instance from URL using environment variables
274 ///
275 /// Analyzes the URL pattern to determine the provider, then loads
276 /// credentials from the appropriate environment variables.
277 pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
278 let auth = AuthProvider::from_url(url)?;
279 Ok(Self { auth, body: Body::default(), timeout: None })
280 }
281
282 /// Returns the authentication provider
283 pub fn auth(&self) -> &AuthProvider {
284 &self.auth
285 }
286
287 /// Sets a custom API endpoint URL (OpenAI only)
288 ///
289 /// Use this to point to alternative OpenAI-compatible APIs.
290 ///
291 /// # Arguments
292 ///
293 /// * `url` - The base URL (e.g., "https://my-proxy.example.com/v1")
294 ///
295 /// # Returns
296 ///
297 /// A mutable reference to self for method chaining
298 pub fn base_url<T: AsRef<str>>(&mut self, url: T) -> &mut Self {
299 if let AuthProvider::OpenAI(ref openai_auth) = self.auth {
300 let new_auth = OpenAIAuth::new(openai_auth.api_key()).with_base_url(url.as_ref());
301 self.auth = AuthProvider::OpenAI(new_auth);
302 } else {
303 tracing::warn!("base_url() is only supported for OpenAI provider. Use azure() or with_auth() for Azure.");
304 }
305 self
306 }
307
308 /// Sets the model to use for embedding generation.
309 ///
310 /// # Arguments
311 ///
312 /// * `model` - The embedding model to use
313 ///
314 /// # Returns
315 ///
316 /// A mutable reference to self for method chaining
317 ///
318 /// # Example
319 ///
320 /// ```rust,no_run
321 /// use openai_tools::embedding::request::Embedding;
322 /// use openai_tools::common::models::EmbeddingModel;
323 ///
324 /// let mut embedding = Embedding::new().unwrap();
325 /// embedding.model(EmbeddingModel::TextEmbedding3Small);
326 /// ```
327 pub fn model(&mut self, model: EmbeddingModel) -> &mut Self {
328 self.body.model = model;
329 self
330 }
331
332 /// Sets the model using a string ID (for backward compatibility).
333 ///
334 /// Prefer using [`model`] with `EmbeddingModel` enum for type safety.
335 ///
336 /// # Arguments
337 ///
338 /// * `model_id` - The model identifier string (e.g., "text-embedding-3-small")
339 ///
340 /// # Returns
341 ///
342 /// A mutable reference to self for method chaining
343 #[deprecated(since = "0.2.0", note = "Use `model(EmbeddingModel)` instead for type safety")]
344 pub fn model_id<T: AsRef<str>>(&mut self, model_id: T) -> &mut Self {
345 self.body.model = EmbeddingModel::from(model_id.as_ref());
346 self
347 }
348
349 /// Sets the request timeout duration.
350 ///
351 /// # Arguments
352 ///
353 /// * `timeout` - The maximum time to wait for a response
354 ///
355 /// # Returns
356 ///
357 /// A mutable reference to self for method chaining
358 ///
359 /// # Example
360 ///
361 /// ```rust,no_run
362 /// use std::time::Duration;
363 /// use openai_tools::embedding::request::Embedding;
364 /// use openai_tools::common::models::EmbeddingModel;
365 ///
366 /// let mut embedding = Embedding::new().unwrap();
367 /// embedding.model(EmbeddingModel::TextEmbedding3Small)
368 /// .timeout(Duration::from_secs(30));
369 /// ```
370 pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
371 self.timeout = Some(timeout);
372 self
373 }
374
375 /// Sets a single text input for embedding.
376 ///
377 /// Use this method when you want to embed a single piece of text.
378 /// For multiple texts, use [`input_text_array`] instead.
379 ///
380 /// # Arguments
381 ///
382 /// * `input_text` - The text to convert into an embedding vector
383 ///
384 /// # Returns
385 ///
386 /// A mutable reference to self for method chaining
387 ///
388 /// # Example
389 ///
390 /// ```rust,no_run
391 /// # use openai_tools::embedding::request::Embedding;
392 /// # let mut embedding = Embedding::new().unwrap();
393 /// embedding.input_text("Hello, world!");
394 /// ```
395 pub fn input_text<T: AsRef<str>>(&mut self, input_text: T) -> &mut Self {
396 self.body.input = Input::from_text(input_text.as_ref());
397 self
398 }
399
400 /// Sets multiple text inputs for batch embedding.
401 ///
402 /// Use this method when you want to embed multiple texts in a single API call.
403 /// This is more efficient than making separate requests for each text.
404 ///
405 /// # Arguments
406 ///
407 /// * `input_text_array` - Vector of texts to convert into embedding vectors
408 ///
409 /// # Returns
410 ///
411 /// A mutable reference to self for method chaining
412 ///
413 /// # Example
414 ///
415 /// ```rust,no_run
416 /// # use openai_tools::embedding::request::Embedding;
417 /// # let mut embedding = Embedding::new().unwrap();
418 /// let texts = vec!["First text", "Second text", "Third text"];
419 /// embedding.input_text_array(texts);
420 /// ```
421 pub fn input_text_array<T: AsRef<str>>(&mut self, input_text_array: Vec<T>) -> &mut Self {
422 let input_strings = input_text_array.into_iter().map(|s| s.as_ref().to_string()).collect();
423 self.body.input = Input::from_text_array(input_strings);
424 self
425 }
426
427 /// Sets the encoding format for the output embeddings.
428 ///
429 /// # Arguments
430 ///
431 /// * `encoding_format` - Either "float" (default) or "base64"
432 /// - `"float"`: Returns embeddings as arrays of floating point numbers
433 /// - `"base64"`: Returns embeddings as base64-encoded strings (more compact)
434 ///
435 /// # Returns
436 ///
437 /// A mutable reference to self for method chaining
438 ///
439 /// # Panics
440 ///
441 /// Panics if `encoding_format` is not "float" or "base64"
442 ///
443 /// # Example
444 ///
445 /// ```rust,no_run
446 /// # use openai_tools::embedding::request::Embedding;
447 /// # let mut embedding = Embedding::new().unwrap();
448 /// embedding.encoding_format("float");
449 /// ```
450 pub fn encoding_format<T: AsRef<str>>(&mut self, encoding_format: T) -> &mut Self {
451 let encoding_format = encoding_format.as_ref();
452 assert!(encoding_format == "float" || encoding_format == "base64", "encoding_format must be either 'float' or 'base64'");
453 self.body.encoding_format = Some(encoding_format.to_string());
454 self
455 }
456
457 /// Sends the embedding request to the OpenAI API.
458 ///
459 /// This method validates the request parameters, constructs the HTTP request,
460 /// sends it to the OpenAI Embeddings API endpoint, and parses the response.
461 ///
462 /// # Returns
463 ///
464 /// * `Ok(Response)` - The embedding response containing vectors and metadata
465 /// * `Err(OpenAIToolError)` - If validation fails, the request fails, or parsing fails
466 ///
467 /// # Errors
468 ///
469 /// Returns an error if:
470 /// - API key is not set
471 /// - Model ID is not set
472 /// - Input text is not set
473 /// - Network request fails
474 /// - Response parsing fails
475 ///
476 /// # Example
477 ///
478 /// ```rust,no_run
479 /// # use openai_tools::embedding::request::Embedding;
480 /// # use openai_tools::common::models::EmbeddingModel;
481 /// # #[tokio::main]
482 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
483 /// let mut embedding = Embedding::new()?;
484 /// let response = embedding
485 /// .model(EmbeddingModel::TextEmbedding3Small)
486 /// .input_text("Hello, world!")
487 /// .embed()
488 /// .await?;
489 /// # Ok(())
490 /// # }
491 /// ```
492 pub async fn embed(&self) -> Result<Response> {
493 // Validate that input text is set
494 if self.body.input.input_text.is_empty() && self.body.input.input_text_array.is_empty() {
495 return Err(OpenAIToolError::Error("Input text is not set.".into()));
496 }
497
498 let body = serde_json::to_string(&self.body)?;
499
500 let client = create_http_client(self.timeout)?;
501 let mut headers = request::header::HeaderMap::new();
502 headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
503 headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
504
505 // Apply provider-specific authentication headers
506 self.auth.apply_headers(&mut headers)?;
507
508 if cfg!(test) {
509 // Replace API key with a placeholder in debug mode
510 let body_for_debug = serde_json::to_string_pretty(&self.body).unwrap().replace(self.auth.api_key(), "*************");
511 tracing::info!("Request body: {}", body_for_debug);
512 }
513
514 // Get the endpoint URL from the auth provider
515 let endpoint = self.auth.endpoint(EMBEDDINGS_PATH);
516
517 let response = client.post(&endpoint).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
518 let status = response.status();
519 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
520
521 if cfg!(test) {
522 tracing::info!("Response content: {}", content);
523 }
524
525 if !status.is_success() {
526 if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
527 return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
528 }
529 return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
530 }
531
532 serde_json::from_str::<Response>(&content).map_err(OpenAIToolError::SerdeJsonError)
533 }
534}