rig/providers/openai/
image_generation.rs

1use crate::http_client::HttpClientExt;
2use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
3use crate::json_utils::merge_inplace;
4use crate::providers::openai::{ApiResponse, Client};
5use crate::{http_client, image_generation};
6use base64::Engine;
7use base64::prelude::BASE64_STANDARD;
8use serde::Deserialize;
9use serde_json::json;
10
11// ================================================================
12// OpenAI Image Generation API
13// ================================================================
14pub const DALL_E_2: &str = "dall-e-2";
15pub const DALL_E_3: &str = "dall-e-3";
16
17pub const GPT_IMAGE_1: &str = "gpt-image-1";
18
19#[derive(Debug, Deserialize)]
20pub struct ImageGenerationData {
21    pub b64_json: String,
22}
23
24#[derive(Debug, Deserialize)]
25pub struct ImageGenerationResponse {
26    pub created: i32,
27    pub data: Vec<ImageGenerationData>,
28}
29
30impl TryFrom<ImageGenerationResponse>
31    for image_generation::ImageGenerationResponse<ImageGenerationResponse>
32{
33    type Error = ImageGenerationError;
34
35    fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
36        let b64_json = value.data[0].b64_json.clone();
37
38        let bytes = BASE64_STANDARD
39            .decode(&b64_json)
40            .expect("Failed to decode b64");
41
42        Ok(image_generation::ImageGenerationResponse {
43            image: bytes,
44            response: value,
45        })
46    }
47}
48
49#[derive(Clone)]
50pub struct ImageGenerationModel<T = reqwest::Client> {
51    client: Client<T>,
52    /// Name of the model (e.g.: dall-e-2)
53    pub model: String,
54}
55
56impl<T> ImageGenerationModel<T> {
57    pub(crate) fn new(client: Client<T>, model: &str) -> Self {
58        Self {
59            client,
60            model: model.to_string(),
61        }
62    }
63}
64
65impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
66where
67    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
68{
69    type Response = ImageGenerationResponse;
70
71    #[cfg_attr(feature = "worker", worker::send)]
72    async fn image_generation(
73        &self,
74        generation_request: ImageGenerationRequest,
75    ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
76    {
77        let mut request = json!({
78            "model": self.model,
79            "prompt": generation_request.prompt,
80            "size": format!("{}x{}", generation_request.width, generation_request.height),
81        });
82
83        if self.model != *"gpt-image-1" {
84            merge_inplace(
85                &mut request,
86                json!({
87                    "response_format": "b64_json"
88                }),
89            );
90        }
91
92        let body = serde_json::to_vec(&request)?;
93
94        let request = self
95            .client
96            .post("/images/generations")?
97            .header("Content-Type", "application/json")
98            .body(body)
99            .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
100
101        let response = self.client.send(request).await?;
102
103        if !response.status().is_success() {
104            let status = response.status();
105            let text = http_client::text(response).await?;
106
107            return Err(ImageGenerationError::ProviderError(format!(
108                "{}: {}",
109                status, text,
110            )));
111        }
112
113        let text = http_client::text(response).await?;
114
115        match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&text)? {
116            ApiResponse::Ok(response) => response.try_into(),
117            ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
118        }
119    }
120}