1use crate::{client::image_generation::ImageGenerationModelHandle, http_client};
4use futures::future::BoxFuture;
5use serde_json::Value;
6use std::sync::Arc;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10pub enum ImageGenerationError {
11 #[error("HttpError: {0}")]
13 HttpError(#[from] http_client::Error),
14
15 #[error("JsonError: {0}")]
17 JsonError(#[from] serde_json::Error),
18
19 #[error("RequestError: {0}")]
21 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
22
23 #[error("ResponseError: {0}")]
25 ResponseError(String),
26
27 #[error("ProviderError: {0}")]
29 ProviderError(String),
30}
31pub trait ImageGeneration<M>
32where
33 M: ImageGenerationModel,
34{
35 fn image_generation(
44 &self,
45 prompt: &str,
46 size: &(u32, u32),
47 ) -> impl std::future::Future<
48 Output = Result<ImageGenerationRequestBuilder<M>, ImageGenerationError>,
49 > + Send;
50}
51
52#[derive(Debug)]
54pub struct ImageGenerationResponse<T> {
55 pub image: Vec<u8>,
56 pub response: T,
57}
58
59pub trait ImageGenerationModel: Clone + Send + Sync {
60 type Response: Send + Sync;
61
62 fn image_generation(
63 &self,
64 request: ImageGenerationRequest,
65 ) -> impl std::future::Future<
66 Output = Result<ImageGenerationResponse<Self::Response>, ImageGenerationError>,
67 > + Send;
68
69 fn image_generation_request(&self) -> ImageGenerationRequestBuilder<Self> {
70 ImageGenerationRequestBuilder::new(self.clone())
71 }
72}
73
74pub trait ImageGenerationModelDyn: Send + Sync {
75 fn image_generation(
76 &self,
77 request: ImageGenerationRequest,
78 ) -> BoxFuture<'_, Result<ImageGenerationResponse<()>, ImageGenerationError>>;
79
80 fn image_generation_request(
81 &self,
82 ) -> ImageGenerationRequestBuilder<ImageGenerationModelHandle<'_>>;
83}
84
85impl<T> ImageGenerationModelDyn for T
86where
87 T: ImageGenerationModel,
88{
89 fn image_generation(
90 &self,
91 request: ImageGenerationRequest,
92 ) -> BoxFuture<'_, Result<ImageGenerationResponse<()>, ImageGenerationError>> {
93 Box::pin(async {
94 let resp = self.image_generation(request).await;
95 resp.map(|r| ImageGenerationResponse {
96 image: r.image,
97 response: (),
98 })
99 })
100 }
101
102 fn image_generation_request(
103 &self,
104 ) -> ImageGenerationRequestBuilder<ImageGenerationModelHandle<'_>> {
105 ImageGenerationRequestBuilder::new(ImageGenerationModelHandle {
106 inner: Arc::new(self.clone()),
107 })
108 }
109}
110
111#[non_exhaustive]
113pub struct ImageGenerationRequest {
114 pub prompt: String,
115 pub width: u32,
116 pub height: u32,
117 pub additional_params: Option<Value>,
118}
119
120#[non_exhaustive]
123pub struct ImageGenerationRequestBuilder<M>
124where
125 M: ImageGenerationModel,
126{
127 model: M,
128 prompt: String,
129 width: u32,
130 height: u32,
131 additional_params: Option<Value>,
132}
133
134impl<M> ImageGenerationRequestBuilder<M>
135where
136 M: ImageGenerationModel,
137{
138 pub fn new(model: M) -> Self {
139 Self {
140 model,
141 prompt: "".to_string(),
142 height: 256,
143 width: 256,
144 additional_params: None,
145 }
146 }
147
148 pub fn prompt(mut self, prompt: &str) -> Self {
150 self.prompt = prompt.to_string();
151 self
152 }
153
154 pub fn width(mut self, width: u32) -> Self {
156 self.width = width;
157 self
158 }
159
160 pub fn height(mut self, height: u32) -> Self {
162 self.height = height;
163 self
164 }
165
166 pub fn additional_params(mut self, params: Value) -> Self {
168 self.additional_params = Some(params);
169 self
170 }
171
172 pub fn build(self) -> ImageGenerationRequest {
173 ImageGenerationRequest {
174 prompt: self.prompt,
175 width: self.width,
176 height: self.height,
177 additional_params: self.additional_params,
178 }
179 }
180
181 pub async fn send(self) -> Result<ImageGenerationResponse<M::Response>, ImageGenerationError> {
182 let model = self.model.clone();
183
184 model.image_generation(self.build()).await
185 }
186}