Skip to main content

openai_tools/embedding/
request.rs

1//! OpenAI Embeddings API Request Module
2//!
3//! This module provides the functionality to build and send requests to the OpenAI Embeddings API.
4//! It offers a builder pattern for constructing embedding requests, allowing you to convert text
5//! into numerical vector representations that capture semantic meaning.
6//!
7//! # Key Features
8//!
9//! - **Builder Pattern**: Fluent API for constructing embedding requests
10//! - **Single & Batch Input**: Support for single text or multiple texts at once
11//! - **Encoding Formats**: Support for `float` and `base64` output formats
12//! - **Error Handling**: Robust error management and validation
13//!
14//! # Quick Start
15//!
16//! ```rust,no_run
17//! use openai_tools::embedding::request::Embedding;
18//! use openai_tools::common::models::EmbeddingModel;
19//!
20//! #[tokio::main]
21//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
22//!     // Initialize the embedding client
23//!     let mut embedding = Embedding::new()?;
24//!
25//!     // Generate embedding for a single text
26//!     let response = embedding
27//!         .model(EmbeddingModel::TextEmbedding3Small)
28//!         .input_text("Hello, world!")
29//!         .embed()
30//!         .await?;
31//!
32//!     let vector = response.data[0].embedding.as_1d().unwrap();
33//!     println!("Embedding dimension: {}", vector.len());
34//!     Ok(())
35//! }
36//! ```
37//!
38//! # Batch Processing
39//!
40//! ```rust,no_run
41//! use openai_tools::embedding::request::Embedding;
42//! use openai_tools::common::models::EmbeddingModel;
43//!
44//! #[tokio::main]
45//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
46//!     let mut embedding = Embedding::new()?;
47//!
48//!     // Embed multiple texts in a single request
49//!     let texts = vec!["First text", "Second text", "Third text"];
50//!
51//!     let response = embedding
52//!         .model(EmbeddingModel::TextEmbedding3Small)
53//!         .input_text_array(texts)
54//!         .embed()
55//!         .await?;
56//!
57//!     for data in &response.data {
58//!         println!("Index {}: {} dimensions",
59//!                  data.index,
60//!                  data.embedding.as_1d().unwrap().len());
61//!     }
62//!     Ok(())
63//! }
64//! ```
65
66use crate::common::auth::{AuthProvider, OpenAIAuth};
67use crate::common::client::create_http_client;
68use crate::common::errors::{ErrorResponse, OpenAIToolError, Result};
69use crate::common::models::EmbeddingModel;
70use crate::embedding::response::Response;
71use core::str;
72use serde::{Deserialize, Serialize};
73use std::time::Duration;
74
75/// Internal structure for handling input text in embedding requests.
76///
77/// This struct supports two input formats:
78/// - Single text string (`input_text`)
79/// - Array of text strings (`input_text_array`)
80///
81/// The custom `Serialize` implementation ensures proper JSON formatting
82/// based on which input type is provided.
83#[derive(Debug, Clone, Deserialize, Default)]
84struct Input {
85    /// Single input text for embedding
86    #[serde(skip_serializing_if = "String::is_empty")]
87    input_text: String,
88    /// Array of input texts for batch embedding
89    #[serde(skip_serializing_if = "Vec::is_empty")]
90    input_text_array: Vec<String>,
91}
92
93impl Input {
94    /// Creates an Input from a single text string.
95    ///
96    /// # Arguments
97    ///
98    /// * `input_text` - The text to embed
99    ///
100    /// # Returns
101    ///
102    /// A new `Input` instance with the single text set
103    pub fn from_text(input_text: &str) -> Self {
104        Self { input_text: input_text.to_string(), input_text_array: vec![] }
105    }
106
107    /// Creates an Input from an array of text strings.
108    ///
109    /// # Arguments
110    ///
111    /// * `input_text_array` - Vector of texts to embed
112    ///
113    /// # Returns
114    ///
115    /// A new `Input` instance with the text array set
116    pub fn from_text_array(input_text_array: Vec<String>) -> Self {
117        Self { input_text: String::new(), input_text_array }
118    }
119}
120
121/// Custom serialization for Input to match OpenAI API format.
122///
123/// The OpenAI Embeddings API accepts either a single string or an array of strings
124/// for the `input` field. This implementation serializes to the appropriate format
125/// based on which field is populated.
126impl Serialize for Input {
127    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
128    where
129        S: serde::Serializer,
130    {
131        if !self.input_text.is_empty() && self.input_text_array.is_empty() {
132            self.input_text.serialize(serializer)
133        } else if self.input_text.is_empty() && !self.input_text_array.is_empty() {
134            self.input_text_array.serialize(serializer)
135        } else {
136            // Default to empty string if both are empty
137            "".serialize(serializer)
138        }
139    }
140}
141
142/// Request body structure for the OpenAI Embeddings API.
143///
144/// Contains all parameters that can be sent to the API endpoint.
145#[derive(Debug, Clone, Deserialize, Serialize, Default)]
146struct Body {
147    /// The model to use for embedding generation
148    model: EmbeddingModel,
149    /// The input text(s) to embed
150    input: Input,
151    /// The format for the output embeddings ("float" or "base64")
152    encoding_format: Option<String>,
153}
154
155/// Default API path for Embeddings
156const EMBEDDINGS_PATH: &str = "embeddings";
157
158/// Main struct for building and sending embedding requests to the OpenAI API.
159///
160/// This struct provides a builder pattern interface for constructing embedding
161/// requests with various parameters. Use [`Embedding::new()`] to create a new
162/// instance, then chain methods to configure the request before calling [`embed()`].
163///
164/// # Providers
165///
166/// The client supports two providers:
167/// - **OpenAI**: Standard OpenAI API (default)
168/// - **Azure**: Azure OpenAI Service
169///
170/// # Example
171///
172/// ```rust,no_run
173/// use openai_tools::embedding::request::Embedding;
174/// use openai_tools::common::models::EmbeddingModel;
175///
176/// #[tokio::main]
177/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
178///     let mut embedding = Embedding::new()?;
179///
180///     let response = embedding
181///         .model(EmbeddingModel::TextEmbedding3Small)
182///         .input_text("Sample text")
183///         .embed()
184///         .await?;
185///
186///     Ok(())
187/// }
188/// ```
189pub struct Embedding {
190    /// Authentication provider (OpenAI or Azure)
191    auth: AuthProvider,
192    /// Request body containing model and input parameters
193    body: Body,
194    /// Optional request timeout duration
195    timeout: Option<Duration>,
196}
197
198impl Embedding {
199    /// Creates a new Embedding instance for OpenAI API.
200    ///
201    /// Initializes the embedding client by loading the OpenAI API key from
202    /// the environment variable `OPENAI_API_KEY`. Supports `.env` file loading
203    /// via dotenvy.
204    ///
205    /// # Returns
206    ///
207    /// * `Ok(Embedding)` - A new embedding instance ready for configuration
208    /// * `Err(OpenAIToolError)` - If the API key is not found in the environment
209    ///
210    /// # Example
211    ///
212    /// ```rust,no_run
213    /// use openai_tools::embedding::request::Embedding;
214    ///
215    /// let embedding = Embedding::new().expect("API key should be set");
216    /// ```
217    pub fn new() -> Result<Self> {
218        let auth = AuthProvider::openai_from_env()?;
219        let body = Body::default();
220        Ok(Self { auth, body, timeout: None })
221    }
222
223    /// Creates a new Embedding instance with a custom authentication provider
224    ///
225    /// Use this to explicitly configure OpenAI or Azure authentication.
226    ///
227    /// # Arguments
228    ///
229    /// * `auth` - The authentication provider
230    ///
231    /// # Returns
232    ///
233    /// A new Embedding instance with the specified auth provider
234    pub fn with_auth(auth: AuthProvider) -> Self {
235        Self { auth, body: Body::default(), timeout: None }
236    }
237
238    /// Creates a new Embedding instance for Azure OpenAI API
239    ///
240    /// Loads configuration from Azure-specific environment variables.
241    ///
242    /// # Returns
243    ///
244    /// `Result<Embedding>` - Configured for Azure or error if env vars missing
245    pub fn azure() -> Result<Self> {
246        let auth = AuthProvider::azure_from_env()?;
247        Ok(Self { auth, body: Body::default(), timeout: None })
248    }
249
250    /// Creates a new Embedding instance by auto-detecting the provider
251    ///
252    /// Tries Azure first (if AZURE_OPENAI_API_KEY is set), then falls back to OpenAI.
253    pub fn detect_provider() -> Result<Self> {
254        let auth = AuthProvider::from_env()?;
255        Ok(Self { auth, body: Body::default(), timeout: None })
256    }
257
258    /// Creates a new Embedding instance with URL-based provider detection
259    ///
260    /// Analyzes the URL pattern to determine the provider:
261    /// - URLs containing `.openai.azure.com` → Azure
262    /// - All other URLs → OpenAI-compatible
263    ///
264    /// # Arguments
265    ///
266    /// * `base_url` - The complete base URL for API requests
267    /// * `api_key` - The API key or token
268    pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
269        let auth = AuthProvider::from_url_with_key(base_url, api_key);
270        Self { auth, body: Body::default(), timeout: None }
271    }
272
273    /// Creates a new Embedding instance from URL using environment variables
274    ///
275    /// Analyzes the URL pattern to determine the provider, then loads
276    /// credentials from the appropriate environment variables.
277    pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
278        let auth = AuthProvider::from_url(url)?;
279        Ok(Self { auth, body: Body::default(), timeout: None })
280    }
281
282    /// Returns the authentication provider
283    pub fn auth(&self) -> &AuthProvider {
284        &self.auth
285    }
286
287    /// Sets a custom API endpoint URL (OpenAI only)
288    ///
289    /// Use this to point to alternative OpenAI-compatible APIs.
290    ///
291    /// # Arguments
292    ///
293    /// * `url` - The base URL (e.g., "https://my-proxy.example.com/v1")
294    ///
295    /// # Returns
296    ///
297    /// A mutable reference to self for method chaining
298    pub fn base_url<T: AsRef<str>>(&mut self, url: T) -> &mut Self {
299        if let AuthProvider::OpenAI(ref openai_auth) = self.auth {
300            let new_auth = OpenAIAuth::new(openai_auth.api_key()).with_base_url(url.as_ref());
301            self.auth = AuthProvider::OpenAI(new_auth);
302        } else {
303            tracing::warn!("base_url() is only supported for OpenAI provider. Use azure() or with_auth() for Azure.");
304        }
305        self
306    }
307
308    /// Sets the model to use for embedding generation.
309    ///
310    /// # Arguments
311    ///
312    /// * `model` - The embedding model to use
313    ///
314    /// # Returns
315    ///
316    /// A mutable reference to self for method chaining
317    ///
318    /// # Example
319    ///
320    /// ```rust,no_run
321    /// use openai_tools::embedding::request::Embedding;
322    /// use openai_tools::common::models::EmbeddingModel;
323    ///
324    /// let mut embedding = Embedding::new().unwrap();
325    /// embedding.model(EmbeddingModel::TextEmbedding3Small);
326    /// ```
327    pub fn model(&mut self, model: EmbeddingModel) -> &mut Self {
328        self.body.model = model;
329        self
330    }
331
332    /// Sets the model using a string ID (for backward compatibility).
333    ///
334    /// Prefer using [`model`] with `EmbeddingModel` enum for type safety.
335    ///
336    /// # Arguments
337    ///
338    /// * `model_id` - The model identifier string (e.g., "text-embedding-3-small")
339    ///
340    /// # Returns
341    ///
342    /// A mutable reference to self for method chaining
343    #[deprecated(since = "0.2.0", note = "Use `model(EmbeddingModel)` instead for type safety")]
344    pub fn model_id<T: AsRef<str>>(&mut self, model_id: T) -> &mut Self {
345        self.body.model = EmbeddingModel::from(model_id.as_ref());
346        self
347    }
348
349    /// Sets the request timeout duration.
350    ///
351    /// # Arguments
352    ///
353    /// * `timeout` - The maximum time to wait for a response
354    ///
355    /// # Returns
356    ///
357    /// A mutable reference to self for method chaining
358    ///
359    /// # Example
360    ///
361    /// ```rust,no_run
362    /// use std::time::Duration;
363    /// use openai_tools::embedding::request::Embedding;
364    /// use openai_tools::common::models::EmbeddingModel;
365    ///
366    /// let mut embedding = Embedding::new().unwrap();
367    /// embedding.model(EmbeddingModel::TextEmbedding3Small)
368    ///     .timeout(Duration::from_secs(30));
369    /// ```
370    pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
371        self.timeout = Some(timeout);
372        self
373    }
374
375    /// Sets a single text input for embedding.
376    ///
377    /// Use this method when you want to embed a single piece of text.
378    /// For multiple texts, use [`input_text_array`] instead.
379    ///
380    /// # Arguments
381    ///
382    /// * `input_text` - The text to convert into an embedding vector
383    ///
384    /// # Returns
385    ///
386    /// A mutable reference to self for method chaining
387    ///
388    /// # Example
389    ///
390    /// ```rust,no_run
391    /// # use openai_tools::embedding::request::Embedding;
392    /// # let mut embedding = Embedding::new().unwrap();
393    /// embedding.input_text("Hello, world!");
394    /// ```
395    pub fn input_text<T: AsRef<str>>(&mut self, input_text: T) -> &mut Self {
396        self.body.input = Input::from_text(input_text.as_ref());
397        self
398    }
399
400    /// Sets multiple text inputs for batch embedding.
401    ///
402    /// Use this method when you want to embed multiple texts in a single API call.
403    /// This is more efficient than making separate requests for each text.
404    ///
405    /// # Arguments
406    ///
407    /// * `input_text_array` - Vector of texts to convert into embedding vectors
408    ///
409    /// # Returns
410    ///
411    /// A mutable reference to self for method chaining
412    ///
413    /// # Example
414    ///
415    /// ```rust,no_run
416    /// # use openai_tools::embedding::request::Embedding;
417    /// # let mut embedding = Embedding::new().unwrap();
418    /// let texts = vec!["First text", "Second text", "Third text"];
419    /// embedding.input_text_array(texts);
420    /// ```
421    pub fn input_text_array<T: AsRef<str>>(&mut self, input_text_array: Vec<T>) -> &mut Self {
422        let input_strings = input_text_array.into_iter().map(|s| s.as_ref().to_string()).collect();
423        self.body.input = Input::from_text_array(input_strings);
424        self
425    }
426
427    /// Sets the encoding format for the output embeddings.
428    ///
429    /// # Arguments
430    ///
431    /// * `encoding_format` - Either "float" (default) or "base64"
432    ///   - `"float"`: Returns embeddings as arrays of floating point numbers
433    ///   - `"base64"`: Returns embeddings as base64-encoded strings (more compact)
434    ///
435    /// # Returns
436    ///
437    /// A mutable reference to self for method chaining
438    ///
439    /// # Panics
440    ///
441    /// Panics if `encoding_format` is not "float" or "base64"
442    ///
443    /// # Example
444    ///
445    /// ```rust,no_run
446    /// # use openai_tools::embedding::request::Embedding;
447    /// # let mut embedding = Embedding::new().unwrap();
448    /// embedding.encoding_format("float");
449    /// ```
450    pub fn encoding_format<T: AsRef<str>>(&mut self, encoding_format: T) -> &mut Self {
451        let encoding_format = encoding_format.as_ref();
452        assert!(encoding_format == "float" || encoding_format == "base64", "encoding_format must be either 'float' or 'base64'");
453        self.body.encoding_format = Some(encoding_format.to_string());
454        self
455    }
456
457    /// Sends the embedding request to the OpenAI API.
458    ///
459    /// This method validates the request parameters, constructs the HTTP request,
460    /// sends it to the OpenAI Embeddings API endpoint, and parses the response.
461    ///
462    /// # Returns
463    ///
464    /// * `Ok(Response)` - The embedding response containing vectors and metadata
465    /// * `Err(OpenAIToolError)` - If validation fails, the request fails, or parsing fails
466    ///
467    /// # Errors
468    ///
469    /// Returns an error if:
470    /// - API key is not set
471    /// - Model ID is not set
472    /// - Input text is not set
473    /// - Network request fails
474    /// - Response parsing fails
475    ///
476    /// # Example
477    ///
478    /// ```rust,no_run
479    /// # use openai_tools::embedding::request::Embedding;
480    /// # use openai_tools::common::models::EmbeddingModel;
481    /// # #[tokio::main]
482    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
483    /// let mut embedding = Embedding::new()?;
484    /// let response = embedding
485    ///     .model(EmbeddingModel::TextEmbedding3Small)
486    ///     .input_text("Hello, world!")
487    ///     .embed()
488    ///     .await?;
489    /// # Ok(())
490    /// # }
491    /// ```
492    pub async fn embed(&self) -> Result<Response> {
493        // Validate that input text is set
494        if self.body.input.input_text.is_empty() && self.body.input.input_text_array.is_empty() {
495            return Err(OpenAIToolError::Error("Input text is not set.".into()));
496        }
497
498        let body = serde_json::to_string(&self.body)?;
499
500        let client = create_http_client(self.timeout)?;
501        let mut headers = request::header::HeaderMap::new();
502        headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
503        headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
504
505        // Apply provider-specific authentication headers
506        self.auth.apply_headers(&mut headers)?;
507
508        if cfg!(test) {
509            // Replace API key with a placeholder in debug mode
510            let body_for_debug = serde_json::to_string_pretty(&self.body).unwrap().replace(self.auth.api_key(), "*************");
511            tracing::info!("Request body: {}", body_for_debug);
512        }
513
514        // Get the endpoint URL from the auth provider
515        let endpoint = self.auth.endpoint(EMBEDDINGS_PATH);
516
517        let response = client.post(&endpoint).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
518        let status = response.status();
519        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
520
521        if cfg!(test) {
522            tracing::info!("Response content: {}", content);
523        }
524
525        if !status.is_success() {
526            if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
527                return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
528            }
529            return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
530        }
531
532        serde_json::from_str::<Response>(&content).map_err(OpenAIToolError::SerdeJsonError)
533    }
534}