llm_sdk/api/
create_image.rs

1use crate::IntoRequest;
2use derive_builder::Builder;
3use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Builder)]
7#[builder(pattern = "mutable")]
8pub struct CreateImageRequest {
9    /// A text description of the desired image(s). The maximum length is 4000 characters for dall-e-3.
10    #[builder(setter(into))]
11    prompt: String,
12    /// The model to use for image generation. Only support Dall-e-3
13    #[builder(default)]
14    model: ImageModel,
15    /// The number of images to generate. Must be between 1 and 10. For dall-e-3, only n=1 is supported.
16    #[builder(default, setter(strip_option))]
17    #[serde(skip_serializing_if = "Option::is_none")]
18    n: Option<usize>,
19    /// The quality of the image that will be generated. hd creates images with finer details and greater consistency across the image. This param is only supported for dall-e-3.
20    #[builder(default, setter(strip_option))]
21    #[serde(skip_serializing_if = "Option::is_none")]
22    quality: Option<ImageQuality>,
23    /// The format in which the generated images are returned. Must be one of url or b64_json.
24    #[builder(default, setter(strip_option))]
25    #[serde(skip_serializing_if = "Option::is_none")]
26    response_format: Option<ImageResponseFormat>,
27    /// The size of the generated images. Must be one of 1024x1024, 1792x1024, or 1024x1792 for dall-e-3 models.
28    #[builder(default, setter(strip_option))]
29    #[serde(skip_serializing_if = "Option::is_none")]
30    size: Option<ImageSize>,
31    /// The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for dall-e-3.
32    #[builder(default, setter(strip_option))]
33    #[serde(skip_serializing_if = "Option::is_none")]
34    style: Option<ImageStyle>,
35    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
36    #[builder(default, setter(strip_option, into))]
37    #[serde(skip_serializing_if = "Option::is_none")]
38    user: Option<String>,
39}
40
41#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
42pub enum ImageModel {
43    #[serde(rename = "dall-e-3")]
44    #[default]
45    DallE3,
46}
47
48#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
49#[serde(rename_all = "snake_case")]
50pub enum ImageQuality {
51    #[default]
52    Standard,
53    Hd,
54}
55
56#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
57#[serde(rename_all = "snake_case")]
58pub enum ImageResponseFormat {
59    #[default]
60    Url,
61    B64Json,
62}
63
64#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
65pub enum ImageSize {
66    #[serde(rename = "1024x1024")]
67    #[default]
68    Large,
69    #[serde(rename = "1792x1024")]
70    LargeWide,
71    #[serde(rename = "1024x1792")]
72    LargeTall,
73}
74
75#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
76#[serde(rename_all = "snake_case")]
77pub enum ImageStyle {
78    #[default]
79    Vivid,
80    Natural,
81}
82
83#[derive(Debug, Clone, Deserialize)]
84pub struct CreateImageResponse {
85    pub created: u64,
86    pub data: Vec<ImageObject>,
87}
88
89#[derive(Debug, Clone, Deserialize)]
90pub struct ImageObject {
91    /// The base64-encoded JSON of the generated image, if response_format is b64_json
92    pub b64_json: Option<String>,
93    /// The URL of the generated image, if response_format is url (default).
94    pub url: Option<String>,
95    /// The prompt that was used to generate the image, if there was any revision to the prompt.
96    pub revised_prompt: String,
97}
98
99impl IntoRequest for CreateImageRequest {
100    fn into_request(self, base_url: &str, client: ClientWithMiddleware) -> RequestBuilder {
101        let url = format!("{}/images/generations", base_url);
102        client.post(url).json(&self)
103    }
104}
105
106impl CreateImageRequest {
107    pub fn new(prompt: impl Into<String>) -> Self {
108        CreateImageRequestBuilder::default()
109            .prompt(prompt)
110            .build()
111            .unwrap()
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::SDK;
119    use anyhow::Result;
120    use serde_json::json;
121
122    #[test]
123    fn create_image_request_should_serialize() -> Result<()> {
124        let req = CreateImageRequest::new("draw a cute caterpillar");
125        assert_eq!(
126            serde_json::to_value(req)?,
127            json!({
128              "prompt": "draw a cute caterpillar",
129              "model": "dall-e-3",
130            })
131        );
132        Ok(())
133    }
134
135    #[test]
136    fn create_image_request_custom_should_serialize() -> Result<()> {
137        let req = CreateImageRequestBuilder::default()
138            .prompt("draw a cute caterpillar")
139            .style(ImageStyle::Natural)
140            .quality(ImageQuality::Hd)
141            .build()?;
142        assert_eq!(
143            serde_json::to_value(req)?,
144            json!({
145              "prompt": "draw a cute caterpillar",
146              "model": "dall-e-3",
147              "style": "natural",
148              "quality": "hd",
149            })
150        );
151        Ok(())
152    }
153
154    // this test is too expensive to run, skip for CI
155    #[tokio::test]
156    #[ignore]
157    async fn create_image_should_work() -> Result<()> {
158        let req = CreateImageRequest::new("draw a cute caterpillar");
159        let res = SDK.create_image(req).await?;
160        assert_eq!(res.data.len(), 1);
161        let image = &res.data[0];
162        assert!(image.url.is_some());
163        assert!(image.b64_json.is_none());
164        println!("image: {:?}", image);
165
166        Ok(())
167    }
168}