use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use crate::error::{Error, Result};
use crate::provider::{Provider, ProviderConfig};
use crate::types::{
CompletionRequest, CompletionResponse, ContentBlock, ContentDelta, StopReason, StreamChunk,
StreamEventType, Usage,
};
use futures::Stream;
use std::pin::Pin;
const RUNWAYML_API_URL: &str = "https://api.runwayml.com/v1";
pub struct RunwayMLProvider {
#[allow(dead_code)]
config: ProviderConfig,
client: Client,
}
#[derive(Debug, Serialize)]
struct RunwayMLTaskRequest {
model: String,
prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
duration: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
aspect_ratio: Option<String>,
}
#[derive(Debug, Deserialize)]
struct RunwayMLTaskResponse {
id: String,
status: String,
#[serde(default)]
result: Option<RunwayMLTaskResult>,
#[serde(default)]
error: Option<String>,
}
#[derive(Debug, Deserialize)]
struct RunwayMLTaskResult {
#[serde(default)]
output: Option<Vec<String>>, }
impl RunwayMLProvider {
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("RUNWAYML_API_SECRET").ok();
if api_key.is_none() {
return Err(Error::config(
"RUNWAYML_API_SECRET environment variable not set",
));
}
let config = ProviderConfig {
api_key,
..Default::default()
};
Self::new(config)
}
pub fn with_api_key(api_key: impl Into<String>) -> Result<Self> {
let config = ProviderConfig::new(api_key);
Self::new(config)
}
pub fn new(config: ProviderConfig) -> Result<Self> {
let mut headers = reqwest::header::HeaderMap::new();
if let Some(ref key) = config.api_key {
let bearer = format!("Bearer {}", key);
headers.insert(
"Authorization",
bearer
.parse()
.map_err(|_| Error::config("Invalid API key format"))?,
);
}
let client = Client::builder()
.timeout(config.timeout)
.default_headers(headers)
.build()?;
Ok(Self { config, client })
}
fn extract_prompt(&self, request: &CompletionRequest) -> String {
let mut prompts = Vec::new();
for message in &request.messages {
for content_block in &message.content {
if let ContentBlock::Text { text } = content_block {
prompts.push(text.clone());
}
}
}
prompts.join(" ")
}
async fn create_task(&self, request: &RunwayMLTaskRequest) -> Result<String> {
let response = self
.client
.post(format!("{}/tasks", RUNWAYML_API_URL))
.json(request)
.send()
.await?;
let status = response.status();
let body = response.text().await?;
if !status.is_success() {
return Err(self.handle_error_response(status, &body));
}
let task_response: RunwayMLTaskResponse = serde_json::from_str(&body)
.map_err(|e| Error::other(format!("Failed to parse response: {}", e)))?;
Ok(task_response.id)
}
async fn poll_task(&self, task_id: &str, timeout_secs: u64) -> Result<RunwayMLTaskResponse> {
let start = std::time::Instant::now();
let timeout = Duration::from_secs(timeout_secs);
let mut delay = Duration::from_millis(500);
loop {
if start.elapsed() > timeout {
return Err(Error::other("RunwayML task polling timeout"));
}
let response = self
.client
.get(format!("{}/tasks/{}", RUNWAYML_API_URL, task_id))
.send()
.await?;
let status = response.status();
let body = response.text().await?;
if !status.is_success() {
return Err(self.handle_error_response(status, &body));
}
let task_response: RunwayMLTaskResponse = serde_json::from_str(&body)
.map_err(|e| Error::other(format!("Failed to parse response: {}", e)))?;
match task_response.status.as_str() {
"SUCCEEDED" => return Ok(task_response),
"FAILED" => {
return Err(Error::other(format!(
"RunwayML task failed: {}",
task_response
.error
.unwrap_or_else(|| "Unknown error".to_string())
)))
}
"PENDING" | "RUNNING" => {
tokio::time::sleep(delay).await;
delay = std::cmp::min(delay.mul_f32(1.5), Duration::from_secs(10));
}
_ => {
return Err(Error::other(format!(
"Unknown task status: {}",
task_response.status
)))
}
}
}
}
fn handle_error_response(&self, status: reqwest::StatusCode, body: &str) -> Error {
match status.as_u16() {
401 => Error::auth(format!("RunwayML authentication failed: {}", body)),
429 => Error::rate_limited("RunwayML rate limited", None),
500..=599 => Error::server(status.as_u16(), body.to_string()),
_ => Error::other(format!("RunwayML error ({}): {}", status, body)),
}
}
}
#[async_trait]
impl Provider for RunwayMLProvider {
fn name(&self) -> &str {
"runwayml"
}
fn default_model(&self) -> Option<&str> {
Some("gen4_turbo")
}
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
let prompt = self.extract_prompt(&request);
let model = if request.model.is_empty() {
"gen4_turbo".to_string()
} else {
request.model.clone()
};
let task_request = RunwayMLTaskRequest {
model: model.clone(),
prompt,
duration: None,
aspect_ratio: None,
};
let task_id = self.create_task(&task_request).await?;
let result = self.poll_task(&task_id, 300).await?;
let video_url = result
.result
.and_then(|r| r.output)
.and_then(|mut output| output.pop())
.ok_or_else(|| Error::other("No video URL in response"))?;
Ok(CompletionResponse {
id: task_id,
model,
content: vec![ContentBlock::Text {
text: format!("Video generated: {}", video_url),
}],
stop_reason: StopReason::EndTurn,
usage: Usage {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
})
}
async fn complete_stream(
&self,
request: CompletionRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>> {
let response = self.complete(request).await?;
let chunks = vec![
Ok(StreamChunk {
event_type: StreamEventType::MessageStart,
index: None,
delta: None,
stop_reason: None,
usage: None,
}),
Ok(StreamChunk {
event_type: StreamEventType::ContentBlockDelta,
index: Some(0),
delta: Some(ContentDelta::Text {
text: "[Video generated]".to_string(),
}),
stop_reason: None,
usage: None,
}),
Ok(StreamChunk {
event_type: StreamEventType::MessageStop,
index: None,
delta: None,
stop_reason: Some(response.stop_reason),
usage: Some(response.usage),
}),
];
Ok(Box::pin(futures::stream::iter(chunks)))
}
fn supports_vision(&self) -> bool {
false
}
fn supports_tools(&self) -> bool {
false
}
fn supports_streaming(&self) -> bool {
false
}
}