rig/providers/huggingface/
image_generation.rs1use super::client::Client;
2use crate::http_client::HttpClientExt;
3use crate::image_generation;
4use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
5use serde_json::json;
6
7#[allow(non_upper_case_globals)]
8pub mod image_generation_models {
9 pub const Flux1: &str = "black-forest-labs/FLUX.1-dev";
10 pub const Kolors: &str = "Kwai-Kolors/Kolors";
11 pub const StableDiffusion3: &str = "stabilityai/stable-diffusion-3-medium-diffusers";
12}
13pub use image_generation_models::*;
14
15#[derive(Debug)]
16pub struct ImageGenerationResponse {
17 data: Vec<u8>,
18}
19
20impl TryFrom<ImageGenerationResponse>
21 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
22{
23 type Error = ImageGenerationError;
24
25 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
26 Ok(image_generation::ImageGenerationResponse {
27 image: value.data.clone(),
28 response: value,
29 })
30 }
31}
32
33#[derive(Clone)]
34pub struct ImageGenerationModel<T = reqwest::Client> {
35 client: Client<T>,
36 pub model: String,
37}
38
39impl<T> ImageGenerationModel<T> {
40 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
41 ImageGenerationModel {
42 client,
43 model: model.into(),
44 }
45 }
46}
47
48impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
49where
50 T: HttpClientExt + Send + Clone + 'static,
51{
52 type Response = ImageGenerationResponse;
53
54 type Client = Client<T>;
55
56 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
57 Self::new(client.clone(), model)
58 }
59
60 async fn image_generation(
61 &self,
62 request: ImageGenerationRequest,
63 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
64 {
65 let request = json!({
66 "inputs": request.prompt,
67 "parameters": {
68 "width": request.width,
69 "height": request.height
70 }
71 });
72
73 let route = self
74 .client
75 .subprovider()
76 .image_generation_endpoint(&self.model)?;
77
78 let body = serde_json::to_vec(&request)?;
79
80 let req = self
81 .client
82 .post(&route)?
83 .header("Content-Type", "application/json")
84 .body(body)
85 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
86
87 let response = self.client.send(req).await?;
88
89 if !response.status().is_success() {
90 let status = response.status();
91 let text: Vec<u8> = response.into_body().await?;
92 let text: String = String::from_utf8_lossy(&text).into();
93
94 return Err(ImageGenerationError::ProviderError(format!(
95 "{}: {}",
96 status, text
97 )));
98 }
99
100 let data: Vec<u8> = response.into_body().await?;
101
102 ImageGenerationResponse { data }.try_into()
103 }
104}