Skip to main content

llmsdk_provider/middleware/
image_model.rs

1//! `ImageModelMiddleware` trait and the `wrap_image_model` combinator.
2//!
3//! Mirrors `image-model-v4-middleware.ts` (trait surface) and
4//! `wrap-image-model.ts` (combinator). Structurally identical to
5//! [`super::language_model`] / [`super::embedding_model`] — only the callable
6//! surface differs (`do_generate(ImageOptions) -> ImageResult`).
7// Rust guideline compliant 2026-02-21
8
9use std::sync::Arc;
10
11use async_trait::async_trait;
12
13use crate::error::Result;
14use crate::image_model::{ImageModel, ImageOptions, ImageResult};
15
16/// Contract for middleware that decorates an [`ImageModel`].
17#[async_trait]
18pub trait ImageModelMiddleware: Send + Sync + std::fmt::Debug {
19    /// Override the provider id exposed by the wrapped model.
20    fn override_provider(&self, _inner: &dyn ImageModel) -> Option<String> {
21        None
22    }
23
24    /// Override the model id exposed by the wrapped model.
25    fn override_model_id(&self, _inner: &dyn ImageModel) -> Option<String> {
26        None
27    }
28
29    /// Override [`ImageModel::max_images_per_call`].
30    async fn override_max_images_per_call(&self, _inner: &dyn ImageModel) -> Option<Option<u32>> {
31        None
32    }
33
34    /// Transform the image options before they reach the inner model.
35    ///
36    /// # Errors
37    ///
38    /// Return a [`crate::ProviderError`] to fail the call without invoking
39    /// the model.
40    async fn transform_params(
41        &self,
42        params: ImageOptions,
43        _inner: &dyn ImageModel,
44    ) -> Result<ImageOptions> {
45        Ok(params)
46    }
47
48    /// Wrap an image-generation call.
49    ///
50    /// Default: forwards to `next.do_generate(params)`.
51    ///
52    /// # Errors
53    ///
54    /// Returns whatever error `next` returns, or a middleware-introduced
55    /// failure.
56    async fn wrap_generate(
57        &self,
58        next: &dyn ImageModel,
59        params: ImageOptions,
60    ) -> Result<ImageResult> {
61        next.do_generate(params).await
62    }
63}
64
65/// Compose an image model with one or more middlewares.
66///
67/// Outer-to-inner ordering on the way in, inner-to-outer on the way out
68/// (list head = outermost). Empty middleware iterator returns the model
69/// unchanged.
70pub fn wrap_image_model<I>(model: Arc<dyn ImageModel>, middleware: I) -> Arc<dyn ImageModel>
71where
72    I: IntoIterator<Item = Arc<dyn ImageModelMiddleware>>,
73{
74    let mut layers: Vec<Arc<dyn ImageModelMiddleware>> = middleware.into_iter().collect();
75    layers.reverse();
76    layers
77        .into_iter()
78        .fold(model, |inner, mw| Arc::new(Wrapped::new(inner, mw)))
79}
80
81/// Internal one-layer wrapper.
82struct Wrapped {
83    inner: Arc<dyn ImageModel>,
84    middleware: Arc<dyn ImageModelMiddleware>,
85    provider: String,
86    model_id: String,
87}
88
89impl Wrapped {
90    fn new(inner: Arc<dyn ImageModel>, middleware: Arc<dyn ImageModelMiddleware>) -> Self {
91        let provider = middleware
92            .override_provider(inner.as_ref())
93            .unwrap_or_else(|| inner.provider().to_owned());
94        let model_id = middleware
95            .override_model_id(inner.as_ref())
96            .unwrap_or_else(|| inner.model_id().to_owned());
97        Self {
98            inner,
99            middleware,
100            provider,
101            model_id,
102        }
103    }
104}
105
106impl std::fmt::Debug for Wrapped {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        f.debug_struct("Wrapped")
109            .field("provider", &self.provider)
110            .field("model_id", &self.model_id)
111            .field("middleware", &self.middleware)
112            .field("inner", &self.inner)
113            .finish()
114    }
115}
116
117#[async_trait]
118impl ImageModel for Wrapped {
119    fn provider(&self) -> &str {
120        &self.provider
121    }
122
123    fn model_id(&self) -> &str {
124        &self.model_id
125    }
126
127    async fn max_images_per_call(&self) -> Option<u32> {
128        if let Some(custom) = self
129            .middleware
130            .override_max_images_per_call(self.inner.as_ref())
131            .await
132        {
133            return custom;
134        }
135        self.inner.max_images_per_call().await
136    }
137
138    async fn do_generate(&self, options: ImageOptions) -> Result<ImageResult> {
139        let transformed = self
140            .middleware
141            .transform_params(options, self.inner.as_ref())
142            .await?;
143        self.middleware
144            .wrap_generate(self.inner.as_ref(), transformed)
145            .await
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use std::sync::Mutex;
152    use std::sync::atomic::{AtomicUsize, Ordering};
153
154    use super::*;
155
156    #[derive(Debug, Default)]
157    struct MockImage {
158        provider: String,
159        model_id: String,
160        calls: AtomicUsize,
161        last_prompt: Mutex<String>,
162    }
163
164    impl MockImage {
165        fn new(provider: &str, model_id: &str) -> Self {
166            Self {
167                provider: provider.to_owned(),
168                model_id: model_id.to_owned(),
169                calls: AtomicUsize::new(0),
170                last_prompt: Mutex::new(String::new()),
171            }
172        }
173    }
174
175    #[async_trait]
176    impl ImageModel for MockImage {
177        fn provider(&self) -> &str {
178            &self.provider
179        }
180        fn model_id(&self) -> &str {
181            &self.model_id
182        }
183        async fn do_generate(&self, options: ImageOptions) -> Result<ImageResult> {
184            self.calls.fetch_add(1, Ordering::SeqCst);
185            *self.last_prompt.lock().expect("mutex") = options.prompt;
186            Ok(ImageResult {
187                images: vec![],
188                warnings: vec![],
189                usage: None,
190                provider_metadata: None,
191                request: None,
192                response: None,
193            })
194        }
195    }
196
197    #[derive(Debug)]
198    struct OverrideAndPrefix;
199
200    #[async_trait]
201    impl ImageModelMiddleware for OverrideAndPrefix {
202        fn override_model_id(&self, _: &dyn ImageModel) -> Option<String> {
203            Some("wrapped-model".to_owned())
204        }
205
206        async fn transform_params(
207            &self,
208            mut params: ImageOptions,
209            _inner: &dyn ImageModel,
210        ) -> Result<ImageOptions> {
211            params.prompt = format!("PREFIX: {}", params.prompt);
212            Ok(params)
213        }
214    }
215
216    #[tokio::test]
217    async fn empty_middleware_unchanged() {
218        let model = Arc::new(MockImage::new("p", "m"));
219        let wrapped: Arc<dyn ImageModel> = wrap_image_model(Arc::clone(&model) as _, Vec::new());
220        assert_eq!(wrapped.model_id(), "m");
221    }
222
223    #[tokio::test]
224    async fn overrides_and_prefix_run() {
225        let model = Arc::new(MockImage::new("p", "m"));
226        let wrapped = wrap_image_model(
227            Arc::clone(&model) as _,
228            [Arc::new(OverrideAndPrefix) as Arc<dyn ImageModelMiddleware>],
229        );
230
231        assert_eq!(wrapped.model_id(), "wrapped-model");
232
233        wrapped
234            .do_generate(ImageOptions {
235                prompt: "a cat".into(),
236                ..Default::default()
237            })
238            .await
239            .expect("generate");
240
241        assert_eq!(model.calls.load(Ordering::SeqCst), 1);
242        assert_eq!(*model.last_prompt.lock().expect("mutex"), "PREFIX: a cat");
243    }
244}