use core::fmt;
use reqwest::{Client as ReqwestClient, Method, RequestBuilder, Url};
use serde::{Deserialize, Serialize};
use crate::error::Error;
pub struct ImageClient {
base_url: Url,
http_client: ReqwestClient,
}
impl ImageClient {
pub fn new(base_url: Url, http_client: ReqwestClient) -> Self {
Self {
base_url,
http_client,
}
}
pub async fn create_image(&self, payload: CreateImage) -> Result<CreateImageResponse, Error> {
let image = self
.request(Method::POST, "images/generations")?
.json(&payload)
.send()
.await?
.json::<CreateImageResponse>()
.await?;
Ok(image)
}
fn request(&self, method: Method, path: &str) -> Result<RequestBuilder, Error> {
let url = self
.base_url
.join(path)
.map_err(|err| Error::UrlParse(err.to_string()))?;
Ok(self.http_client.request(method, url))
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CreateImage {
pub prompt: String,
pub model: ImageModel,
pub n: u8,
pub quality: ImageQuality,
pub response_format: ImageResponseFormat,
pub size: ImageSize,
pub style: ImageStyle,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
impl CreateImage {
pub fn new(prompt: impl Into<String>, model: ImageModel) -> Self {
Self {
prompt: prompt.into(),
model,
n: 1,
quality: ImageQuality::default(),
response_format: ImageResponseFormat::default(),
size: ImageSize::default(),
style: ImageStyle::default(),
user: None,
}
}
pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
self.prompt = prompt.into();
self
}
pub fn with_model(mut self, model: ImageModel) -> Self {
self.model = model;
self
}
pub fn with_n(mut self, n: u8) -> Self {
self.n = if n > 10 { 10 } else { n };
self
}
pub fn with_quality(mut self, quality: ImageQuality) -> Self {
self.quality = quality;
self
}
pub fn with_response_format(mut self, response_format: ImageResponseFormat) -> Self {
self.response_format = response_format;
self
}
pub fn with_style(mut self, style: ImageStyle) -> Self {
self.style = style;
self
}
pub fn with_user(mut self, user: impl Into<String>) -> Self {
self.user = Some(user.into());
self
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ImageModel {
#[serde(rename = "dall-e-2")]
DallE2,
#[serde(rename = "dall-e-3")]
DallE3,
}
impl fmt::Display for ImageModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DallE2 => write!(f, "dall-e-2"),
Self::DallE3 => write!(f, "dall-e-3"),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ImageQuality {
HD,
Standard,
}
impl Default for ImageQuality {
fn default() -> Self {
Self::Standard
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CreateImageResponse {
pub created: u64,
pub data: Vec<ImageData>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ImageData {
pub url: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ImageResponseFormat {
Url,
Base64Json,
}
impl Default for ImageResponseFormat {
fn default() -> Self {
Self::Url
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ImageSize {
#[serde(rename = "256x256")]
S256x256,
#[serde(rename = "512x512")]
S512x512,
#[serde(rename = "1024x1024")]
S1024x1024,
#[serde(rename = "1792x1024")]
S1792x1024,
#[serde(rename = "1024x1792")]
S1024x1792,
}
impl Default for ImageSize {
fn default() -> Self {
Self::S1024x1024
}
}
impl fmt::Display for ImageSize {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::S256x256 => write!(f, "256x256"),
Self::S512x512 => write!(f, "512x512"),
Self::S1024x1024 => write!(f, "1024x1024"),
Self::S1792x1024 => write!(f, "1792x1024"),
Self::S1024x1792 => write!(f, "1024x1792"),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ImageStyle {
Vivid,
Natural,
}
impl Default for ImageStyle {
fn default() -> Self {
Self::Vivid
}
}