Skip to main content

openai_tools/images/
request.rs

1//! OpenAI Images API Request Module
2//!
3//! This module provides the functionality to interact with the OpenAI Images API.
4//! It allows you to generate, edit, and create variations of images using DALL-E models.
5//!
6//! # Key Features
7//!
8//! - **Generate**: Create images from text prompts
9//! - **Edit**: Modify existing images with new prompts and masks
10//! - **Variations**: Create variations of existing images (DALL-E 2 only)
11//!
12//! # Quick Start
13//!
14//! ```rust,no_run
15//! use openai_tools::images::request::{Images, GenerateOptions};
16//!
17//! #[tokio::main]
18//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
19//!     let images = Images::new()?;
20//!
21//!     // Generate an image
22//!     let response = images.generate("A white cat", GenerateOptions::default()).await?;
23//!     println!("Image URL: {:?}", response.data[0].url);
24//!
25//!     Ok(())
26//! }
27//! ```
28
29use crate::common::auth::AuthProvider;
30use crate::common::client::create_http_client;
31use crate::common::errors::{ErrorResponse, OpenAIToolError, Result};
32use crate::images::response::ImageResponse;
33use request::multipart::{Form, Part};
34use serde::{Deserialize, Serialize};
35use std::path::Path;
36use std::time::Duration;
37
38/// Default API path for Images
39const IMAGES_PATH: &str = "images";
40
41/// Image generation models.
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
43pub enum ImageModel {
44    /// DALL-E 2 model - supports variations, smaller sizes
45    #[serde(rename = "dall-e-2")]
46    DallE2,
47    /// DALL-E 3 model - higher quality, HD support, style options
48    #[serde(rename = "dall-e-3")]
49    #[default]
50    DallE3,
51    /// GPT Image model - latest generation
52    #[serde(rename = "gpt-image-1")]
53    GptImage1,
54}
55
56impl ImageModel {
57    /// Returns the model identifier string.
58    pub fn as_str(&self) -> &'static str {
59        match self {
60            Self::DallE2 => "dall-e-2",
61            Self::DallE3 => "dall-e-3",
62            Self::GptImage1 => "gpt-image-1",
63        }
64    }
65}
66
67impl std::fmt::Display for ImageModel {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        write!(f, "{}", self.as_str())
70    }
71}
72
73/// Image sizes for generation.
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
75pub enum ImageSize {
76    /// 256x256 pixels (DALL-E 2 only)
77    #[serde(rename = "256x256")]
78    Size256x256,
79    /// 512x512 pixels (DALL-E 2 only)
80    #[serde(rename = "512x512")]
81    Size512x512,
82    /// 1024x1024 pixels (all models)
83    #[serde(rename = "1024x1024")]
84    #[default]
85    Size1024x1024,
86    /// 1792x1024 pixels - landscape (DALL-E 3 only)
87    #[serde(rename = "1792x1024")]
88    Size1792x1024,
89    /// 1024x1792 pixels - portrait (DALL-E 3 only)
90    #[serde(rename = "1024x1792")]
91    Size1024x1792,
92}
93
94impl ImageSize {
95    /// Returns the size string.
96    pub fn as_str(&self) -> &'static str {
97        match self {
98            Self::Size256x256 => "256x256",
99            Self::Size512x512 => "512x512",
100            Self::Size1024x1024 => "1024x1024",
101            Self::Size1792x1024 => "1792x1024",
102            Self::Size1024x1792 => "1024x1792",
103        }
104    }
105}
106
107impl std::fmt::Display for ImageSize {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        write!(f, "{}", self.as_str())
110    }
111}
112
113/// Image quality options (DALL-E 3 only).
114#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
115#[serde(rename_all = "lowercase")]
116pub enum ImageQuality {
117    /// Standard quality
118    #[default]
119    Standard,
120    /// High definition quality
121    Hd,
122}
123
124impl ImageQuality {
125    /// Returns the quality string.
126    pub fn as_str(&self) -> &'static str {
127        match self {
128            Self::Standard => "standard",
129            Self::Hd => "hd",
130        }
131    }
132}
133
134/// Image style options (DALL-E 3 only).
135#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
136#[serde(rename_all = "lowercase")]
137pub enum ImageStyle {
138    /// Vivid - hyper-real and dramatic
139    #[default]
140    Vivid,
141    /// Natural - more natural, less hyper-real
142    Natural,
143}
144
145impl ImageStyle {
146    /// Returns the style string.
147    pub fn as_str(&self) -> &'static str {
148        match self {
149            Self::Vivid => "vivid",
150            Self::Natural => "natural",
151        }
152    }
153}
154
155/// Response format for images.
156#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
157#[serde(rename_all = "snake_case")]
158pub enum ResponseFormat {
159    /// Return URLs to the generated images (valid for 60 minutes)
160    #[default]
161    Url,
162    /// Return base64-encoded image data
163    B64Json,
164}
165
166impl ResponseFormat {
167    /// Returns the format string.
168    pub fn as_str(&self) -> &'static str {
169        match self {
170            Self::Url => "url",
171            Self::B64Json => "b64_json",
172        }
173    }
174}
175
176/// Options for image generation.
177#[derive(Debug, Clone, Default)]
178pub struct GenerateOptions {
179    /// The model to use (defaults to DALL-E 3)
180    pub model: Option<ImageModel>,
181    /// Number of images to generate (1-10, DALL-E 3 only supports 1)
182    pub n: Option<u32>,
183    /// Image quality (DALL-E 3 only)
184    pub quality: Option<ImageQuality>,
185    /// Response format (URL or base64)
186    pub response_format: Option<ResponseFormat>,
187    /// Image size
188    pub size: Option<ImageSize>,
189    /// Image style (DALL-E 3 only)
190    pub style: Option<ImageStyle>,
191    /// User identifier for abuse monitoring
192    pub user: Option<String>,
193}
194
195/// Options for image editing.
196#[derive(Debug, Clone, Default)]
197pub struct EditOptions {
198    /// Path to the mask image (transparent areas will be edited)
199    pub mask: Option<String>,
200    /// The model to use (only DALL-E 2 supports editing)
201    pub model: Option<ImageModel>,
202    /// Number of images to generate (1-10)
203    pub n: Option<u32>,
204    /// Image size
205    pub size: Option<ImageSize>,
206    /// Response format
207    pub response_format: Option<ResponseFormat>,
208    /// User identifier for abuse monitoring
209    pub user: Option<String>,
210}
211
212/// Options for image variations.
213#[derive(Debug, Clone, Default)]
214pub struct VariationOptions {
215    /// The model to use (only DALL-E 2 supports variations)
216    pub model: Option<ImageModel>,
217    /// Number of variations to generate (1-10)
218    pub n: Option<u32>,
219    /// Response format
220    pub response_format: Option<ResponseFormat>,
221    /// Image size
222    pub size: Option<ImageSize>,
223    /// User identifier for abuse monitoring
224    pub user: Option<String>,
225}
226
227/// Request payload for image generation.
228#[derive(Debug, Clone, Serialize)]
229struct GenerateRequest {
230    prompt: String,
231    #[serde(skip_serializing_if = "Option::is_none")]
232    model: Option<String>,
233    #[serde(skip_serializing_if = "Option::is_none")]
234    n: Option<u32>,
235    #[serde(skip_serializing_if = "Option::is_none")]
236    quality: Option<String>,
237    #[serde(skip_serializing_if = "Option::is_none")]
238    response_format: Option<String>,
239    #[serde(skip_serializing_if = "Option::is_none")]
240    size: Option<String>,
241    #[serde(skip_serializing_if = "Option::is_none")]
242    style: Option<String>,
243    #[serde(skip_serializing_if = "Option::is_none")]
244    user: Option<String>,
245}
246
247/// Client for interacting with the OpenAI Images API.
248///
249/// This struct provides methods to generate, edit, and create variations of images.
250/// Use [`Images::new()`] to create a new instance.
251///
252/// # Example
253///
254/// ```rust,no_run
255/// use openai_tools::images::request::{Images, GenerateOptions, ImageModel, ImageSize};
256///
257/// #[tokio::main]
258/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
259///     let images = Images::new()?;
260///
261///     let options = GenerateOptions {
262///         model: Some(ImageModel::DallE3),
263///         size: Some(ImageSize::Size1024x1024),
264///         ..Default::default()
265///     };
266///
267///     let response = images.generate("A sunset over mountains", options).await?;
268///     println!("Generated image: {:?}", response.data[0].url);
269///
270///     Ok(())
271/// }
272/// ```
273pub struct Images {
274    /// Authentication provider (OpenAI or Azure)
275    auth: AuthProvider,
276    /// Optional request timeout duration
277    timeout: Option<Duration>,
278}
279
280impl Images {
281    /// Creates a new Images client for OpenAI API.
282    ///
283    /// Initializes the client by loading the OpenAI API key from
284    /// the environment variable `OPENAI_API_KEY`. Supports `.env` file loading
285    /// via dotenvy.
286    ///
287    /// # Returns
288    ///
289    /// * `Ok(Images)` - A new Images client ready for use
290    /// * `Err(OpenAIToolError)` - If the API key is not found in the environment
291    ///
292    /// # Example
293    ///
294    /// ```rust,no_run
295    /// use openai_tools::images::request::Images;
296    ///
297    /// let images = Images::new().expect("API key should be set");
298    /// ```
299    pub fn new() -> Result<Self> {
300        let auth = AuthProvider::openai_from_env()?;
301        Ok(Self { auth, timeout: None })
302    }
303
304    /// Creates a new Images client with a custom authentication provider
305    pub fn with_auth(auth: AuthProvider) -> Self {
306        Self { auth, timeout: None }
307    }
308
309    /// Creates a new Images client for Azure OpenAI API
310    pub fn azure() -> Result<Self> {
311        let auth = AuthProvider::azure_from_env()?;
312        Ok(Self { auth, timeout: None })
313    }
314
315    /// Creates a new Images client by auto-detecting the provider
316    pub fn detect_provider() -> Result<Self> {
317        let auth = AuthProvider::from_env()?;
318        Ok(Self { auth, timeout: None })
319    }
320
321    /// Creates a new Images client with URL-based provider detection
322    pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
323        let auth = AuthProvider::from_url_with_key(base_url, api_key);
324        Self { auth, timeout: None }
325    }
326
327    /// Creates a new Images client from URL using environment variables
328    pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
329        let auth = AuthProvider::from_url(url)?;
330        Ok(Self { auth, timeout: None })
331    }
332
333    /// Returns the authentication provider
334    pub fn auth(&self) -> &AuthProvider {
335        &self.auth
336    }
337
338    /// Sets the request timeout duration.
339    ///
340    /// # Arguments
341    ///
342    /// * `timeout` - The maximum time to wait for a response
343    ///
344    /// # Returns
345    ///
346    /// A mutable reference to self for method chaining
347    pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
348        self.timeout = Some(timeout);
349        self
350    }
351
352    /// Creates the HTTP client with default headers.
353    fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
354        let client = create_http_client(self.timeout)?;
355        let mut headers = request::header::HeaderMap::new();
356        self.auth.apply_headers(&mut headers)?;
357        headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
358        Ok((client, headers))
359    }
360
361    /// Generates images from a text prompt.
362    ///
363    /// Creates one or more images based on the provided text description.
364    ///
365    /// # Arguments
366    ///
367    /// * `prompt` - Text description of the desired image(s)
368    /// * `options` - Generation options (model, size, quality, etc.)
369    ///
370    /// # Returns
371    ///
372    /// * `Ok(ImageResponse)` - The generated image(s)
373    /// * `Err(OpenAIToolError)` - If the request fails
374    ///
375    /// # Example
376    ///
377    /// ```rust,no_run
378    /// use openai_tools::images::request::{Images, GenerateOptions, ImageQuality, ImageStyle};
379    ///
380    /// #[tokio::main]
381    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
382    ///     let images = Images::new()?;
383    ///
384    ///     let options = GenerateOptions {
385    ///         quality: Some(ImageQuality::Hd),
386    ///         style: Some(ImageStyle::Natural),
387    ///         ..Default::default()
388    ///     };
389    ///
390    ///     let response = images.generate("A serene lake at dawn", options).await?;
391    ///
392    ///     if let Some(url) = &response.data[0].url {
393    ///         println!("Image URL: {}", url);
394    ///     }
395    ///
396    ///     Ok(())
397    /// }
398    /// ```
399    pub async fn generate(&self, prompt: &str, options: GenerateOptions) -> Result<ImageResponse> {
400        let (client, mut headers) = self.create_client()?;
401        headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
402
403        let request_body = GenerateRequest {
404            prompt: prompt.to_string(),
405            model: options.model.map(|m| m.as_str().to_string()),
406            n: options.n,
407            quality: options.quality.map(|q| q.as_str().to_string()),
408            response_format: options.response_format.map(|f| f.as_str().to_string()),
409            size: options.size.map(|s| s.as_str().to_string()),
410            style: options.style.map(|s| s.as_str().to_string()),
411            user: options.user,
412        };
413
414        let body = serde_json::to_string(&request_body).map_err(OpenAIToolError::SerdeJsonError)?;
415
416        let url = format!("{}/generations", self.auth.endpoint(IMAGES_PATH));
417
418        let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
419
420        let status = response.status();
421        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
422
423        if cfg!(test) {
424            tracing::info!("Response content: {}", content);
425        }
426
427        if !status.is_success() {
428            if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
429                return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
430            }
431            return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
432        }
433
434        serde_json::from_str::<ImageResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
435    }
436
437    /// Edits an existing image based on a prompt.
438    ///
439    /// Creates edited versions of an image by replacing areas indicated by
440    /// a transparent mask. Only available with DALL-E 2.
441    ///
442    /// # Arguments
443    ///
444    /// * `image_path` - Path to the image to edit (PNG, max 4MB, square)
445    /// * `prompt` - Text description of the desired edit
446    /// * `options` - Edit options (mask, size, etc.)
447    ///
448    /// # Returns
449    ///
450    /// * `Ok(ImageResponse)` - The edited image(s)
451    /// * `Err(OpenAIToolError)` - If the request fails
452    ///
453    /// # Example
454    ///
455    /// ```rust,no_run
456    /// use openai_tools::images::request::{Images, EditOptions};
457    ///
458    /// #[tokio::main]
459    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
460    ///     let images = Images::new()?;
461    ///
462    ///     let options = EditOptions {
463    ///         mask: Some("mask.png".to_string()),
464    ///         ..Default::default()
465    ///     };
466    ///
467    ///     let response = images.edit("original.png", "Add a red hat", options).await?;
468    ///     println!("Edited image: {:?}", response.data[0].url);
469    ///
470    ///     Ok(())
471    /// }
472    /// ```
473    pub async fn edit(&self, image_path: &str, prompt: &str, options: EditOptions) -> Result<ImageResponse> {
474        let (client, headers) = self.create_client()?;
475
476        // Read the image file
477        let image_content = tokio::fs::read(image_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read image: {}", e)))?;
478
479        let image_filename = Path::new(image_path).file_name().and_then(|n| n.to_str()).unwrap_or("image.png").to_string();
480
481        let image_part = Part::bytes(image_content)
482            .file_name(image_filename)
483            .mime_str("image/png")
484            .map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
485
486        let mut form = Form::new().part("image", image_part).text("prompt", prompt.to_string());
487
488        // Add mask if provided
489        if let Some(mask_path) = options.mask {
490            let mask_content = tokio::fs::read(&mask_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read mask: {}", e)))?;
491
492            let mask_filename = Path::new(&mask_path).file_name().and_then(|n| n.to_str()).unwrap_or("mask.png").to_string();
493
494            let mask_part = Part::bytes(mask_content)
495                .file_name(mask_filename)
496                .mime_str("image/png")
497                .map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
498
499            form = form.part("mask", mask_part);
500        }
501
502        // Add optional parameters
503        if let Some(model) = options.model {
504            form = form.text("model", model.as_str().to_string());
505        }
506        if let Some(n) = options.n {
507            form = form.text("n", n.to_string());
508        }
509        if let Some(size) = options.size {
510            form = form.text("size", size.as_str().to_string());
511        }
512        if let Some(response_format) = options.response_format {
513            form = form.text("response_format", response_format.as_str().to_string());
514        }
515        if let Some(user) = options.user {
516            form = form.text("user", user);
517        }
518
519        let url = format!("{}/edits", self.auth.endpoint(IMAGES_PATH));
520
521        let response = client.post(&url).headers(headers).multipart(form).send().await.map_err(OpenAIToolError::RequestError)?;
522
523        let status = response.status();
524        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
525
526        if cfg!(test) {
527            tracing::info!("Response content: {}", content);
528        }
529
530        if !status.is_success() {
531            if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
532                return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
533            }
534            return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
535        }
536
537        serde_json::from_str::<ImageResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
538    }
539
540    /// Creates variations of an existing image.
541    ///
542    /// Only available with DALL-E 2.
543    ///
544    /// # Arguments
545    ///
546    /// * `image_path` - Path to the image to create variations of (PNG, max 4MB, square)
547    /// * `options` - Variation options (n, size, etc.)
548    ///
549    /// # Returns
550    ///
551    /// * `Ok(ImageResponse)` - The image variation(s)
552    /// * `Err(OpenAIToolError)` - If the request fails
553    ///
554    /// # Example
555    ///
556    /// ```rust,no_run
557    /// use openai_tools::images::request::{Images, VariationOptions, ImageModel};
558    ///
559    /// #[tokio::main]
560    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
561    ///     let images = Images::new()?;
562    ///
563    ///     let options = VariationOptions {
564    ///         model: Some(ImageModel::DallE2),
565    ///         n: Some(3),
566    ///         ..Default::default()
567    ///     };
568    ///
569    ///     let response = images.variation("original.png", options).await?;
570    ///
571    ///     for (i, image) in response.data.iter().enumerate() {
572    ///         println!("Variation {}: {:?}", i + 1, image.url);
573    ///     }
574    ///
575    ///     Ok(())
576    /// }
577    /// ```
578    pub async fn variation(&self, image_path: &str, options: VariationOptions) -> Result<ImageResponse> {
579        let (client, headers) = self.create_client()?;
580
581        // Read the image file
582        let image_content = tokio::fs::read(image_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read image: {}", e)))?;
583
584        let image_filename = Path::new(image_path).file_name().and_then(|n| n.to_str()).unwrap_or("image.png").to_string();
585
586        let image_part = Part::bytes(image_content)
587            .file_name(image_filename)
588            .mime_str("image/png")
589            .map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
590
591        let mut form = Form::new().part("image", image_part);
592
593        // Add optional parameters
594        if let Some(model) = options.model {
595            form = form.text("model", model.as_str().to_string());
596        }
597        if let Some(n) = options.n {
598            form = form.text("n", n.to_string());
599        }
600        if let Some(size) = options.size {
601            form = form.text("size", size.as_str().to_string());
602        }
603        if let Some(response_format) = options.response_format {
604            form = form.text("response_format", response_format.as_str().to_string());
605        }
606        if let Some(user) = options.user {
607            form = form.text("user", user);
608        }
609
610        let url = format!("{}/variations", self.auth.endpoint(IMAGES_PATH));
611
612        let response = client.post(&url).headers(headers).multipart(form).send().await.map_err(OpenAIToolError::RequestError)?;
613
614        let status = response.status();
615        let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
616
617        if cfg!(test) {
618            tracing::info!("Response content: {}", content);
619        }
620
621        if !status.is_success() {
622            if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
623                return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
624            }
625            return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
626        }
627
628        serde_json::from_str::<ImageResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
629    }
630}