use async_trait::async_trait;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::shared::{FileBytes, Headers, ProviderMetadata, ProviderOptions, Warning};
#[async_trait]
pub trait VideoModel: Send + Sync + std::fmt::Debug {
fn provider(&self) -> &str;
fn model_id(&self) -> &str;
fn specification_version(&self) -> &'static str {
"v4"
}
async fn max_videos_per_call(&self) -> Option<u32> {
Some(1)
}
async fn do_generate(&self, options: VideoOptions) -> Result<VideoResult>;
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct VideoOptions {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
#[serde(default = "default_n")]
pub n: u32,
#[serde(
default,
rename = "aspectRatio",
skip_serializing_if = "Option::is_none"
)]
pub aspect_ratio: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub resolution: Option<String>,
#[serde(default, rename = "duration", skip_serializing_if = "Option::is_none")]
pub duration_seconds: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub fps: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub image: Option<VideoFile>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub headers: Option<Headers>,
#[serde(
default,
rename = "providerOptions",
skip_serializing_if = "Option::is_none"
)]
pub provider_options: Option<ProviderOptions>,
}
fn default_n() -> u32 {
1
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub enum VideoFile {
File {
#[serde(rename = "mediaType")]
media_type: String,
data: FileBytes,
#[serde(
default,
rename = "providerOptions",
skip_serializing_if = "Option::is_none"
)]
provider_options: Option<ProviderOptions>,
},
Url {
url: String,
#[serde(
default,
rename = "providerOptions",
skip_serializing_if = "Option::is_none"
)]
provider_options: Option<ProviderOptions>,
},
}
#[derive(Debug, Clone)]
pub struct VideoResult {
pub videos: Vec<VideoData>,
pub warnings: Vec<Warning>,
pub provider_metadata: Option<ProviderMetadata>,
pub response: VideoResponseInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct VideoResponseInfo {
pub timestamp: String,
#[serde(rename = "modelId")]
pub model_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub headers: Option<Headers>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub enum VideoData {
Url {
url: String,
#[serde(rename = "mediaType")]
media_type: String,
},
Base64 {
data: String,
#[serde(rename = "mediaType")]
media_type: String,
},
Binary {
#[serde(with = "binary_serde")]
data: Bytes,
#[serde(rename = "mediaType")]
media_type: String,
},
}
mod binary_serde {
use bytes::Bytes;
use serde::{Deserialize, Deserializer, Serializer};
pub fn serialize<S: Serializer>(b: &Bytes, s: S) -> Result<S::Ok, S::Error> {
s.serialize_bytes(b)
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Bytes, D::Error> {
let v: Vec<u8> = Vec::deserialize(d)?;
Ok(Bytes::from(v))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn options_default_n_is_one() {
let v: VideoOptions = serde_json::from_value(json!({})).unwrap();
assert_eq!(v.n, 1);
}
#[test]
fn options_roundtrip_camelcase() {
let v = VideoOptions {
prompt: Some("a cat".into()),
n: 2,
aspect_ratio: Some("16:9".into()),
resolution: Some("1920x1080".into()),
duration_seconds: Some(5.0),
fps: Some(30),
seed: Some(42),
image: None,
headers: None,
provider_options: None,
};
let j = serde_json::to_value(&v).unwrap();
assert_eq!(j["aspectRatio"], "16:9");
assert_eq!(j["duration"], 5.0);
let back: VideoOptions = serde_json::from_value(j).unwrap();
assert_eq!(back.aspect_ratio.as_deref(), Some("16:9"));
assert_eq!(back.fps, Some(30));
}
#[test]
fn file_tagged_correctly() {
let f = VideoFile::Url {
url: "https://example.com/start.png".into(),
provider_options: None,
};
let j = serde_json::to_value(&f).unwrap();
assert_eq!(j["type"], "url");
}
#[test]
fn data_tagged_correctly() {
let d = VideoData::Url {
url: "https://example.com/x.mp4".into(),
media_type: "video/mp4".into(),
};
let j = serde_json::to_value(&d).unwrap();
assert_eq!(j["type"], "url");
assert_eq!(j["mediaType"], "video/mp4");
}
}