#[allow(deprecated)]
use crate::client::image_generation::ImageGenerationModelHandle;
use crate::http_client;
use futures::future::BoxFuture;
use serde_json::Value;
use std::sync::Arc;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ImageGenerationError {
#[error("HttpError: {0}")]
HttpError(#[from] http_client::Error),
#[error("JsonError: {0}")]
JsonError(#[from] serde_json::Error),
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("ResponseError: {0}")]
ResponseError(String),
#[error("ProviderError: {0}")]
ProviderError(String),
}
pub trait ImageGeneration<M>
where
M: ImageGenerationModel,
{
fn image_generation(
&self,
prompt: &str,
size: &(u32, u32),
) -> impl std::future::Future<
Output = Result<ImageGenerationRequestBuilder<M>, ImageGenerationError>,
> + Send;
}
#[derive(Debug)]
pub struct ImageGenerationResponse<T> {
pub image: Vec<u8>,
pub response: T,
}
pub trait ImageGenerationModel: Clone + Send + Sync {
type Response: Send + Sync;
type Client;
fn make(client: &Self::Client, model: impl Into<String>) -> Self;
fn image_generation(
&self,
request: ImageGenerationRequest,
) -> impl std::future::Future<
Output = Result<ImageGenerationResponse<Self::Response>, ImageGenerationError>,
> + Send;
fn image_generation_request(&self) -> ImageGenerationRequestBuilder<Self> {
ImageGenerationRequestBuilder::new(self.clone())
}
}
#[allow(deprecated)]
#[deprecated(
since = "0.25.0",
note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `ImageGenerationModel` instead."
)]
pub trait ImageGenerationModelDyn: Send + Sync {
fn image_generation(
&self,
request: ImageGenerationRequest,
) -> BoxFuture<'_, Result<ImageGenerationResponse<()>, ImageGenerationError>>;
fn image_generation_request(
&self,
) -> ImageGenerationRequestBuilder<ImageGenerationModelHandle<'_>>;
}
#[allow(deprecated)]
impl<T> ImageGenerationModelDyn for T
where
T: ImageGenerationModel,
{
fn image_generation(
&self,
request: ImageGenerationRequest,
) -> BoxFuture<'_, Result<ImageGenerationResponse<()>, ImageGenerationError>> {
Box::pin(async {
let resp = self.image_generation(request).await;
resp.map(|r| ImageGenerationResponse {
image: r.image,
response: (),
})
})
}
fn image_generation_request(
&self,
) -> ImageGenerationRequestBuilder<ImageGenerationModelHandle<'_>> {
ImageGenerationRequestBuilder::new(ImageGenerationModelHandle {
inner: Arc::new(self.clone()),
})
}
}
#[non_exhaustive]
pub struct ImageGenerationRequest {
pub prompt: String,
pub width: u32,
pub height: u32,
pub additional_params: Option<Value>,
}
#[non_exhaustive]
pub struct ImageGenerationRequestBuilder<M>
where
M: ImageGenerationModel,
{
model: M,
prompt: String,
width: u32,
height: u32,
additional_params: Option<Value>,
}
impl<M> ImageGenerationRequestBuilder<M>
where
M: ImageGenerationModel,
{
pub fn new(model: M) -> Self {
Self {
model,
prompt: "".to_string(),
height: 256,
width: 256,
additional_params: None,
}
}
pub fn prompt(mut self, prompt: &str) -> Self {
self.prompt = prompt.to_string();
self
}
pub fn width(mut self, width: u32) -> Self {
self.width = width;
self
}
pub fn height(mut self, height: u32) -> Self {
self.height = height;
self
}
pub fn additional_params(mut self, params: Value) -> Self {
self.additional_params = Some(params);
self
}
pub fn build(self) -> ImageGenerationRequest {
ImageGenerationRequest {
prompt: self.prompt,
width: self.width,
height: self.height,
additional_params: self.additional_params,
}
}
pub async fn send(self) -> Result<ImageGenerationResponse<M::Response>, ImageGenerationError> {
let model = self.model.clone();
model.image_generation(self.build()).await
}
}