1#[allow(deprecated)]
4use crate::client::image_generation::ImageGenerationModelHandle;
5use crate::http_client;
6use futures::future::BoxFuture;
7use serde_json::Value;
8use std::sync::Arc;
9use thiserror::Error;
10
11#[derive(Debug, Error)]
12pub enum ImageGenerationError {
13 #[error("HttpError: {0}")]
15 HttpError(#[from] http_client::Error),
16
17 #[error("JsonError: {0}")]
19 JsonError(#[from] serde_json::Error),
20
21 #[error("RequestError: {0}")]
23 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
24
25 #[error("ResponseError: {0}")]
27 ResponseError(String),
28
29 #[error("ProviderError: {0}")]
31 ProviderError(String),
32}
33pub trait ImageGeneration<M>
34where
35 M: ImageGenerationModel,
36{
37 fn image_generation(
46 &self,
47 prompt: &str,
48 size: &(u32, u32),
49 ) -> impl std::future::Future<
50 Output = Result<ImageGenerationRequestBuilder<M>, ImageGenerationError>,
51 > + Send;
52}
53
54#[derive(Debug)]
56pub struct ImageGenerationResponse<T> {
57 pub image: Vec<u8>,
58 pub response: T,
59}
60
61pub trait ImageGenerationModel: Clone + Send + Sync {
62 type Response: Send + Sync;
63
64 type Client;
65
66 fn make(client: &Self::Client, model: impl Into<String>) -> Self;
67
68 fn image_generation(
69 &self,
70 request: ImageGenerationRequest,
71 ) -> impl std::future::Future<
72 Output = Result<ImageGenerationResponse<Self::Response>, ImageGenerationError>,
73 > + Send;
74
75 fn image_generation_request(&self) -> ImageGenerationRequestBuilder<Self> {
76 ImageGenerationRequestBuilder::new(self.clone())
77 }
78}
79
80#[allow(deprecated)]
81#[deprecated(
82 since = "0.25.0",
83 note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `ImageGenerationModel` instead."
84)]
85pub trait ImageGenerationModelDyn: Send + Sync {
86 fn image_generation(
87 &self,
88 request: ImageGenerationRequest,
89 ) -> BoxFuture<'_, Result<ImageGenerationResponse<()>, ImageGenerationError>>;
90
91 fn image_generation_request(
92 &self,
93 ) -> ImageGenerationRequestBuilder<ImageGenerationModelHandle<'_>>;
94}
95
96#[allow(deprecated)]
97impl<T> ImageGenerationModelDyn for T
98where
99 T: ImageGenerationModel,
100{
101 fn image_generation(
102 &self,
103 request: ImageGenerationRequest,
104 ) -> BoxFuture<'_, Result<ImageGenerationResponse<()>, ImageGenerationError>> {
105 Box::pin(async {
106 let resp = self.image_generation(request).await;
107 resp.map(|r| ImageGenerationResponse {
108 image: r.image,
109 response: (),
110 })
111 })
112 }
113
114 fn image_generation_request(
115 &self,
116 ) -> ImageGenerationRequestBuilder<ImageGenerationModelHandle<'_>> {
117 ImageGenerationRequestBuilder::new(ImageGenerationModelHandle {
118 inner: Arc::new(self.clone()),
119 })
120 }
121}
122
123#[non_exhaustive]
125pub struct ImageGenerationRequest {
126 pub prompt: String,
127 pub width: u32,
128 pub height: u32,
129 pub additional_params: Option<Value>,
130}
131
132#[non_exhaustive]
135pub struct ImageGenerationRequestBuilder<M>
136where
137 M: ImageGenerationModel,
138{
139 model: M,
140 prompt: String,
141 width: u32,
142 height: u32,
143 additional_params: Option<Value>,
144}
145
146impl<M> ImageGenerationRequestBuilder<M>
147where
148 M: ImageGenerationModel,
149{
150 pub fn new(model: M) -> Self {
151 Self {
152 model,
153 prompt: "".to_string(),
154 height: 256,
155 width: 256,
156 additional_params: None,
157 }
158 }
159
160 pub fn prompt(mut self, prompt: &str) -> Self {
162 self.prompt = prompt.to_string();
163 self
164 }
165
166 pub fn width(mut self, width: u32) -> Self {
168 self.width = width;
169 self
170 }
171
172 pub fn height(mut self, height: u32) -> Self {
174 self.height = height;
175 self
176 }
177
178 pub fn additional_params(mut self, params: Value) -> Self {
180 self.additional_params = Some(params);
181 self
182 }
183
184 pub fn build(self) -> ImageGenerationRequest {
185 ImageGenerationRequest {
186 prompt: self.prompt,
187 width: self.width,
188 height: self.height,
189 additional_params: self.additional_params,
190 }
191 }
192
193 pub async fn send(self) -> Result<ImageGenerationResponse<M::Response>, ImageGenerationError> {
194 let model = self.model.clone();
195
196 model.image_generation(self.build()).await
197 }
198}