rig/providers/openai/
image_generation.rs1use super::{Client, client::ApiResponse};
2use crate::http_client::HttpClientExt;
3use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
4use crate::json_utils::merge_inplace;
5use crate::{http_client, image_generation};
6use base64::Engine;
7use base64::prelude::BASE64_STANDARD;
8use serde::Deserialize;
9use serde_json::json;
10
11pub const DALL_E_2: &str = "dall-e-2";
15pub const DALL_E_3: &str = "dall-e-3";
16pub const GPT_IMAGE_1: &str = "gpt-image-1";
17pub const GPT_IMAGE_1_5: &str = "gpt-image-1.5";
18
19#[derive(Debug, Deserialize)]
20pub struct ImageGenerationData {
21 pub b64_json: String,
22}
23
24#[derive(Debug, Deserialize)]
25pub struct ImageGenerationResponse {
26 pub created: i32,
27 pub data: Vec<ImageGenerationData>,
28}
29
30impl TryFrom<ImageGenerationResponse>
31 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
32{
33 type Error = ImageGenerationError;
34
35 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
36 let b64_json = value.data[0].b64_json.clone();
37
38 let bytes = BASE64_STANDARD
39 .decode(&b64_json)
40 .expect("Failed to decode b64");
41
42 Ok(image_generation::ImageGenerationResponse {
43 image: bytes,
44 response: value,
45 })
46 }
47}
48
49#[derive(Clone)]
50pub struct ImageGenerationModel<T = reqwest::Client> {
51 client: Client<T>,
52 pub model: String,
54}
55
56impl<T> ImageGenerationModel<T> {
57 pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
58 Self {
59 client,
60 model: model.into(),
61 }
62 }
63}
64
65impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
66where
67 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
68{
69 type Response = ImageGenerationResponse;
70
71 type Client = Client<T>;
72
73 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
74 Self::new(client.clone(), model)
75 }
76
77 async fn image_generation(
78 &self,
79 generation_request: ImageGenerationRequest,
80 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
81 {
82 let mut request = json!({
83 "model": self.model,
84 "prompt": generation_request.prompt,
85 "size": format!("{}x{}", generation_request.width, generation_request.height),
86 });
87
88 if self.model.as_str() != GPT_IMAGE_1 && self.model.as_str() != GPT_IMAGE_1_5 {
89 merge_inplace(
90 &mut request,
91 json!({
92 "response_format": "b64_json"
93 }),
94 );
95 }
96
97 let body = serde_json::to_vec(&request)?;
98
99 let request = self
100 .client
101 .post("/images/generations")?
102 .body(body)
103 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
104
105 let response = self.client.send(request).await?;
106
107 if !response.status().is_success() {
108 let status = response.status();
109 let text = http_client::text(response).await?;
110
111 return Err(ImageGenerationError::ProviderError(format!(
112 "{}: {}",
113 status, text,
114 )));
115 }
116
117 let text = http_client::text(response).await?;
118
119 match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&text)? {
120 ApiResponse::Ok(response) => response.try_into(),
121 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
122 }
123 }
124}