Skip to main content

llmsdk_provider/middleware/
video_model.rs

1//! `VideoModelMiddleware` trait and the `wrap_video_model` combinator.
2//!
3//! Mirrors the v4 middleware pattern from `language-model-middleware`. Surface
4//! is intentionally identical to [`super::image_model`] —
5//! `do_generate(VideoOptions) -> VideoResult`.
6// Rust guideline compliant 2026-02-21
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use crate::error::Result;
13use crate::video_model::{VideoModel, VideoOptions, VideoResult};
14
15/// Contract for middleware that decorates a [`VideoModel`].
16#[async_trait]
17pub trait VideoModelMiddleware: Send + Sync + std::fmt::Debug {
18    /// Override the provider id exposed by the wrapped model.
19    fn override_provider(&self, _inner: &dyn VideoModel) -> Option<String> {
20        None
21    }
22
23    /// Override the model id exposed by the wrapped model.
24    fn override_model_id(&self, _inner: &dyn VideoModel) -> Option<String> {
25        None
26    }
27
28    /// Override [`VideoModel::max_videos_per_call`].
29    async fn override_max_videos_per_call(&self, _inner: &dyn VideoModel) -> Option<Option<u32>> {
30        None
31    }
32
33    /// Transform the video options before they reach the inner model.
34    ///
35    /// # Errors
36    ///
37    /// Return a [`crate::ProviderError`] to fail the call without invoking
38    /// the model.
39    async fn transform_params(
40        &self,
41        params: VideoOptions,
42        _inner: &dyn VideoModel,
43    ) -> Result<VideoOptions> {
44        Ok(params)
45    }
46
47    /// Wrap a video-generation call.
48    ///
49    /// Default: forwards to `next.do_generate(params)`.
50    ///
51    /// # Errors
52    ///
53    /// Returns whatever error `next` returns, or a middleware-introduced
54    /// failure.
55    async fn wrap_generate(
56        &self,
57        next: &dyn VideoModel,
58        params: VideoOptions,
59    ) -> Result<VideoResult> {
60        next.do_generate(params).await
61    }
62}
63
64/// Compose a video model with one or more middlewares.
65///
66/// Outer-to-inner ordering on the way in, inner-to-outer on the way out
67/// (list head = outermost). Empty middleware iterator returns the model
68/// unchanged.
69pub fn wrap_video_model<I>(model: Arc<dyn VideoModel>, middleware: I) -> Arc<dyn VideoModel>
70where
71    I: IntoIterator<Item = Arc<dyn VideoModelMiddleware>>,
72{
73    let mut layers: Vec<Arc<dyn VideoModelMiddleware>> = middleware.into_iter().collect();
74    layers.reverse();
75    layers
76        .into_iter()
77        .fold(model, |inner, mw| Arc::new(Wrapped::new(inner, mw)))
78}
79
80/// Internal one-layer wrapper.
81struct Wrapped {
82    inner: Arc<dyn VideoModel>,
83    middleware: Arc<dyn VideoModelMiddleware>,
84    provider: String,
85    model_id: String,
86}
87
88impl Wrapped {
89    fn new(inner: Arc<dyn VideoModel>, middleware: Arc<dyn VideoModelMiddleware>) -> Self {
90        let provider = middleware
91            .override_provider(inner.as_ref())
92            .unwrap_or_else(|| inner.provider().to_owned());
93        let model_id = middleware
94            .override_model_id(inner.as_ref())
95            .unwrap_or_else(|| inner.model_id().to_owned());
96        Self {
97            inner,
98            middleware,
99            provider,
100            model_id,
101        }
102    }
103}
104
105impl std::fmt::Debug for Wrapped {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        f.debug_struct("Wrapped")
108            .field("provider", &self.provider)
109            .field("model_id", &self.model_id)
110            .field("middleware", &self.middleware)
111            .field("inner", &self.inner)
112            .finish()
113    }
114}
115
116#[async_trait]
117impl VideoModel for Wrapped {
118    fn provider(&self) -> &str {
119        &self.provider
120    }
121
122    fn model_id(&self) -> &str {
123        &self.model_id
124    }
125
126    async fn max_videos_per_call(&self) -> Option<u32> {
127        if let Some(custom) = self
128            .middleware
129            .override_max_videos_per_call(self.inner.as_ref())
130            .await
131        {
132            return custom;
133        }
134        self.inner.max_videos_per_call().await
135    }
136
137    async fn do_generate(&self, options: VideoOptions) -> Result<VideoResult> {
138        let transformed = self
139            .middleware
140            .transform_params(options, self.inner.as_ref())
141            .await?;
142        self.middleware
143            .wrap_generate(self.inner.as_ref(), transformed)
144            .await
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::video_model::VideoResponseInfo;
152
153    #[derive(Debug, Default)]
154    struct MockVideo {
155        provider: String,
156        model_id: String,
157    }
158
159    impl MockVideo {
160        fn new(provider: &str, model_id: &str) -> Self {
161            Self {
162                provider: provider.to_owned(),
163                model_id: model_id.to_owned(),
164            }
165        }
166    }
167
168    #[async_trait]
169    impl VideoModel for MockVideo {
170        fn provider(&self) -> &str {
171            &self.provider
172        }
173        fn model_id(&self) -> &str {
174            &self.model_id
175        }
176        async fn do_generate(&self, _options: VideoOptions) -> Result<VideoResult> {
177            Ok(VideoResult {
178                videos: vec![],
179                warnings: vec![],
180                provider_metadata: None,
181                response: VideoResponseInfo {
182                    timestamp: "2026-05-25T00:00:00Z".into(),
183                    model_id: "mock".into(),
184                    headers: None,
185                },
186            })
187        }
188    }
189
190    #[derive(Debug)]
191    struct OverrideName;
192
193    #[async_trait]
194    impl VideoModelMiddleware for OverrideName {
195        fn override_model_id(&self, _: &dyn VideoModel) -> Option<String> {
196            Some("wrapped-video".into())
197        }
198    }
199
200    #[tokio::test]
201    async fn empty_middleware_unchanged() {
202        let model = Arc::new(MockVideo::new("xai", "v1"));
203        let wrapped = wrap_video_model(model as _, Vec::new());
204        assert_eq!(wrapped.model_id(), "v1");
205    }
206
207    #[tokio::test]
208    async fn override_runs_at_construction() {
209        let model = Arc::new(MockVideo::new("xai", "v1"));
210        let wrapped = wrap_video_model(
211            model as _,
212            [Arc::new(OverrideName) as Arc<dyn VideoModelMiddleware>],
213        );
214        assert_eq!(wrapped.model_id(), "wrapped-video");
215    }
216}