rig/
image_generation.rs

1//! Everything related to core image generation abstractions in Rig.
2//! Rig allows calling a number of different providers (that support image generation) using the [ImageGenerationModel] trait.
3#[allow(deprecated)]
4use crate::client::image_generation::ImageGenerationModelHandle;
5use crate::http_client;
6use futures::future::BoxFuture;
7use serde_json::Value;
8use std::sync::Arc;
9use thiserror::Error;
10
11#[derive(Debug, Error)]
12pub enum ImageGenerationError {
13    /// Http error (e.g.: connection error, timeout, etc.)
14    #[error("HttpError: {0}")]
15    HttpError(#[from] http_client::Error),
16
17    /// Json error (e.g.: serialization, deserialization)
18    #[error("JsonError: {0}")]
19    JsonError(#[from] serde_json::Error),
20
21    /// Error building the transcription request
22    #[error("RequestError: {0}")]
23    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
24
25    /// Error parsing the transcription response
26    #[error("ResponseError: {0}")]
27    ResponseError(String),
28
29    /// Error returned by the transcription model provider
30    #[error("ProviderError: {0}")]
31    ProviderError(String),
32}
33pub trait ImageGeneration<M>
34where
35    M: ImageGenerationModel,
36{
37    /// Generates a transcription request builder for the given `file`.
38    /// This function is meant to be called by the user to further customize the
39    /// request at transcription time before sending it.
40    ///
41    /// ❗IMPORTANT: The type that implements this trait might have already
42    /// populated fields in the builder (the exact fields depend on the type).
43    /// For fields that have already been set by the model, calling the corresponding
44    /// method on the builder will overwrite the value set by the model.
45    fn image_generation(
46        &self,
47        prompt: &str,
48        size: &(u32, u32),
49    ) -> impl std::future::Future<
50        Output = Result<ImageGenerationRequestBuilder<M>, ImageGenerationError>,
51    > + Send;
52}
53
54/// A unified response for a model image generation, returning both the image and the raw response.
55#[derive(Debug)]
56pub struct ImageGenerationResponse<T> {
57    pub image: Vec<u8>,
58    pub response: T,
59}
60
61pub trait ImageGenerationModel: Clone + Send + Sync {
62    type Response: Send + Sync;
63
64    type Client;
65
66    fn make(client: &Self::Client, model: impl Into<String>) -> Self;
67
68    fn image_generation(
69        &self,
70        request: ImageGenerationRequest,
71    ) -> impl std::future::Future<
72        Output = Result<ImageGenerationResponse<Self::Response>, ImageGenerationError>,
73    > + Send;
74
75    fn image_generation_request(&self) -> ImageGenerationRequestBuilder<Self> {
76        ImageGenerationRequestBuilder::new(self.clone())
77    }
78}
79
80#[allow(deprecated)]
81#[deprecated(
82    since = "0.25.0",
83    note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `ImageGenerationModel` instead."
84)]
85pub trait ImageGenerationModelDyn: Send + Sync {
86    fn image_generation(
87        &self,
88        request: ImageGenerationRequest,
89    ) -> BoxFuture<'_, Result<ImageGenerationResponse<()>, ImageGenerationError>>;
90
91    fn image_generation_request(
92        &self,
93    ) -> ImageGenerationRequestBuilder<ImageGenerationModelHandle<'_>>;
94}
95
96#[allow(deprecated)]
97impl<T> ImageGenerationModelDyn for T
98where
99    T: ImageGenerationModel,
100{
101    fn image_generation(
102        &self,
103        request: ImageGenerationRequest,
104    ) -> BoxFuture<'_, Result<ImageGenerationResponse<()>, ImageGenerationError>> {
105        Box::pin(async {
106            let resp = self.image_generation(request).await;
107            resp.map(|r| ImageGenerationResponse {
108                image: r.image,
109                response: (),
110            })
111        })
112    }
113
114    fn image_generation_request(
115        &self,
116    ) -> ImageGenerationRequestBuilder<ImageGenerationModelHandle<'_>> {
117        ImageGenerationRequestBuilder::new(ImageGenerationModelHandle {
118            inner: Arc::new(self.clone()),
119        })
120    }
121}
122
123/// An image generation request.
124#[non_exhaustive]
125pub struct ImageGenerationRequest {
126    pub prompt: String,
127    pub width: u32,
128    pub height: u32,
129    pub additional_params: Option<Value>,
130}
131
132/// A builder for `ImageGenerationRequest`.
133/// Can be sent to a model provider.
134#[non_exhaustive]
135pub struct ImageGenerationRequestBuilder<M>
136where
137    M: ImageGenerationModel,
138{
139    model: M,
140    prompt: String,
141    width: u32,
142    height: u32,
143    additional_params: Option<Value>,
144}
145
146impl<M> ImageGenerationRequestBuilder<M>
147where
148    M: ImageGenerationModel,
149{
150    pub fn new(model: M) -> Self {
151        Self {
152            model,
153            prompt: "".to_string(),
154            height: 256,
155            width: 256,
156            additional_params: None,
157        }
158    }
159
160    /// Sets the prompt for the image generation request
161    pub fn prompt(mut self, prompt: &str) -> Self {
162        self.prompt = prompt.to_string();
163        self
164    }
165
166    /// The width of the generated image
167    pub fn width(mut self, width: u32) -> Self {
168        self.width = width;
169        self
170    }
171
172    /// The height of the generated image
173    pub fn height(mut self, height: u32) -> Self {
174        self.height = height;
175        self
176    }
177
178    /// Adds additional parameters to the image generation request.
179    pub fn additional_params(mut self, params: Value) -> Self {
180        self.additional_params = Some(params);
181        self
182    }
183
184    pub fn build(self) -> ImageGenerationRequest {
185        ImageGenerationRequest {
186            prompt: self.prompt,
187            width: self.width,
188            height: self.height,
189            additional_params: self.additional_params,
190        }
191    }
192
193    pub async fn send(self) -> Result<ImageGenerationResponse<M::Response>, ImageGenerationError> {
194        let model = self.model.clone();
195
196        model.image_generation(self.build()).await
197    }
198}