llmsdk_provider/middleware/
image_model.rs1use std::sync::Arc;
10
11use async_trait::async_trait;
12
13use crate::error::Result;
14use crate::image_model::{ImageModel, ImageOptions, ImageResult};
15
16#[async_trait]
18pub trait ImageModelMiddleware: Send + Sync + std::fmt::Debug {
19 fn override_provider(&self, _inner: &dyn ImageModel) -> Option<String> {
21 None
22 }
23
24 fn override_model_id(&self, _inner: &dyn ImageModel) -> Option<String> {
26 None
27 }
28
29 async fn override_max_images_per_call(&self, _inner: &dyn ImageModel) -> Option<Option<u32>> {
31 None
32 }
33
34 async fn transform_params(
41 &self,
42 params: ImageOptions,
43 _inner: &dyn ImageModel,
44 ) -> Result<ImageOptions> {
45 Ok(params)
46 }
47
48 async fn wrap_generate(
57 &self,
58 next: &dyn ImageModel,
59 params: ImageOptions,
60 ) -> Result<ImageResult> {
61 next.do_generate(params).await
62 }
63}
64
65pub fn wrap_image_model<I>(model: Arc<dyn ImageModel>, middleware: I) -> Arc<dyn ImageModel>
71where
72 I: IntoIterator<Item = Arc<dyn ImageModelMiddleware>>,
73{
74 let mut layers: Vec<Arc<dyn ImageModelMiddleware>> = middleware.into_iter().collect();
75 layers.reverse();
76 layers
77 .into_iter()
78 .fold(model, |inner, mw| Arc::new(Wrapped::new(inner, mw)))
79}
80
81struct Wrapped {
83 inner: Arc<dyn ImageModel>,
84 middleware: Arc<dyn ImageModelMiddleware>,
85 provider: String,
86 model_id: String,
87}
88
89impl Wrapped {
90 fn new(inner: Arc<dyn ImageModel>, middleware: Arc<dyn ImageModelMiddleware>) -> Self {
91 let provider = middleware
92 .override_provider(inner.as_ref())
93 .unwrap_or_else(|| inner.provider().to_owned());
94 let model_id = middleware
95 .override_model_id(inner.as_ref())
96 .unwrap_or_else(|| inner.model_id().to_owned());
97 Self {
98 inner,
99 middleware,
100 provider,
101 model_id,
102 }
103 }
104}
105
106impl std::fmt::Debug for Wrapped {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 f.debug_struct("Wrapped")
109 .field("provider", &self.provider)
110 .field("model_id", &self.model_id)
111 .field("middleware", &self.middleware)
112 .field("inner", &self.inner)
113 .finish()
114 }
115}
116
117#[async_trait]
118impl ImageModel for Wrapped {
119 fn provider(&self) -> &str {
120 &self.provider
121 }
122
123 fn model_id(&self) -> &str {
124 &self.model_id
125 }
126
127 async fn max_images_per_call(&self) -> Option<u32> {
128 if let Some(custom) = self
129 .middleware
130 .override_max_images_per_call(self.inner.as_ref())
131 .await
132 {
133 return custom;
134 }
135 self.inner.max_images_per_call().await
136 }
137
138 async fn do_generate(&self, options: ImageOptions) -> Result<ImageResult> {
139 let transformed = self
140 .middleware
141 .transform_params(options, self.inner.as_ref())
142 .await?;
143 self.middleware
144 .wrap_generate(self.inner.as_ref(), transformed)
145 .await
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use std::sync::Mutex;
152 use std::sync::atomic::{AtomicUsize, Ordering};
153
154 use super::*;
155
156 #[derive(Debug, Default)]
157 struct MockImage {
158 provider: String,
159 model_id: String,
160 calls: AtomicUsize,
161 last_prompt: Mutex<String>,
162 }
163
164 impl MockImage {
165 fn new(provider: &str, model_id: &str) -> Self {
166 Self {
167 provider: provider.to_owned(),
168 model_id: model_id.to_owned(),
169 calls: AtomicUsize::new(0),
170 last_prompt: Mutex::new(String::new()),
171 }
172 }
173 }
174
175 #[async_trait]
176 impl ImageModel for MockImage {
177 fn provider(&self) -> &str {
178 &self.provider
179 }
180 fn model_id(&self) -> &str {
181 &self.model_id
182 }
183 async fn do_generate(&self, options: ImageOptions) -> Result<ImageResult> {
184 self.calls.fetch_add(1, Ordering::SeqCst);
185 *self.last_prompt.lock().expect("mutex") = options.prompt;
186 Ok(ImageResult {
187 images: vec![],
188 warnings: vec![],
189 usage: None,
190 provider_metadata: None,
191 request: None,
192 response: None,
193 })
194 }
195 }
196
197 #[derive(Debug)]
198 struct OverrideAndPrefix;
199
200 #[async_trait]
201 impl ImageModelMiddleware for OverrideAndPrefix {
202 fn override_model_id(&self, _: &dyn ImageModel) -> Option<String> {
203 Some("wrapped-model".to_owned())
204 }
205
206 async fn transform_params(
207 &self,
208 mut params: ImageOptions,
209 _inner: &dyn ImageModel,
210 ) -> Result<ImageOptions> {
211 params.prompt = format!("PREFIX: {}", params.prompt);
212 Ok(params)
213 }
214 }
215
216 #[tokio::test]
217 async fn empty_middleware_unchanged() {
218 let model = Arc::new(MockImage::new("p", "m"));
219 let wrapped: Arc<dyn ImageModel> = wrap_image_model(Arc::clone(&model) as _, Vec::new());
220 assert_eq!(wrapped.model_id(), "m");
221 }
222
223 #[tokio::test]
224 async fn overrides_and_prefix_run() {
225 let model = Arc::new(MockImage::new("p", "m"));
226 let wrapped = wrap_image_model(
227 Arc::clone(&model) as _,
228 [Arc::new(OverrideAndPrefix) as Arc<dyn ImageModelMiddleware>],
229 );
230
231 assert_eq!(wrapped.model_id(), "wrapped-model");
232
233 wrapped
234 .do_generate(ImageOptions {
235 prompt: "a cat".into(),
236 ..Default::default()
237 })
238 .await
239 .expect("generate");
240
241 assert_eq!(model.calls.load(Ordering::SeqCst), 1);
242 assert_eq!(*model.last_prompt.lock().expect("mutex"), "PREFIX: a cat");
243 }
244}