llmsdk_provider/middleware/
video_model.rs1use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use crate::error::Result;
13use crate::video_model::{VideoModel, VideoOptions, VideoResult};
14
15#[async_trait]
17pub trait VideoModelMiddleware: Send + Sync + std::fmt::Debug {
18 fn override_provider(&self, _inner: &dyn VideoModel) -> Option<String> {
20 None
21 }
22
23 fn override_model_id(&self, _inner: &dyn VideoModel) -> Option<String> {
25 None
26 }
27
28 async fn override_max_videos_per_call(&self, _inner: &dyn VideoModel) -> Option<Option<u32>> {
30 None
31 }
32
33 async fn transform_params(
40 &self,
41 params: VideoOptions,
42 _inner: &dyn VideoModel,
43 ) -> Result<VideoOptions> {
44 Ok(params)
45 }
46
47 async fn wrap_generate(
56 &self,
57 next: &dyn VideoModel,
58 params: VideoOptions,
59 ) -> Result<VideoResult> {
60 next.do_generate(params).await
61 }
62}
63
64pub fn wrap_video_model<I>(model: Arc<dyn VideoModel>, middleware: I) -> Arc<dyn VideoModel>
70where
71 I: IntoIterator<Item = Arc<dyn VideoModelMiddleware>>,
72{
73 let mut layers: Vec<Arc<dyn VideoModelMiddleware>> = middleware.into_iter().collect();
74 layers.reverse();
75 layers
76 .into_iter()
77 .fold(model, |inner, mw| Arc::new(Wrapped::new(inner, mw)))
78}
79
80struct Wrapped {
82 inner: Arc<dyn VideoModel>,
83 middleware: Arc<dyn VideoModelMiddleware>,
84 provider: String,
85 model_id: String,
86}
87
88impl Wrapped {
89 fn new(inner: Arc<dyn VideoModel>, middleware: Arc<dyn VideoModelMiddleware>) -> Self {
90 let provider = middleware
91 .override_provider(inner.as_ref())
92 .unwrap_or_else(|| inner.provider().to_owned());
93 let model_id = middleware
94 .override_model_id(inner.as_ref())
95 .unwrap_or_else(|| inner.model_id().to_owned());
96 Self {
97 inner,
98 middleware,
99 provider,
100 model_id,
101 }
102 }
103}
104
105impl std::fmt::Debug for Wrapped {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 f.debug_struct("Wrapped")
108 .field("provider", &self.provider)
109 .field("model_id", &self.model_id)
110 .field("middleware", &self.middleware)
111 .field("inner", &self.inner)
112 .finish()
113 }
114}
115
116#[async_trait]
117impl VideoModel for Wrapped {
118 fn provider(&self) -> &str {
119 &self.provider
120 }
121
122 fn model_id(&self) -> &str {
123 &self.model_id
124 }
125
126 async fn max_videos_per_call(&self) -> Option<u32> {
127 if let Some(custom) = self
128 .middleware
129 .override_max_videos_per_call(self.inner.as_ref())
130 .await
131 {
132 return custom;
133 }
134 self.inner.max_videos_per_call().await
135 }
136
137 async fn do_generate(&self, options: VideoOptions) -> Result<VideoResult> {
138 let transformed = self
139 .middleware
140 .transform_params(options, self.inner.as_ref())
141 .await?;
142 self.middleware
143 .wrap_generate(self.inner.as_ref(), transformed)
144 .await
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::video_model::VideoResponseInfo;
152
153 #[derive(Debug, Default)]
154 struct MockVideo {
155 provider: String,
156 model_id: String,
157 }
158
159 impl MockVideo {
160 fn new(provider: &str, model_id: &str) -> Self {
161 Self {
162 provider: provider.to_owned(),
163 model_id: model_id.to_owned(),
164 }
165 }
166 }
167
168 #[async_trait]
169 impl VideoModel for MockVideo {
170 fn provider(&self) -> &str {
171 &self.provider
172 }
173 fn model_id(&self) -> &str {
174 &self.model_id
175 }
176 async fn do_generate(&self, _options: VideoOptions) -> Result<VideoResult> {
177 Ok(VideoResult {
178 videos: vec![],
179 warnings: vec![],
180 provider_metadata: None,
181 response: VideoResponseInfo {
182 timestamp: "2026-05-25T00:00:00Z".into(),
183 model_id: "mock".into(),
184 headers: None,
185 },
186 })
187 }
188 }
189
190 #[derive(Debug)]
191 struct OverrideName;
192
193 #[async_trait]
194 impl VideoModelMiddleware for OverrideName {
195 fn override_model_id(&self, _: &dyn VideoModel) -> Option<String> {
196 Some("wrapped-video".into())
197 }
198 }
199
200 #[tokio::test]
201 async fn empty_middleware_unchanged() {
202 let model = Arc::new(MockVideo::new("xai", "v1"));
203 let wrapped = wrap_video_model(model as _, Vec::new());
204 assert_eq!(wrapped.model_id(), "v1");
205 }
206
207 #[tokio::test]
208 async fn override_runs_at_construction() {
209 let model = Arc::new(MockVideo::new("xai", "v1"));
210 let wrapped = wrap_video_model(
211 model as _,
212 [Arc::new(OverrideName) as Arc<dyn VideoModelMiddleware>],
213 );
214 assert_eq!(wrapped.model_id(), "wrapped-video");
215 }
216}