openai_tools/embedding/
request.rs

1//! OpenAI Embeddings API Request Module
2//!
3//! This module provides the functionality to build and send requests to the OpenAI Embeddings API.
4//! It offers a builder pattern for constructing embedding requests, allowing you to convert text
5//! into numerical vector representations that capture semantic meaning.
6//!
7//! # Key Features
8//!
9//! - **Builder Pattern**: Fluent API for constructing embedding requests
10//! - **Single & Batch Input**: Support for single text or multiple texts at once
11//! - **Encoding Formats**: Support for `float` and `base64` output formats
12//! - **Error Handling**: Robust error management and validation
13//!
14//! # Quick Start
15//!
16//! ```rust,no_run
17//! use openai_tools::embedding::request::Embedding;
18//!
19//! #[tokio::main]
20//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
21//!     // Initialize the embedding client
22//!     let mut embedding = Embedding::new()?;
23//!     
24//!     // Generate embedding for a single text
25//!     let response = embedding
26//!         .model("text-embedding-3-small")
27//!         .input_text("Hello, world!")
28//!         .embed()
29//!         .await?;
30//!         
31//!     let vector = response.data[0].embedding.as_1d().unwrap();
32//!     println!("Embedding dimension: {}", vector.len());
33//!     Ok(())
34//! }
35//! ```
36//!
37//! # Batch Processing
38//!
39//! ```rust,no_run
40//! use openai_tools::embedding::request::Embedding;
41//!
42//! #[tokio::main]
43//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
44//!     let mut embedding = Embedding::new()?;
45//!     
46//!     // Embed multiple texts in a single request
47//!     let texts = vec!["First text", "Second text", "Third text"];
48//!     
49//!     let response = embedding
50//!         .model("text-embedding-3-small")
51//!         .input_text_array(texts)
52//!         .embed()
53//!         .await?;
54//!         
55//!     for data in &response.data {
56//!         println!("Index {}: {} dimensions",
57//!                  data.index,
58//!                  data.embedding.as_1d().unwrap().len());
59//!     }
60//!     Ok(())
61//! }
62//! ```
63
64use crate::common::errors::{OpenAIToolError, Result};
65use crate::embedding::response::Response;
66use core::str;
67use dotenvy::dotenv;
68use serde::{Deserialize, Serialize};
69use std::env;
70
71/// Internal structure for handling input text in embedding requests.
72///
73/// This struct supports two input formats:
74/// - Single text string (`input_text`)
75/// - Array of text strings (`input_text_array`)
76///
77/// The custom `Serialize` implementation ensures proper JSON formatting
78/// based on which input type is provided.
79#[derive(Debug, Clone, Deserialize, Default)]
80struct Input {
81    /// Single input text for embedding
82    #[serde(skip_serializing_if = "String::is_empty")]
83    input_text: String,
84    /// Array of input texts for batch embedding
85    #[serde(skip_serializing_if = "Vec::is_empty")]
86    input_text_array: Vec<String>,
87}
88
89impl Input {
90    /// Creates an Input from a single text string.
91    ///
92    /// # Arguments
93    ///
94    /// * `input_text` - The text to embed
95    ///
96    /// # Returns
97    ///
98    /// A new `Input` instance with the single text set
99    pub fn from_text(input_text: &str) -> Self {
100        Self { input_text: input_text.to_string(), input_text_array: vec![] }
101    }
102
103    /// Creates an Input from an array of text strings.
104    ///
105    /// # Arguments
106    ///
107    /// * `input_text_array` - Vector of texts to embed
108    ///
109    /// # Returns
110    ///
111    /// A new `Input` instance with the text array set
112    pub fn from_text_array(input_text_array: Vec<String>) -> Self {
113        Self { input_text: String::new(), input_text_array }
114    }
115}
116
117/// Custom serialization for Input to match OpenAI API format.
118///
119/// The OpenAI Embeddings API accepts either a single string or an array of strings
120/// for the `input` field. This implementation serializes to the appropriate format
121/// based on which field is populated.
122impl Serialize for Input {
123    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
124    where
125        S: serde::Serializer,
126    {
127        if !self.input_text.is_empty() && self.input_text_array.is_empty() {
128            self.input_text.serialize(serializer)
129        } else if self.input_text.is_empty() && !self.input_text_array.is_empty() {
130            self.input_text_array.serialize(serializer)
131        } else {
132            // Default to empty string if both are empty
133            "".serialize(serializer)
134        }
135    }
136}
137
138/// Request body structure for the OpenAI Embeddings API.
139///
140/// Contains all parameters that can be sent to the API endpoint.
141#[derive(Debug, Clone, Deserialize, Serialize, Default)]
142struct Body {
143    /// The model ID to use for embedding generation
144    model: String,
145    /// The input text(s) to embed
146    input: Input,
147    /// The format for the output embeddings ("float" or "base64")
148    encoding_format: Option<String>,
149}
150
151/// Main struct for building and sending embedding requests to the OpenAI API.
152///
153/// This struct provides a builder pattern interface for constructing embedding
154/// requests with various parameters. Use [`Embedding::new()`] to create a new
155/// instance, then chain methods to configure the request before calling [`embed()`].
156///
157/// # Example
158///
159/// ```rust,no_run
160/// use openai_tools::embedding::request::Embedding;
161///
162/// #[tokio::main]
163/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
164///     let mut embedding = Embedding::new()?;
165///     
166///     let response = embedding
167///         .model("text-embedding-3-small")
168///         .input_text("Sample text")
169///         .embed()
170///         .await?;
171///         
172///     Ok(())
173/// }
174/// ```
175pub struct Embedding {
176    /// OpenAI API key for authentication
177    api_key: String,
178    /// Request body containing model and input parameters
179    body: Body,
180}
181
182impl Embedding {
183    /// Creates a new Embedding instance.
184    ///
185    /// Initializes the embedding client by loading the OpenAI API key from
186    /// the environment variable `OPENAI_API_KEY`. Supports `.env` file loading
187    /// via dotenvy.
188    ///
189    /// # Returns
190    ///
191    /// * `Ok(Embedding)` - A new embedding instance ready for configuration
192    /// * `Err(OpenAIToolError)` - If the API key is not found in the environment
193    ///
194    /// # Example
195    ///
196    /// ```rust,no_run
197    /// use openai_tools::embedding::request::Embedding;
198    ///
199    /// let embedding = Embedding::new().expect("API key should be set");
200    /// ```
201    pub fn new() -> Result<Self> {
202        dotenv().ok();
203        let api_key = env::var("OPENAI_API_KEY").map_err(|e| OpenAIToolError::Error(format!("OPENAI_API_KEY not set in environment: {}", e)))?;
204        let body = Body::default();
205        Ok(Self { api_key, body })
206    }
207
208    /// Sets the model to use for embedding generation.
209    ///
210    /// # Arguments
211    ///
212    /// * `model` - The model identifier (e.g., "text-embedding-3-small", "text-embedding-3-large")
213    ///
214    /// # Returns
215    ///
216    /// A mutable reference to self for method chaining
217    ///
218    /// # Example
219    ///
220    /// ```rust,no_run
221    /// # use openai_tools::embedding::request::Embedding;
222    /// # let mut embedding = Embedding::new().unwrap();
223    /// embedding.model("text-embedding-3-small");
224    /// ```
225    pub fn model<T: AsRef<str>>(&mut self, model: T) -> &mut Self {
226        self.body.model = model.as_ref().to_string();
227        self
228    }
229
230    /// Sets a single text input for embedding.
231    ///
232    /// Use this method when you want to embed a single piece of text.
233    /// For multiple texts, use [`input_text_array`] instead.
234    ///
235    /// # Arguments
236    ///
237    /// * `input_text` - The text to convert into an embedding vector
238    ///
239    /// # Returns
240    ///
241    /// A mutable reference to self for method chaining
242    ///
243    /// # Example
244    ///
245    /// ```rust,no_run
246    /// # use openai_tools::embedding::request::Embedding;
247    /// # let mut embedding = Embedding::new().unwrap();
248    /// embedding.input_text("Hello, world!");
249    /// ```
250    pub fn input_text<T: AsRef<str>>(&mut self, input_text: T) -> &mut Self {
251        self.body.input = Input::from_text(input_text.as_ref());
252        self
253    }
254
255    /// Sets multiple text inputs for batch embedding.
256    ///
257    /// Use this method when you want to embed multiple texts in a single API call.
258    /// This is more efficient than making separate requests for each text.
259    ///
260    /// # Arguments
261    ///
262    /// * `input_text_array` - Vector of texts to convert into embedding vectors
263    ///
264    /// # Returns
265    ///
266    /// A mutable reference to self for method chaining
267    ///
268    /// # Example
269    ///
270    /// ```rust,no_run
271    /// # use openai_tools::embedding::request::Embedding;
272    /// # let mut embedding = Embedding::new().unwrap();
273    /// let texts = vec!["First text", "Second text", "Third text"];
274    /// embedding.input_text_array(texts);
275    /// ```
276    pub fn input_text_array<T: AsRef<str>>(&mut self, input_text_array: Vec<T>) -> &mut Self {
277        let input_strings = input_text_array.into_iter().map(|s| s.as_ref().to_string()).collect();
278        self.body.input = Input::from_text_array(input_strings);
279        self
280    }
281
282    /// Sets the encoding format for the output embeddings.
283    ///
284    /// # Arguments
285    ///
286    /// * `encoding_format` - Either "float" (default) or "base64"
287    ///   - `"float"`: Returns embeddings as arrays of floating point numbers
288    ///   - `"base64"`: Returns embeddings as base64-encoded strings (more compact)
289    ///
290    /// # Returns
291    ///
292    /// A mutable reference to self for method chaining
293    ///
294    /// # Panics
295    ///
296    /// Panics if `encoding_format` is not "float" or "base64"
297    ///
298    /// # Example
299    ///
300    /// ```rust,no_run
301    /// # use openai_tools::embedding::request::Embedding;
302    /// # let mut embedding = Embedding::new().unwrap();
303    /// embedding.encoding_format("float");
304    /// ```
305    pub fn encoding_format<T: AsRef<str>>(&mut self, encoding_format: T) -> &mut Self {
306        let encoding_format = encoding_format.as_ref();
307        assert!(encoding_format == "float" || encoding_format == "base64", "encoding_format must be either 'float' or 'base64'");
308        self.body.encoding_format = Some(encoding_format.to_string());
309        self
310    }
311
312    /// Sends the embedding request to the OpenAI API.
313    ///
314    /// This method validates the request parameters, constructs the HTTP request,
315    /// sends it to the OpenAI Embeddings API endpoint, and parses the response.
316    ///
317    /// # Returns
318    ///
319    /// * `Ok(Response)` - The embedding response containing vectors and metadata
320    /// * `Err(OpenAIToolError)` - If validation fails, the request fails, or parsing fails
321    ///
322    /// # Errors
323    ///
324    /// Returns an error if:
325    /// - API key is not set
326    /// - Model ID is not set
327    /// - Input text is not set
328    /// - Network request fails
329    /// - Response parsing fails
330    ///
331    /// # Example
332    ///
333    /// ```rust,no_run
334    /// # use openai_tools::embedding::request::Embedding;
335    /// # #[tokio::main]
336    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
337    /// let mut embedding = Embedding::new()?;
338    /// let response = embedding
339    ///     .model("text-embedding-3-small")
340    ///     .input_text("Hello, world!")
341    ///     .embed()
342    ///     .await?;
343    /// # Ok(())
344    /// # }
345    /// ```
346    pub async fn embed(&self) -> Result<Response> {
347        if self.api_key.is_empty() {
348            return Err(OpenAIToolError::Error("API key is not set.".into()));
349        }
350        if self.body.model.is_empty() {
351            return Err(OpenAIToolError::Error("Model ID is not set.".into()));
352        }
353        if self.body.input.input_text.is_empty() && self.body.input.input_text_array.is_empty() {
354            return Err(OpenAIToolError::Error("Input text is not set.".into()));
355        }
356
357        let body = serde_json::to_string(&self.body)?;
358        let url = "https://api.openai.com/v1/embeddings";
359
360        let client = request::Client::new();
361        let mut header = request::header::HeaderMap::new();
362        header.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
363        header.insert("Authorization", request::header::HeaderValue::from_str(&format!("Bearer {}", self.api_key)).unwrap());
364        header.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
365
366        if cfg!(test) {
367            // Replace API key with a placeholder in debug mode
368            let body_for_debug = serde_json::to_string_pretty(&self.body).unwrap().replace(&self.api_key, "*************");
369            tracing::info!("Request body: {}", body_for_debug);
370        }
371
372        let response = client.post(url).headers(header).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
373        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
374
375        if cfg!(test) {
376            tracing::info!("Response content: {}", content);
377        }
378
379        serde_json::from_str::<Response>(&content).map_err(OpenAIToolError::SerdeJsonError)
380    }
381}