rig/providers/openai/
image_generation.rs1use super::{Client, client::ApiResponse};
2use crate::http_client::HttpClientExt;
3use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
4use crate::json_utils::merge_inplace;
5use crate::{http_client, image_generation};
6use base64::Engine;
7use base64::prelude::BASE64_STANDARD;
8use serde::Deserialize;
9use serde_json::json;
10
11pub const DALL_E_2: &str = "dall-e-2";
15pub const DALL_E_3: &str = "dall-e-3";
16pub const GPT_IMAGE_1: &str = "gpt-image-1";
17
18#[derive(Debug, Deserialize)]
19pub struct ImageGenerationData {
20 pub b64_json: String,
21}
22
23#[derive(Debug, Deserialize)]
24pub struct ImageGenerationResponse {
25 pub created: i32,
26 pub data: Vec<ImageGenerationData>,
27}
28
29impl TryFrom<ImageGenerationResponse>
30 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
31{
32 type Error = ImageGenerationError;
33
34 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
35 let b64_json = value.data[0].b64_json.clone();
36
37 let bytes = BASE64_STANDARD
38 .decode(&b64_json)
39 .expect("Failed to decode b64");
40
41 Ok(image_generation::ImageGenerationResponse {
42 image: bytes,
43 response: value,
44 })
45 }
46}
47
48#[derive(Clone)]
49pub struct ImageGenerationModel<T = reqwest::Client> {
50 client: Client<T>,
51 pub model: String,
53}
54
55impl<T> ImageGenerationModel<T> {
56 pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
57 Self {
58 client,
59 model: model.into(),
60 }
61 }
62}
63
64impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
65where
66 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
67{
68 type Response = ImageGenerationResponse;
69
70 type Client = Client<T>;
71
72 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
73 Self::new(client.clone(), model)
74 }
75
76 #[cfg_attr(feature = "worker", worker::send)]
77 async fn image_generation(
78 &self,
79 generation_request: ImageGenerationRequest,
80 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
81 {
82 let mut request = json!({
83 "model": self.model,
84 "prompt": generation_request.prompt,
85 "size": format!("{}x{}", generation_request.width, generation_request.height),
86 });
87
88 if self.model.as_str() != GPT_IMAGE_1 {
89 merge_inplace(
90 &mut request,
91 json!({
92 "response_format": "b64_json"
93 }),
94 );
95 }
96
97 let body = serde_json::to_vec(&request)?;
98
99 let request = self
100 .client
101 .post("/images/generations")?
102 .body(body)
103 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
104
105 let response = self.client.send(request).await?;
106
107 if !response.status().is_success() {
108 let status = response.status();
109 let text = http_client::text(response).await?;
110
111 return Err(ImageGenerationError::ProviderError(format!(
112 "{}: {}",
113 status, text,
114 )));
115 }
116
117 let text = http_client::text(response).await?;
118
119 match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&text)? {
120 ApiResponse::Ok(response) => response.try_into(),
121 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
122 }
123 }
124}