rig/providers/huggingface/
image_generation.rs1use super::client::Client;
2use crate::http_client::HttpClientExt;
3use crate::image_generation;
4use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
5use serde_json::json;
6
7#[allow(non_upper_case_globals)]
8pub mod image_generation_models {
9 pub const Flux1: &str = "black-forest-labs/FLUX.1-dev";
10 pub const Kolors: &str = "Kwai-Kolors/Kolors";
11 pub const StableDiffusion3: &str = "stabilityai/stable-diffusion-3-medium-diffusers";
12}
13pub use image_generation_models::*;
14
15#[derive(Debug)]
16pub struct ImageGenerationResponse {
17 data: Vec<u8>,
18}
19
20impl TryFrom<ImageGenerationResponse>
21 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
22{
23 type Error = ImageGenerationError;
24
25 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
26 Ok(image_generation::ImageGenerationResponse {
27 image: value.data.clone(),
28 response: value,
29 })
30 }
31}
32
33#[derive(Clone)]
34pub struct ImageGenerationModel<T = reqwest::Client> {
35 client: Client<T>,
36 pub model: String,
37}
38
39impl<T> ImageGenerationModel<T> {
40 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
41 ImageGenerationModel {
42 client,
43 model: model.into(),
44 }
45 }
46}
47
48impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
49where
50 T: HttpClientExt + Send + Clone + 'static,
51{
52 type Response = ImageGenerationResponse;
53
54 type Client = Client<T>;
55
56 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
57 Self::new(client.clone(), model)
58 }
59
60 #[cfg_attr(feature = "worker", worker::send)]
61 async fn image_generation(
62 &self,
63 request: ImageGenerationRequest,
64 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
65 {
66 let request = json!({
67 "inputs": request.prompt,
68 "parameters": {
69 "width": request.width,
70 "height": request.height
71 }
72 });
73
74 let route = self
75 .client
76 .subprovider()
77 .image_generation_endpoint(&self.model)?;
78
79 let body = serde_json::to_vec(&request)?;
80
81 let req = self
82 .client
83 .post(&route)?
84 .header("Content-Type", "application/json")
85 .body(body)
86 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
87
88 let response = self.client.send(req).await?;
89
90 if !response.status().is_success() {
91 let status = response.status();
92 let text: Vec<u8> = response.into_body().await?;
93 let text: String = String::from_utf8_lossy(&text).into();
94
95 return Err(ImageGenerationError::ProviderError(format!(
96 "{}: {}",
97 status, text
98 )));
99 }
100
101 let data: Vec<u8> = response.into_body().await?;
102
103 ImageGenerationResponse { data }.try_into()
104 }
105}