Skip to main content

rig/
image_generation.rs

1//! Everything related to core image generation abstractions in Rig.
2//! Rig allows calling a number of different providers (that support image generation) using the [ImageGenerationModel] trait.
3use crate::http_client;
4use crate::markers::{Missing, Provided};
5use serde_json::Value;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum ImageGenerationError {
10    /// Http error (e.g.: connection error, timeout, etc.)
11    #[error("HttpError: {0}")]
12    HttpError(#[from] http_client::Error),
13
14    /// Json error (e.g.: serialization, deserialization)
15    #[error("JsonError: {0}")]
16    JsonError(#[from] serde_json::Error),
17
18    /// Error building the transcription request
19    #[error("RequestError: {0}")]
20    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
21
22    /// Error parsing the transcription response
23    #[error("ResponseError: {0}")]
24    ResponseError(String),
25
26    /// Error returned by the transcription model provider
27    #[error("ProviderError: {0}")]
28    ProviderError(String),
29}
30pub trait ImageGeneration<M>
31where
32    M: ImageGenerationModel,
33{
34    /// Generates a transcription request builder for the given `file`.
35    /// This function is meant to be called by the user to further customize the
36    /// request at transcription time before sending it.
37    ///
38    /// ❗IMPORTANT: The type that implements this trait might have already
39    /// populated fields in the builder (the exact fields depend on the type).
40    /// For fields that have already been set by the model, calling the corresponding
41    /// method on the builder will overwrite the value set by the model.
42    fn image_generation(
43        &self,
44        prompt: &str,
45        size: &(u32, u32),
46    ) -> impl std::future::Future<
47        Output = Result<ImageGenerationRequestBuilder<M, Provided<String>>, ImageGenerationError>,
48    > + Send;
49}
50
51/// A unified response for a model image generation, returning both the image and the raw response.
52#[derive(Debug)]
53pub struct ImageGenerationResponse<T> {
54    pub image: Vec<u8>,
55    pub response: T,
56}
57
58pub trait ImageGenerationModel: Clone + Send + Sync {
59    type Response: Send + Sync;
60
61    type Client;
62
63    fn make(client: &Self::Client, model: impl Into<String>) -> Self;
64
65    fn image_generation(
66        &self,
67        request: ImageGenerationRequest,
68    ) -> impl std::future::Future<
69        Output = Result<ImageGenerationResponse<Self::Response>, ImageGenerationError>,
70    > + Send;
71
72    fn image_generation_request(&self) -> ImageGenerationRequestBuilder<Self, Missing> {
73        ImageGenerationRequestBuilder::new(self.clone())
74    }
75}
76/// An image generation request.
77#[non_exhaustive]
78pub struct ImageGenerationRequest {
79    pub prompt: String,
80    pub width: u32,
81    pub height: u32,
82    pub additional_params: Option<Value>,
83}
84
85/// A builder for `ImageGenerationRequest`.
86/// Can be sent to a model provider.
87#[non_exhaustive]
88pub struct ImageGenerationRequestBuilder<M, P = Missing>
89where
90    M: ImageGenerationModel,
91{
92    model: M,
93    prompt: P,
94    width: u32,
95    height: u32,
96    additional_params: Option<Value>,
97}
98
99impl<M> ImageGenerationRequestBuilder<M, Missing>
100where
101    M: ImageGenerationModel,
102{
103    pub fn new(model: M) -> Self {
104        Self {
105            model,
106            prompt: Missing,
107            height: 256,
108            width: 256,
109            additional_params: None,
110        }
111    }
112}
113
114impl<M, P> ImageGenerationRequestBuilder<M, P>
115where
116    M: ImageGenerationModel,
117{
118    /// Sets the prompt for the image generation request
119    pub fn prompt(self, prompt: &str) -> ImageGenerationRequestBuilder<M, Provided<String>> {
120        ImageGenerationRequestBuilder {
121            model: self.model,
122            prompt: Provided(prompt.to_string()),
123            width: self.width,
124            height: self.height,
125            additional_params: self.additional_params,
126        }
127    }
128
129    /// The width of the generated image
130    pub fn width(mut self, width: u32) -> Self {
131        self.width = width;
132        self
133    }
134
135    /// The height of the generated image
136    pub fn height(mut self, height: u32) -> Self {
137        self.height = height;
138        self
139    }
140
141    /// Adds additional parameters to the image generation request.
142    pub fn additional_params(mut self, params: Value) -> Self {
143        self.additional_params = Some(params);
144        self
145    }
146}
147
148impl<M> ImageGenerationRequestBuilder<M, Provided<String>>
149where
150    M: ImageGenerationModel,
151{
152    pub fn build(self) -> ImageGenerationRequest {
153        ImageGenerationRequest {
154            prompt: self.prompt.0,
155            width: self.width,
156            height: self.height,
157            additional_params: self.additional_params,
158        }
159    }
160
161    pub async fn send(self) -> Result<ImageGenerationResponse<M::Response>, ImageGenerationError> {
162        let model = self.model.clone();
163
164        model.image_generation(self.build()).await
165    }
166}