gpt_rs/image/
mod.rs

1pub mod model;
2
3use super::client::Client;
4use super::entry_point::{EntryPoint, Function, Version};
5use anyhow::Result;
6use model::{CreateImageRequest, CreateImageResponse, CreateImageResponseDto, Url, B64};
7
8#[derive(Debug, Default, Clone, PartialEq)]
9pub struct CreateImageBuilder {
10    version: Option<Version>,
11    request: Option<CreateImageRequest>,
12}
13
14impl CreateImageBuilder {
15    pub fn version(mut self, value: Version) -> Self {
16        self.version = Some(value);
17        self
18    }
19    pub fn request(mut self, value: CreateImageRequest) -> Self {
20        self.request = Some(value);
21        self
22    }
23    pub fn build(&self) -> Result<CreateImage> {
24        let version = self.version.unwrap_or(Default::default());
25        let entry_point = EntryPoint::default()
26            .set_version(version)
27            .set_function(Function::CreateImage);
28        let request = match &self.request {
29            Some(val) => val.clone(),
30            _ => return Err(anyhow::anyhow!("Request must be set.")),
31        };
32        Ok(CreateImage {
33            entry_point,
34            request,
35        })
36    }
37}
38
39#[derive(Debug, Clone, PartialEq)]
40pub struct CreateImage {
41    entry_point: EntryPoint,
42    request: CreateImageRequest,
43}
44
45impl CreateImage {
46    pub fn new(request: CreateImageRequest) -> Self {
47        let entry_point = EntryPoint::default().set_function(Function::CreateImage);
48        Self {
49            entry_point,
50            request,
51        }
52    }
53    pub fn builder() -> CreateImageBuilder {
54        Default::default()
55    }
56    pub async fn execute(&self, client: &Client) -> Result<CreateImageResponse> {
57        let res = client
58            .post(&self.entry_point.path(), self.request.clone())
59            .await?;
60        let res = res.text().await?;
61        let res = match &self.request.response_format {
62            Some(val) if val == &model::Format::B64_Json => {
63                let res: CreateImageResponseDto<B64> = serde_json::from_str(res.as_str())?;
64                res.into()
65            }
66            _ => {
67                let res: CreateImageResponseDto<Url> = serde_json::from_str(res.as_str())?;
68                res.into()
69            }
70        };
71        Ok(res)
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::super::client::Config;
78    use super::*;
79    #[test]
80    fn builder() -> Result<()> {
81        let builder = CreateImageBuilder::default();
82        assert!(builder.clone().build().is_err());
83        let builder = builder.request("".into());
84        assert!(builder.build().is_ok());
85        Ok(())
86    }
87
88    #[tokio::test]
89    async fn post() -> Result<()> {
90        let client = Client::new(Config::from_env()?)?;
91        // Json
92        let res = CreateImage::builder()
93            .request(CreateImageRequest {
94                prompt: "doc".to_string(),
95                n: 1,
96                size: Default::default(),
97                response_format: Some(model::Format::B64_Json),
98                user: None,
99            })
100            .build()?
101            .execute(&client)
102            .await;
103        assert!(res.is_ok());
104        // URL
105        let res = CreateImage::builder()
106            .request("doc".into())
107            .build()?
108            .execute(&client)
109            .await;
110        assert!(res.is_ok());
111        Ok(())
112    }
113}