use crate::IntoRequest;
use derive_builder::Builder;
use reqwest::{Client, RequestBuilder};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Builder)]
#[builder(pattern = "mutable")]
pub struct CreateImageRequest {
#[builder(setter(into))]
prompt: String,
#[builder(default)]
model: ImageModel,
#[builder(default, setter(strip_option))]
#[serde(skip_serializing_if = "Option::is_none")]
n: Option<usize>,
#[builder(default, setter(strip_option))]
#[serde(skip_serializing_if = "Option::is_none")]
quality: Option<ImageQuality>,
#[builder(default, setter(strip_option))]
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<ImageResponseFormat>,
#[builder(default, setter(strip_option))]
#[serde(skip_serializing_if = "Option::is_none")]
size: Option<ImageSize>,
#[builder(default, setter(strip_option))]
#[serde(skip_serializing_if = "Option::is_none")]
style: Option<ImageStyle>,
#[builder(default, setter(strip_option, into))]
#[serde(skip_serializing_if = "Option::is_none")]
user: Option<String>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
pub enum ImageModel {
#[serde(rename = "dall-e-3")]
#[default]
DallE3,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ImageQuality {
#[default]
Standard,
Hd,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ImageResponseFormat {
#[default]
Url,
B64Json,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
pub enum ImageSize {
#[serde(rename = "1024x1024")]
#[default]
Large,
#[serde(rename = "1792x1024")]
LargeWide,
#[serde(rename = "1024x1792")]
LargeTall,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ImageStyle {
#[default]
Vivid,
Natural,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CreateImageResponse {
pub created: u64,
pub data: Vec<ImageObject>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ImageObject {
pub b64_json: Option<String>,
pub url: Option<String>,
pub revised_prompt: String,
}
impl IntoRequest for CreateImageRequest {
fn into_request(self, base_url: &str, client: Client) -> RequestBuilder {
let url = format!("{}/images/generations", base_url);
client.post(url).json(&self)
}
}
impl CreateImageRequest {
pub fn new(prompt: impl Into<String>) -> Self {
CreateImageRequestBuilder::default()
.prompt(prompt)
.build()
.unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::SDK;
use anyhow::Result;
use serde_json::json;
use std::fs;
#[test]
fn create_image_request_should_serialize() -> Result<()> {
let req = CreateImageRequest::new("draw a cute caterpillar");
assert_eq!(
serde_json::to_value(req)?,
json!({
"prompt": "draw a cute caterpillar",
"model": "dall-e-3",
})
);
Ok(())
}
#[test]
fn create_image_request_custom_should_serialize() -> Result<()> {
let req = CreateImageRequestBuilder::default()
.prompt("draw a cute caterpillar")
.style(ImageStyle::Natural)
.quality(ImageQuality::Hd)
.build()?;
assert_eq!(
serde_json::to_value(req)?,
json!({
"prompt": "draw a cute caterpillar",
"model": "dall-e-3",
"style": "natural",
"quality": "hd",
})
);
Ok(())
}
#[tokio::test]
#[ignore]
async fn create_image_should_work() -> Result<()> {
let req = CreateImageRequest::new("draw a cute caterpillar");
let res = SDK.create_image(req).await?;
assert_eq!(res.data.len(), 1);
let image = &res.data[0];
assert!(image.url.is_some());
assert!(image.b64_json.is_none());
println!("image: {:?}", image);
fs::write(
"/tmp/llm-sdk/caterpillar.png",
reqwest::get(image.url.as_ref().unwrap())
.await?
.bytes()
.await?,
)?;
Ok(())
}
}