rig/
image_generation.rs

1//! Everything related to core image generation abstractions in Rig.
2//! Rig allows calling a number of different providers (that support image generation) using the [ImageGenerationModel] trait.
3use crate::{client::image_generation::ImageGenerationModelHandle, http_client};
4use futures::future::BoxFuture;
5use serde_json::Value;
6use std::sync::Arc;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10pub enum ImageGenerationError {
11    /// Http error (e.g.: connection error, timeout, etc.)
12    #[error("HttpError: {0}")]
13    HttpError(#[from] http_client::Error),
14
15    /// Json error (e.g.: serialization, deserialization)
16    #[error("JsonError: {0}")]
17    JsonError(#[from] serde_json::Error),
18
19    /// Error building the transcription request
20    #[error("RequestError: {0}")]
21    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
22
23    /// Error parsing the transcription response
24    #[error("ResponseError: {0}")]
25    ResponseError(String),
26
27    /// Error returned by the transcription model provider
28    #[error("ProviderError: {0}")]
29    ProviderError(String),
30}
31pub trait ImageGeneration<M>
32where
33    M: ImageGenerationModel,
34{
35    /// Generates a transcription request builder for the given `file`.
36    /// This function is meant to be called by the user to further customize the
37    /// request at transcription time before sending it.
38    ///
39    /// ❗IMPORTANT: The type that implements this trait might have already
40    /// populated fields in the builder (the exact fields depend on the type).
41    /// For fields that have already been set by the model, calling the corresponding
42    /// method on the builder will overwrite the value set by the model.
43    fn image_generation(
44        &self,
45        prompt: &str,
46        size: &(u32, u32),
47    ) -> impl std::future::Future<
48        Output = Result<ImageGenerationRequestBuilder<M>, ImageGenerationError>,
49    > + Send;
50}
51
52/// A unified response for a model image generation, returning both the image and the raw response.
53#[derive(Debug)]
54pub struct ImageGenerationResponse<T> {
55    pub image: Vec<u8>,
56    pub response: T,
57}
58
59pub trait ImageGenerationModel: Clone + Send + Sync {
60    type Response: Send + Sync;
61
62    fn image_generation(
63        &self,
64        request: ImageGenerationRequest,
65    ) -> impl std::future::Future<
66        Output = Result<ImageGenerationResponse<Self::Response>, ImageGenerationError>,
67    > + Send;
68
69    fn image_generation_request(&self) -> ImageGenerationRequestBuilder<Self> {
70        ImageGenerationRequestBuilder::new(self.clone())
71    }
72}
73
74pub trait ImageGenerationModelDyn: Send + Sync {
75    fn image_generation(
76        &self,
77        request: ImageGenerationRequest,
78    ) -> BoxFuture<'_, Result<ImageGenerationResponse<()>, ImageGenerationError>>;
79
80    fn image_generation_request(
81        &self,
82    ) -> ImageGenerationRequestBuilder<ImageGenerationModelHandle<'_>>;
83}
84
85impl<T> ImageGenerationModelDyn for T
86where
87    T: ImageGenerationModel,
88{
89    fn image_generation(
90        &self,
91        request: ImageGenerationRequest,
92    ) -> BoxFuture<'_, Result<ImageGenerationResponse<()>, ImageGenerationError>> {
93        Box::pin(async {
94            let resp = self.image_generation(request).await;
95            resp.map(|r| ImageGenerationResponse {
96                image: r.image,
97                response: (),
98            })
99        })
100    }
101
102    fn image_generation_request(
103        &self,
104    ) -> ImageGenerationRequestBuilder<ImageGenerationModelHandle<'_>> {
105        ImageGenerationRequestBuilder::new(ImageGenerationModelHandle {
106            inner: Arc::new(self.clone()),
107        })
108    }
109}
110
111/// An image generation request.
112#[non_exhaustive]
113pub struct ImageGenerationRequest {
114    pub prompt: String,
115    pub width: u32,
116    pub height: u32,
117    pub additional_params: Option<Value>,
118}
119
120/// A builder for `ImageGenerationRequest`.
121/// Can be sent to a model provider.
122#[non_exhaustive]
123pub struct ImageGenerationRequestBuilder<M>
124where
125    M: ImageGenerationModel,
126{
127    model: M,
128    prompt: String,
129    width: u32,
130    height: u32,
131    additional_params: Option<Value>,
132}
133
134impl<M> ImageGenerationRequestBuilder<M>
135where
136    M: ImageGenerationModel,
137{
138    pub fn new(model: M) -> Self {
139        Self {
140            model,
141            prompt: "".to_string(),
142            height: 256,
143            width: 256,
144            additional_params: None,
145        }
146    }
147
148    /// Sets the prompt for the image generation request
149    pub fn prompt(mut self, prompt: &str) -> Self {
150        self.prompt = prompt.to_string();
151        self
152    }
153
154    /// The width of the generated image
155    pub fn width(mut self, width: u32) -> Self {
156        self.width = width;
157        self
158    }
159
160    /// The height of the generated image
161    pub fn height(mut self, height: u32) -> Self {
162        self.height = height;
163        self
164    }
165
166    /// Adds additional parameters to the image generation request.
167    pub fn additional_params(mut self, params: Value) -> Self {
168        self.additional_params = Some(params);
169        self
170    }
171
172    pub fn build(self) -> ImageGenerationRequest {
173        ImageGenerationRequest {
174            prompt: self.prompt,
175            width: self.width,
176            height: self.height,
177            additional_params: self.additional_params,
178        }
179    }
180
181    pub async fn send(self) -> Result<ImageGenerationResponse<M::Response>, ImageGenerationError> {
182        let model = self.model.clone();
183
184        model.image_generation(self.build()).await
185    }
186}