use crate::core::providers::unified_provider::ProviderError;
use crate::core::types::image::ImageGenerationRequest;
use crate::core::types::responses::{ImageData, ImageGenerationResponse};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct StabilityTextToImageRequest {
pub text_prompts: Vec<TextPrompt>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cfg_scale: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub height: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub width: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub samples: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub steps: Option<u32>,
}
#[derive(Debug, Serialize)]
pub struct TextPrompt {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub weight: Option<f32>,
}
#[derive(Debug, Deserialize)]
pub struct StabilityResponse {
pub result: String,
pub artifacts: Vec<StabilityArtifact>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct StabilityArtifact {
pub base64: String,
pub finish_reason: String,
pub seed: i64,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct NovaCanvasRequest {
pub text_to_image_params: TextToImageParams,
pub task_type: String, }
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct TextToImageParams {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub negative_text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub height: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub width: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cfg_scale: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub number_of_images: Option<u32>,
}
#[derive(Debug, Deserialize)]
pub struct NovaCanvasResponse {
pub images: Vec<NovaImage>,
}
#[derive(Debug, Deserialize)]
pub struct NovaImage {
pub base64: String,
}
pub async fn execute_image_generation(
client: &crate::core::providers::bedrock::client::BedrockClient,
request: &ImageGenerationRequest,
) -> Result<ImageGenerationResponse, ProviderError> {
let default_model = "dall-e-3".to_string();
let model = request.model.as_ref().unwrap_or(&default_model);
if model.contains("stability") {
execute_stability_generation(client, request, model).await
} else if model.contains("nova-canvas") {
execute_nova_generation(client, request, model).await
} else {
Err(ProviderError::model_not_found(
"bedrock",
format!("Image generation model {} not supported", model),
))
}
}
async fn execute_stability_generation(
client: &crate::core::providers::bedrock::client::BedrockClient,
request: &ImageGenerationRequest,
model: &str,
) -> Result<ImageGenerationResponse, ProviderError> {
let stability_request = StabilityTextToImageRequest {
text_prompts: vec![TextPrompt {
text: request.prompt.clone(),
weight: Some(1.0),
}],
cfg_scale: Some(7.0),
height: request
.size
.as_ref()
.and_then(|s| s.split('x').next().and_then(|h| h.parse().ok())),
width: request
.size
.as_ref()
.and_then(|s| s.split('x').nth(1).and_then(|w| w.parse().ok())),
samples: request.n,
seed: None,
steps: request.quality.as_ref().and_then(|q| match q.as_str() {
"standard" => Some(30),
"hd" => Some(50),
_ => None,
}),
};
let body = serde_json::to_value(stability_request)?;
let response = client.send_request(model, "invoke", &body).await?;
let stability_response: StabilityResponse = response
.json()
.await
.map_err(|e| ProviderError::response_parsing("bedrock", e.to_string()))?;
let data: Vec<ImageData> = stability_response
.artifacts
.into_iter()
.map(|artifact| ImageData {
url: None,
b64_json: Some(artifact.base64),
revised_prompt: None,
})
.collect();
Ok(ImageGenerationResponse {
created: chrono::Utc::now().timestamp() as u64,
data,
})
}
async fn execute_nova_generation(
client: &crate::core::providers::bedrock::client::BedrockClient,
request: &ImageGenerationRequest,
model: &str,
) -> Result<ImageGenerationResponse, ProviderError> {
let nova_request = NovaCanvasRequest {
text_to_image_params: TextToImageParams {
text: request.prompt.clone(),
negative_text: None,
height: request
.size
.as_ref()
.and_then(|s| s.split('x').next().and_then(|h| h.parse().ok())),
width: request
.size
.as_ref()
.and_then(|s| s.split('x').nth(1).and_then(|w| w.parse().ok())),
cfg_scale: Some(7.0),
seed: None,
number_of_images: request.n,
},
task_type: "TEXT_IMAGE".to_string(),
};
let body = serde_json::to_value(nova_request)?;
let response = client.send_request(model, "invoke", &body).await?;
let nova_response: NovaCanvasResponse = response
.json()
.await
.map_err(|e| ProviderError::response_parsing("bedrock", e.to_string()))?;
let data: Vec<ImageData> = nova_response
.images
.into_iter()
.map(|image| ImageData {
url: None,
b64_json: Some(image.base64),
revised_prompt: None,
})
.collect();
Ok(ImageGenerationResponse {
created: chrono::Utc::now().timestamp() as u64,
data,
})
}