Skip to main content

rig/providers/openai/
image_generation.rs

1use super::{Client, client::ApiResponse};
2use crate::http_client::HttpClientExt;
3use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
4use crate::json_utils::merge_inplace;
5use crate::{http_client, image_generation};
6use base64::Engine;
7use base64::prelude::BASE64_STANDARD;
8use serde::Deserialize;
9use serde_json::json;
10
11// ================================================================
12// OpenAI Image Generation API
13// ================================================================
14pub const DALL_E_2: &str = "dall-e-2";
15pub const DALL_E_3: &str = "dall-e-3";
16pub const GPT_IMAGE_1: &str = "gpt-image-1";
17pub const GPT_IMAGE_1_5: &str = "gpt-image-1.5";
18pub const GPT_IMAGE_2: &str = "gpt-image-2";
19
20#[derive(Debug, Deserialize)]
21pub struct ImageGenerationData {
22    pub b64_json: String,
23}
24
25#[derive(Debug, Deserialize)]
26pub struct ImageGenerationResponse {
27    pub created: i32,
28    pub data: Vec<ImageGenerationData>,
29}
30
31impl TryFrom<ImageGenerationResponse>
32    for image_generation::ImageGenerationResponse<ImageGenerationResponse>
33{
34    type Error = ImageGenerationError;
35
36    fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
37        let b64_json = value
38            .data
39            .first()
40            .ok_or_else(|| ImageGenerationError::ResponseError("missing image data".into()))?
41            .b64_json
42            .clone();
43
44        let bytes = BASE64_STANDARD
45            .decode(&b64_json)
46            .map_err(|err| ImageGenerationError::ResponseError(err.to_string()))?;
47
48        Ok(image_generation::ImageGenerationResponse {
49            image: bytes,
50            response: value,
51        })
52    }
53}
54
55#[derive(Clone)]
56pub struct ImageGenerationModel<T = reqwest::Client> {
57    client: Client<T>,
58    /// Name of the model (e.g.: dall-e-2)
59    pub model: String,
60}
61
62impl<T> ImageGenerationModel<T> {
63    pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
64        Self {
65            client,
66            model: model.into(),
67        }
68    }
69}
70
71impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
72where
73    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
74{
75    type Response = ImageGenerationResponse;
76
77    type Client = Client<T>;
78
79    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
80        Self::new(client.clone(), model)
81    }
82
83    async fn image_generation(
84        &self,
85        generation_request: ImageGenerationRequest,
86    ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
87    {
88        let mut request = json!({
89            "model": self.model,
90            "prompt": generation_request.prompt,
91            "size": format!("{}x{}", generation_request.width, generation_request.height),
92        });
93
94        if !matches!(
95            self.model.as_str(),
96            GPT_IMAGE_1 | GPT_IMAGE_1_5 | GPT_IMAGE_2
97        ) {
98            merge_inplace(
99                &mut request,
100                json!({
101                    "response_format": "b64_json"
102                }),
103            );
104        }
105
106        let body = serde_json::to_vec(&request)?;
107
108        let request = self
109            .client
110            .post("/images/generations")?
111            .body(body)
112            .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
113
114        let response = self.client.send(request).await?;
115
116        if !response.status().is_success() {
117            let status = response.status();
118            let text = http_client::text(response).await?;
119
120            return Err(ImageGenerationError::ProviderError(format!(
121                "{}: {}",
122                status, text,
123            )));
124        }
125
126        let text = http_client::text(response).await?;
127
128        match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&text)? {
129            ApiResponse::Ok(response) => response.try_into(),
130            ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
131        }
132    }
133}