use std::sync::Arc;
use async_trait::async_trait;
use crate::error::Result;
use crate::video_model::{VideoModel, VideoOptions, VideoResult};
#[async_trait]
pub trait VideoModelMiddleware: Send + Sync + std::fmt::Debug {
fn override_provider(&self, _inner: &dyn VideoModel) -> Option<String> {
None
}
fn override_model_id(&self, _inner: &dyn VideoModel) -> Option<String> {
None
}
async fn override_max_videos_per_call(&self, _inner: &dyn VideoModel) -> Option<Option<u32>> {
None
}
async fn transform_params(
&self,
params: VideoOptions,
_inner: &dyn VideoModel,
) -> Result<VideoOptions> {
Ok(params)
}
async fn wrap_generate(
&self,
next: &dyn VideoModel,
params: VideoOptions,
) -> Result<VideoResult> {
next.do_generate(params).await
}
}
pub fn wrap_video_model<I>(model: Arc<dyn VideoModel>, middleware: I) -> Arc<dyn VideoModel>
where
I: IntoIterator<Item = Arc<dyn VideoModelMiddleware>>,
{
let mut layers: Vec<Arc<dyn VideoModelMiddleware>> = middleware.into_iter().collect();
layers.reverse();
layers
.into_iter()
.fold(model, |inner, mw| Arc::new(Wrapped::new(inner, mw)))
}
struct Wrapped {
inner: Arc<dyn VideoModel>,
middleware: Arc<dyn VideoModelMiddleware>,
provider: String,
model_id: String,
}
impl Wrapped {
fn new(inner: Arc<dyn VideoModel>, middleware: Arc<dyn VideoModelMiddleware>) -> Self {
let provider = middleware
.override_provider(inner.as_ref())
.unwrap_or_else(|| inner.provider().to_owned());
let model_id = middleware
.override_model_id(inner.as_ref())
.unwrap_or_else(|| inner.model_id().to_owned());
Self {
inner,
middleware,
provider,
model_id,
}
}
}
impl std::fmt::Debug for Wrapped {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Wrapped")
.field("provider", &self.provider)
.field("model_id", &self.model_id)
.field("middleware", &self.middleware)
.field("inner", &self.inner)
.finish()
}
}
#[async_trait]
impl VideoModel for Wrapped {
fn provider(&self) -> &str {
&self.provider
}
fn model_id(&self) -> &str {
&self.model_id
}
async fn max_videos_per_call(&self) -> Option<u32> {
if let Some(custom) = self
.middleware
.override_max_videos_per_call(self.inner.as_ref())
.await
{
return custom;
}
self.inner.max_videos_per_call().await
}
async fn do_generate(&self, options: VideoOptions) -> Result<VideoResult> {
let transformed = self
.middleware
.transform_params(options, self.inner.as_ref())
.await?;
self.middleware
.wrap_generate(self.inner.as_ref(), transformed)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::video_model::VideoResponseInfo;
#[derive(Debug, Default)]
struct MockVideo {
provider: String,
model_id: String,
}
impl MockVideo {
fn new(provider: &str, model_id: &str) -> Self {
Self {
provider: provider.to_owned(),
model_id: model_id.to_owned(),
}
}
}
#[async_trait]
impl VideoModel for MockVideo {
fn provider(&self) -> &str {
&self.provider
}
fn model_id(&self) -> &str {
&self.model_id
}
async fn do_generate(&self, _options: VideoOptions) -> Result<VideoResult> {
Ok(VideoResult {
videos: vec![],
warnings: vec![],
provider_metadata: None,
response: VideoResponseInfo {
timestamp: "2026-05-25T00:00:00Z".into(),
model_id: "mock".into(),
headers: None,
},
})
}
}
#[derive(Debug)]
struct OverrideName;
#[async_trait]
impl VideoModelMiddleware for OverrideName {
fn override_model_id(&self, _: &dyn VideoModel) -> Option<String> {
Some("wrapped-video".into())
}
}
#[tokio::test]
async fn empty_middleware_unchanged() {
let model = Arc::new(MockVideo::new("xai", "v1"));
let wrapped = wrap_video_model(model as _, Vec::new());
assert_eq!(wrapped.model_id(), "v1");
}
#[tokio::test]
async fn override_runs_at_construction() {
let model = Arc::new(MockVideo::new("xai", "v1"));
let wrapped = wrap_video_model(
model as _,
[Arc::new(OverrideName) as Arc<dyn VideoModelMiddleware>],
);
assert_eq!(wrapped.model_id(), "wrapped-video");
}
}