use std::path::PathBuf;
use std::time::Duration;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct ImageGenerationRequest {
pub prompt: String,
pub model: String,
pub n: Option<u8>,
pub size: Option<ImageSize>,
pub quality: Option<ImageQuality>,
pub style: Option<ImageStyle>,
pub response_format: Option<ImageFormat>,
pub negative_prompt: Option<String>,
pub seed: Option<u64>,
}
impl ImageGenerationRequest {
pub fn new(model: impl Into<String>, prompt: impl Into<String>) -> Self {
Self {
prompt: prompt.into(),
model: model.into(),
n: None,
size: None,
quality: None,
style: None,
response_format: None,
negative_prompt: None,
seed: None,
}
}
pub fn with_n(mut self, n: u8) -> Self {
self.n = Some(n);
self
}
pub fn with_size(mut self, size: ImageSize) -> Self {
self.size = Some(size);
self
}
pub fn with_quality(mut self, quality: ImageQuality) -> Self {
self.quality = Some(quality);
self
}
pub fn with_style(mut self, style: ImageStyle) -> Self {
self.style = Some(style);
self
}
pub fn with_format(mut self, format: ImageFormat) -> Self {
self.response_format = Some(format);
self
}
pub fn with_negative_prompt(mut self, negative_prompt: impl Into<String>) -> Self {
self.negative_prompt = Some(negative_prompt.into());
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ImageSize {
Square256,
Square512,
#[default]
Square1024,
Portrait1024x1792,
Landscape1792x1024,
Custom { width: u32, height: u32 },
}
impl ImageSize {
pub fn dimensions(&self) -> (u32, u32) {
match self {
ImageSize::Square256 => (256, 256),
ImageSize::Square512 => (512, 512),
ImageSize::Square1024 => (1024, 1024),
ImageSize::Portrait1024x1792 => (1024, 1792),
ImageSize::Landscape1792x1024 => (1792, 1024),
ImageSize::Custom { width, height } => (*width, *height),
}
}
pub fn to_openai_string(self) -> String {
let (w, h) = self.dimensions();
format!("{}x{}", w, h)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ImageQuality {
#[default]
Standard,
Hd,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ImageStyle {
#[default]
Natural,
Vivid,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ImageFormat {
#[default]
Url,
B64Json,
}
#[derive(Debug, Clone)]
pub struct ImageGenerationResponse {
pub created: u64,
pub images: Vec<GeneratedImage>,
}
impl ImageGenerationResponse {
pub fn first(&self) -> Option<&GeneratedImage> {
self.images.first()
}
}
#[derive(Debug, Clone)]
pub struct GeneratedImage {
pub url: Option<String>,
pub b64_json: Option<String>,
pub revised_prompt: Option<String>,
}
impl GeneratedImage {
pub fn from_url(url: impl Into<String>) -> Self {
Self {
url: Some(url.into()),
b64_json: None,
revised_prompt: None,
}
}
pub fn from_b64(data: impl Into<String>) -> Self {
Self {
url: None,
b64_json: Some(data.into()),
revised_prompt: None,
}
}
pub fn with_revised_prompt(mut self, prompt: impl Into<String>) -> Self {
self.revised_prompt = Some(prompt.into());
self
}
}
#[derive(Debug, Clone)]
pub struct ImageEditRequest {
pub image: ImageInput,
pub prompt: String,
pub mask: Option<ImageInput>,
pub model: String,
pub n: Option<u8>,
pub size: Option<ImageSize>,
pub response_format: Option<ImageFormat>,
}
#[derive(Debug, Clone)]
pub struct ImageVariationRequest {
pub image: ImageInput,
pub model: String,
pub n: Option<u8>,
pub size: Option<ImageSize>,
pub response_format: Option<ImageFormat>,
}
#[derive(Debug, Clone)]
pub enum ImageInput {
File(PathBuf),
Base64 { data: String, media_type: String },
Url(String),
}
#[async_trait]
pub trait ImageProvider: Send + Sync {
fn name(&self) -> &str;
async fn generate_image(
&self,
request: ImageGenerationRequest,
) -> Result<ImageGenerationResponse>;
async fn edit_image(&self, _request: ImageEditRequest) -> Result<ImageGenerationResponse> {
Err(Error::not_supported("Image editing"))
}
async fn create_variation(
&self,
_request: ImageVariationRequest,
) -> Result<ImageGenerationResponse> {
Err(Error::not_supported("Image variations"))
}
fn supported_sizes(&self) -> &[ImageSize];
fn max_images_per_request(&self) -> u8 {
4
}
fn default_image_model(&self) -> Option<&str> {
None
}
fn supported_image_models(&self) -> Option<&[&str]> {
None
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct JobId(pub String);
impl JobId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
}
impl std::fmt::Display for JobId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone)]
pub enum JobStatus {
Queued,
Running,
Completed,
Failed(String),
Cancelled,
}
#[async_trait]
pub trait AsyncImageProvider: ImageProvider {
async fn start_generation(&self, request: ImageGenerationRequest) -> Result<JobId>;
async fn poll_status(&self, job_id: &JobId) -> Result<JobStatus>;
async fn get_result(&self, job_id: &JobId) -> Result<ImageGenerationResponse>;
async fn wait_for_completion(
&self,
job_id: &JobId,
poll_interval: Duration,
timeout: Duration,
) -> Result<ImageGenerationResponse> {
let start = std::time::Instant::now();
loop {
if start.elapsed() > timeout {
return Err(Error::Timeout);
}
match self.poll_status(job_id).await? {
JobStatus::Completed => return self.get_result(job_id).await,
JobStatus::Failed(msg) => return Err(Error::other(msg)),
JobStatus::Cancelled => return Err(Error::other("Image generation was cancelled")),
JobStatus::Queued | JobStatus::Running => {
tokio::time::sleep(poll_interval).await;
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct ImageModelInfo {
pub id: &'static str,
pub provider: &'static str,
pub sizes: &'static [ImageSize],
pub max_images: u8,
pub supports_editing: bool,
pub supports_variations: bool,
pub price_per_image: f64,
}
pub static IMAGE_MODELS: &[ImageModelInfo] = &[
ImageModelInfo {
id: "dall-e-3",
provider: "openai",
sizes: &[
ImageSize::Square1024,
ImageSize::Portrait1024x1792,
ImageSize::Landscape1792x1024,
],
max_images: 1,
supports_editing: false,
supports_variations: false,
price_per_image: 0.04, },
ImageModelInfo {
id: "dall-e-2",
provider: "openai",
sizes: &[
ImageSize::Square256,
ImageSize::Square512,
ImageSize::Square1024,
],
max_images: 10,
supports_editing: true,
supports_variations: true,
price_per_image: 0.02, },
];
pub fn get_image_model_info(model_id: &str) -> Option<&'static ImageModelInfo> {
IMAGE_MODELS.iter().find(|m| m.id == model_id)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_image_size_dimensions() {
assert_eq!(ImageSize::Square1024.dimensions(), (1024, 1024));
assert_eq!(ImageSize::Portrait1024x1792.dimensions(), (1024, 1792));
assert_eq!(
ImageSize::Custom {
width: 800,
height: 600
}
.dimensions(),
(800, 600)
);
}
#[test]
fn test_image_size_to_openai_string() {
assert_eq!(ImageSize::Square1024.to_openai_string(), "1024x1024");
assert_eq!(
ImageSize::Landscape1792x1024.to_openai_string(),
"1792x1024"
);
}
#[test]
fn test_image_request_builder() {
let request = ImageGenerationRequest::new("dall-e-3", "A cat")
.with_size(ImageSize::Square1024)
.with_quality(ImageQuality::Hd)
.with_style(ImageStyle::Vivid)
.with_n(2);
assert_eq!(request.model, "dall-e-3");
assert_eq!(request.prompt, "A cat");
assert_eq!(request.size, Some(ImageSize::Square1024));
assert_eq!(request.quality, Some(ImageQuality::Hd));
assert_eq!(request.style, Some(ImageStyle::Vivid));
assert_eq!(request.n, Some(2));
}
#[test]
fn test_generated_image() {
let img = GeneratedImage::from_url("https://example.com/image.png")
.with_revised_prompt("A cute cat sleeping on a couch");
assert_eq!(img.url, Some("https://example.com/image.png".to_string()));
assert!(img.b64_json.is_none());
assert!(img.revised_prompt.is_some());
}
#[test]
fn test_job_id() {
let job_id = JobId::new("job-123");
assert_eq!(job_id.to_string(), "job-123");
}
#[test]
fn test_image_model_registry() {
let model = get_image_model_info("dall-e-3");
assert!(model.is_some());
let model = model.unwrap();
assert_eq!(model.provider, "openai");
assert_eq!(model.max_images, 1);
}
}