ai_sdk_provider/image_model/
trait_def.rs1use super::*;
2use crate::{Result, SharedHeaders, SharedProviderOptions};
3use async_trait::async_trait;
4use std::future::{Future, IntoFuture};
5use std::pin::Pin;
6
7pub struct ImageGenerateBuilder<'a, M: ImageModel + ?Sized> {
9 model: &'a M,
10 options: ImageGenerateOptions,
11}
12
13impl<'a, M: ImageModel + ?Sized> ImageGenerateBuilder<'a, M> {
14 pub fn new(model: &'a M, prompt: impl Into<String>) -> Self {
15 Self {
16 model,
17 options: ImageGenerateOptions {
18 prompt: prompt.into(),
19 n: None,
20 size: None,
21 aspect_ratio: None,
22 seed: None,
23 provider_options: None,
24 headers: None,
25 },
26 }
27 }
28
29 pub fn n(mut self, n: usize) -> Self {
31 self.options.n = Some(n);
32 self
33 }
34
35 pub fn size(mut self, size: impl Into<String>) -> Self {
37 self.options.size = Some(size.into());
38 self
39 }
40
41 pub fn aspect_ratio(mut self, aspect_ratio: impl Into<String>) -> Self {
43 self.options.aspect_ratio = Some(aspect_ratio.into());
44 self
45 }
46
47 pub fn seed(mut self, seed: i64) -> Self {
49 self.options.seed = Some(seed);
50 self
51 }
52
53 pub fn provider_options(mut self, provider_options: SharedProviderOptions) -> Self {
55 self.options.provider_options = Some(provider_options);
56 self
57 }
58
59 pub fn headers(mut self, headers: SharedHeaders) -> Self {
61 self.options.headers = Some(headers);
62 self
63 }
64
65 pub async fn send(self) -> Result<ImageGenerateResponse> {
67 self.model.do_generate(self.options).await
68 }
69}
70
71impl<'a, M: ImageModel + ?Sized> IntoFuture for ImageGenerateBuilder<'a, M> {
72 type Output = Result<ImageGenerateResponse>;
73 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
74
75 fn into_future(self) -> Self::IntoFuture {
76 Box::pin(async move { self.model.do_generate(self.options).await })
77 }
78}
79
80#[async_trait]
86pub trait ImageModel: Send + Sync {
87 fn specification_version(&self) -> &str {
89 "v3"
90 }
91
92 fn provider(&self) -> &str;
94
95 fn model_id(&self) -> &str;
97
98 async fn max_images_per_call(&self) -> Option<usize>;
102
103 fn generate(&self, prompt: impl Into<String>) -> ImageGenerateBuilder<'_, Self>
105 where
106 Self: Sized,
107 {
108 ImageGenerateBuilder::new(self, prompt)
109 }
110
111 async fn do_generate(&self, options: ImageGenerateOptions) -> Result<ImageGenerateResponse>;
115}
116
117#[async_trait]
118impl<T: ImageModel + ?Sized> ImageModel for Box<T> {
119 fn specification_version(&self) -> &str {
120 (**self).specification_version()
121 }
122
123 fn provider(&self) -> &str {
124 (**self).provider()
125 }
126
127 fn model_id(&self) -> &str {
128 (**self).model_id()
129 }
130
131 async fn max_images_per_call(&self) -> Option<usize> {
132 (**self).max_images_per_call().await
133 }
134
135 async fn do_generate(&self, options: ImageGenerateOptions) -> Result<ImageGenerateResponse> {
136 (**self).do_generate(options).await
137 }
138}
139
140#[async_trait]
141impl<T: ImageModel + ?Sized> ImageModel for std::sync::Arc<T> {
142 fn specification_version(&self) -> &str {
143 (**self).specification_version()
144 }
145
146 fn provider(&self) -> &str {
147 (**self).provider()
148 }
149
150 fn model_id(&self) -> &str {
151 (**self).model_id()
152 }
153
154 async fn max_images_per_call(&self) -> Option<usize> {
155 (**self).max_images_per_call().await
156 }
157
158 async fn do_generate(&self, options: ImageGenerateOptions) -> Result<ImageGenerateResponse> {
159 (**self).do_generate(options).await
160 }
161}
162
163#[async_trait]
164impl<T: ImageModel + ?Sized> ImageModel for &T {
165 fn specification_version(&self) -> &str {
166 (**self).specification_version()
167 }
168
169 fn provider(&self) -> &str {
170 (**self).provider()
171 }
172
173 fn model_id(&self) -> &str {
174 (**self).model_id()
175 }
176
177 async fn max_images_per_call(&self) -> Option<usize> {
178 (**self).max_images_per_call().await
179 }
180
181 async fn do_generate(&self, options: ImageGenerateOptions) -> Result<ImageGenerateResponse> {
182 (**self).do_generate(options).await
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 struct DummyImageModel;
191
192 #[async_trait]
193 impl ImageModel for DummyImageModel {
194 fn provider(&self) -> &str {
195 "test"
196 }
197
198 fn model_id(&self) -> &str {
199 "dummy"
200 }
201
202 async fn max_images_per_call(&self) -> Option<usize> {
203 Some(1)
204 }
205
206 async fn do_generate(
207 &self,
208 _options: ImageGenerateOptions,
209 ) -> Result<ImageGenerateResponse> {
210 Ok(ImageGenerateResponse {
211 images: vec![ImageData::Base64("test".to_string())],
212 warnings: vec![],
213 provider_metadata: None,
214 response: ResponseInfo {
215 timestamp: std::time::SystemTime::now(),
216 model_id: "dummy".to_string(),
217 headers: None,
218 },
219 })
220 }
221 }
222
223 #[tokio::test]
224 async fn test_image_model_trait() {
225 let model = DummyImageModel;
226 assert_eq!(model.provider(), "test");
227 assert_eq!(model.model_id(), "dummy");
228 assert_eq!(model.specification_version(), "v3");
229 assert_eq!(model.max_images_per_call().await, Some(1));
230
231 let res = model.generate("test prompt").n(1).await;
233 assert!(res.is_ok());
234 }
235}