use crate::common::auth::AuthProvider;
use crate::common::client::create_http_client;
use crate::common::errors::{ErrorResponse, OpenAIToolError, Result};
use crate::images::response::ImageResponse;
use request::multipart::{Form, Part};
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::time::Duration;
const IMAGES_PATH: &str = "images";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ImageModel {
#[serde(rename = "dall-e-2")]
DallE2,
#[serde(rename = "dall-e-3")]
#[default]
DallE3,
#[serde(rename = "gpt-image-1")]
GptImage1,
}
impl ImageModel {
pub fn as_str(&self) -> &'static str {
match self {
Self::DallE2 => "dall-e-2",
Self::DallE3 => "dall-e-3",
Self::GptImage1 => "gpt-image-1",
}
}
}
impl std::fmt::Display for ImageModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ImageSize {
#[serde(rename = "256x256")]
Size256x256,
#[serde(rename = "512x512")]
Size512x512,
#[serde(rename = "1024x1024")]
#[default]
Size1024x1024,
#[serde(rename = "1792x1024")]
Size1792x1024,
#[serde(rename = "1024x1792")]
Size1024x1792,
}
impl ImageSize {
pub fn as_str(&self) -> &'static str {
match self {
Self::Size256x256 => "256x256",
Self::Size512x512 => "512x512",
Self::Size1024x1024 => "1024x1024",
Self::Size1792x1024 => "1792x1024",
Self::Size1024x1792 => "1024x1792",
}
}
}
impl std::fmt::Display for ImageSize {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum ImageQuality {
#[default]
Standard,
Hd,
}
impl ImageQuality {
pub fn as_str(&self) -> &'static str {
match self {
Self::Standard => "standard",
Self::Hd => "hd",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum ImageStyle {
#[default]
Vivid,
Natural,
}
impl ImageStyle {
pub fn as_str(&self) -> &'static str {
match self {
Self::Vivid => "vivid",
Self::Natural => "natural",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ResponseFormat {
#[default]
Url,
B64Json,
}
impl ResponseFormat {
pub fn as_str(&self) -> &'static str {
match self {
Self::Url => "url",
Self::B64Json => "b64_json",
}
}
}
#[derive(Debug, Clone, Default)]
pub struct GenerateOptions {
pub model: Option<ImageModel>,
pub n: Option<u32>,
pub quality: Option<ImageQuality>,
pub response_format: Option<ResponseFormat>,
pub size: Option<ImageSize>,
pub style: Option<ImageStyle>,
pub user: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct EditOptions {
pub mask: Option<String>,
pub model: Option<ImageModel>,
pub n: Option<u32>,
pub size: Option<ImageSize>,
pub response_format: Option<ResponseFormat>,
pub user: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct VariationOptions {
pub model: Option<ImageModel>,
pub n: Option<u32>,
pub response_format: Option<ResponseFormat>,
pub size: Option<ImageSize>,
pub user: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
struct GenerateRequest {
prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
quality: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
size: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
style: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
user: Option<String>,
}
pub struct Images {
auth: AuthProvider,
timeout: Option<Duration>,
}
impl Images {
pub fn new() -> Result<Self> {
let auth = AuthProvider::openai_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_auth(auth: AuthProvider) -> Self {
Self { auth, timeout: None }
}
pub fn azure() -> Result<Self> {
let auth = AuthProvider::azure_from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn detect_provider() -> Result<Self> {
let auth = AuthProvider::from_env()?;
Ok(Self { auth, timeout: None })
}
pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
let auth = AuthProvider::from_url_with_key(base_url, api_key);
Self { auth, timeout: None }
}
pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
let auth = AuthProvider::from_url(url)?;
Ok(Self { auth, timeout: None })
}
pub fn auth(&self) -> &AuthProvider {
&self.auth
}
pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
self.timeout = Some(timeout);
self
}
fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
let client = create_http_client(self.timeout)?;
let mut headers = request::header::HeaderMap::new();
self.auth.apply_headers(&mut headers)?;
headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
Ok((client, headers))
}
pub async fn generate(&self, prompt: &str, options: GenerateOptions) -> Result<ImageResponse> {
let (client, mut headers) = self.create_client()?;
headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
let request_body = GenerateRequest {
prompt: prompt.to_string(),
model: options.model.map(|m| m.as_str().to_string()),
n: options.n,
quality: options.quality.map(|q| q.as_str().to_string()),
response_format: options.response_format.map(|f| f.as_str().to_string()),
size: options.size.map(|s| s.as_str().to_string()),
style: options.style.map(|s| s.as_str().to_string()),
user: options.user,
};
let body = serde_json::to_string(&request_body).map_err(OpenAIToolError::SerdeJsonError)?;
let url = format!("{}/generations", self.auth.endpoint(IMAGES_PATH));
let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
}
return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
}
serde_json::from_str::<ImageResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn edit(&self, image_path: &str, prompt: &str, options: EditOptions) -> Result<ImageResponse> {
let (client, headers) = self.create_client()?;
let image_content = tokio::fs::read(image_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read image: {}", e)))?;
let image_filename = Path::new(image_path).file_name().and_then(|n| n.to_str()).unwrap_or("image.png").to_string();
let image_part = Part::bytes(image_content)
.file_name(image_filename)
.mime_str("image/png")
.map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
let mut form = Form::new().part("image", image_part).text("prompt", prompt.to_string());
if let Some(mask_path) = options.mask {
let mask_content = tokio::fs::read(&mask_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read mask: {}", e)))?;
let mask_filename = Path::new(&mask_path).file_name().and_then(|n| n.to_str()).unwrap_or("mask.png").to_string();
let mask_part = Part::bytes(mask_content)
.file_name(mask_filename)
.mime_str("image/png")
.map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
form = form.part("mask", mask_part);
}
if let Some(model) = options.model {
form = form.text("model", model.as_str().to_string());
}
if let Some(n) = options.n {
form = form.text("n", n.to_string());
}
if let Some(size) = options.size {
form = form.text("size", size.as_str().to_string());
}
if let Some(response_format) = options.response_format {
form = form.text("response_format", response_format.as_str().to_string());
}
if let Some(user) = options.user {
form = form.text("user", user);
}
let url = format!("{}/edits", self.auth.endpoint(IMAGES_PATH));
let response = client.post(&url).headers(headers).multipart(form).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
}
return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
}
serde_json::from_str::<ImageResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
pub async fn variation(&self, image_path: &str, options: VariationOptions) -> Result<ImageResponse> {
let (client, headers) = self.create_client()?;
let image_content = tokio::fs::read(image_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read image: {}", e)))?;
let image_filename = Path::new(image_path).file_name().and_then(|n| n.to_str()).unwrap_or("image.png").to_string();
let image_part = Part::bytes(image_content)
.file_name(image_filename)
.mime_str("image/png")
.map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
let mut form = Form::new().part("image", image_part);
if let Some(model) = options.model {
form = form.text("model", model.as_str().to_string());
}
if let Some(n) = options.n {
form = form.text("n", n.to_string());
}
if let Some(size) = options.size {
form = form.text("size", size.as_str().to_string());
}
if let Some(response_format) = options.response_format {
form = form.text("response_format", response_format.as_str().to_string());
}
if let Some(user) = options.user {
form = form.text("user", user);
}
let url = format!("{}/variations", self.auth.endpoint(IMAGES_PATH));
let response = client.post(&url).headers(headers).multipart(form).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
}
return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
}
serde_json::from_str::<ImageResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
}