1use crate::http_client;
4use crate::markers::{Missing, Provided};
5use serde_json::Value;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum ImageGenerationError {
10 #[error("HttpError: {0}")]
12 HttpError(#[from] http_client::Error),
13
14 #[error("JsonError: {0}")]
16 JsonError(#[from] serde_json::Error),
17
18 #[error("RequestError: {0}")]
20 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
21
22 #[error("ResponseError: {0}")]
24 ResponseError(String),
25
26 #[error("ProviderError: {0}")]
28 ProviderError(String),
29}
30pub trait ImageGeneration<M>
31where
32 M: ImageGenerationModel,
33{
34 fn image_generation(
43 &self,
44 prompt: &str,
45 size: &(u32, u32),
46 ) -> impl std::future::Future<
47 Output = Result<ImageGenerationRequestBuilder<M, Provided<String>>, ImageGenerationError>,
48 > + Send;
49}
50
51#[derive(Debug)]
53pub struct ImageGenerationResponse<T> {
54 pub image: Vec<u8>,
55 pub response: T,
56}
57
58pub trait ImageGenerationModel: Clone + Send + Sync {
59 type Response: Send + Sync;
60
61 type Client;
62
63 fn make(client: &Self::Client, model: impl Into<String>) -> Self;
64
65 fn image_generation(
66 &self,
67 request: ImageGenerationRequest,
68 ) -> impl std::future::Future<
69 Output = Result<ImageGenerationResponse<Self::Response>, ImageGenerationError>,
70 > + Send;
71
72 fn image_generation_request(&self) -> ImageGenerationRequestBuilder<Self, Missing> {
73 ImageGenerationRequestBuilder::new(self.clone())
74 }
75}
76#[non_exhaustive]
78pub struct ImageGenerationRequest {
79 pub prompt: String,
80 pub width: u32,
81 pub height: u32,
82 pub additional_params: Option<Value>,
83}
84
85#[non_exhaustive]
88pub struct ImageGenerationRequestBuilder<M, P = Missing>
89where
90 M: ImageGenerationModel,
91{
92 model: M,
93 prompt: P,
94 width: u32,
95 height: u32,
96 additional_params: Option<Value>,
97}
98
99impl<M> ImageGenerationRequestBuilder<M, Missing>
100where
101 M: ImageGenerationModel,
102{
103 pub fn new(model: M) -> Self {
104 Self {
105 model,
106 prompt: Missing,
107 height: 256,
108 width: 256,
109 additional_params: None,
110 }
111 }
112}
113
114impl<M, P> ImageGenerationRequestBuilder<M, P>
115where
116 M: ImageGenerationModel,
117{
118 pub fn prompt(self, prompt: &str) -> ImageGenerationRequestBuilder<M, Provided<String>> {
120 ImageGenerationRequestBuilder {
121 model: self.model,
122 prompt: Provided(prompt.to_string()),
123 width: self.width,
124 height: self.height,
125 additional_params: self.additional_params,
126 }
127 }
128
129 pub fn width(mut self, width: u32) -> Self {
131 self.width = width;
132 self
133 }
134
135 pub fn height(mut self, height: u32) -> Self {
137 self.height = height;
138 self
139 }
140
141 pub fn additional_params(mut self, params: Value) -> Self {
143 self.additional_params = Some(params);
144 self
145 }
146}
147
148impl<M> ImageGenerationRequestBuilder<M, Provided<String>>
149where
150 M: ImageGenerationModel,
151{
152 pub fn build(self) -> ImageGenerationRequest {
153 ImageGenerationRequest {
154 prompt: self.prompt.0,
155 width: self.width,
156 height: self.height,
157 additional_params: self.additional_params,
158 }
159 }
160
161 pub async fn send(self) -> Result<ImageGenerationResponse<M::Response>, ImageGenerationError> {
162 let model = self.model.clone();
163
164 model.image_generation(self.build()).await
165 }
166}