rig/providers/openai/
image_generation.rs

1use super::{Client, client::ApiResponse};
2use crate::http_client::HttpClientExt;
3use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
4use crate::json_utils::merge_inplace;
5use crate::{http_client, image_generation};
6use base64::Engine;
7use base64::prelude::BASE64_STANDARD;
8use serde::Deserialize;
9use serde_json::json;
10
11// ================================================================
12// OpenAI Image Generation API
13// ================================================================
14pub const DALL_E_2: &str = "dall-e-2";
15pub const DALL_E_3: &str = "dall-e-3";
16pub const GPT_IMAGE_1: &str = "gpt-image-1";
17
18#[derive(Debug, Deserialize)]
19pub struct ImageGenerationData {
20    pub b64_json: String,
21}
22
23#[derive(Debug, Deserialize)]
24pub struct ImageGenerationResponse {
25    pub created: i32,
26    pub data: Vec<ImageGenerationData>,
27}
28
29impl TryFrom<ImageGenerationResponse>
30    for image_generation::ImageGenerationResponse<ImageGenerationResponse>
31{
32    type Error = ImageGenerationError;
33
34    fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
35        let b64_json = value.data[0].b64_json.clone();
36
37        let bytes = BASE64_STANDARD
38            .decode(&b64_json)
39            .expect("Failed to decode b64");
40
41        Ok(image_generation::ImageGenerationResponse {
42            image: bytes,
43            response: value,
44        })
45    }
46}
47
48#[derive(Clone)]
49pub struct ImageGenerationModel<T = reqwest::Client> {
50    client: Client<T>,
51    /// Name of the model (e.g.: dall-e-2)
52    pub model: String,
53}
54
55impl<T> ImageGenerationModel<T> {
56    pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
57        Self {
58            client,
59            model: model.into(),
60        }
61    }
62}
63
64impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
65where
66    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
67{
68    type Response = ImageGenerationResponse;
69
70    type Client = Client<T>;
71
72    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
73        Self::new(client.clone(), model)
74    }
75
76    #[cfg_attr(feature = "worker", worker::send)]
77    async fn image_generation(
78        &self,
79        generation_request: ImageGenerationRequest,
80    ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
81    {
82        let mut request = json!({
83            "model": self.model,
84            "prompt": generation_request.prompt,
85            "size": format!("{}x{}", generation_request.width, generation_request.height),
86        });
87
88        if self.model.as_str() != GPT_IMAGE_1 {
89            merge_inplace(
90                &mut request,
91                json!({
92                    "response_format": "b64_json"
93                }),
94            );
95        }
96
97        let body = serde_json::to_vec(&request)?;
98
99        let request = self
100            .client
101            .post("/images/generations")?
102            .body(body)
103            .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
104
105        let response = self.client.send(request).await?;
106
107        if !response.status().is_success() {
108            let status = response.status();
109            let text = http_client::text(response).await?;
110
111            return Err(ImageGenerationError::ProviderError(format!(
112                "{}: {}",
113                status, text,
114            )));
115        }
116
117        let text = http_client::text(response).await?;
118
119        match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&text)? {
120            ApiResponse::Ok(response) => response.try_into(),
121            ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
122        }
123    }
124}