use std::path::PathBuf;
use std::time::Duration;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct VideoGenerationRequest {
pub model: String,
pub prompt: String,
pub negative_prompt: Option<String>,
pub image: Option<VideoInput>,
pub input_video: Option<VideoInput>,
pub duration: Option<u32>,
pub aspect_ratio: Option<String>,
pub resolution: Option<VideoResolution>,
pub fps: Option<u32>,
pub seed: Option<u64>,
pub motion_amount: Option<f32>,
pub camera_motion: Option<CameraMotion>,
}
impl VideoGenerationRequest {
pub fn new(model: impl Into<String>, prompt: impl Into<String>) -> Self {
Self {
model: model.into(),
prompt: prompt.into(),
negative_prompt: None,
image: None,
input_video: None,
duration: None,
aspect_ratio: None,
resolution: None,
fps: None,
seed: None,
motion_amount: None,
camera_motion: None,
}
}
pub fn with_negative_prompt(mut self, prompt: impl Into<String>) -> Self {
self.negative_prompt = Some(prompt.into());
self
}
pub fn with_image(mut self, image: VideoInput) -> Self {
self.image = Some(image);
self
}
pub fn with_input_video(mut self, video: VideoInput) -> Self {
self.input_video = Some(video);
self
}
pub fn with_duration(mut self, duration: u32) -> Self {
self.duration = Some(duration);
self
}
pub fn with_aspect_ratio(mut self, aspect_ratio: impl Into<String>) -> Self {
self.aspect_ratio = Some(aspect_ratio.into());
self
}
pub fn with_resolution(mut self, resolution: VideoResolution) -> Self {
self.resolution = Some(resolution);
self
}
pub fn with_fps(mut self, fps: u32) -> Self {
self.fps = Some(fps);
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn with_motion_amount(mut self, amount: f32) -> Self {
self.motion_amount = Some(amount.clamp(0.0, 1.0));
self
}
pub fn with_camera_motion(mut self, motion: CameraMotion) -> Self {
self.camera_motion = Some(motion);
self
}
}
#[derive(Debug, Clone)]
pub enum VideoInput {
File(PathBuf),
Bytes {
data: Vec<u8>,
filename: String,
media_type: String,
},
Url(String),
Base64 { data: String, media_type: String },
}
impl VideoInput {
pub fn file(path: impl Into<PathBuf>) -> Self {
VideoInput::File(path.into())
}
pub fn bytes(
data: Vec<u8>,
filename: impl Into<String>,
media_type: impl Into<String>,
) -> Self {
VideoInput::Bytes {
data,
filename: filename.into(),
media_type: media_type.into(),
}
}
pub fn url(url: impl Into<String>) -> Self {
VideoInput::Url(url.into())
}
pub fn base64(data: impl Into<String>, media_type: impl Into<String>) -> Self {
VideoInput::Base64 {
data: data.into(),
media_type: media_type.into(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum VideoResolution {
Sd480,
Hd720,
#[default]
Hd1080,
Uhd4k,
Custom { width: u32, height: u32 },
}
impl VideoResolution {
pub fn dimensions(&self) -> (u32, u32) {
match self {
VideoResolution::Sd480 => (854, 480),
VideoResolution::Hd720 => (1280, 720),
VideoResolution::Hd1080 => (1920, 1080),
VideoResolution::Uhd4k => (3840, 2160),
VideoResolution::Custom { width, height } => (*width, *height),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CameraMotion {
Static,
ZoomIn,
ZoomOut,
PanLeft,
PanRight,
TiltUp,
TiltDown,
DollyIn,
DollyOut,
Orbit,
Dynamic,
}
#[derive(Debug, Clone)]
pub struct VideoGenerationResponse {
pub job_id: String,
pub status: VideoJobStatus,
pub estimated_duration: Option<Duration>,
}
impl VideoGenerationResponse {
pub fn new(job_id: impl Into<String>, status: VideoJobStatus) -> Self {
Self {
job_id: job_id.into(),
status,
estimated_duration: None,
}
}
pub fn with_estimated_duration(mut self, duration: Duration) -> Self {
self.estimated_duration = Some(duration);
self
}
}
#[derive(Debug, Clone)]
pub enum VideoJobStatus {
Queued,
Processing {
progress: Option<u8>,
stage: Option<String>,
},
Completed {
video_url: String,
duration_seconds: Option<f32>,
thumbnail_url: Option<String>,
},
Failed {
error: String,
code: Option<String>,
},
Cancelled,
}
impl VideoJobStatus {
pub fn is_pending(&self) -> bool {
matches!(
self,
VideoJobStatus::Queued | VideoJobStatus::Processing { .. }
)
}
pub fn is_terminal(&self) -> bool {
matches!(
self,
VideoJobStatus::Completed { .. }
| VideoJobStatus::Failed { .. }
| VideoJobStatus::Cancelled
)
}
pub fn video_url(&self) -> Option<&str> {
match self {
VideoJobStatus::Completed { video_url, .. } => Some(video_url),
_ => None,
}
}
pub fn error(&self) -> Option<&str> {
match self {
VideoJobStatus::Failed { error, .. } => Some(error),
_ => None,
}
}
}
#[async_trait]
pub trait VideoProvider: Send + Sync {
fn name(&self) -> &str;
async fn generate_video(
&self,
request: VideoGenerationRequest,
) -> Result<VideoGenerationResponse>;
async fn get_video_status(&self, job_id: &str) -> Result<VideoJobStatus>;
async fn cancel_video(&self, _job_id: &str) -> Result<()> {
Err(Error::not_supported("Video cancellation"))
}
async fn wait_for_video(
&self,
job_id: &str,
poll_interval: Duration,
timeout: Duration,
) -> Result<VideoJobStatus> {
let start = std::time::Instant::now();
loop {
if start.elapsed() > timeout {
return Err(Error::Timeout);
}
let status = self.get_video_status(job_id).await?;
if status.is_terminal() {
return Ok(status);
}
tokio::time::sleep(poll_interval).await;
}
}
fn supported_aspect_ratios(&self) -> &[&str] {
&["16:9", "9:16", "1:1"]
}
fn supported_durations(&self) -> &[u32] {
&[4, 6, 8]
}
fn max_duration(&self) -> u32 {
10
}
fn default_video_model(&self) -> Option<&str> {
None
}
fn supports_image_to_video(&self) -> bool {
true
}
fn supports_video_to_video(&self) -> bool {
false
}
}
#[derive(Debug, Clone)]
pub struct VideoModelInfo {
pub id: &'static str,
pub provider: &'static str,
pub max_duration: u32,
pub aspect_ratios: &'static [&'static str],
pub supports_i2v: bool,
pub supports_v2v: bool,
pub price_per_second: f64,
}
pub static VIDEO_MODELS: &[VideoModelInfo] = &[
VideoModelInfo {
id: "gen-3",
provider: "runway",
max_duration: 10,
aspect_ratios: &["16:9", "9:16", "1:1"],
supports_i2v: true,
supports_v2v: false,
price_per_second: 0.05,
},
VideoModelInfo {
id: "gen-4",
provider: "runway",
max_duration: 10,
aspect_ratios: &["16:9", "9:16", "1:1", "4:5"],
supports_i2v: true,
supports_v2v: true,
price_per_second: 0.10,
},
VideoModelInfo {
id: "pika-1.0",
provider: "pika",
max_duration: 4,
aspect_ratios: &["16:9", "9:16", "1:1"],
supports_i2v: true,
supports_v2v: true,
price_per_second: 0.04,
},
VideoModelInfo {
id: "dream-machine",
provider: "luma",
max_duration: 5,
aspect_ratios: &["16:9", "9:16", "1:1"],
supports_i2v: true,
supports_v2v: false,
price_per_second: 0.05,
},
VideoModelInfo {
id: "kling-2.0",
provider: "kling",
max_duration: 10,
aspect_ratios: &["16:9", "9:16", "1:1"],
supports_i2v: true,
supports_v2v: false,
price_per_second: 0.03,
},
VideoModelInfo {
id: "hailuo-video",
provider: "minimax",
max_duration: 6,
aspect_ratios: &["16:9", "9:16"],
supports_i2v: true,
supports_v2v: false,
price_per_second: 0.025,
},
];
pub fn get_video_model_info(model_id: &str) -> Option<&'static VideoModelInfo> {
VIDEO_MODELS.iter().find(|m| m.id == model_id)
}
pub fn get_video_models_by_provider(provider: &str) -> Vec<&'static VideoModelInfo> {
VIDEO_MODELS
.iter()
.filter(|m| m.provider.eq_ignore_ascii_case(provider))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_video_request_builder() {
let request = VideoGenerationRequest::new("gen-3", "A cat playing")
.with_duration(6)
.with_aspect_ratio("16:9")
.with_camera_motion(CameraMotion::ZoomIn)
.with_motion_amount(0.5);
assert_eq!(request.model, "gen-3");
assert_eq!(request.prompt, "A cat playing");
assert_eq!(request.duration, Some(6));
assert_eq!(request.aspect_ratio, Some("16:9".to_string()));
assert_eq!(request.camera_motion, Some(CameraMotion::ZoomIn));
assert_eq!(request.motion_amount, Some(0.5));
}
#[test]
fn test_motion_amount_clamping() {
let request = VideoGenerationRequest::new("gen-3", "test").with_motion_amount(2.0);
assert_eq!(request.motion_amount, Some(1.0));
let request = VideoGenerationRequest::new("gen-3", "test").with_motion_amount(-0.5);
assert_eq!(request.motion_amount, Some(0.0));
}
#[test]
fn test_video_resolution() {
assert_eq!(VideoResolution::Hd1080.dimensions(), (1920, 1080));
assert_eq!(VideoResolution::Uhd4k.dimensions(), (3840, 2160));
assert_eq!(
VideoResolution::Custom {
width: 1280,
height: 720
}
.dimensions(),
(1280, 720)
);
}
#[test]
fn test_video_input() {
let file_input = VideoInput::file("video.mp4");
assert!(matches!(file_input, VideoInput::File(_)));
let url_input = VideoInput::url("https://example.com/video.mp4");
assert!(matches!(url_input, VideoInput::Url(_)));
}
#[test]
fn test_job_status() {
let queued = VideoJobStatus::Queued;
assert!(queued.is_pending());
assert!(!queued.is_terminal());
let processing = VideoJobStatus::Processing {
progress: Some(50),
stage: Some("rendering".to_string()),
};
assert!(processing.is_pending());
assert!(!processing.is_terminal());
let completed = VideoJobStatus::Completed {
video_url: "https://example.com/video.mp4".to_string(),
duration_seconds: Some(6.0),
thumbnail_url: None,
};
assert!(!completed.is_pending());
assert!(completed.is_terminal());
assert_eq!(completed.video_url(), Some("https://example.com/video.mp4"));
let failed = VideoJobStatus::Failed {
error: "Generation failed".to_string(),
code: None,
};
assert!(!failed.is_pending());
assert!(failed.is_terminal());
assert_eq!(failed.error(), Some("Generation failed"));
}
#[test]
fn test_video_model_registry() {
let model = get_video_model_info("gen-3");
assert!(model.is_some());
let model = model.unwrap();
assert_eq!(model.provider, "runway");
assert!(model.supports_i2v);
}
#[test]
fn test_get_models_by_provider() {
let runway_models = get_video_models_by_provider("runway");
assert!(!runway_models.is_empty());
assert!(runway_models.iter().all(|m| m.provider == "runway"));
}
}