ai_sdk_provider/image_model/
trait_def.rs

1use super::*;
2use crate::{Result, SharedHeaders, SharedProviderOptions};
3use async_trait::async_trait;
4use std::future::{Future, IntoFuture};
5use std::pin::Pin;
6
7/// Builder for image generation requests.
8pub struct ImageGenerateBuilder<'a, M: ImageModel + ?Sized> {
9    model: &'a M,
10    options: ImageGenerateOptions,
11}
12
13impl<'a, M: ImageModel + ?Sized> ImageGenerateBuilder<'a, M> {
14    pub fn new(model: &'a M, prompt: impl Into<String>) -> Self {
15        Self {
16            model,
17            options: ImageGenerateOptions {
18                prompt: prompt.into(),
19                n: None,
20                size: None,
21                aspect_ratio: None,
22                seed: None,
23                provider_options: None,
24                headers: None,
25            },
26        }
27    }
28
29    /// Number of images to generate
30    pub fn n(mut self, n: usize) -> Self {
31        self.options.n = Some(n);
32        self
33    }
34
35    /// Size of the images to generate (e.g. "1024x1024")
36    pub fn size(mut self, size: impl Into<String>) -> Self {
37        self.options.size = Some(size.into());
38        self
39    }
40
41    /// Aspect ratio of the images to generate (e.g. "16:9")
42    pub fn aspect_ratio(mut self, aspect_ratio: impl Into<String>) -> Self {
43        self.options.aspect_ratio = Some(aspect_ratio.into());
44        self
45    }
46
47    /// Seed for the image generation
48    pub fn seed(mut self, seed: i64) -> Self {
49        self.options.seed = Some(seed);
50        self
51    }
52
53    /// Provider-specific options
54    pub fn provider_options(mut self, provider_options: SharedProviderOptions) -> Self {
55        self.options.provider_options = Some(provider_options);
56        self
57    }
58
59    /// Custom headers
60    pub fn headers(mut self, headers: SharedHeaders) -> Self {
61        self.options.headers = Some(headers);
62        self
63    }
64
65    /// Send the image generation request
66    pub async fn send(self) -> Result<ImageGenerateResponse> {
67        self.model.do_generate(self.options).await
68    }
69}
70
71impl<'a, M: ImageModel + ?Sized> IntoFuture for ImageGenerateBuilder<'a, M> {
72    type Output = Result<ImageGenerateResponse>;
73    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
74
75    fn into_future(self) -> Self::IntoFuture {
76        Box::pin(async move { self.model.do_generate(self.options).await })
77    }
78}
79
80/// Image generation model specification version 3.
81///
82/// The image model must specify which image model interface
83/// version it implements. This will allow us to evolve the image
84/// model interface and retain backwards compatibility.
85#[async_trait]
86pub trait ImageModel: Send + Sync {
87    /// Specification version (always "v3")
88    fn specification_version(&self) -> &str {
89        "v3"
90    }
91
92    /// Name of the provider for logging purposes (e.g., "openai")
93    fn provider(&self) -> &str;
94
95    /// Provider-specific model ID for logging purposes (e.g., "dall-e-3")
96    fn model_id(&self) -> &str;
97
98    /// Limit of how many images can be generated in a single API call.
99    ///
100    /// Returns None for models that do not have a limit.
101    async fn max_images_per_call(&self) -> Option<usize>;
102
103    /// Creates a builder for an image generation request.
104    fn generate(&self, prompt: impl Into<String>) -> ImageGenerateBuilder<'_, Self>
105    where
106        Self: Sized,
107    {
108        ImageGenerateBuilder::new(self, prompt)
109    }
110
111    /// Generates an array of images.
112    ///
113    /// Naming: "do" prefix to prevent accidental direct usage of the method by the user.
114    async fn do_generate(&self, options: ImageGenerateOptions) -> Result<ImageGenerateResponse>;
115}
116
117#[async_trait]
118impl<T: ImageModel + ?Sized> ImageModel for Box<T> {
119    fn specification_version(&self) -> &str {
120        (**self).specification_version()
121    }
122
123    fn provider(&self) -> &str {
124        (**self).provider()
125    }
126
127    fn model_id(&self) -> &str {
128        (**self).model_id()
129    }
130
131    async fn max_images_per_call(&self) -> Option<usize> {
132        (**self).max_images_per_call().await
133    }
134
135    async fn do_generate(&self, options: ImageGenerateOptions) -> Result<ImageGenerateResponse> {
136        (**self).do_generate(options).await
137    }
138}
139
140#[async_trait]
141impl<T: ImageModel + ?Sized> ImageModel for std::sync::Arc<T> {
142    fn specification_version(&self) -> &str {
143        (**self).specification_version()
144    }
145
146    fn provider(&self) -> &str {
147        (**self).provider()
148    }
149
150    fn model_id(&self) -> &str {
151        (**self).model_id()
152    }
153
154    async fn max_images_per_call(&self) -> Option<usize> {
155        (**self).max_images_per_call().await
156    }
157
158    async fn do_generate(&self, options: ImageGenerateOptions) -> Result<ImageGenerateResponse> {
159        (**self).do_generate(options).await
160    }
161}
162
163#[async_trait]
164impl<T: ImageModel + ?Sized> ImageModel for &T {
165    fn specification_version(&self) -> &str {
166        (**self).specification_version()
167    }
168
169    fn provider(&self) -> &str {
170        (**self).provider()
171    }
172
173    fn model_id(&self) -> &str {
174        (**self).model_id()
175    }
176
177    async fn max_images_per_call(&self) -> Option<usize> {
178        (**self).max_images_per_call().await
179    }
180
181    async fn do_generate(&self, options: ImageGenerateOptions) -> Result<ImageGenerateResponse> {
182        (**self).do_generate(options).await
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    struct DummyImageModel;
191
192    #[async_trait]
193    impl ImageModel for DummyImageModel {
194        fn provider(&self) -> &str {
195            "test"
196        }
197
198        fn model_id(&self) -> &str {
199            "dummy"
200        }
201
202        async fn max_images_per_call(&self) -> Option<usize> {
203            Some(1)
204        }
205
206        async fn do_generate(
207            &self,
208            _options: ImageGenerateOptions,
209        ) -> Result<ImageGenerateResponse> {
210            Ok(ImageGenerateResponse {
211                images: vec![ImageData::Base64("test".to_string())],
212                warnings: vec![],
213                provider_metadata: None,
214                response: ResponseInfo {
215                    timestamp: std::time::SystemTime::now(),
216                    model_id: "dummy".to_string(),
217                    headers: None,
218                },
219            })
220        }
221    }
222
223    #[tokio::test]
224    async fn test_image_model_trait() {
225        let model = DummyImageModel;
226        assert_eq!(model.provider(), "test");
227        assert_eq!(model.model_id(), "dummy");
228        assert_eq!(model.specification_version(), "v3");
229        assert_eq!(model.max_images_per_call().await, Some(1));
230
231        // Test builder
232        let res = model.generate("test prompt").n(1).await;
233        assert!(res.is_ok());
234    }
235}