ai/
image.rs

1use std::fs;
2use std::path::PathBuf;
3use clap::{Args,ValueEnum};
4use reqwest::Client;
5use serde::{Deserialize,Serialize};
6use serde_json::json;
7use rustc_serialize::base64::FromBase64;
8use derive_more::{From,TryInto};
9use crate::openai::OpenAIError;
10
11#[derive(Clone, Debug, Args)]
12pub struct ImageCommand {
13    /// Description of the image
14    #[arg(long, short)]
15    pub prompt: String,
16
17    /// Number of images generated
18    #[arg(long, short, default_value_t = ImageCommand::default().count)]
19    pub count: usize,
20
21    /// Generated image size
22    #[arg(value_enum, long, short, default_value_t = PictureSize::default())]
23    pub size: PictureSize,
24
25    /// Format of the response
26    #[arg(value_enum, long, short, default_value_t = PictureFormat::default())]
27    pub format: PictureFormat,
28
29    /// Directory to output files
30    #[arg(value_enum, long, short)]
31    pub out: Option<PathBuf>,
32}
33
34impl Default for ImageCommand {
35    fn default() -> Self {
36        Self {
37            prompt: String::new(),
38            count: 1,
39            size: PictureSize::default(),
40            format: PictureFormat::default(),
41            out: None
42        }
43    }
44}
45
46pub type ImageResult = Result<Vec<ImageData>, ImageError>;
47
48#[derive(Debug, From)]
49pub enum ImageError {
50    OpenAIError(OpenAIError),
51    DeserializeError(reqwest::Error)
52}
53
54impl ImageCommand {
55    pub async fn run(&self, client: &Client) -> ImageResult {
56        let request = client.post("https://api.openai.com/v1/images/generations")
57            .json(&json!({
58                "prompt": &self.prompt,
59                "n": self.count,
60                "size": match self.size {
61                    PictureSize::x256 => "256x256",
62                    PictureSize::x512 => "512x512",
63                    PictureSize::x1024 => "1024x1024",
64                },
65                "response_format": match &self.out {
66                    Some(_) => "b64_json",
67                    None => match self.format {
68                        PictureFormat::Url => "url",
69                        PictureFormat::Binary => "b64_json"
70                    }
71                }
72            }))
73            .send()
74            .await
75            .expect("Failed to send completion");
76
77        if !request.status().is_success() {
78            return Err(ImageError::OpenAIError(request.json().await?));
79        }
80
81        let response: OpenAIImageResponse = request.json().await?;
82
83        if let Some(out) = &self.out {
84            write_data_to_directory(out, &response);
85        }
86
87        Ok(response.data)
88    }
89}
90
91fn write_data_to_directory(out: &PathBuf, response: &OpenAIImageResponse) {
92    fs::create_dir_all(&out)
93        .expect(r#"Image "out" directory could not be created"#);
94
95    for (i, data) in response.data.iter().enumerate() {
96        match data {
97            ImageData::Url(_) => unreachable!(
98                "Response data should be in binary format"),
99
100            ImageData::Binary(data) => {
101                let content = data.b64_json.from_base64().unwrap();
102                let mut path = out.clone();
103                path.push(format!("{}.png", i));
104
105                fs::write(path, content).unwrap();
106            }
107        }
108    }
109}
110
111#[derive(Deserialize, Debug)]
112pub struct OpenAIImageResponse {
113    pub created: usize,
114    pub data: Vec<ImageData>
115}
116
117#[derive(Clone, From, TryInto, Serialize, Deserialize, Debug)]
118#[serde(untagged)]
119#[try_into(owned, ref, ref_mut)]
120pub enum ImageData {
121    Url(ImageUrl),
122    Binary(ImageBinary),
123}
124
125#[derive(Clone, Default, Serialize, Deserialize, Debug)]
126pub struct ImageUrl {
127    pub url: String
128}
129
130#[derive(Clone, Serialize, Deserialize, Debug)]
131pub struct ImageBinary {
132    pub b64_json: String
133}
134
135#[derive(Default, Copy, Clone, Serialize, Deserialize, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
136#[allow(non_camel_case_types)]
137pub enum PictureSize {
138    x256,
139    #[default]
140    x512,
141    x1024
142}
143
144#[derive(Default, Copy, Clone, Serialize, Deserialize, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
145pub enum PictureFormat {
146    #[default]
147    Url,
148    Binary
149}