Skip to main content

rig/providers/xai/
image_generation.rs

1use super::api::ApiResponse;
2use super::client::Client;
3use crate::http_client::HttpClientExt;
4use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
5use crate::json_utils::merge_inplace;
6use crate::{http_client, image_generation};
7use base64::Engine;
8use base64::prelude::BASE64_STANDARD;
9use serde::Deserialize;
10use serde_json::json;
11
12// ================================================================
13// xAI Image Generation API
14// ================================================================
15pub const GROK_IMAGINE_IMAGE: &str = "grok-imagine-image";
16pub const GROK_IMAGINE_IMAGE_PRO: &str = "grok-imagine-image-pro";
17
18#[derive(Debug, Deserialize)]
19pub struct ImageGenerationData {
20    pub b64_json: String,
21}
22
23#[derive(Debug, Deserialize)]
24pub struct ImageGenerationResponse {
25    pub data: Vec<ImageGenerationData>,
26}
27
28impl TryFrom<ImageGenerationResponse>
29    for image_generation::ImageGenerationResponse<ImageGenerationResponse>
30{
31    type Error = ImageGenerationError;
32
33    fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
34        let first = value
35            .data
36            .first()
37            .ok_or_else(|| ImageGenerationError::ResponseError("No image data returned".into()))?;
38
39        let bytes = BASE64_STANDARD.decode(&first.b64_json).map_err(|e| {
40            ImageGenerationError::ResponseError(format!("Base64 decode error: {e}"))
41        })?;
42
43        Ok(image_generation::ImageGenerationResponse {
44            image: bytes,
45            response: value,
46        })
47    }
48}
49
50#[derive(Clone)]
51pub struct ImageGenerationModel<T = reqwest::Client> {
52    client: Client<T>,
53    /// Name of the model (e.g.: grok-imagine-image)
54    pub model: String,
55}
56
57impl<T> ImageGenerationModel<T> {
58    pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
59        Self {
60            client,
61            model: model.into(),
62        }
63    }
64}
65
66impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
67where
68    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
69{
70    type Response = ImageGenerationResponse;
71
72    type Client = Client<T>;
73
74    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
75        Self::new(client.clone(), model)
76    }
77
78    async fn image_generation(
79        &self,
80        generation_request: ImageGenerationRequest,
81    ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
82    {
83        let mut request = json!({
84            "model": self.model,
85            "prompt": generation_request.prompt,
86            "response_format": "b64_json",
87            "aspect_ratio": "1:1",
88        });
89
90        if let Some(additional_params) = generation_request.additional_params {
91            merge_inplace(&mut request, additional_params);
92        }
93
94        let body = serde_json::to_vec(&request)?;
95
96        let request = self
97            .client
98            .post("/v1/images/generations")?
99            .body(body)
100            .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
101
102        let response = self.client.send(request).await?;
103
104        if !response.status().is_success() {
105            let status = response.status();
106            let text = http_client::text(response).await?;
107
108            return Err(ImageGenerationError::ProviderError(format!(
109                "{}: {}",
110                status, text,
111            )));
112        }
113
114        let text = http_client::text(response).await?;
115
116        match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&text)? {
117            ApiResponse::Ok(response) => response.try_into(),
118            ApiResponse::Error(err) => Err(ImageGenerationError::ProviderError(err.message())),
119        }
120    }
121}