rig/providers/huggingface/
image_generation.rs

1use super::Client;
2use crate::http_client::HttpClientExt;
3use crate::image_generation;
4use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
5use serde_json::json;
6
7pub const FLUX_1: &str = "black-forest-labs/FLUX.1-dev";
8pub const KOLORS: &str = "Kwai-Kolors/Kolors";
9pub const STABLE_DIFFUSION_3: &str = "stabilityai/stable-diffusion-3-medium-diffusers";
10
11#[derive(Debug)]
12pub struct ImageGenerationResponse {
13    data: Vec<u8>,
14}
15
16impl TryFrom<ImageGenerationResponse>
17    for image_generation::ImageGenerationResponse<ImageGenerationResponse>
18{
19    type Error = ImageGenerationError;
20
21    fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
22        Ok(image_generation::ImageGenerationResponse {
23            image: value.data.clone(),
24            response: value,
25        })
26    }
27}
28
29#[derive(Clone)]
30pub struct ImageGenerationModel<T = reqwest::Client> {
31    client: Client<T>,
32    pub model: String,
33}
34
35impl<T> ImageGenerationModel<T> {
36    pub fn new(client: Client<T>, model: &str) -> Self {
37        ImageGenerationModel {
38            client,
39            model: model.to_string(),
40        }
41    }
42}
43
44impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
45where
46    T: HttpClientExt + Send + Clone + 'static,
47{
48    type Response = ImageGenerationResponse;
49
50    #[cfg_attr(feature = "worker", worker::send)]
51    async fn image_generation(
52        &self,
53        request: ImageGenerationRequest,
54    ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
55    {
56        let request = json!({
57            "inputs": request.prompt,
58            "parameters": {
59                "width": request.width,
60                "height": request.height
61            }
62        });
63
64        let route = self
65            .client
66            .sub_provider
67            .image_generation_endpoint(&self.model)?;
68
69        let body = serde_json::to_vec(&request)?;
70
71        let req = self
72            .client
73            .post(&route)?
74            .header("Content-Type", "application/json")
75            .body(body)
76            .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
77
78        let response = self.client.send(req).await?;
79
80        if !response.status().is_success() {
81            let status = response.status();
82            let text: Vec<u8> = response.into_body().await?;
83            let text: String = String::from_utf8_lossy(&text).into();
84
85            return Err(ImageGenerationError::ProviderError(format!(
86                "{}: {}",
87                status, text
88            )));
89        }
90
91        let data: Vec<u8> = response.into_body().await?;
92
93        ImageGenerationResponse { data }.try_into()
94    }
95}