use crate::client::AsyncForgeClient;
use crate::error::ForgeError;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageRequest {
pub prompt: String,
#[serde(default = "default_model")]
pub model: String,
#[serde(default = "default_n")]
pub n: u32,
#[serde(default = "default_size")]
pub size: ImageSize,
#[serde(default)]
pub response_format: ResponseFormat,
#[serde(default)]
pub quality: ImageQuality,
#[serde(default)]
pub style: Option<ImageStyle>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
fn default_model() -> String {
"dall-e-3".to_string()
}
fn default_n() -> u32 {
1
}
fn default_size() -> ImageSize {
ImageSize::Size1024x1024
}
impl ImageRequest {
pub fn new(prompt: impl Into<String>) -> Self {
Self {
prompt: prompt.into(),
model: default_model(),
n: default_n(),
size: default_size(),
response_format: ResponseFormat::default(),
quality: ImageQuality::default(),
style: None,
user: None,
}
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn n(mut self, n: u32) -> Self {
self.n = n;
self
}
pub fn size(mut self, size: ImageSize) -> Self {
self.size = size;
self
}
pub fn response_format(mut self, format: ResponseFormat) -> Self {
self.response_format = format;
self
}
pub fn quality(mut self, quality: ImageQuality) -> Self {
self.quality = quality;
self
}
pub fn style(mut self, style: ImageStyle) -> Self {
self.style = Some(style);
self
}
pub fn user(mut self, user: impl Into<String>) -> Self {
self.user = Some(user.into());
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum ImageSize {
#[serde(rename = "256x256")]
Size256x256,
#[serde(rename = "512x512")]
Size512x512,
#[serde(rename = "1024x1024")]
#[default]
Size1024x1024,
#[serde(rename = "1792x1024")]
Size1792x1024,
#[serde(rename = "1024x1792")]
Size1024x1792,
}
impl std::fmt::Display for ImageSize {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ImageSize::Size256x256 => write!(f, "256x256"),
ImageSize::Size512x512 => write!(f, "512x512"),
ImageSize::Size1024x1024 => write!(f, "1024x1024"),
ImageSize::Size1792x1024 => write!(f, "1792x1024"),
ImageSize::Size1024x1792 => write!(f, "1024x1792"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseFormat {
#[default]
Url,
B64Json,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ImageQuality {
#[default]
Standard,
Hd,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ImageStyle {
Natural,
Vivid,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageResponse {
pub created: u64,
pub data: Vec<ImageData>,
}
impl ImageResponse {
pub fn url(&self) -> Option<&str> {
self.data.first().and_then(|d| d.url.as_deref())
}
pub fn urls(&self) -> Vec<&str> {
self.data.iter().filter_map(|d| d.url.as_deref()).collect()
}
pub fn b64_json(&self) -> Option<&str> {
self.data.first().and_then(|d| d.b64_json.as_deref())
}
pub fn revised_prompt(&self) -> Option<&str> {
self.data.first().and_then(|d| d.revised_prompt.as_deref())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageData {
#[serde(skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub b64_json: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub revised_prompt: Option<String>,
}
pub async fn generate_image(
client: &AsyncForgeClient,
request: ImageRequest,
) -> Result<ImageResponse, ForgeError> {
let response = client.post("images/generations", &request).await?;
Ok(response)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageEditRequest {
pub image: String,
pub prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub mask: Option<String>,
#[serde(default = "default_edit_model")]
pub model: String,
#[serde(default = "default_n")]
pub n: u32,
#[serde(default = "default_size")]
pub size: ImageSize,
#[serde(default)]
pub response_format: ResponseFormat,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
fn default_edit_model() -> String {
"dall-e-2".to_string()
}
impl ImageEditRequest {
pub fn new(image: impl Into<String>, prompt: impl Into<String>) -> Self {
Self {
image: image.into(),
prompt: prompt.into(),
mask: None,
model: default_edit_model(),
n: default_n(),
size: default_size(),
response_format: ResponseFormat::default(),
user: None,
}
}
pub fn mask(mut self, mask: impl Into<String>) -> Self {
self.mask = Some(mask.into());
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn n(mut self, n: u32) -> Self {
self.n = n;
self
}
pub fn size(mut self, size: ImageSize) -> Self {
self.size = size;
self
}
}
pub async fn edit_image(
client: &AsyncForgeClient,
request: ImageEditRequest,
) -> Result<ImageResponse, ForgeError> {
let response = client.post("images/edits", &request).await?;
Ok(response)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageVariationRequest {
pub image: String,
#[serde(default = "default_edit_model")]
pub model: String,
#[serde(default = "default_n")]
pub n: u32,
#[serde(default = "default_size")]
pub size: ImageSize,
#[serde(default)]
pub response_format: ResponseFormat,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
impl ImageVariationRequest {
pub fn new(image: impl Into<String>) -> Self {
Self {
image: image.into(),
model: default_edit_model(),
n: default_n(),
size: default_size(),
response_format: ResponseFormat::default(),
user: None,
}
}
pub fn n(mut self, n: u32) -> Self {
self.n = n;
self
}
pub fn size(mut self, size: ImageSize) -> Self {
self.size = size;
self
}
}
pub async fn create_variations(
client: &AsyncForgeClient,
request: ImageVariationRequest,
) -> Result<ImageResponse, ForgeError> {
let response = client.post("images/variations", &request).await?;
Ok(response)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_image_request_builder() {
let request = ImageRequest::new("A beautiful sunset")
.model("dall-e-3")
.n(2)
.size(ImageSize::Size1792x1024)
.quality(ImageQuality::Hd)
.style(ImageStyle::Vivid)
.user("user123");
assert_eq!(request.prompt, "A beautiful sunset");
assert_eq!(request.model, "dall-e-3");
assert_eq!(request.n, 2);
assert_eq!(request.size, ImageSize::Size1792x1024);
assert_eq!(request.quality, ImageQuality::Hd);
assert_eq!(request.style, Some(ImageStyle::Vivid));
assert_eq!(request.user, Some("user123".to_string()));
}
#[test]
fn test_image_size_display() {
assert_eq!(ImageSize::Size256x256.to_string(), "256x256");
assert_eq!(ImageSize::Size512x512.to_string(), "512x512");
assert_eq!(ImageSize::Size1024x1024.to_string(), "1024x1024");
assert_eq!(ImageSize::Size1792x1024.to_string(), "1792x1024");
assert_eq!(ImageSize::Size1024x1792.to_string(), "1024x1792");
}
#[test]
fn test_image_response_helpers() {
let response = ImageResponse {
created: 1234567890,
data: vec![
ImageData {
url: Some("https://example.com/image1.png".to_string()),
b64_json: None,
revised_prompt: Some("A revised prompt".to_string()),
},
ImageData {
url: Some("https://example.com/image2.png".to_string()),
b64_json: None,
revised_prompt: None,
},
],
};
assert_eq!(response.url(), Some("https://example.com/image1.png"));
assert_eq!(response.urls().len(), 2);
assert_eq!(response.revised_prompt(), Some("A revised prompt"));
}
#[test]
fn test_image_edit_request() {
let request = ImageEditRequest::new("base64data", "Make it blue")
.mask("maskdata")
.n(2)
.size(ImageSize::Size512x512);
assert_eq!(request.image, "base64data");
assert_eq!(request.prompt, "Make it blue");
assert_eq!(request.mask, Some("maskdata".to_string()));
assert_eq!(request.n, 2);
}
#[test]
fn test_image_variation_request() {
let request = ImageVariationRequest::new("base64data")
.n(3)
.size(ImageSize::Size256x256);
assert_eq!(request.image, "base64data");
assert_eq!(request.n, 3);
assert_eq!(request.size, ImageSize::Size256x256);
}
#[test]
fn test_defaults() {
let request = ImageRequest::new("test");
assert_eq!(request.model, "dall-e-3");
assert_eq!(request.n, 1);
assert_eq!(request.size, ImageSize::Size1024x1024);
assert_eq!(request.response_format, ResponseFormat::Url);
assert_eq!(request.quality, ImageQuality::Standard);
}
}