rig/providers/huggingface/
image_generation.rs

1use super::client::Client;
2use crate::http_client::HttpClientExt;
3use crate::image_generation;
4use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
5use serde_json::json;
6
7#[allow(non_upper_case_globals)]
8pub mod image_generation_models {
9    pub const Flux1: &str = "black-forest-labs/FLUX.1-dev";
10    pub const Kolors: &str = "Kwai-Kolors/Kolors";
11    pub const StableDiffusion3: &str = "stabilityai/stable-diffusion-3-medium-diffusers";
12}
13pub use image_generation_models::*;
14
15#[derive(Debug)]
16pub struct ImageGenerationResponse {
17    data: Vec<u8>,
18}
19
20impl TryFrom<ImageGenerationResponse>
21    for image_generation::ImageGenerationResponse<ImageGenerationResponse>
22{
23    type Error = ImageGenerationError;
24
25    fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
26        Ok(image_generation::ImageGenerationResponse {
27            image: value.data.clone(),
28            response: value,
29        })
30    }
31}
32
33#[derive(Clone)]
34pub struct ImageGenerationModel<T = reqwest::Client> {
35    client: Client<T>,
36    pub model: String,
37}
38
39impl<T> ImageGenerationModel<T> {
40    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
41        ImageGenerationModel {
42            client,
43            model: model.into(),
44        }
45    }
46}
47
48impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
49where
50    T: HttpClientExt + Send + Clone + 'static,
51{
52    type Response = ImageGenerationResponse;
53
54    type Client = Client<T>;
55
56    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
57        Self::new(client.clone(), model)
58    }
59
60    #[cfg_attr(feature = "worker", worker::send)]
61    async fn image_generation(
62        &self,
63        request: ImageGenerationRequest,
64    ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
65    {
66        let request = json!({
67            "inputs": request.prompt,
68            "parameters": {
69                "width": request.width,
70                "height": request.height
71            }
72        });
73
74        let route = self
75            .client
76            .subprovider()
77            .image_generation_endpoint(&self.model)?;
78
79        let body = serde_json::to_vec(&request)?;
80
81        let req = self
82            .client
83            .post(&route)?
84            .header("Content-Type", "application/json")
85            .body(body)
86            .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
87
88        let response = self.client.send(req).await?;
89
90        if !response.status().is_success() {
91            let status = response.status();
92            let text: Vec<u8> = response.into_body().await?;
93            let text: String = String::from_utf8_lossy(&text).into();
94
95            return Err(ImageGenerationError::ProviderError(format!(
96                "{}: {}",
97                status, text
98            )));
99        }
100
101        let data: Vec<u8> = response.into_body().await?;
102
103        ImageGenerationResponse { data }.try_into()
104    }
105}