async_openai/
image.rs

1use crate::{
2    config::Config,
3    error::OpenAIError,
4    types::images::{
5        CreateImageEditRequest, CreateImageRequest, CreateImageVariationRequest, ImagesResponse,
6    },
7    Client, RequestOptions,
8};
9
10#[cfg(not(target_family = "wasm"))]
11use crate::types::images::{ImageEditStream, ImageGenStream};
12
13/// Given a prompt and/or an input image, the model will generate a new image.
14///
15/// Related guide: [Image generation](https://platform.openai.com/docs/guides/images)
16pub struct Images<'c, C: Config> {
17    client: &'c Client<C>,
18    pub(crate) request_options: RequestOptions,
19}
20
21impl<'c, C: Config> Images<'c, C> {
22    pub fn new(client: &'c Client<C>) -> Self {
23        Self {
24            client,
25            request_options: RequestOptions::new(),
26        }
27    }
28
29    /// Creates an image given a prompt.
30    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
31    pub async fn generate(
32        &self,
33        request: CreateImageRequest,
34    ) -> Result<ImagesResponse, OpenAIError> {
35        self.client
36            .post("/images/generations", request, &self.request_options)
37            .await
38    }
39
40    /// Creates an image given a prompt.
41    #[cfg(not(target_family = "wasm"))]
42    #[crate::byot(
43        T0 = serde::Serialize,
44        R = serde::de::DeserializeOwned,
45        stream = "true",
46        where_clause = "R: std::marker::Send + 'static"
47    )]
48    #[allow(unused_mut)]
49    pub async fn generate_stream(
50        &self,
51        mut request: CreateImageRequest,
52    ) -> Result<ImageGenStream, OpenAIError> {
53        #[cfg(not(feature = "byot"))]
54        {
55            if request.stream.is_some() && !request.stream.unwrap() {
56                return Err(OpenAIError::InvalidArgument(
57                    "When stream is false, use Image::generate".into(),
58                ));
59            }
60
61            request.stream = Some(true);
62        }
63
64        Ok(self
65            .client
66            .post_stream("/images/generations", request, &self.request_options)
67            .await)
68    }
69
70    /// Creates an edited or extended image given one or more source images and a prompt.
71    /// This endpoint only supports gpt-image-1 and dall-e-2.
72    #[crate::byot(
73        T0 = Clone,
74        R = serde::de::DeserializeOwned,
75        where_clause =  "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
76    )]
77    pub async fn edit(
78        &self,
79        request: CreateImageEditRequest,
80    ) -> Result<ImagesResponse, OpenAIError> {
81        self.client
82            .post_form("/images/edits", request, &self.request_options)
83            .await
84    }
85
86    /// Creates an edited or extended image given one or more source images and a prompt.
87    /// This endpoint only supports gpt-image-1 and dall-e-2.
88    #[cfg(not(target_family = "wasm"))]
89    #[crate::byot(
90        T0 = Clone,
91        R = serde::de::DeserializeOwned,
92        stream = "true",
93        where_clause = "R: std::marker::Send + 'static, reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>"
94    )]
95    #[allow(unused_mut)]
96    pub async fn edit_stream(
97        &self,
98        mut request: CreateImageEditRequest,
99    ) -> Result<ImageEditStream, OpenAIError> {
100        #[cfg(not(feature = "byot"))]
101        {
102            if let Some(stream) = request.stream {
103                if !stream {
104                    return Err(OpenAIError::InvalidArgument(
105                        "When stream is false, use Image::edit".into(),
106                    ));
107                }
108            }
109            request.stream = Some(true);
110        }
111        self.client
112            .post_form_stream("/images/edits", request, &self.request_options)
113            .await
114    }
115
116    /// Creates a variation of a given image. This endpoint only supports dall-e-2.
117    #[crate::byot(
118        T0 = Clone,
119        R = serde::de::DeserializeOwned,
120        where_clause =  "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
121    )]
122    pub async fn create_variation(
123        &self,
124        request: CreateImageVariationRequest,
125    ) -> Result<ImagesResponse, OpenAIError> {
126        self.client
127            .post_form("/images/variations", request, &self.request_options)
128            .await
129    }
130}