use derive_builder::Builder;
use reqwest::{
get,
multipart::{Form, Part},
};
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use std::{
borrow::Cow,
fs,
io::{copy, Cursor},
path::Path,
str::FromStr,
};
use crate::{api_resources::TokenUsage, Client, Result};
#[derive(Clone, Debug, Default, Deserialize, PartialEq)]
pub enum ImageSize {
S256x256,
S512x512,
#[default]
S1024x1024,
}
impl std::fmt::Display for ImageSize {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ImageSize::S256x256 => write!(f, "256x256"),
ImageSize::S512x512 => write!(f, "512x512"),
ImageSize::S1024x1024 => write!(f, "1024x1024"),
}
}
}
impl FromStr for ImageSize {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"256x256" => Ok(ImageSize::S256x256),
"512x512" => Ok(ImageSize::S512x512),
"1024x1024" => Ok(ImageSize::S1024x1024),
_ => Err(format!("Invalid ImageSize: {}", s)),
}
}
}
impl Serialize for ImageSize {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
#[skip_serializing_none]
#[derive(Builder, Debug, Default, Deserialize, Serialize)]
#[builder(default, setter(into, strip_option))]
pub struct GenerateImageParam {
prompt: String,
n: Option<u8>,
size: Option<ImageSize>,
user: Option<String>,
}
impl GenerateImageParamBuilder {
pub fn new(prompt: impl Into<String>) -> Self {
Self {
prompt: Some(prompt.into()),
..Default::default()
}
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Image {
pub created: Option<u64>,
pub data: Option<Links>,
pub token_usage: Option<TokenUsage>,
}
impl Image {
pub async fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
if let Some(data) = &self.data {
for (i, link) in data.iter().enumerate() {
let resp = get(&link.url).await?;
let def_img_name = format!("image_{i}.png");
let fname = resp
.url()
.path_segments()
.and_then(|segments| segments.last())
.unwrap_or(def_img_name.as_str());
let full_path = Path::new(path.as_ref()).join(fname);
let mut file = fs::File::create(full_path)?;
let mut content = Cursor::new(resp.bytes().await?);
copy(&mut content, &mut file)?;
}
}
Ok(())
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Link {
pub url: String,
}
type Links = Vec<Link>;
#[skip_serializing_none]
#[derive(Builder, Debug, Deserialize, Serialize)]
#[builder(default, setter(into, strip_option))]
pub struct EditImageParam {
prompt: String,
n: u8,
size: ImageSize,
user: String,
}
impl Default for EditImageParam {
fn default() -> Self {
Self {
prompt: String::new(),
n: 1,
size: ImageSize::S1024x1024,
user: String::new(),
}
}
}
impl EditImageParamBuilder {
pub fn new(prompt: impl Into<String>) -> Self {
Self {
prompt: Some(prompt.into()),
..Default::default()
}
}
}
#[skip_serializing_none]
#[derive(Builder, Debug, Deserialize, Serialize)]
#[builder(default, setter(into, strip_option))]
pub struct VariateImageParam {
n: u8,
size: ImageSize,
user: String,
}
impl Default for VariateImageParam {
fn default() -> Self {
Self {
n: 1,
size: ImageSize::S1024x1024,
user: String::new(),
}
}
}
impl VariateImageParamBuilder {
pub fn new() -> Self {
Self::default()
}
}
pub async fn generate(client: &Client, param: &GenerateImageParam) -> Result<Image> {
client.generate_image(param).await
}
pub async fn edit<P>(client: &Client, image: P, param: &EditImageParam) -> Result<Image>
where
P: AsRef<Path> + Into<Cow<'static, str>> + Copy,
{
client.edit_image(image, param).await
}
pub async fn variate<P>(client: &Client, image: P, param: &VariateImageParam) -> Result<Image>
where
P: AsRef<Path> + Into<Cow<'static, str>> + Copy,
{
client.variate_image(image, param).await
}
impl Client {
async fn generate_image(&self, param: &GenerateImageParam) -> Result<Image> {
self.post::<GenerateImageParam, Image>("images/generations", Some(param))
.await
}
async fn edit_image<P>(&self, image: P, param: &EditImageParam) -> Result<Image>
where
P: AsRef<Path> + Into<Cow<'static, str>> + Copy,
{
let data = fs::read(image)?;
let part = Part::bytes(data).file_name(image);
let form = Form::new()
.part("image", part)
.text("prompt", "22")
.text("n", param.n.to_string())
.text("size", param.size.to_string())
.text("user", param.user.to_string());
self.post_data::<Image>("images/edits", form).await
}
async fn variate_image<P>(&self, image: P, param: &VariateImageParam) -> Result<Image>
where
P: AsRef<Path> + Into<Cow<'static, str>> + Copy,
{
let data = fs::read(image)?;
let part = Part::bytes(data).file_name(image);
let form = Form::new()
.part("image", part)
.text("n", param.n.to_string())
.text("size", param.size.to_string())
.text("user", param.user.to_string());
self.post_data::<Image>("images/variations", form).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_image_response() {
let param: GenerateImageParam = serde_json::from_str(
r#"{
"prompt": "A cute baby sea otter",
"size": "S256x256",
"n": 1
}"#,
)
.unwrap();
let result: Image = serde_json::from_str(
r#"
{
"created": 1589478378,
"data": [
{
"url": "https://..."
},
{
"url": "https://..."
}
]
}
"#,
)
.unwrap();
assert_eq!(param.prompt, "A cute baby sea otter");
assert_eq!(param.size, Some(ImageSize::S256x256));
assert_eq!(param.user, None);
assert_eq!(result.data.unwrap().len(), 2);
}
}