rig/providers/openai/
image_generation.rs1use super::{Client, client::ApiResponse};
2use crate::http_client::HttpClientExt;
3use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
4use crate::json_utils::merge_inplace;
5use crate::{http_client, image_generation};
6use base64::Engine;
7use base64::prelude::BASE64_STANDARD;
8use serde::Deserialize;
9use serde_json::json;
10
11pub const DALL_E_2: &str = "dall-e-2";
15pub const DALL_E_3: &str = "dall-e-3";
16pub const GPT_IMAGE_1: &str = "gpt-image-1";
17
18#[derive(Debug, Deserialize)]
19pub struct ImageGenerationData {
20 pub b64_json: String,
21}
22
23#[derive(Debug, Deserialize)]
24pub struct ImageGenerationResponse {
25 pub created: i32,
26 pub data: Vec<ImageGenerationData>,
27}
28
29impl TryFrom<ImageGenerationResponse>
30 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
31{
32 type Error = ImageGenerationError;
33
34 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
35 let b64_json = value.data[0].b64_json.clone();
36
37 let bytes = BASE64_STANDARD
38 .decode(&b64_json)
39 .expect("Failed to decode b64");
40
41 Ok(image_generation::ImageGenerationResponse {
42 image: bytes,
43 response: value,
44 })
45 }
46}
47
48#[derive(Clone)]
49pub struct ImageGenerationModel<T = reqwest::Client> {
50 client: Client<T>,
51 pub model: String,
53}
54
55impl<T> ImageGenerationModel<T> {
56 pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
57 Self {
58 client,
59 model: model.into(),
60 }
61 }
62}
63
64impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
65where
66 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
67{
68 type Response = ImageGenerationResponse;
69
70 type Client = Client<T>;
71
72 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
73 Self::new(client.clone(), model)
74 }
75
76 async fn image_generation(
77 &self,
78 generation_request: ImageGenerationRequest,
79 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
80 {
81 let mut request = json!({
82 "model": self.model,
83 "prompt": generation_request.prompt,
84 "size": format!("{}x{}", generation_request.width, generation_request.height),
85 });
86
87 if self.model.as_str() != GPT_IMAGE_1 {
88 merge_inplace(
89 &mut request,
90 json!({
91 "response_format": "b64_json"
92 }),
93 );
94 }
95
96 let body = serde_json::to_vec(&request)?;
97
98 let request = self
99 .client
100 .post("/images/generations")?
101 .body(body)
102 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
103
104 let response = self.client.send(request).await?;
105
106 if !response.status().is_success() {
107 let status = response.status();
108 let text = http_client::text(response).await?;
109
110 return Err(ImageGenerationError::ProviderError(format!(
111 "{}: {}",
112 status, text,
113 )));
114 }
115
116 let text = http_client::text(response).await?;
117
118 match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&text)? {
119 ApiResponse::Ok(response) => response.try_into(),
120 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
121 }
122 }
123}