use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::LlmError;
use crate::traits::ImageGenerationCapability;
use crate::types::{
GeneratedImage, ImageEditRequest, ImageGenerationRequest, ImageGenerationResponse,
ImageVariationRequest,
};
use super::config::OpenAiConfig;
#[derive(Debug, Clone, Serialize)]
struct OpenAiImageRequest {
prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
negative_prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
size: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
quality: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
style: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
user: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiImageResponse {
created: u64,
data: Vec<OpenAiImageData>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiImageData {
#[serde(skip_serializing_if = "Option::is_none")]
url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
b64_json: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
revised_prompt: Option<String>,
}
#[derive(Debug, Clone)]
pub struct OpenAiImages {
config: OpenAiConfig,
http_client: reqwest::Client,
}
impl OpenAiImages {
pub const fn new(config: OpenAiConfig, http_client: reqwest::Client) -> Self {
Self {
config,
http_client,
}
}
fn is_siliconflow(&self) -> bool {
self.config.base_url.contains("siliconflow.cn")
}
fn default_model(&self) -> String {
if self.is_siliconflow() {
"Kwai-Kolors/Kolors".to_string()
} else {
"dall-e-3".to_string()
}
}
fn get_supported_models(&self) -> Vec<String> {
if self.is_siliconflow() {
vec![
"Kwai-Kolors/Kolors".to_string(),
"black-forest-labs/FLUX.1-schnell".to_string(),
"stabilityai/stable-diffusion-3.5-large".to_string(),
]
} else {
vec![
"dall-e-2".to_string(),
"dall-e-3".to_string(),
"gpt-image-1".to_string(), ]
}
}
fn convert_request_for_provider(&self, request: &OpenAiImageRequest) -> serde_json::Value {
if self.is_siliconflow() {
let mut siliconflow_request = serde_json::json!({
"model": request.model.as_ref().unwrap_or(&self.default_model()),
"prompt": request.prompt,
"image_size": request.size.as_ref().unwrap_or(&"1024x1024".to_string()),
"batch_size": request.n.unwrap_or(1),
"num_inference_steps": 20,
"guidance_scale": 7.5
});
if let Some(negative_prompt) = &request.negative_prompt {
siliconflow_request["negative_prompt"] =
serde_json::Value::String(negative_prompt.clone());
}
siliconflow_request
} else {
serde_json::to_value(request).unwrap_or_default()
}
}
async fn make_request(
&self,
request: OpenAiImageRequest,
) -> Result<OpenAiImageResponse, LlmError> {
let url = format!("{}/images/generations", self.config.base_url);
let mut headers = reqwest::header::HeaderMap::new();
for (key, value) in self.config.get_headers() {
let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())
.map_err(|e| LlmError::HttpError(format!("Invalid header name: {e}")))?;
let header_value = reqwest::header::HeaderValue::from_str(&value)
.map_err(|e| LlmError::HttpError(format!("Invalid header value: {e}")))?;
headers.insert(header_name, header_value);
}
let request_body = self.convert_request_for_provider(&request);
let response = self
.http_client
.post(&url)
.headers(headers)
.json(&request_body)
.send()
.await
.map_err(|e| LlmError::HttpError(format!("Request failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::ApiError {
code: status.as_u16(),
message: format!("OpenAI Images API error {status}: {error_text}"),
details: None,
});
}
let openai_response: OpenAiImageResponse = response
.json()
.await
.map_err(|e| LlmError::ParseError(format!("Failed to parse response: {e}")))?;
Ok(openai_response)
}
fn convert_response(&self, openai_response: OpenAiImageResponse) -> ImageGenerationResponse {
let images: Vec<GeneratedImage> = openai_response
.data
.into_iter()
.map(|img| GeneratedImage {
url: img.url,
b64_json: img.b64_json,
format: None, width: None, height: None, revised_prompt: img.revised_prompt,
metadata: HashMap::new(),
})
.collect();
let mut metadata = HashMap::new();
metadata.insert(
"created".to_string(),
serde_json::Value::Number(openai_response.created.into()),
);
ImageGenerationResponse { images, metadata }
}
fn get_supported_sizes(&self, model: &str) -> Vec<String> {
match model {
"dall-e-2" => vec![
"256x256".to_string(),
"512x512".to_string(),
"1024x1024".to_string(),
],
"dall-e-3" => vec![
"1024x1024".to_string(),
"1792x1024".to_string(),
"1024x1792".to_string(),
],
"gpt-image-1" => vec![
"1024x1024".to_string(),
"1792x1024".to_string(),
"1024x1792".to_string(),
"2048x2048".to_string(), ],
_ => vec!["1024x1024".to_string()], }
}
fn validate_request(&self, request: &ImageGenerationRequest) -> Result<(), LlmError> {
let model = request.model.as_deref().unwrap_or("dall-e-3");
if !self.get_supported_models().contains(&model.to_string()) {
return Err(LlmError::InvalidInput(format!(
"Unsupported model: {}. Supported models: {:?}",
model,
self.get_supported_models()
)));
}
match model {
"dall-e-2" => {
if request.count > 10 {
return Err(LlmError::InvalidInput(
"DALL-E 2 can generate at most 10 images".to_string(),
));
}
}
"dall-e-3" => {
if request.count > 1 {
return Err(LlmError::InvalidInput(
"DALL-E 3 can generate only 1 image at a time".to_string(),
));
}
}
"gpt-image-1" => {
if request.count > 4 {
return Err(LlmError::InvalidInput(
"GPT-Image-1 can generate at most 4 images".to_string(),
));
}
}
_ => {
return Err(LlmError::InvalidInput(format!(
"Unsupported model: {model}"
)));
}
}
if let Some(size) = &request.size {
let supported_sizes = self.get_supported_sizes(model);
if !supported_sizes.contains(size) {
return Err(LlmError::InvalidInput(format!(
"Unsupported size '{size}' for model '{model}'. Supported sizes: {supported_sizes:?}"
)));
}
}
Ok(())
}
}
#[async_trait]
impl ImageGenerationCapability for OpenAiImages {
async fn generate_images(
&self,
request: ImageGenerationRequest,
) -> Result<ImageGenerationResponse, LlmError> {
self.validate_request(&request)?;
let model = request
.model
.clone()
.unwrap_or_else(|| self.default_model());
let openai_request = OpenAiImageRequest {
prompt: request.prompt,
negative_prompt: request.negative_prompt,
model: Some(model),
n: if request.count > 0 {
Some(request.count)
} else {
Some(1)
},
size: request.size,
quality: request.quality,
style: request.style,
response_format: Some("url".to_string()), user: None, };
let openai_response = self.make_request(openai_request).await?;
Ok(self.convert_response(openai_response))
}
fn get_supported_sizes(&self) -> Vec<String> {
if self.is_siliconflow() {
vec![
"1024x1024".to_string(),
"960x1280".to_string(),
"768x1024".to_string(),
"720x1440".to_string(),
"720x1280".to_string(),
]
} else {
vec![
"256x256".to_string(),
"512x512".to_string(),
"1024x1024".to_string(),
"1792x1024".to_string(),
"1024x1792".to_string(),
"2048x2048".to_string(), ]
}
}
fn get_supported_formats(&self) -> Vec<String> {
if self.is_siliconflow() {
vec!["url".to_string()]
} else {
vec!["url".to_string(), "b64_json".to_string()]
}
}
fn supports_image_editing(&self) -> bool {
!self.is_siliconflow() }
fn supports_image_variations(&self) -> bool {
!self.is_siliconflow() }
async fn edit_image(
&self,
request: ImageEditRequest,
) -> Result<ImageGenerationResponse, LlmError> {
let url = format!("{}/images/edits", self.config.base_url);
let mut headers = reqwest::header::HeaderMap::new();
for (key, value) in self.config.get_headers() {
let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())
.map_err(|e| LlmError::HttpError(format!("Invalid header name: {e}")))?;
let header_value = reqwest::header::HeaderValue::from_str(&value)
.map_err(|e| LlmError::HttpError(format!("Invalid header value: {e}")))?;
headers.insert(header_name, header_value);
}
let mut form = reqwest::multipart::Form::new().text("prompt", request.prompt);
let part = reqwest::multipart::Part::bytes(request.image)
.file_name("image.png")
.mime_str("image/png")?;
form = form.part("image", part);
if let Some(mask_data) = request.mask {
let part = reqwest::multipart::Part::bytes(mask_data)
.file_name("mask.png")
.mime_str("image/png")?;
form = form.part("mask", part);
}
if let Some(size) = request.size {
form = form.text("size", size);
}
if let Some(count) = request.count
&& count > 0
{
form = form.text("n", count.to_string());
}
if let Some(response_format) = request.response_format {
form = form.text("response_format", response_format);
}
let response = self
.http_client
.post(&url)
.headers(headers)
.multipart(form)
.send()
.await
.map_err(|e| LlmError::HttpError(format!("Request failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::ApiError {
code: status.as_u16(),
message: format!("OpenAI Images API error {status}: {error_text}"),
details: None,
});
}
let openai_response: OpenAiImageResponse = response
.json()
.await
.map_err(|e| LlmError::ParseError(format!("Failed to parse response: {e}")))?;
Ok(self.convert_response(openai_response))
}
async fn create_variation(
&self,
request: ImageVariationRequest,
) -> Result<ImageGenerationResponse, LlmError> {
let url = format!("{}/images/variations", self.config.base_url);
let mut headers = reqwest::header::HeaderMap::new();
for (key, value) in self.config.get_headers() {
let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())
.map_err(|e| LlmError::HttpError(format!("Invalid header name: {e}")))?;
let header_value = reqwest::header::HeaderValue::from_str(&value)
.map_err(|e| LlmError::HttpError(format!("Invalid header value: {e}")))?;
headers.insert(header_name, header_value);
}
let mut form = reqwest::multipart::Form::new();
let part = reqwest::multipart::Part::bytes(request.image)
.file_name("image.png")
.mime_str("image/png")?;
form = form.part("image", part);
if let Some(size) = request.size {
form = form.text("size", size);
}
if let Some(count) = request.count
&& count > 0
{
form = form.text("n", count.to_string());
}
if let Some(response_format) = request.response_format {
form = form.text("response_format", response_format);
}
let response = self
.http_client
.post(&url)
.headers(headers)
.multipart(form)
.send()
.await
.map_err(|e| LlmError::HttpError(format!("Request failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::ApiError {
code: status.as_u16(),
message: format!("OpenAI Images API error {status}: {error_text}"),
details: None,
});
}
let openai_response: OpenAiImageResponse = response
.json()
.await
.map_err(|e| LlmError::ParseError(format!("Failed to parse response: {e}")))?;
Ok(self.convert_response(openai_response))
}
}