rig/providers/huggingface/
image_generation.rs1use super::Client;
2use crate::http_client::HttpClientExt;
3use crate::image_generation;
4use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
5use serde_json::json;
6
7pub const FLUX_1: &str = "black-forest-labs/FLUX.1-dev";
8pub const KOLORS: &str = "Kwai-Kolors/Kolors";
9pub const STABLE_DIFFUSION_3: &str = "stabilityai/stable-diffusion-3-medium-diffusers";
10
11#[derive(Debug)]
12pub struct ImageGenerationResponse {
13 data: Vec<u8>,
14}
15
16impl TryFrom<ImageGenerationResponse>
17 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
18{
19 type Error = ImageGenerationError;
20
21 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
22 Ok(image_generation::ImageGenerationResponse {
23 image: value.data.clone(),
24 response: value,
25 })
26 }
27}
28
29#[derive(Clone)]
30pub struct ImageGenerationModel<T = reqwest::Client> {
31 client: Client<T>,
32 pub model: String,
33}
34
35impl<T> ImageGenerationModel<T> {
36 pub fn new(client: Client<T>, model: &str) -> Self {
37 ImageGenerationModel {
38 client,
39 model: model.to_string(),
40 }
41 }
42}
43
44impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
45where
46 T: HttpClientExt + Send + Clone + 'static,
47{
48 type Response = ImageGenerationResponse;
49
50 #[cfg_attr(feature = "worker", worker::send)]
51 async fn image_generation(
52 &self,
53 request: ImageGenerationRequest,
54 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
55 {
56 let request = json!({
57 "inputs": request.prompt,
58 "parameters": {
59 "width": request.width,
60 "height": request.height
61 }
62 });
63
64 let route = self
65 .client
66 .sub_provider
67 .image_generation_endpoint(&self.model)?;
68
69 let body = serde_json::to_vec(&request)?;
70
71 let req = self
72 .client
73 .post(&route)?
74 .header("Content-Type", "application/json")
75 .body(body)
76 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
77
78 let response = self.client.send(req).await?;
79
80 if !response.status().is_success() {
81 let status = response.status();
82 let text: Vec<u8> = response.into_body().await?;
83 let text: String = String::from_utf8_lossy(&text).into();
84
85 return Err(ImageGenerationError::ProviderError(format!(
86 "{}: {}",
87 status, text
88 )));
89 }
90
91 let data: Vec<u8> = response.into_body().await?;
92
93 ImageGenerationResponse { data }.try_into()
94 }
95}