rig/providers/xai/
image_generation.rs1use super::api::ApiResponse;
2use super::client::Client;
3use crate::http_client::HttpClientExt;
4use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
5use crate::json_utils::merge_inplace;
6use crate::{http_client, image_generation};
7use base64::Engine;
8use base64::prelude::BASE64_STANDARD;
9use serde::Deserialize;
10use serde_json::json;
11
12pub const GROK_IMAGINE_IMAGE: &str = "grok-imagine-image";
16pub const GROK_IMAGINE_IMAGE_PRO: &str = "grok-imagine-image-pro";
17
18#[derive(Debug, Deserialize)]
19pub struct ImageGenerationData {
20 pub b64_json: String,
21}
22
23#[derive(Debug, Deserialize)]
24pub struct ImageGenerationResponse {
25 pub data: Vec<ImageGenerationData>,
26}
27
28impl TryFrom<ImageGenerationResponse>
29 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
30{
31 type Error = ImageGenerationError;
32
33 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
34 let first = value
35 .data
36 .first()
37 .ok_or_else(|| ImageGenerationError::ResponseError("No image data returned".into()))?;
38
39 let bytes = BASE64_STANDARD.decode(&first.b64_json).map_err(|e| {
40 ImageGenerationError::ResponseError(format!("Base64 decode error: {e}"))
41 })?;
42
43 Ok(image_generation::ImageGenerationResponse {
44 image: bytes,
45 response: value,
46 })
47 }
48}
49
50#[derive(Clone)]
51pub struct ImageGenerationModel<T = reqwest::Client> {
52 client: Client<T>,
53 pub model: String,
55}
56
57impl<T> ImageGenerationModel<T> {
58 pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
59 Self {
60 client,
61 model: model.into(),
62 }
63 }
64}
65
66impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
67where
68 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
69{
70 type Response = ImageGenerationResponse;
71
72 type Client = Client<T>;
73
74 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
75 Self::new(client.clone(), model)
76 }
77
78 async fn image_generation(
79 &self,
80 generation_request: ImageGenerationRequest,
81 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
82 {
83 let mut request = json!({
84 "model": self.model,
85 "prompt": generation_request.prompt,
86 "response_format": "b64_json",
87 "aspect_ratio": "1:1",
88 });
89
90 if let Some(additional_params) = generation_request.additional_params {
91 merge_inplace(&mut request, additional_params);
92 }
93
94 let body = serde_json::to_vec(&request)?;
95
96 let request = self
97 .client
98 .post("/v1/images/generations")?
99 .body(body)
100 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
101
102 let response = self.client.send(request).await?;
103
104 if !response.status().is_success() {
105 let status = response.status();
106 let text = http_client::text(response).await?;
107
108 return Err(ImageGenerationError::ProviderError(format!(
109 "{}: {}",
110 status, text,
111 )));
112 }
113
114 let text = http_client::text(response).await?;
115
116 match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&text)? {
117 ApiResponse::Ok(response) => response.try_into(),
118 ApiResponse::Error(err) => Err(ImageGenerationError::ProviderError(err.message())),
119 }
120 }
121}