rig/providers/huggingface/
image_generation.rs

1use super::client::Client;
2use crate::http_client::HttpClientExt;
3use crate::image_generation;
4use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
5use serde_json::json;
6
7#[allow(non_upper_case_globals)]
8pub mod image_generation_models {
9    pub const Flux1: &str = "black-forest-labs/FLUX.1-dev";
10    pub const Kolors: &str = "Kwai-Kolors/Kolors";
11    pub const StableDiffusion3: &str = "stabilityai/stable-diffusion-3-medium-diffusers";
12}
13pub use image_generation_models::*;
14
15#[derive(Debug)]
16pub struct ImageGenerationResponse {
17    data: Vec<u8>,
18}
19
20impl TryFrom<ImageGenerationResponse>
21    for image_generation::ImageGenerationResponse<ImageGenerationResponse>
22{
23    type Error = ImageGenerationError;
24
25    fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
26        Ok(image_generation::ImageGenerationResponse {
27            image: value.data.clone(),
28            response: value,
29        })
30    }
31}
32
33#[derive(Clone)]
34pub struct ImageGenerationModel<T = reqwest::Client> {
35    client: Client<T>,
36    pub model: String,
37}
38
39impl<T> ImageGenerationModel<T> {
40    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
41        ImageGenerationModel {
42            client,
43            model: model.into(),
44        }
45    }
46}
47
48impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
49where
50    T: HttpClientExt + Send + Clone + 'static,
51{
52    type Response = ImageGenerationResponse;
53
54    type Client = Client<T>;
55
56    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
57        Self::new(client.clone(), model)
58    }
59
60    async fn image_generation(
61        &self,
62        request: ImageGenerationRequest,
63    ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
64    {
65        let request = json!({
66            "inputs": request.prompt,
67            "parameters": {
68                "width": request.width,
69                "height": request.height
70            }
71        });
72
73        let route = self
74            .client
75            .subprovider()
76            .image_generation_endpoint(&self.model)?;
77
78        let body = serde_json::to_vec(&request)?;
79
80        let req = self
81            .client
82            .post(&route)?
83            .header("Content-Type", "application/json")
84            .body(body)
85            .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
86
87        let response = self.client.send(req).await?;
88
89        if !response.status().is_success() {
90            let status = response.status();
91            let text: Vec<u8> = response.into_body().await?;
92            let text: String = String::from_utf8_lossy(&text).into();
93
94            return Err(ImageGenerationError::ProviderError(format!(
95                "{}: {}",
96                status, text
97            )));
98        }
99
100        let data: Vec<u8> = response.into_body().await?;
101
102        ImageGenerationResponse { data }.try_into()
103    }
104}