1use crate::common::auth::AuthProvider;
30use crate::common::client::create_http_client;
31use crate::common::errors::{ErrorResponse, OpenAIToolError, Result};
32use crate::images::response::ImageResponse;
33use request::multipart::{Form, Part};
34use serde::{Deserialize, Serialize};
35use std::path::Path;
36use std::time::Duration;
37
38const IMAGES_PATH: &str = "images";
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
43pub enum ImageModel {
44 #[serde(rename = "dall-e-2")]
46 DallE2,
47 #[serde(rename = "dall-e-3")]
49 #[default]
50 DallE3,
51 #[serde(rename = "gpt-image-1")]
53 GptImage1,
54}
55
56impl ImageModel {
57 pub fn as_str(&self) -> &'static str {
59 match self {
60 Self::DallE2 => "dall-e-2",
61 Self::DallE3 => "dall-e-3",
62 Self::GptImage1 => "gpt-image-1",
63 }
64 }
65}
66
67impl std::fmt::Display for ImageModel {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 write!(f, "{}", self.as_str())
70 }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
75pub enum ImageSize {
76 #[serde(rename = "256x256")]
78 Size256x256,
79 #[serde(rename = "512x512")]
81 Size512x512,
82 #[serde(rename = "1024x1024")]
84 #[default]
85 Size1024x1024,
86 #[serde(rename = "1792x1024")]
88 Size1792x1024,
89 #[serde(rename = "1024x1792")]
91 Size1024x1792,
92}
93
94impl ImageSize {
95 pub fn as_str(&self) -> &'static str {
97 match self {
98 Self::Size256x256 => "256x256",
99 Self::Size512x512 => "512x512",
100 Self::Size1024x1024 => "1024x1024",
101 Self::Size1792x1024 => "1792x1024",
102 Self::Size1024x1792 => "1024x1792",
103 }
104 }
105}
106
107impl std::fmt::Display for ImageSize {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 write!(f, "{}", self.as_str())
110 }
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
115#[serde(rename_all = "lowercase")]
116pub enum ImageQuality {
117 #[default]
119 Standard,
120 Hd,
122}
123
124impl ImageQuality {
125 pub fn as_str(&self) -> &'static str {
127 match self {
128 Self::Standard => "standard",
129 Self::Hd => "hd",
130 }
131 }
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
136#[serde(rename_all = "lowercase")]
137pub enum ImageStyle {
138 #[default]
140 Vivid,
141 Natural,
143}
144
145impl ImageStyle {
146 pub fn as_str(&self) -> &'static str {
148 match self {
149 Self::Vivid => "vivid",
150 Self::Natural => "natural",
151 }
152 }
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
157#[serde(rename_all = "snake_case")]
158pub enum ResponseFormat {
159 #[default]
161 Url,
162 B64Json,
164}
165
166impl ResponseFormat {
167 pub fn as_str(&self) -> &'static str {
169 match self {
170 Self::Url => "url",
171 Self::B64Json => "b64_json",
172 }
173 }
174}
175
176#[derive(Debug, Clone, Default)]
178pub struct GenerateOptions {
179 pub model: Option<ImageModel>,
181 pub n: Option<u32>,
183 pub quality: Option<ImageQuality>,
185 pub response_format: Option<ResponseFormat>,
187 pub size: Option<ImageSize>,
189 pub style: Option<ImageStyle>,
191 pub user: Option<String>,
193}
194
195#[derive(Debug, Clone, Default)]
197pub struct EditOptions {
198 pub mask: Option<String>,
200 pub model: Option<ImageModel>,
202 pub n: Option<u32>,
204 pub size: Option<ImageSize>,
206 pub response_format: Option<ResponseFormat>,
208 pub user: Option<String>,
210}
211
212#[derive(Debug, Clone, Default)]
214pub struct VariationOptions {
215 pub model: Option<ImageModel>,
217 pub n: Option<u32>,
219 pub response_format: Option<ResponseFormat>,
221 pub size: Option<ImageSize>,
223 pub user: Option<String>,
225}
226
227#[derive(Debug, Clone, Serialize)]
229struct GenerateRequest {
230 prompt: String,
231 #[serde(skip_serializing_if = "Option::is_none")]
232 model: Option<String>,
233 #[serde(skip_serializing_if = "Option::is_none")]
234 n: Option<u32>,
235 #[serde(skip_serializing_if = "Option::is_none")]
236 quality: Option<String>,
237 #[serde(skip_serializing_if = "Option::is_none")]
238 response_format: Option<String>,
239 #[serde(skip_serializing_if = "Option::is_none")]
240 size: Option<String>,
241 #[serde(skip_serializing_if = "Option::is_none")]
242 style: Option<String>,
243 #[serde(skip_serializing_if = "Option::is_none")]
244 user: Option<String>,
245}
246
247pub struct Images {
274 auth: AuthProvider,
276 timeout: Option<Duration>,
278}
279
280impl Images {
281 pub fn new() -> Result<Self> {
300 let auth = AuthProvider::openai_from_env()?;
301 Ok(Self { auth, timeout: None })
302 }
303
304 pub fn with_auth(auth: AuthProvider) -> Self {
306 Self { auth, timeout: None }
307 }
308
309 pub fn azure() -> Result<Self> {
311 let auth = AuthProvider::azure_from_env()?;
312 Ok(Self { auth, timeout: None })
313 }
314
315 pub fn detect_provider() -> Result<Self> {
317 let auth = AuthProvider::from_env()?;
318 Ok(Self { auth, timeout: None })
319 }
320
321 pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
323 let auth = AuthProvider::from_url_with_key(base_url, api_key);
324 Self { auth, timeout: None }
325 }
326
327 pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
329 let auth = AuthProvider::from_url(url)?;
330 Ok(Self { auth, timeout: None })
331 }
332
333 pub fn auth(&self) -> &AuthProvider {
335 &self.auth
336 }
337
338 pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
348 self.timeout = Some(timeout);
349 self
350 }
351
352 fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
354 let client = create_http_client(self.timeout)?;
355 let mut headers = request::header::HeaderMap::new();
356 self.auth.apply_headers(&mut headers)?;
357 headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
358 Ok((client, headers))
359 }
360
361 pub async fn generate(&self, prompt: &str, options: GenerateOptions) -> Result<ImageResponse> {
400 let (client, mut headers) = self.create_client()?;
401 headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
402
403 let request_body = GenerateRequest {
404 prompt: prompt.to_string(),
405 model: options.model.map(|m| m.as_str().to_string()),
406 n: options.n,
407 quality: options.quality.map(|q| q.as_str().to_string()),
408 response_format: options.response_format.map(|f| f.as_str().to_string()),
409 size: options.size.map(|s| s.as_str().to_string()),
410 style: options.style.map(|s| s.as_str().to_string()),
411 user: options.user,
412 };
413
414 let body = serde_json::to_string(&request_body).map_err(OpenAIToolError::SerdeJsonError)?;
415
416 let url = format!("{}/generations", self.auth.endpoint(IMAGES_PATH));
417
418 let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
419
420 let status = response.status();
421 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
422
423 if cfg!(test) {
424 tracing::info!("Response content: {}", content);
425 }
426
427 if !status.is_success() {
428 if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
429 return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
430 }
431 return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
432 }
433
434 serde_json::from_str::<ImageResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
435 }
436
437 pub async fn edit(&self, image_path: &str, prompt: &str, options: EditOptions) -> Result<ImageResponse> {
474 let (client, headers) = self.create_client()?;
475
476 let image_content = tokio::fs::read(image_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read image: {}", e)))?;
478
479 let image_filename = Path::new(image_path).file_name().and_then(|n| n.to_str()).unwrap_or("image.png").to_string();
480
481 let image_part = Part::bytes(image_content)
482 .file_name(image_filename)
483 .mime_str("image/png")
484 .map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
485
486 let mut form = Form::new().part("image", image_part).text("prompt", prompt.to_string());
487
488 if let Some(mask_path) = options.mask {
490 let mask_content = tokio::fs::read(&mask_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read mask: {}", e)))?;
491
492 let mask_filename = Path::new(&mask_path).file_name().and_then(|n| n.to_str()).unwrap_or("mask.png").to_string();
493
494 let mask_part = Part::bytes(mask_content)
495 .file_name(mask_filename)
496 .mime_str("image/png")
497 .map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
498
499 form = form.part("mask", mask_part);
500 }
501
502 if let Some(model) = options.model {
504 form = form.text("model", model.as_str().to_string());
505 }
506 if let Some(n) = options.n {
507 form = form.text("n", n.to_string());
508 }
509 if let Some(size) = options.size {
510 form = form.text("size", size.as_str().to_string());
511 }
512 if let Some(response_format) = options.response_format {
513 form = form.text("response_format", response_format.as_str().to_string());
514 }
515 if let Some(user) = options.user {
516 form = form.text("user", user);
517 }
518
519 let url = format!("{}/edits", self.auth.endpoint(IMAGES_PATH));
520
521 let response = client.post(&url).headers(headers).multipart(form).send().await.map_err(OpenAIToolError::RequestError)?;
522
523 let status = response.status();
524 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
525
526 if cfg!(test) {
527 tracing::info!("Response content: {}", content);
528 }
529
530 if !status.is_success() {
531 if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
532 return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
533 }
534 return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
535 }
536
537 serde_json::from_str::<ImageResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
538 }
539
540 pub async fn variation(&self, image_path: &str, options: VariationOptions) -> Result<ImageResponse> {
579 let (client, headers) = self.create_client()?;
580
581 let image_content = tokio::fs::read(image_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read image: {}", e)))?;
583
584 let image_filename = Path::new(image_path).file_name().and_then(|n| n.to_str()).unwrap_or("image.png").to_string();
585
586 let image_part = Part::bytes(image_content)
587 .file_name(image_filename)
588 .mime_str("image/png")
589 .map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
590
591 let mut form = Form::new().part("image", image_part);
592
593 if let Some(model) = options.model {
595 form = form.text("model", model.as_str().to_string());
596 }
597 if let Some(n) = options.n {
598 form = form.text("n", n.to_string());
599 }
600 if let Some(size) = options.size {
601 form = form.text("size", size.as_str().to_string());
602 }
603 if let Some(response_format) = options.response_format {
604 form = form.text("response_format", response_format.as_str().to_string());
605 }
606 if let Some(user) = options.user {
607 form = form.text("user", user);
608 }
609
610 let url = format!("{}/variations", self.auth.endpoint(IMAGES_PATH));
611
612 let response = client.post(&url).headers(headers).multipart(form).send().await.map_err(OpenAIToolError::RequestError)?;
613
614 let status = response.status();
615 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
616
617 if cfg!(test) {
618 tracing::info!("Response content: {}", content);
619 }
620
621 if !status.is_success() {
622 if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
623 return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
624 }
625 return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
626 }
627
628 serde_json::from_str::<ImageResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
629 }
630}