1use std::fs;
2use std::path::PathBuf;
3use clap::{Args,ValueEnum};
4use reqwest::Client;
5use serde::{Deserialize,Serialize};
6use serde_json::json;
7use rustc_serialize::base64::FromBase64;
8use derive_more::{From,TryInto};
9use crate::openai::OpenAIError;
10
11#[derive(Clone, Debug, Args)]
12pub struct ImageCommand {
13 #[arg(long, short)]
15 pub prompt: String,
16
17 #[arg(long, short, default_value_t = ImageCommand::default().count)]
19 pub count: usize,
20
21 #[arg(value_enum, long, short, default_value_t = PictureSize::default())]
23 pub size: PictureSize,
24
25 #[arg(value_enum, long, short, default_value_t = PictureFormat::default())]
27 pub format: PictureFormat,
28
29 #[arg(value_enum, long, short)]
31 pub out: Option<PathBuf>,
32}
33
34impl Default for ImageCommand {
35 fn default() -> Self {
36 Self {
37 prompt: String::new(),
38 count: 1,
39 size: PictureSize::default(),
40 format: PictureFormat::default(),
41 out: None
42 }
43 }
44}
45
46pub type ImageResult = Result<Vec<ImageData>, ImageError>;
47
48#[derive(Debug, From)]
49pub enum ImageError {
50 OpenAIError(OpenAIError),
51 DeserializeError(reqwest::Error)
52}
53
54impl ImageCommand {
55 pub async fn run(&self, client: &Client) -> ImageResult {
56 let request = client.post("https://api.openai.com/v1/images/generations")
57 .json(&json!({
58 "prompt": &self.prompt,
59 "n": self.count,
60 "size": match self.size {
61 PictureSize::x256 => "256x256",
62 PictureSize::x512 => "512x512",
63 PictureSize::x1024 => "1024x1024",
64 },
65 "response_format": match &self.out {
66 Some(_) => "b64_json",
67 None => match self.format {
68 PictureFormat::Url => "url",
69 PictureFormat::Binary => "b64_json"
70 }
71 }
72 }))
73 .send()
74 .await
75 .expect("Failed to send completion");
76
77 if !request.status().is_success() {
78 return Err(ImageError::OpenAIError(request.json().await?));
79 }
80
81 let response: OpenAIImageResponse = request.json().await?;
82
83 if let Some(out) = &self.out {
84 write_data_to_directory(out, &response);
85 }
86
87 Ok(response.data)
88 }
89}
90
91fn write_data_to_directory(out: &PathBuf, response: &OpenAIImageResponse) {
92 fs::create_dir_all(&out)
93 .expect(r#"Image "out" directory could not be created"#);
94
95 for (i, data) in response.data.iter().enumerate() {
96 match data {
97 ImageData::Url(_) => unreachable!(
98 "Response data should be in binary format"),
99
100 ImageData::Binary(data) => {
101 let content = data.b64_json.from_base64().unwrap();
102 let mut path = out.clone();
103 path.push(format!("{}.png", i));
104
105 fs::write(path, content).unwrap();
106 }
107 }
108 }
109}
110
111#[derive(Deserialize, Debug)]
112pub struct OpenAIImageResponse {
113 pub created: usize,
114 pub data: Vec<ImageData>
115}
116
117#[derive(Clone, From, TryInto, Serialize, Deserialize, Debug)]
118#[serde(untagged)]
119#[try_into(owned, ref, ref_mut)]
120pub enum ImageData {
121 Url(ImageUrl),
122 Binary(ImageBinary),
123}
124
125#[derive(Clone, Default, Serialize, Deserialize, Debug)]
126pub struct ImageUrl {
127 pub url: String
128}
129
130#[derive(Clone, Serialize, Deserialize, Debug)]
131pub struct ImageBinary {
132 pub b64_json: String
133}
134
135#[derive(Default, Copy, Clone, Serialize, Deserialize, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
136#[allow(non_camel_case_types)]
137pub enum PictureSize {
138 x256,
139 #[default]
140 x512,
141 x1024
142}
143
144#[derive(Default, Copy, Clone, Serialize, Deserialize, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
145pub enum PictureFormat {
146 #[default]
147 Url,
148 Binary
149}